In [41]:
import numpy as np

train_data = np.load('train_data.npy', allow_pickle=True)
valid_data = np.load('valid_data.npy', allow_pickle=True)

#artists, duration_ms, popularity, release_date, tempo, key, liveness
x_train, y_train = train_data[..., :-1], train_data[..., -1]
x_valid, y_valid = valid_data[..., :-1], valid_data[..., -1]

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)

(79999, 6) (79999,)
(10000, 6) (10000,)


In [42]:
x_train

array([[list([284, 6654]), 4.462883333333333, 0, 76, 75.657, 8],
       [list([10771]), 5.54, 37, 49, 106.01899999999999, 7],
       [list([13718]), 4.1217, 38, 46, 81.115, 11],
       ...,
       [list([2048]), 0.6673333333333333, 46, 51, 169.03900000000004, 2],
       [list([5073]), 3.421116666666667, 51, 8, 100.053, 1],
       [list([1520, 2019]), 3.5888833333333334, 18, 59, 120.365, 7]],
      dtype=object)

In [46]:
import torch
from torch import nn
import torch.nn.functional as F

class InputFeatures(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(18907, 8)
    
    def forward(self, x):
        result = []
        for i in x[..., 0]:
            result.append(self.emb(torch.LongTensor(i).cuda()).sum(0))
        result = torch.stack(result)
        return torch.cat([result, torch.FloatTensor(x[..., 1:].astype(np.float32)).cuda()], dim=1)

class Body(nn.Module):
    def __init__(self):
        super().__init__()
        self.f = nn.Sequential(
            nn.Linear(13, 32),
            nn.LayerNorm(32),
            nn.GELU(),
            nn.Linear(32, 32),
            nn.LayerNorm(32),
            nn.GELU(),
            nn.Linear(32, 32),
            nn.LayerNorm(32),
            nn.GELU(),
            nn.Linear(32, 1)
        )
    
    def forward(self, x):
        return self.f(x)
    
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.f = nn.Sequential(
            InputFeatures(),
            Body()
        )
    
    def forward(self, x):
        return self.f(x)

In [55]:
import random
from tqdm.auto import tqdm

batch_size = 256
epoch = 20

model = Model().cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

for e in range(epoch):
    # train
    total_loss = 0
    model.train()
    for i in tqdm(range(0, len(x_train), batch_size)):
        x_batch = x_train[i:i+batch_size]
        y_batch = y_train[i:i+batch_size]

        out = model(x_batch)
        label = torch.FloatTensor(y_batch.astype(np.float32)).unsqueeze(-1).cuda()
        loss = F.mse_loss(out, label) + F.l1_loss(out, label) 
        total_loss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f'Epoch [{e+1}/{epoch}], Train Loss: {total_loss/(i+1)}')

    index = list(range(len(x_train)))
    random.shuffle(index)
    x_train = x_train[index]
    y_train = y_train[index]

    # val
    total_loss = 0
    model.eval()
    for i in range(0, len(x_valid), batch_size):
        x_batch = x_valid[i:i+batch_size]
        y_batch = y_valid[i:i+batch_size]

        out = model(x_batch)
        label = torch.FloatTensor(y_batch.astype(np.float32)).unsqueeze(-1).cuda()
        loss = F.mse_loss(out, label)
        total_loss += loss.item()

    print(f'Epoch [{e+1}/{epoch}], Val Loss: {total_loss/(i+1)}')

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [1/20], Train Loss: 0.0005584206928961841
Epoch [1/20], Val Loss: 0.00011583569183997888


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [2/20], Train Loss: 0.0005351975438923852
Epoch [2/20], Val Loss: 0.00011749908824264257


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [3/20], Train Loss: 0.000533480548251798
Epoch [3/20], Val Loss: 0.00011165235042348169


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [4/20], Train Loss: 0.0005284712630281308
Epoch [4/20], Val Loss: 0.00011241330719018018


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [5/20], Train Loss: 0.0005239811385849339
Epoch [5/20], Val Loss: 0.00011054097432745774


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [6/20], Train Loss: 0.0005156669740105377
Epoch [6/20], Val Loss: 0.00011467440366282365


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [7/20], Train Loss: 0.0005104845669121716
Epoch [7/20], Val Loss: 0.00011116727873212884


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [8/20], Train Loss: 0.0005056549527758749
Epoch [8/20], Val Loss: 0.00010782908191218876


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [9/20], Train Loss: 0.0005007029068874017
Epoch [9/20], Val Loss: 0.0001132752168988907


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [10/20], Train Loss: 0.0004962957526947864
Epoch [10/20], Val Loss: 0.00011310666728658442


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [11/20], Train Loss: 0.0004928504743991916
Epoch [11/20], Val Loss: 0.0001056202546948484


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [12/20], Train Loss: 0.0004878881273762114
Epoch [12/20], Val Loss: 0.00010698334932789901


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [13/20], Train Loss: 0.0004842970371709008
Epoch [13/20], Val Loss: 0.00010868582163802611


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [14/20], Train Loss: 0.00048118951364736214
Epoch [14/20], Val Loss: 0.00010617737204552771


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [15/20], Train Loss: 0.0004765128350537864
Epoch [15/20], Val Loss: 0.00011077213707824439


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [16/20], Train Loss: 0.00047312009635001526
Epoch [16/20], Val Loss: 0.00010712661565424623


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [17/20], Train Loss: 0.0004700109187794877
Epoch [17/20], Val Loss: 0.00010709960077053079


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [18/20], Train Loss: 0.00046681122942945
Epoch [18/20], Val Loss: 0.00010901298373937607


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [19/20], Train Loss: 0.0004641759332429919
Epoch [19/20], Val Loss: 0.00010522610144610398


HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Epoch [20/20], Train Loss: 0.0004614852569850767
Epoch [20/20], Val Loss: 0.00010946322769732492


In [58]:
torch.save(model.state_dict(), 'model_weights.pth')