In [290]:
import os

from typing import List, Callable, Dict

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.nn.init import constant_, kaiming_normal_
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.nn.functional import softmax

from datetime import datetime
from tqdm.auto import tqdm
from itertools import chain
from loguru import logger

In [171]:
DEVICE = "cpu"

In [172]:
BASE_DIR = "/Users/artemvopilov/Programming/yandex_cup_2023"

In [173]:
DATA_DIR = f"{BASE_DIR}/data"

TRAIN_DF_PATH = f"{DATA_DIR}/train.csv"
TEST_DF_PATH = f"{DATA_DIR}/test.csv"

NORMED_EMBEDDINGS_DIR = f"{BASE_DIR}/normed_embeddings"
PCA_EMBEDDINGS_DIR = f"{BASE_DIR}/pca_embeddings"
VAE_EMBEDDINGS_DIR = f"{BASE_DIR}/vae_embeddings"
NORMED_LSTM_EMBEDDINGS_DIR = f"{BASE_DIR}/normed_lstm_embeddings"
VAE_LSTM_EMBEDDINGS_DIR = f"{BASE_DIR}/vae_lstm_embeddings"

### Read data

In [174]:
train_df = pd.read_csv(TRAIN_DF_PATH)
test_df = pd.read_csv(TEST_DF_PATH)

In [404]:
track_id_to_embedding = {}
for fn in tqdm(os.listdir(VAE_EMBEDDINGS_DIR)):
    fp = f"{VAE_EMBEDDINGS_DIR}/{fn}"

    track_id = fn.split('.')[0]
    embedding = np.mean(np.load(fp).astype(np.float32), axis=0)
    # embedding = np.load(fp).astype(np.float32)[-1]
    track_id_to_embedding[int(track_id)] = embedding

  0%|          | 0/76714 [00:00<?, ?it/s]

### Model

In [405]:
class Block(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, dropout_rate: float):
        super().__init__()

        self.block = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.BatchNorm1d(output_dim),
            nn.ReLU(),
            # nn.Dropout(dropout_rate),
        )

        self.block.apply(self._init_weight)

    @staticmethod
    def _init_weight(layer):
        if isinstance(layer, nn.Linear):
            nn.init.kaiming_normal_(layer.weight)
            nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        return self.block(x)


class TrackTower(nn.Module):
    def __init__(self, input_dims: List[int], output_dims: List[int], dropout_rates: List[float], result_dim: int):
        super().__init__()

        layers = []
        for input_dim, output_dim, dropout_rate in zip(input_dims, output_dims, dropout_rates):
            layers.append(Block(input_dim, output_dim, dropout_rate))

        layers.append(nn.Linear(output_dims[-1], result_dim))

        self.backbone = nn.Sequential(*layers)

    def forward(self, x):
        return self.backbone(x)


class TagTower(nn.Module):
    def __init__(self, num_tags: int, input_dims: List[int], output_dims: List[int], dropout_rates: List[float], result_dim: int):
        super().__init__()

        layers = []

        layers.append(nn.Embedding(num_tags, input_dims[0]))
        
        for input_dim, output_dim, dropout_rate in zip(input_dims, output_dims, dropout_rates):
            layers.append(Block(input_dim, output_dim, dropout_rate))

        layers.append(nn.Linear(output_dims[-1], result_dim))

        self.backbone = nn.Sequential(*layers)

    def forward(self, x):
        return self.backbone(x)


class DSSMModel(nn.Module):
    def __init__(self, track_model_config, tag_model_config, smoothing):
        super().__init__()

        self.track_model = TrackTower(**track_model_config)
        self.tag_model = TagTower(**tag_model_config)

        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.smoothing = smoothing

    def forward(self, track, tag):
        track_embedding = self.track_model(track)
        tag_embedding = self.tag_model(tag)
        return self.cos(track_embedding, tag_embedding) * self.smoothing  # , track_embedding, tag_embedding

### Dataset

In [406]:
class TracksDataset(Dataset):
    def __init__(
        self, 
        tracks_tags: pd.DataFrame, 
        tracks_to_embed: Dict[int, np.ndarray], 
        tracks_to_neg_tags: Dict[int, List[int]], 
        neg_samples: int
    ):
        self.tracks_tags = tracks_tags
        self.tracks_to_embed = tracks_to_embed
        self.tracks_to_neg_tags = tracks_to_neg_tags
        self.neg_samples = neg_samples
        
    def __len__(self):
        return len(self.tracks_tags)
        
    def __getitem__(self, ind: int):
        row = self.tracks_tags.iloc[ind]
        track = row['track']
        embed = self.tracks_to_embed[track]
        tag = row['tag']
        if self.tracks_to_neg_tags is not None:
            t_neg_tags = tracks_to_neg_tags[track]
            neg_tags = t_neg_tags if len(t_neg_tags) < self.neg_samples else np.random.choice(t_neg_tags, self.neg_samples, replace=False)
            return track, tag, embed, neg_tags
        return track, tag, embed, np.array([])

### Loss

In [407]:
class SSMLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pos, neg):
        all_cos = torch.hstack((pos.unsqueeze(1), neg))
        ssm = softmax(all_cos, dim=1)[:, 0]
        loss = -torch.mean(torch.log(ssm))
        return loss

### Trainer

In [408]:
def train(model: nn.Module, data_loader: DataLoader, loss_fn: Callable, optimizer: torch.optim) -> None:
    model.train()
    running_loss = None
    alpha = 0.8
    for iter, data in enumerate(tqdm(data_loader)):
        optimizer.zero_grad()

        track, tag, embed, neg_tags = data
        tag, embed, neg_tags = tag.to(DEVICE), embed.to(DEVICE), neg_tags.to(DEVICE)

        prep_embed = embed.repeat_interleave(neg_tags.shape[1], 0)
        prep_neg_tags = neg_tags.flatten(0, 1)
    
        pos_out = model(embed, tag)
        neg_out = model(prep_embed, prep_neg_tags).view(-1, neg_tags.shape[1])
        
        loss = loss_fn(pos_out, neg_out)

        loss.backward()
        optimizer.step()

        if running_loss is None:
            running_loss = loss.item()
        else:
            running_loss = alpha * loss.item() + (1 - alpha) * loss.item()
        if iter % 100 == 0:
            logger.info("{} batch {} loss {}".format(datetime.now(), iter + 1, running_loss))

### Train

In [409]:
train_df.head()

Unnamed: 0,track,tags
0,49734,56926325596
1,67845,692839145155
2,25302,62840116168
3,57796,28186
4,13676,623177


In [410]:
track_to_tags = {tr: list(map(int, t.split(','))) for tr, t in zip(train_df['track'].values, train_df['tags'].values)}
track_to_tags

{49734: [5, 6, 9, 26, 32, 55, 96],
 67845: [6, 9, 28, 39, 145, 155],
 25302: [0, 6, 28, 40, 116, 168],
 57796: [28, 186],
 13676: [6, 23, 177],
 29968: [43, 183, 252],
 38652: [0, 10, 48],
 23887: [35, 112, 191],
 44661: [0, 16],
 26449: [6, 9, 32, 85, 122],
 16511: [6, 145, 187, 241],
 32609: [0, 8, 40, 248],
 43932: [0, 1, 8, 12, 13],
 13941: [0, 7, 8, 38, 80],
 20065: [6, 145, 241],
 53370: [1, 5, 15, 35, 64, 70, 99, 165],
 62174: [1, 5, 104, 172],
 52322: [0, 2, 8, 32, 51],
 41853: [1, 15, 25, 71, 92, 99],
 58614: [0, 8, 30, 51],
 22115: [3, 35, 55, 73, 112, 146, 198],
 34257: [6, 122],
 63054: [0, 1, 2, 8, 128],
 850: [6, 145, 170],
 2980: [6, 215],
 31505: [0, 80, 100, 156],
 27354: [0, 4, 7, 16, 88],
 38840: [0, 2, 8, 9, 24, 40, 141],
 71885: [0, 4, 7, 8],
 7290: [0, 7, 57],
 5201: [6, 215],
 22749: [6, 158],
 23811: [2, 6, 9, 26, 32, 47, 103, 117, 151],
 2248: [0, 28, 182],
 57495: [0, 5, 8, 10, 80],
 36136: [0, 2, 8, 51],
 45537: [9, 45, 47],
 42745: [1, 5, 45, 75, 119],
 1621

In [411]:
train_tracks = sorted(track_to_tags.keys())
len(train_tracks)

51134

In [412]:
tracks = [[t] * len(track_to_tags[t]) for t in train_tracks]
tags = [track_to_tags[t] for t in train_tracks]

In [413]:
len(tracks), len(tags)

(51134, 51134)

In [414]:
tracks_tags_df = pd.DataFrame({
    'track': list(chain.from_iterable(tracks)),
    'tag': list(chain.from_iterable(tags))
})

In [415]:
tracks_tags_df.head()

Unnamed: 0,track,tag
0,0,1
1,0,21
2,0,71
3,2,1
4,2,5


In [416]:
tracks_tags_df.shape

(201562, 2)

In [417]:
%%time

tracks_to_neg_tags = {t: [tag for tag in range(256) if tag not in track_to_tags[t]] for t in train_tracks}

CPU times: user 2.95 s, sys: 841 ms, total: 3.79 s
Wall time: 3.99 s


In [418]:
len(tracks_to_neg_tags)

51134

In [419]:
train_dataset = TracksDataset(tracks_tags_df, track_id_to_embedding, tracks_to_neg_tags, 10)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [420]:
track_model_config = {
    'input_dims': [64],
    'output_dims': [64],
    'dropout_rates': [0],
    'result_dim': 32
}

tag_model_config = {
    'num_tags': 256,
    'input_dims': [64],
    'output_dims': [64],
    'dropout_rates': [0],
    'result_dim': 32
}

In [422]:
model = DSSMModel(track_model_config, tag_model_config, 1)
criterion = SSMLoss()

model = model.to(DEVICE)
criterion = criterion.to(DEVICE)
optimizer = Adam(model.parameters(), lr=0.0001)

for epoch in tqdm(range(10)):
    train(model, train_loader, criterion, optimizer)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 15:57:40.039[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:57:40.039696 batch 1 loss 2.4202327728271484[0m
[32m2023-11-12 15:57:41.893[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:57:41.893316 batch 101 loss 2.3626599311828613[0m
[32m2023-11-12 15:57:43.795[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:57:43.795793 batch 201 loss 2.2777791023254395[0m
[32m2023-11-12 15:57:45.760[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:57:45.760100 batch 301 loss 2.253126621246338[0m
[32m2023-11-12 15:57:47.789[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:57:47.789339 batch 401 loss 2.206102132797241[0m
[32m2023-11-12 15:57:49.612[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:57:49.612240 batch 501 loss 2.16738

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 15:58:10.095[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:10.095219 batch 1 loss 1.0983388423919678[0m
[32m2023-11-12 15:58:11.866[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:11.866258 batch 101 loss 1.018102765083313[0m
[32m2023-11-12 15:58:13.628[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:13.628750 batch 201 loss 1.0293123722076416[0m
[32m2023-11-12 15:58:15.498[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:15.498026 batch 301 loss 1.0108797550201416[0m
[32m2023-11-12 15:58:17.255[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:17.255368 batch 401 loss 1.0769505500793457[0m
[32m2023-11-12 15:58:19.010[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:19.010245 batch 501 loss 1.0339

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 15:58:40.121[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:40.121666 batch 1 loss 0.9742926955223083[0m
[32m2023-11-12 15:58:41.874[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:41.874335 batch 101 loss 1.019011378288269[0m
[32m2023-11-12 15:58:43.976[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:43.976008 batch 201 loss 1.0551791191101074[0m
[32m2023-11-12 15:58:45.850[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:45.850517 batch 301 loss 0.9952653050422668[0m
[32m2023-11-12 15:58:48.170[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:48.170509 batch 401 loss 0.9716109037399292[0m
[32m2023-11-12 15:58:50.271[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:58:50.271430 batch 501 loss 0.9832

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 15:59:10.396[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:10.396936 batch 1 loss 0.9788950085639954[0m
[32m2023-11-12 15:59:12.182[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:12.182156 batch 101 loss 0.9547321200370789[0m
[32m2023-11-12 15:59:13.996[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:13.996780 batch 201 loss 0.9788076877593994[0m
[32m2023-11-12 15:59:15.682[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:15.682107 batch 301 loss 1.0024566650390625[0m
[32m2023-11-12 15:59:17.443[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:17.443471 batch 401 loss 1.0023959875106812[0m
[32m2023-11-12 15:59:19.624[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:19.624928 batch 501 loss 1.086

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 15:59:41.099[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:41.099968 batch 1 loss 1.085944652557373[0m
[32m2023-11-12 15:59:43.059[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:43.059151 batch 101 loss 1.0377061367034912[0m
[32m2023-11-12 15:59:45.028[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:45.028874 batch 201 loss 1.0738645792007446[0m
[32m2023-11-12 15:59:47.047[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:47.047085 batch 301 loss 0.989523708820343[0m
[32m2023-11-12 15:59:49.083[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:49.083229 batch 401 loss 0.929476261138916[0m
[32m2023-11-12 15:59:51.264[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 15:59:51.264540 batch 501 loss 1.037630

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 16:00:13.087[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:13.087893 batch 1 loss 0.9649326205253601[0m
[32m2023-11-12 16:00:15.072[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:15.072298 batch 101 loss 0.9648279547691345[0m
[32m2023-11-12 16:00:17.288[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:17.288214 batch 201 loss 0.9648397564888[0m
[32m2023-11-12 16:00:19.396[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:19.395990 batch 301 loss 1.012904167175293[0m
[32m2023-11-12 16:00:21.397[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:21.397070 batch 401 loss 0.952953040599823[0m
[32m2023-11-12 16:00:23.765[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:23.765832 batch 501 loss 0.96462506

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 16:00:47.098[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:47.098077 batch 1 loss 1.0488288402557373[0m
[32m2023-11-12 16:00:49.027[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:49.027209 batch 101 loss 0.9766064882278442[0m
[32m2023-11-12 16:00:50.746[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:50.746088 batch 201 loss 1.0368183851242065[0m
[32m2023-11-12 16:00:52.517[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:52.517806 batch 301 loss 1.0729111433029175[0m
[32m2023-11-12 16:00:54.241[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:54.241502 batch 401 loss 1.036777138710022[0m
[32m2023-11-12 16:00:56.092[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:00:56.092037 batch 501 loss 0.9765

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 16:01:15.900[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:15.900951 batch 1 loss 1.0246689319610596[0m
[32m2023-11-12 16:01:17.580[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:17.580766 batch 101 loss 1.000569224357605[0m
[32m2023-11-12 16:01:19.270[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:19.270567 batch 201 loss 0.9403475522994995[0m
[32m2023-11-12 16:01:21.000[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:21.000746 batch 301 loss 0.9885236620903015[0m
[32m2023-11-12 16:01:22.827[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:22.827253 batch 401 loss 0.9764499068260193[0m
[32m2023-11-12 16:01:24.688[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:24.688101 batch 501 loss 0.9644

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 16:01:48.784[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:48.784880 batch 1 loss 1.0245931148529053[0m
[32m2023-11-12 16:01:50.497[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:50.497779 batch 101 loss 1.0004796981811523[0m
[32m2023-11-12 16:01:52.197[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:52.197197 batch 201 loss 1.0245722532272339[0m
[32m2023-11-12 16:01:53.897[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:53.897636 batch 301 loss 1.000478744506836[0m
[32m2023-11-12 16:01:55.665[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:55.665586 batch 401 loss 1.0245685577392578[0m
[32m2023-11-12 16:01:57.371[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:01:57.371707 batch 501 loss 0.9402

  0%|          | 0/1575 [00:00<?, ?it/s]

[32m2023-11-12 16:02:17.832[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:02:17.832302 batch 1 loss 0.9643689393997192[0m
[32m2023-11-12 16:02:19.564[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:02:19.564126 batch 101 loss 0.9763778448104858[0m
[32m2023-11-12 16:02:21.748[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:02:21.748363 batch 201 loss 0.9884465932846069[0m
[32m2023-11-12 16:02:23.658[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:02:23.658330 batch 301 loss 1.0004632472991943[0m
[32m2023-11-12 16:02:25.537[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:02:25.537022 batch 401 loss 0.9522987008094788[0m
[32m2023-11-12 16:02:27.561[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain[0m:[36m27[0m - [1m2023-11-12 16:02:27.561052 batch 501 loss 1.295

### Predict

In [423]:
all_tracks = list(track_id_to_embedding.keys())
len(all_tracks)

76714

In [424]:
all_tags = list(range(256))
len(all_tags)

256

In [425]:
all_tags

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,


In [426]:
track_embeddings = model.track_model(torch.from_numpy(np.array([track_id_to_embedding[t] for t in all_tracks]))).detach().cpu().numpy()
len(track_embeddings)

76714

In [427]:
tag_embeddins = model.tag_model(torch.from_numpy(np.array(all_tags))).detach().cpu().numpy()
len(tag_embeddins)

256

In [428]:
track_norms = np.linalg.norm(track_embeddings, axis=1)
track_norms.shape

(76714,)

In [429]:
tag_norms = np.linalg.norm(tag_embeddins, axis=1)
tag_norms.shape

(256,)

In [435]:
predictions = track_embeddings.dot(tag_embeddins.T)
# predictions = track_embeddings.dot(tag_embeddins.T) / track_norms.reshape(-1, 1)
# predictions = track_embeddings.dot(tag_embeddins.T) / track_norms.reshape(-1, 1) / tag_norms.reshape(1, -1)
predictions.shape

(76714, 256)

### Save predictions

In [436]:
predictions_df = pd.DataFrame([
    {'track': track, 'prediction': ','.join([str(p) for p in probs])}
    for track, probs in 
    zip(all_tracks, predictions)
])

In [437]:
predictions_df.head()

Unnamed: 0,track,prediction
0,531,"-3.3938477,-259.95837,-9.688505,-5.7735453,-15..."
1,33632,"-2.4650779,-188.81732,-7.037119,-4.1935334,-11..."
2,75667,"-5.233591,-400.87708,-14.940479,-8.903297,-23...."
3,65474,"-7.9688406,-610.3859,-22.748737,-13.556375,-36..."
4,23421,"-1.7090056,-130.90585,-4.878826,-2.9073808,-7...."


In [438]:
predictions_df.shape

(76714, 2)

In [439]:
predictions_df.to_csv('prediction_vae_dssm_dot.csv', index=False)