In [135]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from glob import glob
from torch.nn.utils.rnn import pad_sequence

In [35]:
DEVICE = 'cpu'
NUM_TAGS = 256

In [3]:
df_train = pd.read_csv('train.csv')
df_test = pd.read_csv('test.csv')

In [9]:
df_train.shape

(51134, 2)

In [10]:
df_test.shape

(25580, 1)

In [93]:
track_idx2embeds = {}
for fn in tqdm(glob('track_embeddings/*')):
    track_idx = int(fn.split('/')[1].split('.')[0])
    embeds = np.load(fn)
    track_idx2embeds[track_idx] = np.array(embeds)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76714/76714 [01:18<00:00, 976.76it/s]


In [95]:
track_idx2embeds[531]

array([[ 0.45148772,  0.76740855,  1.3221693 , ...,  1.5691732 ,
         1.1770165 ,  0.5565597 ],
       [ 0.30100685,  0.79274833,  1.1078403 , ...,  0.8639254 ,
         0.968358  ,  0.28113842],
       [-0.03283544,  0.9865178 ,  1.2763743 , ...,  1.3526566 ,
         1.1101085 ,  0.7033231 ],
       ...,
       [ 0.23617499,  0.67322105,  1.0676353 , ...,  1.0679574 ,
         0.8110492 ,  0.64638203],
       [ 0.06359969,  0.19562058,  0.21925768, ...,  0.509882  ,
         0.1549287 ,  0.14764403],
       [ 1.3971076 ,  1.784399  ,  2.1582959 , ...,  2.2253149 ,
         2.3128386 ,  2.2868373 ]], dtype=float32)

In [171]:
class TaggingDataset(Dataset):
    def __init__(self, df, testing=False):
        self.df = df
        self.testing = testing
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        track_idx = row.track
        embeds = track_idx2embeds[track_idx]
        if self.testing:
            return track_idx, embeds
        tags = [int(x) for x in row.tags.split(',')]
        target = np.zeros(NUM_TAGS)
        target[tags] = 1
        return track_idx, embeds, target


In [172]:
train_dataset = TaggingDataset(df_train)
test_dataset = TaggingDataset(df_test, True)

In [173]:
class Network(nn.Module):
    def __init__(self, num_classes = NUM_TAGS, input_dim = 768, hidden_dim = 512):
        super().__init__()
        self.num_classes = num_classes
        self.bn = nn.LayerNorm(hidden_dim)
        self.projector =  nn.Linear(input_dim, hidden_dim)
        self.lin = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim)
        )
        self.fc = nn.Linear(hidden_dim, num_classes)
        

    def forward(self, embeds):
        x = [self.projector(x) for x in embeds]
        x = [v.mean(0).unsqueeze(0) for v in x]
        x = self.bn(torch.cat(x, dim = 0))
        x = self.lin(x)
        outs = self.fc(x)
        return outs

In [174]:
class Network(nn.Module):
    def __init__(self, num_classes = NUM_TAGS, input_dim = 768, hidden_dim = 256):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, embeds):
        rnn_output, (hn, cn) = self.rnn(embeds)
        outs = self.fc(hn[0, :, :])
        return outs

In [None]:
class Network(nn.Module):
    def __init__(self, num_classes = NUM_TAGS, input_dim = 768, hidden_dim = 256):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, embeds):
        rnn_output, (hn, cn) = self.rnn(embeds)
        outs = self.fc(hn[0, :, :])
        return outs

In [175]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = None
    alpha = 0.8
    for iteration, data in enumerate(loader):
        optimizer.zero_grad()
        track_idxs, embeds, target = data
        embeds = pad_sequence([x.to(DEVICE) for x in embeds], batch_first=True)
        target = target.to(DEVICE)
        pred_logits = model(embeds)
        pred_probs = torch.sigmoid(pred_logits)
        ce_loss = criterion(pred_logits, target)
        ce_loss.backward()
        optimizer.step()
        
        if running_loss is None:
            running_loss = ce_loss.item()
        else:
            running_loss = alpha * ce_loss.item() + (1 - alpha) * ce_loss.item()
        if iteration % 100 == 0:
            print('   {} batch {} loss {}'.format(
                datetime.now(), iteration + 1, running_loss
            ))

In [176]:
def predict(model, loader):
    model.eval()
    track_idxs = []
    predictions = []
    with torch.no_grad():
        for data in loader:
            track_idx, embeds = data
            embeds =  pad_sequence([x.to(DEVICE) for x in embeds], batch_first=True)
            pred_logits = model(embeds)
            pred_probs = torch.sigmoid(pred_logits)
            predictions.append(pred_probs.cpu().numpy())
            track_idxs.append(track_idx.numpy())
    predictions = np.vstack(predictions)
    track_idxs = np.vstack(track_idxs).ravel()
    return track_idxs, predictions
            

In [177]:
def collate_fn(b):
    track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
    embeds = [torch.from_numpy(x[1]) for x in b]
    targets = torch.from_numpy(np.vstack([x[2] for x in b]))
    return track_idxs, embeds, targets

def collate_fn_test(b):
    track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
    embeds = [torch.from_numpy(x[1]) for x in b]
    return track_idxs, embeds

In [178]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_test)

In [179]:
model = Network()
criterion = nn.CrossEntropyLoss()

epochs = 5
model = model.to(DEVICE)
criterion = criterion.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

for epoch in tqdm(range(epochs)):
    train_epoch(model, train_dataloader, criterion, optimizer)


  0%|                                                                                                                                            | 0/5 [00:00<?, ?it/s]

   2023-10-29 23:25:21.081177 batch 1 loss 22.37625127285719
   2023-10-29 23:26:33.107934 batch 101 loss 18.489122908562422
   2023-10-29 23:27:21.331202 batch 201 loss 18.69697166979313
   2023-10-29 23:28:09.800322 batch 301 loss 20.655828643590212
   2023-10-29 23:29:01.405622 batch 401 loss 18.544672245159745
   2023-10-29 23:29:53.608091 batch 501 loss 19.37050810456276
   2023-10-29 23:30:49.323122 batch 601 loss 19.723336935043335
   2023-10-29 23:31:41.050994 batch 701 loss 17.52207120321691


 20%|██████████████████████████▏                                                                                                        | 1/5 [07:12<28:51, 432.97s/it]

   2023-10-29 23:32:34.084487 batch 1 loss 17.41731971502304
   2023-10-29 23:33:18.713145 batch 101 loss 19.449612522497773
   2023-10-29 23:34:03.521985 batch 201 loss 17.777461072430015
   2023-10-29 23:34:54.615307 batch 301 loss 21.173745607957244
   2023-10-29 23:35:47.964047 batch 401 loss 18.782363824546337
   2023-10-29 23:36:38.698915 batch 501 loss 16.86319096572697
   2023-10-29 23:37:28.118176 batch 601 loss 17.429485904052854
   2023-10-29 23:38:11.628434 batch 701 loss 16.10832716524601


 40%|████████████████████████████████████████████████████▍                                                                              | 2/5 [13:38<20:14, 404.83s/it]

   2023-10-29 23:38:59.230093 batch 1 loss 16.625599475577474
   2023-10-29 23:39:45.756129 batch 101 loss 16.385829858481884
   2023-10-29 23:40:29.041722 batch 201 loss 16.623455258086324
   2023-10-29 23:41:09.330977 batch 301 loss 17.42536965571344
   2023-10-29 23:41:53.901299 batch 401 loss 14.661055436357856
   2023-10-29 23:42:37.745581 batch 501 loss 18.29888349212706
   2023-10-29 23:43:21.853436 batch 601 loss 16.956267001107335
   2023-10-29 23:44:07.052272 batch 701 loss 17.286550596356392


 60%|██████████████████████████████████████████████████████████████████████████████▌                                                    | 3/5 [19:29<12:41, 380.64s/it]

   2023-10-29 23:44:50.947111 batch 1 loss 15.718635438010097
   2023-10-29 23:45:36.313228 batch 101 loss 15.25408748909831
   2023-10-29 23:46:16.877621 batch 201 loss 18.041512947529554
   2023-10-29 23:46:57.267487 batch 301 loss 15.738254209980369
   2023-10-29 23:47:38.131298 batch 401 loss 18.937168275937438
   2023-10-29 23:48:23.130824 batch 501 loss 16.4217742215842
   2023-10-29 23:49:09.566672 batch 601 loss 17.30804492533207
   2023-10-29 23:49:57.144972 batch 701 loss 17.25799271836877


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▊                          | 4/5 [25:18<06:08, 368.00s/it]

   2023-10-29 23:50:39.481890 batch 1 loss 14.801182946190238
   2023-10-29 23:51:21.254439 batch 101 loss 16.532756632193923
   2023-10-29 23:51:57.173591 batch 201 loss 14.760481188073754
   2023-10-29 23:52:39.710681 batch 301 loss 15.299261562526226
   2023-10-29 23:53:23.430927 batch 401 loss 15.52678732573986
   2023-10-29 23:54:08.262721 batch 501 loss 15.34805710054934
   2023-10-29 23:54:50.704279 batch 601 loss 13.333010986447334
   2023-10-29 23:55:34.232805 batch 701 loss 15.371110832318664


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [30:53<00:00, 370.61s/it]


In [180]:
track_idxs, predictions = predict(model, test_dataloader)

In [181]:
predictions_df = pd.DataFrame([
    {'track': track, 'prediction': ','.join([str(p) for p in probs])}
    for track, probs in zip(track_idxs, predictions)
])

In [182]:
predictions_df.head()

Unnamed: 0,track,prediction
0,17730,"0.9614,0.9692255,0.9273112,0.95447135,0.68378,..."
1,32460,"0.9363525,0.97831887,0.91003394,0.9414933,0.58..."
2,11288,"0.971713,0.9706743,0.9733383,0.9584738,0.64071..."
3,18523,"0.97833043,0.9606864,0.9839956,0.9512931,0.805..."
4,71342,"0.96631485,0.9280631,0.9739908,0.99105024,0.60..."


In [183]:
predictions_df.to_csv('prediction_lstm.csv', index=False)