In [1]:
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
from glob import glob

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
print('pytorch version:', torch.__version__)
global_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', global_device)
from torch.utils.data import DataLoader, random_split, Dataset
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pad_sequence

from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.classification import MultilabelAveragePrecision

pytorch version: 2.2.0.dev20231027+cu121
Device: cuda:0


In [2]:
# Matplotlib settings
import matplotlib
import matplotlib as mp
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import PolyCollection
from matplotlib.colors import colorConverter

titlesize = 20
labelsize = 16
legendsize = labelsize
xticksize = 14
yticksize = xticksize

matplotlib.rcParams['legend.markerscale'] = 1.5     # the relative size of legend markers vs. original
matplotlib.rcParams['legend.handletextpad'] = 0.5
matplotlib.rcParams['legend.labelspacing'] = 0.4    # the vertical space between the legend entries in fraction of fontsize
matplotlib.rcParams['legend.borderpad'] = 0.5       # border whitespace in fontsize units
matplotlib.rcParams['font.size'] = 12
matplotlib.rcParams['font.family'] = 'serif'
matplotlib.rcParams['font.serif'] = 'Times New Roman'
matplotlib.rcParams['axes.labelsize'] = labelsize
matplotlib.rcParams['axes.titlesize'] = titlesize

matplotlib.rc('xtick', labelsize=xticksize)
matplotlib.rc('ytick', labelsize=yticksize)
matplotlib.rc('legend', fontsize=legendsize)

matplotlib.rc('font', **{'family':'serif'})

# Data loading

In [3]:
df_trainval = pd.read_csv('data/train.csv')
df_train, df_val = train_test_split(df_trainval, test_size=0.2, random_state=42)
df_train['tags'] = df_train['tags'].apply(lambda x: np.array(list(map(int, x.split(',')))))
df_val['tags'] = df_val['tags'].apply(lambda x: np.array(list(map(int, x.split(',')))))

df_test = pd.read_csv('data/test.csv')

In [4]:
global_idx2embeds = {} # {idx: np.ndarray[n, 768])}
for npy_file in tqdm(glob('data/track_embeddings/*')):
    track_idx = int(npy_file.split('\\')[1].split('.')[0])
    embeds = np.load(npy_file)
    global_idx2embeds[track_idx] = embeds[:-2]

100%|██████████| 76714/76714 [02:58<00:00, 429.38it/s] 


In [5]:
emb_train_mean = 0
for track_idx in tqdm(df_train['track']):
    emb_train_mean += global_idx2embeds[track_idx].mean(axis=0)
emb_train_mean /= len(df_train['track'])
emb_train_mean.shape

  0%|          | 0/40907 [00:00<?, ?it/s]

100%|██████████| 40907/40907 [00:20<00:00, 2006.22it/s]


(768,)

In [7]:
class TrackDataset(torch.utils.data.Dataset):
    def __init__(self, df_tags, test=False):
        self.df_tags = df_tags
        self.test = test

    def __len__(self):
        return len(self.df_tags)

    def __getitem__(self, idx):
        track_idx = self.df_tags.iloc[idx]['track']
        embeds = global_idx2embeds[track_idx]
        if self.test:
            return track_idx, embeds
        labels_onehot = np.zeros(256)
        labels_onehot[self.df_tags.iloc[idx]['tags']] = 1
        return track_idx, embeds, labels_onehot

In [8]:
cut_size = 64
def collate(batch):
    track_idxs = torch.IntTensor(np.vstack([triplet[0] for triplet in batch]))
    embeds_list = []
    for triplet in batch:
        if triplet[1].shape[0] >= cut_size:
            embeds_list.append(torch.FloatTensor(triplet[1][:cut_size]))
        else:
            repeat_num = cut_size // triplet[1].shape[0] + 1
            ext_emb = torch.cat([torch.FloatTensor(triplet[1])] * repeat_num)
            embeds_list.append(ext_emb[:cut_size])
    embeds = (torch.stack(embeds_list) - emb_train_mean).to(global_device)
    labels_onehot = torch.FloatTensor(np.vstack([triplet[2] for triplet in batch])).to(global_device)
    return track_idxs, embeds, labels_onehot

def collate_test(batch):
    track_idxs = torch.IntTensor(np.vstack([triplet[0] for triplet in batch]))
    embeds_list = []
    for triplet in batch:
        if triplet[1].shape[0] >= cut_size:
            embeds_list.append(torch.FloatTensor(triplet[1][:cut_size]))
        else:
            repeat_num = cut_size // triplet[1].shape[0] + 1
            ext_emb = torch.cat([torch.FloatTensor(triplet[1])] * repeat_num)
            embeds_list.append(ext_emb[:cut_size])
    embeds = (torch.stack(embeds_list) - emb_train_mean).to(global_device)
    return track_idxs, embeds

batch_size = 256

dataset_train = TrackDataset(df_train)
dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True, collate_fn=collate)

dataset_val = TrackDataset(df_val)
dataloader_val = DataLoader(dataset_val, batch_size, shuffle=False, collate_fn=collate)
dataloader_val_astest = DataLoader(dataset_val, batch_size, shuffle=False, collate_fn=collate_test)

dataset_test = TrackDataset(df_test, test=True)
dataloader_test = DataLoader(dataset_test, batch_size, shuffle=False, collate_fn=collate_test)

In [9]:
class TrackToTags(pl.LightningModule):
    def __init__(self, num_classes=256, input_dim=768, hidden_dim=512, pos_weights=None):
        super().__init__()
        self.num_classes = num_classes

        self.conv_layers = torch.nn.Sequential()  # b x 64 x 768
        self.conv_layers.add_module('conv1', torch.nn.Conv2d(1, 4, kernel_size=7))  # b x 1 x 58 x 762
        self.conv_layers.add_module('relu1', torch.nn.ReLU())
        self.conv_layers.add_module('pool1', torch.nn.MaxPool2d(kernel_size=2))  # b x 4 x 29 x 381
        self.conv_layers.add_module('conv2', torch.nn.Conv2d(4, 16, kernel_size=6))  # b x 16 x 24 x 376
        self.conv_layers.add_module('relu2', torch.nn.ReLU())
        self.conv_layers.add_module('pool2', torch.nn.MaxPool2d(kernel_size=2))  # b x 16 x 12 x 188
        
        gru_list = []
        for _ in range(16):
            gru_list.append(nn.GRU(188, 128, batch_first=True, bidirectional=False, num_layers=1).to(global_device))
        self.gru = torch.nn.ModuleList(modules=gru_list)

        self.lin = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(16*128),
            nn.Linear(16*128, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
            )
        
        self.loss = nn.BCEWithLogitsLoss()
        
        self.metric = average_precision_score
        self._transit_val = {'preds': [], 'labels': []}
        
    def forward(self, embeds):  # b x 64 x 768
        out_conv = self.conv_layers(embeds.unsqueeze(1))  # b x 16 x 12 x 188
        s = out_conv.shape
        out_gru = torch.FloatTensor(size=(s[0], s[1], s[2], 128)).to(global_device)
        for i in range(16): 
            out_gru[:,i], h = self.gru[i](out_conv[:,i])  # b x 16 x 12 x 128
        out_mean = out_gru.mean(dim=2)  # b x 16 x 128
        s2 = out_mean.shape
        out_mean = out_mean.reshape(s2[0], s2[1]*s2[2])  # b x 16 * 128
        outs = self.lin(out_mean)  # b x 256
        return outs

    def training_step(self, batch, batch_idx):
        track_idxs, embeds_list, labels_onehot = batch
        pred_logits = self(embeds_list)
        loss = self.loss(pred_logits, labels_onehot)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        track_idxs, embeds_list, labels_onehot = batch
        pred_logits = self(embeds_list)
        loss = self.loss(pred_logits, labels_onehot)
        self.log("val_loss", loss, prog_bar=True)
        
        pred_probs = torch.sigmoid(pred_logits)
        self._transit_val['labels'].append(np.array(labels_onehot.int().cpu()))
        self._transit_val['preds'].append(np.array(pred_probs.cpu()))

    def on_validation_epoch_end(self):
        preds = np.vstack(self._transit_val['preds'])
        labels = np.vstack(self._transit_val['labels'])
        ap = self.metric(labels, preds)
        self.log('val_ap', ap, prog_bar=True)
        self._transit_val['labels'] = []
        self._transit_val['preds'] = []

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)
        
        return {"optimizer": optimizer,
                "lr_scheduler": {
                                "scheduler": scheduler,
                                "interval": "epoch",
                                "frequency": 10
                                },
               }

In [10]:
# checkpoint_callback = ModelCheckpoint(dirpath='lightning_logs/cnnrnn_norm/',
#                                       filename='{epoch}-{val_loss:.3f}-{val_ap:.3f}', 
#                                       save_top_k=-1, 
#                                       monitor="val_ap", 
#                                       every_n_epochs=1)
# trainer = pl.Trainer(accelerator="gpu", devices=1, val_check_interval=1.0, 
#                      max_epochs=20, log_every_n_steps=100,
#                      callbacks=[checkpoint_callback])
# model = TrackToTags()
# trainer.fit(model, dataloader_train, dataloader_val)

In [11]:
model = TrackToTags.load_from_checkpoint('lightning_logs/cnnrnn_norm_2/epoch=3-val_loss=0.048-val_ap=0.199.ckpt')
# checkpoint_callback = ModelCheckpoint(dirpath='lightning_logs/cnnrnn_norm_2/',
#                                       filename='{epoch}-{val_loss:.3f}-{val_ap:.3f}', 
#                                       save_top_k=-1, 
#                                       monitor="val_ap", 
#                                       every_n_epochs=1)
# trainer = pl.Trainer(accelerator="gpu", devices=1, val_check_interval=1.0, 
#                      max_epochs=20, log_every_n_steps=100,
#                      callbacks=[checkpoint_callback])
# trainer.fit(model, dataloader_train, dataloader_val)

In [12]:
def predict(model, loader):
    model.to(global_device)
    model.eval()
    track_idxs = []
    predictions = []
    with torch.no_grad():
        for data in loader:
            track_idx, embeds = data
            # embeds = [x.to(global_device) for x in embeds]
            pred_logits = model(embeds)
            pred_probs = torch.sigmoid(pred_logits)
            predictions.append(pred_probs.cpu().numpy())
            track_idxs.append(track_idx.numpy())
    predictions = np.vstack(predictions)
    track_idxs = np.vstack(track_idxs).ravel()
    return track_idxs, predictions

In [29]:
val_true = np.array([y[2] for y in dataset_val])
track_idxs_val, val_pred = predict(model, dataloader_val_astest)
average_precision_score(val_true, val_pred)

0.19931837905484265

In [14]:
average_precision_score(val_true[10], val_pred[10])

0.55

In [30]:
val_pred[val_pred < 0.02] = 0

In [32]:
average_precision_score(val_true, val_pred)

0.19060376888253802

In [26]:
track_idxs, predictions = predict(model.to(global_device), dataloader_test)

In [27]:
predictions_df = pd.DataFrame([
    {'track': track, 'prediction': ','.join([str(p) for p in probs])}
    for track, probs in zip(track_idxs, predictions)
])
predictions_df.to_csv('results/prediction.csv', index=False)