In [1]:
import os
import json
from typing import List
import pandas as pd
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn.init import xavier_normal_, constant_
from torch.utils.data import DataLoader, Dataset
from catalyst import dl, metrics
from catalyst.utils import set_global_seed

In [2]:
class InteractionsDataset(Dataset):
    def __init__(self, interactions_pickle_path: str):
        data = pd.read_pickle(interactions_pickle_path)
        users = data['user'].to_numpy()
        tracks = data['track'].to_numpy()

        i = torch.from_numpy(np.stack((users, tracks)).astype("int64"))
        v = torch.ones(data.shape[0])

        self.interactions = torch.sparse.FloatTensor(i, v)

    def __len__(self):
        return len(self.interactions)

    def __getitem__(self, idx):
        return self.interactions[idx].to_dense()

In [3]:
def collate_fn(batch: List[torch.Tensor]) -> torch.Tensor:
    return {"inputs": torch.stack(batch), "targets": torch.stack(batch)}

In [4]:
class MultiDAE(nn.Module):
    def __init__(self, p_dims, q_dims=None, dropout=0.5):
        super().__init__()
        self.p_dims = p_dims
        if q_dims:
            assert q_dims[0] == p_dims[-1], "In and Out dimensions must equal to each other"
            assert q_dims[-1] == p_dims[0], "Latent dimension for p- and q- network mismatches."
            self.q_dims = q_dims
        else:
            self.q_dims = p_dims[::-1]

        self.dims = self.q_dims + self.p_dims[1:]
        self.layers = nn.ModuleList([nn.Linear(d_in, d_out) for
            d_in, d_out in zip(self.dims[:-1], self.dims[1:])])
        self.drop = nn.Dropout(dropout)
        
        self.init_weights()
    
    def forward(self, input):
        h = F.normalize(input)
        h = self.drop(h)

        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != len(self.layers) - 1:
                h = torch.tanh(h)
        return h

    def init_weights(self):
        for layer in self.layers:
            xavier_normal_(layer.weight.data)
            constant_(layer.bias.data, 0)

In [5]:
set_global_seed(42)

In [6]:
# For top_k tracks recommendation
top_k = 50

In [7]:
train_dataset = InteractionsDataset("train_data.pkl")
loaders = {
    "train": DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn),
}

In [8]:
item_num = len(train_dataset[0])
model = MultiDAE([50, 300, item_num], dropout=0.5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
engine = dl.DeviceEngine()

In [9]:
callbacks = [
    dl.NDCGCallback("logits", "targets", [top_k]),
    dl.MAPCallback("logits", "targets", [top_k]),
    dl.MRRCallback("logits", "targets", [top_k]),
    dl.HitrateCallback("logits", "targets", [top_k]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
]

In [10]:
runner = dl.SupervisedRunner(
    input_key="inputs", output_key="logits", target_key="targets", loss_key="loss"
)

In [11]:
runner.train(
  model=model,
  optimizer=optimizer,
  criterion=criterion,
  engine=engine,
  loaders=loaders, 
  num_epochs=30,
  verbose=True,
  timeit=False,
  callbacks=callbacks,
  logdir="./logs",
)

1/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (1/30) hitrate50: 0.027579400061257187 | hitrate50/std: 0.0077996005370578134 | loss: 1140.9674562500006 | loss/mean: 1140.9674562500006 | loss/std: 125.08365167589554 | lr: 0.001 | map50: 0.08966458641737701 | map50/std: 0.022428504986206237 | momentum: 0.9 | mrr50: 0.11615213832110169 | mrr50/std: 0.03335677848999456 | ndcg50: 0.043895978221297254 | ndcg50/std: 0.01140951853831152
* Epoch (1/30) 


2/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (2/30) hitrate50: 0.04689595823287962 | hitrate50/std: 0.015617750745945156 | loss: 1125.787861914063 | loss/mean: 1125.787861914063 | loss/std: 124.57489858696772 | lr: 0.001 | map50: 0.17120562121868133 | map50/std: 0.06382950139676609 | momentum: 0.9 | mrr50: 0.21209121668338773 | mrr50/std: 0.0820304246028037 | ndcg50: 0.08148553736209867 | ndcg50/std: 0.02866333384268424
* Epoch (2/30) 


3/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (3/30) hitrate50: 0.10105427539348602 | hitrate50/std: 0.01891093647845262 | loss: 1097.7523689453137 | loss/mean: 1097.7523689453137 | loss/std: 124.19372571390527 | lr: 0.001 | map50: 0.38443026752471915 | map50/std: 0.06624853613393249 | momentum: 0.9 | mrr50: 0.47961425848007194 | mrr50/std: 0.08602470812366386 | ndcg50: 0.1869787357330322 | ndcg50/std: 0.032896785407919155
* Epoch (3/30) 


4/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (4/30) hitrate50: 0.14065722823143006 | hitrate50/std: 0.016618148325030505 | loss: 1071.5928976562498 | loss/mean: 1071.5928976562498 | loss/std: 123.07782549386839 | lr: 0.001 | map50: 0.49875692796707166 | map50/std: 0.046242423666443704 | momentum: 0.9 | mrr50: 0.6286040649414061 | mrr50/std: 0.0624173880717181 | ndcg50: 0.25711763072013855 | ndcg50/std: 0.024629608172544778
* Epoch (4/30) 


5/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (5/30) hitrate50: 0.16741901040077214 | hitrate50/std: 0.01728842220099352 | loss: 1050.35723984375 | loss/mean: 1050.35723984375 | loss/std: 121.74462429258628 | lr: 0.001 | map50: 0.5578113931655881 | map50/std: 0.03742185765367028 | momentum: 0.9 | mrr50: 0.7183637093544006 | mrr50/std: 0.050069634551212494 | ndcg50: 0.3038013055801393 | ndcg50/std: 0.02244285597814498
* Epoch (5/30) 


6/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (6/30) hitrate50: 0.1859690439224243 | hitrate50/std: 0.018153580616857703 | loss: 1034.5259433593753 | loss/mean: 1034.5259433593753 | loss/std: 120.58229479603747 | lr: 0.001 | map50: 0.5856543742179868 | map50/std: 0.035868780742686415 | momentum: 0.9 | mrr50: 0.7579338784217832 | mrr50/std: 0.04875445052372338 | ndcg50: 0.3328938482284547 | ndcg50/std: 0.024036170236068487
* Epoch (6/30) 


7/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (7/30) hitrate50: 0.2020859476089478 | hitrate50/std: 0.019454376500731237 | loss: 1020.9335334960948 | loss/mean: 1020.9335334960948 | loss/std: 119.40780967929167 | lr: 0.001 | map50: 0.6141708169937135 | map50/std: 0.034132157875727015 | momentum: 0.9 | mrr50: 0.7973077083587644 | mrr50/std: 0.040821988110880114 | ndcg50: 0.36004363603591905 | ndcg50/std: 0.025441515842226706
* Epoch (7/30) 


8/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (8/30) hitrate50: 0.21768177604675293 | hitrate50/std: 0.020108093810423222 | loss: 1008.4102422851562 | loss/mean: 1008.4102422851562 | loss/std: 118.59621305827562 | lr: 0.001 | map50: 0.6380952081680298 | map50/std: 0.03667661362421913 | momentum: 0.9 | mrr50: 0.8294541931152343 | mrr50/std: 0.04624560318869547 | ndcg50: 0.385765837574005 | ndcg50/std: 0.027424623197935975
* Epoch (8/30) 


9/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (9/30) hitrate50: 0.22971068987846374 | hitrate50/std: 0.02053498966020865 | loss: 997.4605077148437 | loss/mean: 997.4605077148437 | loss/std: 117.51809015167746 | lr: 0.001 | map50: 0.6506968492507935 | map50/std: 0.03191671011946718 | momentum: 0.9 | mrr50: 0.8489924322128295 | mrr50/std: 0.03732331433408414 | ndcg50: 0.40436135530471806 | ndcg50/std: 0.026059909851600574
* Epoch (9/30) 


10/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (10/30) hitrate50: 0.24089306592941284 | hitrate50/std: 0.02123885253507516 | loss: 987.3702176757809 | loss/mean: 987.3702176757809 | loss/std: 116.46950640992542 | lr: 0.001 | map50: 0.6645388034820559 | map50/std: 0.03275918132520939 | momentum: 0.9 | mrr50: 0.8653368152618408 | mrr50/std: 0.03791574265457443 | ndcg50: 0.4226255654335023 | ndcg50/std: 0.02777604229080632
* Epoch (10/30) 


11/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (11/30) hitrate50: 0.2515826939582824 | hitrate50/std: 0.02225087913866842 | loss: 977.4666650390623 | loss/mean: 977.4666650390623 | loss/std: 115.52516718074737 | lr: 0.001 | map50: 0.6812776021957395 | map50/std: 0.02936964986506183 | momentum: 0.9 | mrr50: 0.8878063190460204 | mrr50/std: 0.03199165240937263 | ndcg50: 0.44209627876281743 | ndcg50/std: 0.027748565652244303
* Epoch (11/30) 


12/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (12/30) hitrate50: 0.26289318704605114 | hitrate50/std: 0.023160051759583015 | loss: 967.3842251953123 | loss/mean: 967.3842251953123 | loss/std: 114.44015201664887 | lr: 0.001 | map50: 0.6949353547096258 | map50/std: 0.026956225305414593 | momentum: 0.9 | mrr50: 0.8997934616088864 | mrr50/std: 0.032024746350901176 | ndcg50: 0.4603751597404479 | ndcg50/std: 0.02763827101186871
* Epoch (12/30) 


13/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (13/30) hitrate50: 0.2713773118972779 | hitrate50/std: 0.023042110541975226 | loss: 957.5248873046876 | loss/mean: 957.5248873046876 | loss/std: 113.12431729772739 | lr: 0.001 | map50: 0.7049735127449035 | map50/std: 0.026169966284375058 | momentum: 0.9 | mrr50: 0.9138364463806155 | mrr50/std: 0.027476582995896802 | ndcg50: 0.47559217548370364 | ndcg50/std: 0.026998032465606475
* Epoch (13/30) 


14/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (14/30) hitrate50: 0.28054342393875126 | hitrate50/std: 0.023693213905045276 | loss: 947.7783306640623 | loss/mean: 947.7783306640623 | loss/std: 112.04106138490671 | lr: 0.001 | map50: 0.7165094941139223 | map50/std: 0.023742250614349338 | momentum: 0.9 | mrr50: 0.9208717090606688 | mrr50/std: 0.02428094740116108 | ndcg50: 0.49135931472778327 | ndcg50/std: 0.02684619224823687
* Epoch (14/30) 


15/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (15/30) hitrate50: 0.29070295133590696 | hitrate50/std: 0.023776493596340736 | loss: 937.7025589843746 | loss/mean: 937.7025589843746 | loss/std: 110.70454957114497 | lr: 0.001 | map50: 0.7276797775268556 | map50/std: 0.023013820974452603 | momentum: 0.9 | mrr50: 0.9307418109893797 | mrr50/std: 0.02331005830695829 | ndcg50: 0.5091457164764407 | ndcg50/std: 0.025761811465964977
* Epoch (15/30) 


16/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (16/30) hitrate50: 0.2985096172332764 | hitrate50/std: 0.023669137850482688 | loss: 927.7874207031249 | loss/mean: 927.7874207031249 | loss/std: 109.52354364988803 | lr: 0.001 | map50: 0.7367875953674311 | map50/std: 0.022155162284200242 | momentum: 0.9 | mrr50: 0.9403590991973877 | mrr50/std: 0.022364848939908026 | ndcg50: 0.5239923107147223 | ndcg50/std: 0.025672510761799883
* Epoch (16/30) 


17/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (17/30) hitrate50: 0.30764557933807385 | hitrate50/std: 0.024294219801794702 | loss: 917.761469140625 | loss/mean: 917.761469140625 | loss/std: 108.44281017973306 | lr: 0.001 | map50: 0.745471927642822 | map50/std: 0.021125348320285668 | momentum: 0.9 | mrr50: 0.9460693946838377 | mrr50/std: 0.018670845380264308 | ndcg50: 0.539913481903076 | ndcg50/std: 0.024754215248959332
* Epoch (17/30) 


18/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (18/30) hitrate50: 0.3160683740615844 | hitrate50/std: 0.024477292584604127 | loss: 907.82615234375 | loss/mean: 907.82615234375 | loss/std: 107.18425088039956 | lr: 0.001 | map50: 0.7579029228210453 | map50/std: 0.021400564315841347 | momentum: 0.9 | mrr50: 0.9536810688018797 | mrr50/std: 0.01858183590285894 | ndcg50: 0.5562382693290706 | ndcg50/std: 0.02513663911779796
* Epoch (18/30) 


19/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (19/30) hitrate50: 0.3242670154571535 | hitrate50/std: 0.024257285400174983 | loss: 897.8245234374998 | loss/mean: 897.8245234374998 | loss/std: 105.75493294504304 | lr: 0.001 | map50: 0.7701637845993043 | map50/std: 0.019632843056253357 | momentum: 0.9 | mrr50: 0.9602862400054931 | mrr50/std: 0.02094114090693359 | ndcg50: 0.5722418687820432 | ndcg50/std: 0.02465444598242214
* Epoch (19/30) 


20/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (20/30) hitrate50: 0.33246371974945066 | hitrate50/std: 0.02483076563809562 | loss: 888.183141015625 | loss/mean: 888.183141015625 | loss/std: 104.4861836392722 | lr: 0.001 | map50: 0.7791687430381775 | map50/std: 0.018823854803900888 | momentum: 0.9 | mrr50: 0.9654107410430909 | mrr50/std: 0.016276703085693386 | ndcg50: 0.5874211407661438 | ndcg50/std: 0.0241819850548008
* Epoch (20/30) 


21/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (21/30) hitrate50: 0.34110703792572017 | hitrate50/std: 0.023927670041266686 | loss: 878.3637478515627 | loss/mean: 878.3637478515627 | loss/std: 103.05914032774112 | lr: 0.001 | map50: 0.7876048995018006 | map50/std: 0.01787285843593382 | momentum: 0.9 | mrr50: 0.9690209655761717 | mrr50/std: 0.016355816278160053 | ndcg50: 0.6026258697509762 | ndcg50/std: 0.023778684811726285
* Epoch (21/30) 


22/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (22/30) hitrate50: 0.349318071937561 | hitrate50/std: 0.023451064121703048 | loss: 869.2597296875 | loss/mean: 869.2597296875 | loss/std: 102.09547576585096 | lr: 0.001 | map50: 0.7977648229598999 | map50/std: 0.01667855060060493 | momentum: 0.9 | mrr50: 0.9727221855163571 | mrr50/std: 0.015796637518419785 | ndcg50: 0.617613687324524 | ndcg50/std: 0.02229068703322882
* Epoch (22/30) 


23/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (23/30) hitrate50: 0.3566292240142823 | hitrate50/std: 0.023715631619520082 | loss: 860.2761212890625 | loss/mean: 860.2761212890625 | loss/std: 100.60893984141288 | lr: 0.001 | map50: 0.8043560026168823 | map50/std: 0.016955732853890803 | momentum: 0.9 | mrr50: 0.9753417747497561 | mrr50/std: 0.014252010628594668 | ndcg50: 0.6305136432647704 | ndcg50/std: 0.02341855531044883
* Epoch (23/30) 


24/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (24/30) hitrate50: 0.3647362045288086 | hitrate50/std: 0.02417179778455743 | loss: 851.6922570312498 | loss/mean: 851.6922570312498 | loss/std: 99.58668554056075 | lr: 0.001 | map50: 0.8141245331764223 | map50/std: 0.015200538017747036 | momentum: 0.9 | mrr50: 0.9793611202239983 | mrr50/std: 0.013223368455269557 | ndcg50: 0.6448094667434692 | ndcg50/std: 0.02273844929746056
* Epoch (24/30) 


25/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (25/30) hitrate50: 0.37250186061859125 | hitrate50/std: 0.024229846454085775 | loss: 843.0534153320311 | loss/mean: 843.0534153320311 | loss/std: 98.49302568442504 | lr: 0.001 | map50: 0.8225351511001587 | map50/std: 0.014951131770188662 | momentum: 0.9 | mrr50: 0.97998790473938 | mrr50/std: 0.01274657461924762 | ndcg50: 0.6582554190635681 | ndcg50/std: 0.02260275751898121
* Epoch (25/30) 


26/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (26/30) hitrate50: 0.37833051691055275 | hitrate50/std: 0.02366955243198295 | loss: 834.8419101562502 | loss/mean: 834.8419101562502 | loss/std: 97.24366222631173 | lr: 0.001 | map50: 0.8276551159858702 | map50/std: 0.014992872698029014 | momentum: 0.9 | mrr50: 0.9812153301239012 | mrr50/std: 0.012930309747876073 | ndcg50: 0.6686336460113521 | ndcg50/std: 0.02219126131651802
* Epoch (26/30) 


27/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (27/30) hitrate50: 0.3842747652053836 | hitrate50/std: 0.023587258612798635 | loss: 827.1515617187496 | loss/mean: 827.1515617187496 | loss/std: 96.25849709390134 | lr: 0.001 | map50: 0.8361689429283141 | map50/std: 0.014854682681891038 | momentum: 0.9 | mrr50: 0.9829692928314203 | mrr50/std: 0.013554472156326419 | ndcg50: 0.6801507616043091 | ndcg50/std: 0.022313023721390816
* Epoch (27/30) 


28/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (28/30) hitrate50: 0.3918063077926633 | hitrate50/std: 0.02384109420571127 | loss: 819.5523734375001 | loss/mean: 819.5523734375001 | loss/std: 95.59280164467988 | lr: 0.001 | map50: 0.8424814487457277 | map50/std: 0.014963895949810067 | momentum: 0.9 | mrr50: 0.9852712375640865 | mrr50/std: 0.011753972723081381 | ndcg50: 0.6914725805282593 | ndcg50/std: 0.021015259986566774
* Epoch (28/30) 


29/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (29/30) hitrate50: 0.3985554338455199 | hitrate50/std: 0.02365801707628632 | loss: 812.0135927734373 | loss/mean: 812.0135927734373 | loss/std: 94.4849641290902 | lr: 0.001 | map50: 0.8469265758514404 | map50/std: 0.013835580788707568 | momentum: 0.9 | mrr50: 0.9853995681762691 | mrr50/std: 0.012053625599193162 | ndcg50: 0.7015929243087765 | ndcg50/std: 0.02172707027723074
* Epoch (29/30) 


30/30 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (30/30) hitrate50: 0.4046434797286987 | hitrate50/std: 0.02378385468192556 | loss: 804.8221525390619 | loss/mean: 804.8221525390619 | loss/std: 93.30984652992261 | lr: 0.001 | map50: 0.852378803730011 | map50/std: 0.013328375184972154 | momentum: 0.9 | mrr50: 0.986595006942749 | mrr50/std: 0.010674718362603947 | ndcg50: 0.7116232017517093 | ndcg50/std: 0.02157568697793173
* Epoch (30/30) 
Top best models:
logs\checkpoints/train.30.pth	30.0000


In [12]:
%%time
with open("recommendations_30.json", "w") as rf:    
    for batch, prediction in enumerate(runner.predict_loader(loader=loaders["train"])):
        preds = prediction["logits"].detach().cpu().numpy()
        for i, pred in enumerate(preds):
            user = loaders["train"].batch_size * batch + i
            pred_tracks = np.argsort(pred)[::-1][:top_k]
            
            recommendation = {
                "user": user,
                "tracks": pred_tracks.tolist(),
            }
            rf.write(json.dumps(recommendation) + "\n")

CPU times: total: 5min 45s
Wall time: 47.5 s
