In [4]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from typing import List
from tqdm import tqdm

%matplotlib inline

import catalyst 
import recbole

from typing import Dict, List, Tuple

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset

from catalyst import dl, metrics
from catalyst.contrib.datasets import MovieLens
from catalyst.utils import get_device, set_global_seed
from torch.nn.utils.rnn import pad_sequence 

import random

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import *


set_global_seed(100)

### ML_100k

In [None]:
df = pd.read_csv('data/ML_100k.csv')
df = df.rename(columns={'userId': 'user_id', 'movieId': 'item_id'})
df['timestamp'] = pd.to_datetime(df['timestamp'],unit='s')
df['weekday'] = pd.to_datetime(df.timestamp).dt.weekday
df['hour'] = pd.to_datetime(df.timestamp).dt.hour
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp,weekday,hour
0,1,31,2.5,2009-12-14 02:52:24,0,2
1,1,1029,3.0,2009-12-14 02:52:59,0,2
2,1,1061,3.0,2009-12-14 02:53:02,0,2
3,1,1129,2.0,2009-12-14 02:53:05,0,2
4,1,1172,4.0,2009-12-14 02:53:25,0,2


In [None]:
splitter = RandomSplit(test_fraction=0.2)
train_df, valid_df, test_df = splitter(df)

In [None]:
train_grouped = train_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id, 
                                                                 x.timestamp,
                                                                 x.weekday,
                                                                 x.hour), key=lambda x: x[1])]
).reset_index()
train_grouped.rename({0:'train_interactions'}, axis=1, inplace=True)

valid_grouped = valid_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
valid_grouped.rename({0:'valid_interactions'}, axis=1, inplace=True)

test_grouped = test_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
test_grouped.rename({0:'test_interactions'}, axis=1, inplace=True)


train_grouped.head()

Unnamed: 0,user_id,train_interactions
0,1,"[(2294, 2009-12-14 02:51:48, 0, 2), (2455, 200..."
1,2,"[(150, 1996-06-21 11:09:55, 4, 11), (296, 1996..."
2,3,"[(355, 2011-02-28 02:53:09, 0, 2), (1271, 2011..."
3,4,"[(1210, 2000-02-05 19:25:14, 5, 19), (2734, 20..."
4,5,"[(1380, 2006-11-12 23:10:44, 6, 23), (1035, 20..."


In [None]:
joined = train_grouped.merge(valid_grouped).merge(test_grouped)
joined.head()

Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions
0,1,"[(2294, 2009-12-14 02:51:48, 0, 2), (2455, 200...","[(1287, 2009-12-14 02:53:07, 0, 2), (1953, 200...","[(2193, 2009-12-14 02:53:18, 0, 2), (2968, 200..."
1,2,"[(150, 1996-06-21 11:09:55, 4, 11), (296, 1996...","[(551, 1996-06-21 11:16:07, 4, 11), (273, 1996...","[(319, 1996-06-21 11:18:38, 4, 11), (485, 1996..."
2,3,"[(355, 2011-02-28 02:53:09, 0, 2), (1271, 2011...","[(2028, 2011-02-28 19:37:42, 0, 19), (110, 201...","[(1721, 2011-02-28 20:00:36, 0, 20), (377, 201..."
3,4,"[(1210, 2000-02-05 19:25:14, 5, 19), (2734, 20...","[(3169, 2000-02-07 10:21:44, 0, 10), (1298, 20...","[(1858, 2000-02-07 10:35:38, 0, 10), (3108, 20..."
4,5,"[(1380, 2006-11-12 23:10:44, 6, 23), (1035, 20...","[(1968, 2006-11-12 23:32:04, 6, 23), (30707, 2...","[(33679, 2006-11-12 23:35:17, 6, 23), (5995, 2..."


In [None]:
our_items = set()
for idx, row in tqdm(joined.iterrows()):
    for el in row.train_interactions:
        our_items.add(el[0])
        
len(our_items)

671it [00:00, 22009.34it/s]


6811

In [None]:
item2idx = {k: i for i, k in enumerate(our_items)}
idx2item = {i: k for k, i in item2idx.items()}

In [None]:
class MyDataset(Dataset):
    
    def __init__(self, ds, num_items, item2idx, phase='valid', N=200):
        super().__init__()
        self.ds = ds
        self.phase = phase
        self.n_items = num_items
        self.item2idx = item2idx
        self.N = N 
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        
        row = self.ds.iloc[idx]
        
        x_input = np.zeros(self.n_items+1)
        x_input[[self.item2idx[x[0]]+1 for x in row['train_interactions'] if x[0] in self.item2idx]] = 1
        
        days_of_weeks = [x[2] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        hours = [x[3] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        seq_input = [self.item2idx[x[0]]+1 for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        targets = np.zeros(self.n_items+1)
        
        dow_valid = row['valid_interactions'][0][2]
        dow_test = row['test_interactions'][0][2]
        
        hours_valid = row['valid_interactions'][0][3]
        hours_test = row['test_interactions'][0][3]
        
        if self.phase == 'train':
            return (seq_input, days_of_weeks, hours, dow_valid, hours_valid)
        elif self.phase == 'valid':
            targets[[self.item2idx[x[0]]+1 for x in row['valid_interactions'] if x[0] in self.item2idx]] = 1
        else:
            return (seq_input, days_of_weeks, hours, dow_test, hours_test)
            
        return (targets, seq_input, days_of_weeks, hours, dow_valid, hours_valid)


671 671


In [None]:
n_items = len(item2idx)

train = MyDataset(ds=joined,
                  num_items=n_items, 
                  item2idx=item2idx,
                  phase='train')

valid = MyDataset(ds=joined,
                  num_items=n_items,
                  item2idx=item2idx,
                  phase='valid')

print(len(train),len(valid))

In [None]:
def collate_fn_train(batch: List[Tuple[torch.Tensor]]) -> Dict[str, torch.Tensor]: 
    
    seq_i,days_of_weeks,hours,dow_valid,hours_valid = zip(*batch)
    seq_len = torch.Tensor([len(x) for x in seq_i])
    dow_valid = torch.Tensor([x for x in dow_valid])
    hours_valid = torch.Tensor([x for x in hours_valid])
    seq_i = pad_sequence([torch.Tensor(t) for t in seq_i]).T    
    days_of_weeks = pad_sequence([torch.Tensor(t) for t in days_of_weeks]).T
    hours = pad_sequence([torch.Tensor(t) for t in hours]).T
    
    return {'seq_i': seq_i, 
            'seq_len':seq_len,
            'dow': days_of_weeks,
            'hours': hours,
            'dow_valid': dow_valid,
            'hours_valid': hours_valid}


def collate_fn_valid(batch: List[Tuple[torch.Tensor]]) -> Dict[str, torch.Tensor]:
    
    y, seq_i, days_of_weeks, hours, dow_valid, hours_valid = zip(*batch)
    
    seq_len = torch.Tensor([len(x) for x in seq_i]).long()
    seq_i = pad_sequence([torch.Tensor(t) for t in seq_i]).T.long()
    days_of_weeks = pad_sequence([torch.Tensor(t) for t in days_of_weeks]).T.long()
    hours = pad_sequence([torch.Tensor(t) for t in hours]).T.long()
    dow_valid = torch.Tensor([x for x in dow_valid])
    hours_valid = torch.Tensor([x for x in hours_valid])
            
    targets = pad_sequence([torch.Tensor(t) for t in y]).T

    return {"targets": targets,
            'seq_i': seq_i,
            'seq_len':seq_len,
            'dow': days_of_weeks,
            'hours': hours,
            'dow_valid': dow_valid,
            'hours_valid': hours_valid}

In [None]:
loaders = {
        "train": DataLoader(train, batch_size=256, collate_fn=collate_fn_train),
        "valid": DataLoader(valid, batch_size=256, collate_fn=collate_fn_valid),
}

In [None]:

model = BERT4Rec(n_items=len(item2idx)+1, mask_ratio=0.2, hidden_size=32)

optimizer = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
engine = dl.DeviceEngine('cpu')
hparams = {
    "anneal_cap": 0.2,
    "total_anneal_steps": 6000,
}


callbacks = [
    dl.NDCGCallback("logits", "targets", [20]),
    dl.MAPCallback("logits", "targets", [10]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.EarlyStoppingCallback(
        patience=7, loader_key="valid", metric_key="map10", minimize=False
    )
]


runner = RecSysRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    engine=engine,
    hparams=hparams,
    scheduler=lr_scheduler,
    loaders=loaders,
    num_epochs=100,
    verbose=True,
    timeit=True,
    callbacks=callbacks,
)


1/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.97s/it, _timer/_fps=57.961, _timer/batch_time=2.743, _timer/data_time=2.004, _timer/model_time=0.739, loss=8.801, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (1/100) loss: 8.817609209950385 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


1/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.78s/it, _timer/_fps=100.001, _timer/batch_time=1.590, _timer/data_time=0.025, _timer/model_time=1.565, loss=8.768, lr=1.000e-03, map10=0.015, momentum=0.900, ndcg20=0.006]


valid (1/100) loss: 8.753207117185863 | lr: 0.001 | map10: 0.021590022464600653 | map10/std: 0.007997730888560109 | momentum: 0.9 | ndcg20: 0.007608442112731285 | ndcg20/std: 0.002752989317687319
* Epoch (1/100) 


2/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.90s/it, _timer/_fps=61.439, _timer/batch_time=2.588, _timer/data_time=1.747, _timer/model_time=0.841, loss=8.729, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (2/100) loss: 8.756602265973383 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


2/100 * Epoch (valid): 100%|██████████| 3/3 [00:07<00:00,  2.37s/it, _timer/_fps=77.305, _timer/batch_time=2.057, _timer/data_time=0.025, _timer/model_time=2.032, loss=8.691, lr=1.000e-03, map10=0.031, momentum=0.900, ndcg20=0.017] 


valid (2/100) loss: 8.684731032915868 | lr: 0.001 | map10: 0.04525850005392346 | map10/std: 0.011668634813275543 | momentum: 0.9 | ndcg20: 0.021061212621107363 | ndcg20/std: 0.0023411509782322016
* Epoch (2/100) 


3/100 * Epoch (train): 100%|██████████| 3/3 [00:09<00:00,  3.02s/it, _timer/_fps=57.393, _timer/batch_time=2.770, _timer/data_time=1.853, _timer/model_time=0.918, loss=8.655, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (3/100) loss: 8.677803350809612 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


3/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.78s/it, _timer/_fps=101.247, _timer/batch_time=1.570, _timer/data_time=0.023, _timer/model_time=1.548, loss=8.616, lr=1.000e-03, map10=0.058, momentum=0.900, ndcg20=0.028]


valid (3/100) loss: 8.60508069693598 | lr: 0.001 | map10: 0.06645712058785776 | map10/std: 0.005071305911044532 | momentum: 0.9 | ndcg20: 0.03114109756716495 | ndcg20/std: 0.0018334795452453626
* Epoch (3/100) 


4/100 * Epoch (train): 100%|██████████| 3/3 [00:10<00:00,  3.43s/it, _timer/_fps=55.465, _timer/batch_time=2.867, _timer/data_time=2.072, _timer/model_time=0.795, loss=8.572, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (4/100) loss: 8.596185731106116 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


4/100 * Epoch (valid): 100%|██████████| 3/3 [00:07<00:00,  2.39s/it, _timer/_fps=78.189, _timer/batch_time=2.034, _timer/data_time=0.021, _timer/model_time=2.012, loss=8.534, lr=1.000e-03, map10=0.060, momentum=0.900, ndcg20=0.032] 


valid (4/100) loss: 8.52427912540123 | lr: 0.001 | map10: 0.072846323407176 | map10/std: 0.008774154536734212 | momentum: 0.9 | ndcg20: 0.03613815557117256 | ndcg20/std: 0.0025304440669781055
* Epoch (4/100) 


5/100 * Epoch (train): 100%|██████████| 3/3 [00:09<00:00,  3.04s/it, _timer/_fps=61.436, _timer/batch_time=2.588, _timer/data_time=1.868, _timer/model_time=0.720, loss=8.501, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (5/100) loss: 8.520754741663016 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


5/100 * Epoch (valid): 100%|██████████| 3/3 [00:07<00:00,  2.64s/it, _timer/_fps=65.108, _timer/batch_time=2.442, _timer/data_time=0.025, _timer/model_time=2.417, loss=8.461, lr=1.000e-03, map10=0.068, momentum=0.900, ndcg20=0.036] 


valid (5/100) loss: 8.449948080605793 | lr: 0.001 | map10: 0.07473126789985222 | map10/std: 0.006301673491160665 | momentum: 0.9 | ndcg20: 0.04071570112301944 | ndcg20/std: 0.0025108930718226023
* Epoch (5/100) 


6/100 * Epoch (train): 100%|██████████| 3/3 [00:09<00:00,  3.33s/it, _timer/_fps=45.456, _timer/batch_time=3.498, _timer/data_time=2.562, _timer/model_time=0.935, loss=8.421, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (6/100) loss: 8.444773048651022 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


6/100 * Epoch (valid): 100%|██████████| 3/3 [00:06<00:00,  2.20s/it, _timer/_fps=101.703, _timer/batch_time=1.563, _timer/data_time=0.026, _timer/model_time=1.537, loss=8.380, lr=1.000e-03, map10=0.077, momentum=0.900, ndcg20=0.041]


valid (6/100) loss: 8.374501403149123 | lr: 0.001 | map10: 0.0815654624636589 | map10/std: 0.005330545477236489 | momentum: 0.9 | ndcg20: 0.0433162837865278 | ndcg20/std: 0.0013549975653603
* Epoch (6/100) 


7/100 * Epoch (train): 100%|██████████| 3/3 [00:09<00:00,  3.17s/it, _timer/_fps=64.150, _timer/batch_time=2.479, _timer/data_time=1.675, _timer/model_time=0.804, loss=8.353, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (7/100) loss: 8.373273036519034 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


7/100 * Epoch (valid): 100%|██████████| 3/3 [00:07<00:00,  2.56s/it, _timer/_fps=84.730, _timer/batch_time=1.877, _timer/data_time=0.026, _timer/model_time=1.850, loss=8.312, lr=1.000e-03, map10=0.082, momentum=0.900, ndcg20=0.043] 


valid (7/100) loss: 8.30372342196378 | lr: 0.001 | map10: 0.08584261080458339 | map10/std: 0.005402157853560334 | momentum: 0.9 | ndcg20: 0.04462463083370254 | ndcg20/std: 0.0006862902092591617
* Epoch (7/100) 


8/100 * Epoch (train): 100%|██████████| 3/3 [00:12<00:00,  4.31s/it, _timer/_fps=28.993, _timer/batch_time=5.484, _timer/data_time=4.695, _timer/model_time=0.789, loss=8.287, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (8/100) loss: 8.306354121906036 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


8/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.59s/it, _timer/_fps=146.661, _timer/batch_time=1.084, _timer/data_time=0.024, _timer/model_time=1.060, loss=8.248, lr=1.000e-03, map10=0.085, momentum=0.900, ndcg20=0.050]


valid (8/100) loss: 8.239468890046576 | lr: 0.001 | map10: 0.08704995013073909 | map10/std: 0.0035970515045673068 | momentum: 0.9 | ndcg20: 0.047170411256848434 | ndcg20/std: 0.0014098860411231595
* Epoch (8/100) 


9/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.34s/it, _timer/_fps=67.306, _timer/batch_time=2.362, _timer/data_time=1.651, _timer/model_time=0.711, loss=8.240, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (9/100) loss: 8.242638823883011 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


9/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.67s/it, _timer/_fps=109.127, _timer/batch_time=1.457, _timer/data_time=0.028, _timer/model_time=1.429, loss=8.203, lr=1.000e-03, map10=0.083, momentum=0.900, ndcg20=0.052]


valid (9/100) loss: 8.178534249079743 | lr: 0.001 | map10: 0.09018121352183303 | map10/std: 0.005440950583483359 | momentum: 0.9 | ndcg20: 0.048943040267898856 | ndcg20/std: 0.0014375393099054184
* Epoch (9/100) 


10/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.76s/it, _timer/_fps=68.930, _timer/batch_time=2.307, _timer/data_time=1.559, _timer/model_time=0.748, loss=8.150, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (10/100) loss: 8.172384490909947 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


10/100 * Epoch (valid): 100%|██████████| 3/3 [00:07<00:00,  2.36s/it, _timer/_fps=46.614, _timer/batch_time=3.411, _timer/data_time=0.024, _timer/model_time=3.387, loss=8.113, lr=1.000e-03, map10=0.086, momentum=0.900, ndcg20=0.053] 


valid (10/100) loss: 8.110826075876522 | lr: 0.001 | map10: 0.09404026759141604 | map10/std: 0.007498053773172233 | momentum: 0.9 | ndcg20: 0.05088454100196479 | ndcg20/std: 0.001256492006970793
* Epoch (10/100) 


11/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.67s/it, _timer/_fps=71.972, _timer/batch_time=2.209, _timer/data_time=1.498, _timer/model_time=0.712, loss=8.102, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (11/100) loss: 8.114410070655243 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


11/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.63s/it, _timer/_fps=133.561, _timer/batch_time=1.190, _timer/data_time=0.024, _timer/model_time=1.167, loss=8.067, lr=1.000e-03, map10=0.087, momentum=0.900, ndcg20=0.053]


valid (11/100) loss: 8.056118183448666 | lr: 0.001 | map10: 0.09263906501311303 | map10/std: 0.005291488561238869 | momentum: 0.9 | ndcg20: 0.05074474925033737 | ndcg20/std: 0.001520022243304918
* Epoch (11/100) 


12/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.97s/it, _timer/_fps=66.776, _timer/batch_time=2.381, _timer/data_time=1.654, _timer/model_time=0.727, loss=8.026, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (12/100) loss: 8.063485036130812 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


12/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.63s/it, _timer/_fps=126.292, _timer/batch_time=1.259, _timer/data_time=0.024, _timer/model_time=1.235, loss=7.991, lr=1.000e-03, map10=0.098, momentum=0.900, ndcg20=0.056]


valid (12/100) loss: 8.009205915533483 | lr: 0.001 | map10: 0.09466037186398414 | map10/std: 0.007400878436324016 | momentum: 0.9 | ndcg20: 0.052482302687873426 | ndcg20/std: 0.0021576760781137267
* Epoch (12/100) 


13/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.89s/it, _timer/_fps=49.235, _timer/batch_time=3.229, _timer/data_time=2.386, _timer/model_time=0.843, loss=8.022, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (13/100) loss: 8.016261097926497 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


13/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.85s/it, _timer/_fps=140.615, _timer/batch_time=1.131, _timer/data_time=0.025, _timer/model_time=1.106, loss=7.992, lr=1.000e-03, map10=0.118, momentum=0.900, ndcg20=0.060]


valid (13/100) loss: 7.965589401085935 | lr: 0.001 | map10: 0.10195664985968708 | map10/std: 0.011850824120388098 | momentum: 0.9 | ndcg20: 0.05537028842773594 | ndcg20/std: 0.0029941983237731333
* Epoch (13/100) 


14/100 * Epoch (train): 100%|██████████| 3/3 [00:09<00:00,  3.09s/it, _timer/_fps=41.462, _timer/batch_time=3.835, _timer/data_time=3.090, _timer/model_time=0.745, loss=7.966, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (14/100) loss: 7.963323272642542 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


14/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.74s/it, _timer/_fps=130.563, _timer/batch_time=1.218, _timer/data_time=0.027, _timer/model_time=1.191, loss=7.938, lr=1.000e-03, map10=0.124, momentum=0.900, ndcg20=0.064]


valid (14/100) loss: 7.915953836568955 | lr: 0.001 | map10: 0.10411216519584243 | map10/std: 0.012682855759900816 | momentum: 0.9 | ndcg20: 0.05804786512464418 | ndcg20/std: 0.003644858194349861
* Epoch (14/100) 


15/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.70s/it, _timer/_fps=67.786, _timer/batch_time=2.346, _timer/data_time=1.619, _timer/model_time=0.727, loss=7.919, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (15/100) loss: 7.924967044511957 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


15/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.53s/it, _timer/_fps=134.516, _timer/batch_time=1.182, _timer/data_time=0.026, _timer/model_time=1.156, loss=7.892, lr=1.000e-03, map10=0.129, momentum=0.900, ndcg20=0.067]


valid (15/100) loss: 7.881310275343598 | lr: 0.001 | map10: 0.10758360007836104 | map10/std: 0.013586573024021447 | momentum: 0.9 | ndcg20: 0.06041921281023992 | ndcg20/std: 0.0037447706559898325
* Epoch (15/100) 


16/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.44s/it, _timer/_fps=61.509, _timer/batch_time=2.585, _timer/data_time=1.870, _timer/model_time=0.714, loss=7.910, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (16/100) loss: 7.892525502359814 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


16/100 * Epoch (valid): 100%|██████████| 3/3 [00:06<00:00,  2.18s/it, _timer/_fps=106.341, _timer/batch_time=1.495, _timer/data_time=0.026, _timer/model_time=1.469, loss=7.886, lr=1.000e-03, map10=0.131, momentum=0.900, ndcg20=0.073]


valid (16/100) loss: 7.852322188290683 | lr: 0.001 | map10: 0.10992585977569955 | map10/std: 0.013293541007723415 | momentum: 0.9 | ndcg20: 0.06671785212353694 | ndcg20/std: 0.0032647017533113313
* Epoch (16/100) 


17/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.49s/it, _timer/_fps=70.209, _timer/batch_time=2.265, _timer/data_time=1.607, _timer/model_time=0.657, loss=7.890, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (17/100) loss: 7.871737447474291 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


17/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.46s/it, _timer/_fps=135.825, _timer/batch_time=1.171, _timer/data_time=0.025, _timer/model_time=1.146, loss=7.867, lr=1.000e-03, map10=0.137, momentum=0.900, ndcg20=0.075]


valid (17/100) loss: 7.835464201101425 | lr: 0.001 | map10: 0.1138191967980517 | map10/std: 0.014433903512020221 | momentum: 0.9 | ndcg20: 0.06928913487362968 | ndcg20/std: 0.0030831601104789703
* Epoch (17/100) 


18/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.56s/it, _timer/_fps=69.148, _timer/batch_time=2.299, _timer/data_time=1.603, _timer/model_time=0.696, loss=7.836, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (18/100) loss: 7.840217973365158 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


18/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.80s/it, _timer/_fps=78.666, _timer/batch_time=2.021, _timer/data_time=0.025, _timer/model_time=1.996, loss=7.815, lr=1.000e-03, map10=0.133, momentum=0.900, ndcg20=0.075] 


valid (18/100) loss: 7.80764134701247 | lr: 0.001 | map10: 0.11565551705847199 | map10/std: 0.01116041827411875 | momentum: 0.9 | ndcg20: 0.07105345863032803 | ndcg20/std: 0.0019975078029173753
* Epoch (18/100) 


19/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.76s/it, _timer/_fps=66.736, _timer/batch_time=2.383, _timer/data_time=1.585, _timer/model_time=0.797, loss=7.814, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (19/100) loss: 7.821316890318596 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


19/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.51s/it, _timer/_fps=130.747, _timer/batch_time=1.216, _timer/data_time=0.025, _timer/model_time=1.191, loss=7.795, lr=1.000e-03, map10=0.138, momentum=0.900, ndcg20=0.075]


valid (19/100) loss: 7.792151526438319 | lr: 0.001 | map10: 0.11889944588640586 | map10/std: 0.01287542365237018 | momentum: 0.9 | ndcg20: 0.07146340799580388 | ndcg20/std: 0.0023619957232551047
* Epoch (19/100) 


20/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.56s/it, _timer/_fps=64.080, _timer/batch_time=2.481, _timer/data_time=1.781, _timer/model_time=0.700, loss=7.797, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (20/100) loss: 7.797453755948654 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


20/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.52s/it, _timer/_fps=145.574, _timer/batch_time=1.092, _timer/data_time=0.025, _timer/model_time=1.067, loss=7.779, lr=1.000e-03, map10=0.147, momentum=0.900, ndcg20=0.076]


valid (20/100) loss: 7.77090804445406 | lr: 0.001 | map10: 0.11948098864089891 | map10/std: 0.016352413400565955 | momentum: 0.9 | ndcg20: 0.07210340640672806 | ndcg20/std: 0.002467549253776049
* Epoch (20/100) 


21/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.52s/it, _timer/_fps=57.732, _timer/batch_time=2.754, _timer/data_time=2.047, _timer/model_time=0.707, loss=7.742, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (21/100) loss: 7.7613828079178155 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


21/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.86s/it, _timer/_fps=122.503, _timer/batch_time=1.298, _timer/data_time=0.025, _timer/model_time=1.273, loss=7.725, lr=1.000e-03, map10=0.161, momentum=0.900, ndcg20=0.079]


valid (21/100) loss: 7.736930279369326 | lr: 0.001 | map10: 0.12458240575033577 | map10/std: 0.020880822921210487 | momentum: 0.9 | ndcg20: 0.07405138977327219 | ndcg20/std: 0.0030301451090562503
* Epoch (21/100) 


22/100 * Epoch (train): 100%|██████████| 3/3 [00:06<00:00,  2.31s/it, _timer/_fps=76.537, _timer/batch_time=2.077, _timer/data_time=1.356, _timer/model_time=0.722, loss=7.737, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (22/100) loss: 7.751056666168122 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


22/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.87s/it, _timer/_fps=82.875, _timer/batch_time=1.919, _timer/data_time=0.025, _timer/model_time=1.894, loss=7.721, lr=1.000e-03, map10=0.158, momentum=0.900, ndcg20=0.081] 


valid (22/100) loss: 7.728475257998607 | lr: 0.001 | map10: 0.12633293382990024 | map10/std: 0.019315198096533537 | momentum: 0.9 | ndcg20: 0.07526292799570877 | ndcg20/std: 0.003450875700125482
* Epoch (22/100) 


23/100 * Epoch (train): 100%|██████████| 3/3 [00:11<00:00,  3.82s/it, _timer/_fps=38.834, _timer/batch_time=4.094, _timer/data_time=2.909, _timer/model_time=1.186, loss=7.726, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (23/100) loss: 7.743804319071876 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


23/100 * Epoch (valid): 100%|██████████| 3/3 [00:06<00:00,  2.32s/it, _timer/_fps=94.043, _timer/batch_time=1.691, _timer/data_time=0.028, _timer/model_time=1.662, loss=7.711, lr=1.000e-03, map10=0.160, momentum=0.900, ndcg20=0.083] 


valid (23/100) loss: 7.723433957902934 | lr: 0.001 | map10: 0.13021736928496383 | map10/std: 0.017770992846335477 | momentum: 0.9 | ndcg20: 0.07728759311528213 | ndcg20/std: 0.00347430700904088
* Epoch (23/100) 


24/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.46s/it, _timer/_fps=69.536, _timer/batch_time=2.287, _timer/data_time=1.498, _timer/model_time=0.789, loss=7.765, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (24/100) loss: 7.747204855195457 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


24/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.60s/it, _timer/_fps=145.486, _timer/batch_time=1.093, _timer/data_time=0.026, _timer/model_time=1.067, loss=7.753, lr=1.000e-03, map10=0.163, momentum=0.900, ndcg20=0.084]


valid (24/100) loss: 7.729626148184082 | lr: 0.001 | map10: 0.1310121151667238 | map10/std: 0.019227197428836704 | momentum: 0.9 | ndcg20: 0.07778253912614817 | ndcg20/std: 0.0035924198841688648
* Epoch (24/100) 


25/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.87s/it, _timer/_fps=72.394, _timer/batch_time=2.196, _timer/data_time=1.486, _timer/model_time=0.710, loss=7.713, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (25/100) loss: 7.716267921174869 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


25/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.81s/it, _timer/_fps=113.994, _timer/batch_time=1.395, _timer/data_time=0.026, _timer/model_time=1.368, loss=7.701, lr=1.000e-03, map10=0.158, momentum=0.900, ndcg20=0.084]


valid (25/100) loss: 7.699730283990169 | lr: 0.001 | map10: 0.13027102801110457 | map10/std: 0.015662432181317746 | momentum: 0.9 | ndcg20: 0.07859879597777227 | ndcg20/std: 0.0033319309809916116
* Epoch (25/100) 


26/100 * Epoch (train): 100%|██████████| 3/3 [00:08<00:00,  2.69s/it, _timer/_fps=80.707, _timer/batch_time=1.970, _timer/data_time=1.262, _timer/model_time=0.708, loss=7.707, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (26/100) loss: 7.714665754719037 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


26/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.56s/it, _timer/_fps=140.163, _timer/batch_time=1.134, _timer/data_time=0.026, _timer/model_time=1.108, loss=7.696, lr=1.000e-03, map10=0.159, momentum=0.900, ndcg20=0.084]


valid (26/100) loss: 7.699418821917561 | lr: 0.001 | map10: 0.13048044782015322 | map10/std: 0.015833647984505806 | momentum: 0.9 | ndcg20: 0.07787766768395812 | ndcg20/std: 0.0036008502004445846
* Epoch (26/100) 


27/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.65s/it, _timer/_fps=65.766, _timer/batch_time=2.418, _timer/data_time=1.697, _timer/model_time=0.720, loss=7.709, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (27/100) loss: 7.692101918460716 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


27/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.79s/it, _timer/_fps=117.103, _timer/batch_time=1.358, _timer/data_time=0.026, _timer/model_time=1.332, loss=7.699, lr=1.000e-03, map10=0.153, momentum=0.900, ndcg20=0.082]


valid (27/100) loss: 7.676646246817713 | lr: 0.001 | map10: 0.12876192201622907 | map10/std: 0.013731780105672352 | momentum: 0.9 | ndcg20: 0.07667427790946647 | ndcg20/std: 0.0037085096926519757
* Epoch (27/100) 


28/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.59s/it, _timer/_fps=63.491, _timer/batch_time=2.504, _timer/data_time=1.829, _timer/model_time=0.675, loss=7.701, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (28/100) loss: 7.683037421742424 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


28/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.73s/it, _timer/_fps=139.992, _timer/batch_time=1.136, _timer/data_time=0.026, _timer/model_time=1.109, loss=7.690, lr=1.000e-03, map10=0.146, momentum=0.900, ndcg20=0.079]


valid (28/100) loss: 7.666361554368953 | lr: 0.001 | map10: 0.12272786998624417 | map10/std: 0.013197650968888358 | momentum: 0.9 | ndcg20: 0.07455859021262867 | ndcg20/std: 0.003833866128296454
* Epoch (28/100) 


29/100 * Epoch (train): 100%|██████████| 3/3 [00:07<00:00,  2.33s/it, _timer/_fps=63.068, _timer/batch_time=2.521, _timer/data_time=1.819, _timer/model_time=0.702, loss=7.759, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (29/100) loss: 7.7025760701800605 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


29/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.57s/it, _timer/_fps=131.948, _timer/batch_time=1.205, _timer/data_time=0.027, _timer/model_time=1.178, loss=7.747, lr=1.000e-03, map10=0.148, momentum=0.900, ndcg20=0.079]


valid (29/100) loss: 7.683494069714838 | lr: 0.001 | map10: 0.12120449114396568 | map10/std: 0.015638173516056623 | momentum: 0.9 | ndcg20: 0.07317504348946754 | ndcg20/std: 0.005633134537055015
* Epoch (29/100) 


30/100 * Epoch (train): 100%|██████████| 3/3 [00:09<00:00,  3.03s/it, _timer/_fps=72.541, _timer/batch_time=2.192, _timer/data_time=1.474, _timer/model_time=0.718, loss=7.688, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (30/100) loss: 7.67553146905231 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


30/100 * Epoch (valid): 100%|██████████| 3/3 [00:04<00:00,  1.52s/it, _timer/_fps=140.272, _timer/batch_time=1.134, _timer/data_time=0.025, _timer/model_time=1.108, loss=7.670, lr=1.000e-03, map10=0.147, momentum=0.900, ndcg20=0.079]


valid (30/100) loss: 7.650187534298165 | lr: 0.001 | map10: 0.11585952035094516 | map10/std: 0.017758133705273457 | momentum: 0.9 | ndcg20: 0.07079297593959931 | ndcg20/std: 0.005553716875240385
* Epoch (30/100) 


31/100 * Epoch (train): 100%|██████████| 3/3 [00:12<00:00,  4.04s/it, _timer/_fps=40.987, _timer/batch_time=3.879, _timer/data_time=2.874, _timer/model_time=1.006, loss=7.661, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (31/100) loss: 7.662603519712582 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


31/100 * Epoch (valid): 100%|██████████| 3/3 [00:05<00:00,  1.71s/it, _timer/_fps=132.248, _timer/batch_time=1.202, _timer/data_time=0.025, _timer/model_time=1.177, loss=7.640, lr=1.000e-03, map10=0.159, momentum=0.900, ndcg20=0.086]

valid (31/100) loss: 7.636357813407341 | lr: 0.001 | map10: 0.12341881009310084 | map10/std: 0.020148011938828616 | momentum: 0.9 | ndcg20: 0.07631984004409409 | ndcg20/std: 0.005762211191972171
* Epoch (31/100) 





In [None]:
test_runner = RecSysRunner(model=model)

In [None]:
test_dataset = MyDataset(ds=joined, num_items=n_items, phase='test',item2idx=item2idx)


inference_loader = DataLoader(test_dataset, 
                              batch_size=joined.shape[0]//100, 
                              collate_fn=collate_fn_train,)

preds = []

for prediction in tqdm(runner.predict_loader(loader=inference_loader)):
    preds.extend(prediction.detach().cpu().numpy().tolist())
    
print(len(preds))
assert len(preds) == joined.shape[0]

joined['preds_bert4rec'] = preds
joined['recs_bert4rec_10'] = joined['preds_bert4rec'].apply(lambda x: np.argsort(-np.array(x))[:10])
joined['recs_bert4rec_10'] = joined['recs_bert4rec_10'].apply(lambda x: [idx2item[t-1] for t in x])
joined['recs_bert4rec_5'] = joined['preds_bert4rec'].apply(lambda x: np.argsort(-np.array(x))[:5])
joined['recs_bert4rec_5'] = joined['recs_bert4rec_5'].apply(lambda x: [idx2item[t-1] for t in x])
joined.drop(['preds_bert4rec'],axis=1, inplace=True)
joined.head()

112it [00:08, 13.18it/s]


671


Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions,recs_bert4rec_10,recs_bert4rec_5
0,1,"[(2294, 2009-12-14 02:51:48, 0, 2), (2455, 200...","[(1287, 2009-12-14 02:53:07, 0, 2), (1953, 200...","[(2193, 2009-12-14 02:53:18, 0, 2), (2968, 200...","[593, 356, 296, 260, 480, 2571, 318, 110, 1270...","[593, 356, 296, 260, 480]"
1,2,"[(150, 1996-06-21 11:09:55, 4, 11), (296, 1996...","[(551, 1996-06-21 11:16:07, 4, 11), (273, 1996...","[(319, 1996-06-21 11:18:38, 4, 11), (485, 1996...","[260, 2571, 318, 1270, 2959, 608, 1196, 1, 344...","[260, 2571, 318, 1270, 2959]"
2,3,"[(355, 2011-02-28 02:53:09, 0, 2), (1271, 2011...","[(2028, 2011-02-28 19:37:42, 0, 19), (110, 201...","[(1721, 2011-02-28 20:00:36, 0, 20), (377, 201...","[260, 480, 2571, 110, 1270, 457, 47, 608, 1196...","[260, 480, 2571, 110, 1270]"
3,4,"[(1210, 2000-02-05 19:25:14, 5, 19), (2734, 20...","[(3169, 2000-02-07 10:21:44, 0, 10), (1298, 20...","[(1858, 2000-02-07 10:35:38, 0, 10), (3108, 20...","[593, 356, 2571, 318, 110, 2959, 457, 1, 608, 47]","[593, 356, 2571, 318, 110]"
4,5,"[(1380, 2006-11-12 23:10:44, 6, 23), (1035, 20...","[(1968, 2006-11-12 23:32:04, 6, 23), (30707, 2...","[(33679, 2006-11-12 23:35:17, 6, 23), (5995, 2...","[593, 296, 260, 480, 2571, 318, 110, 1270, 457...","[593, 296, 260, 480, 2571]"


In [None]:
evaluate_recommender(joined, model_preds='recs_bert4rec_10')

{'ndcg': 0.1280307170558313, 'recall': 0.03206971653677037}

In [None]:
evaluate_recommender(joined, model_preds='recs_bert4rec_5')

{'ndcg': 0.08535465507261165, 'recall': 0.019260518398816713}

### ML-1m

In [None]:
rnames = ['user_id', 'movie_id', 'rating', 'timestamp']
df = pd.read_table('data/ratings.dat', sep='::',header=None, names=rnames, engine='python')
df = df.rename(columns={'userId': 'user_id', 'movie_id': 'item_id'})
df['timestamp'] = pd.to_datetime(df['timestamp'],unit='s')
df['weekday'] = pd.to_datetime(df.timestamp).dt.weekday
df['hour'] = pd.to_datetime(df.timestamp).dt.hour
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp,weekday,hour
0,1,1193,5,2000-12-31 22:12:40,6,22
1,1,661,3,2000-12-31 22:35:09,6,22
2,1,914,3,2000-12-31 22:32:48,6,22
3,1,3408,4,2000-12-31 22:04:35,6,22
4,1,2355,5,2001-01-06 23:38:11,5,23


In [None]:
splitter = RandomSplit(test_fraction=0.2)
train_df, valid_df, test_df = splitter(df)

In [None]:
train_grouped = train_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id, 
                                                                 x.timestamp,
                                                                 x.weekday,
                                                                 x.hour), key=lambda x: x[1])]
).reset_index()
train_grouped.rename({0:'train_interactions'}, axis=1, inplace=True)

valid_grouped = valid_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
valid_grouped.rename({0:'valid_interactions'}, axis=1, inplace=True)

test_grouped = test_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
test_grouped.rename({0:'test_interactions'}, axis=1, inplace=True)


train_grouped.head()

Unnamed: 0,user_id,train_interactions
0,1,"[(3186, 2000-12-31 22:00:19, 6, 22), (1270, 20..."
1,2,"[(1198, 2000-12-31 21:28:44, 6, 21), (1210, 20..."
2,3,"[(593, 2000-12-31 21:10:18, 6, 21), (2858, 200..."
3,4,"[(1210, 2000-12-31 20:18:44, 6, 20), (1097, 20..."
4,5,"[(2717, 2000-12-31 05:37:52, 6, 5), (908, 2000..."


In [None]:
joined = train_grouped.merge(valid_grouped).merge(test_grouped)
joined.head()

Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions
0,1,"[(3186, 2000-12-31 22:00:19, 6, 22), (1270, 20...","[(2791, 2000-12-31 22:36:28, 6, 22), (2321, 20...","[(2687, 2001-01-06 23:37:48, 5, 23), (745, 200..."
1,2,"[(1198, 2000-12-31 21:28:44, 6, 21), (1210, 20...","[(2028, 2000-12-31 21:56:13, 6, 21), (2571, 20...","[(1372, 2000-12-31 21:59:01, 6, 21), (1552, 20..."
2,3,"[(593, 2000-12-31 21:10:18, 6, 21), (2858, 200...","[(648, 2000-12-31 21:24:27, 6, 21), (2735, 200...","[(1270, 2000-12-31 21:30:31, 6, 21), (1079, 20..."
3,4,"[(1210, 2000-12-31 20:18:44, 6, 20), (1097, 20...","[(2947, 2000-12-31 20:23:50, 6, 20), (1214, 20...","[(1240, 2000-12-31 20:24:20, 6, 20), (2951, 20..."
4,5,"[(2717, 2000-12-31 05:37:52, 6, 5), (908, 2000...","[(2323, 2000-12-31 06:50:45, 6, 6), (272, 2000...","[(1715, 2000-12-31 06:58:11, 6, 6), (1653, 200..."


In [None]:
our_items = set()
for idx, row in tqdm(joined.iterrows()):
    for el in row.train_interactions:
        our_items.add(el[0])
        
len(our_items)

6040it [00:00, 40652.52it/s]


3636

In [None]:
item2idx = {k: i for i, k in enumerate(our_items)}
idx2item = {i: k for k, i in item2idx.items()}

In [None]:
n_items = len(item2idx)

train = MyDataset(ds=joined,
                  num_items=n_items, 
                  item2idx=item2idx,
                  phase='train')

valid = MyDataset(ds=joined,
                  num_items=n_items,
                  item2idx=item2idx,
                  phase='valid')

print(len(train),len(valid))

6040 6040


In [None]:
loaders = {
        "train": DataLoader(train, batch_size=256, collate_fn=collate_fn_train),
        "valid": DataLoader(valid, batch_size=256, collate_fn=collate_fn_valid),
}

In [None]:

model = BERT4Rec(n_items=len(item2idx)+1, mask_ratio=0.2, hidden_size=32)

optimizer = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
engine = dl.DeviceEngine('cpu')
hparams = {
    "anneal_cap": 0.2,
    "total_anneal_steps": 6000,
}


callbacks = [
    dl.NDCGCallback("logits", "targets", [20]),
    dl.MAPCallback("logits", "targets", [10]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.EarlyStoppingCallback(
        patience=7, loader_key="valid", metric_key="map10", minimize=False
    )
]


runner = RecSysRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    engine=engine,
    hparams=hparams,
    scheduler=lr_scheduler,
    loaders=loaders,
    num_epochs=100,
    verbose=True,
    timeit=True,
    callbacks=callbacks,
)


1/100 * Epoch (train): 100%|██████████| 24/24 [01:27<00:00,  3.64s/it, _timer/_fps=69.996, _timer/batch_time=2.172, _timer/data_time=1.459, _timer/model_time=0.712, loss=7.722, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (1/100) loss: 7.968992110119752 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


1/100 * Epoch (valid): 100%|██████████| 24/24 [00:56<00:00,  2.35s/it, _timer/_fps=90.903, _timer/batch_time=1.672, _timer/data_time=0.035, _timer/model_time=1.637, loss=7.689, lr=1.000e-03, map10=0.099, momentum=0.900, ndcg20=0.061] 


valid (1/100) loss: 7.696078093951901 | lr: 0.001 | map10: 0.11778197297394671 | map10/std: 0.014863163885628503 | momentum: 0.9 | ndcg20: 0.0633254435036751 | ndcg20/std: 0.0070999549995383925
* Epoch (1/100) 


2/100 * Epoch (train): 100%|██████████| 24/24 [01:23<00:00,  3.46s/it, _timer/_fps=74.131, _timer/batch_time=2.050, _timer/data_time=1.369, _timer/model_time=0.681, loss=7.427, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (2/100) loss: 7.577410570359388 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


2/100 * Epoch (valid): 100%|██████████| 24/24 [00:52<00:00,  2.20s/it, _timer/_fps=91.176, _timer/batch_time=1.667, _timer/data_time=0.049, _timer/model_time=1.619, loss=7.408, lr=1.000e-03, map10=0.114, momentum=0.900, ndcg20=0.069] 


valid (2/100) loss: 7.44668531228375 | lr: 0.001 | map10: 0.13358881045256232 | map10/std: 0.017118664769799057 | momentum: 0.9 | ndcg20: 0.07473827660675869 | ndcg20/std: 0.008000417844073668
* Epoch (2/100) 


3/100 * Epoch (train): 100%|██████████| 24/24 [01:04<00:00,  2.67s/it, _timer/_fps=47.509, _timer/batch_time=3.199, _timer/data_time=1.891, _timer/model_time=1.309, loss=7.342, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (3/100) loss: 7.417819742493283 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


3/100 * Epoch (valid): 100%|██████████| 24/24 [00:57<00:00,  2.40s/it, _timer/_fps=98.483, _timer/batch_time=1.543, _timer/data_time=0.028, _timer/model_time=1.516, loss=7.333, lr=1.000e-03, map10=0.115, momentum=0.900, ndcg20=0.069] 


valid (3/100) loss: 7.376139102076852 | lr: 0.001 | map10: 0.138499222340568 | map10/std: 0.01744302896699519 | momentum: 0.9 | ndcg20: 0.0791560530366487 | ndcg20/std: 0.007810760927903505
* Epoch (3/100) 


4/100 * Epoch (train): 100%|██████████| 24/24 [01:07<00:00,  2.82s/it, _timer/_fps=63.945, _timer/batch_time=2.377, _timer/data_time=1.694, _timer/model_time=0.683, loss=7.328, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (4/100) loss: 7.387026337756226 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


4/100 * Epoch (valid): 100%|██████████| 24/24 [00:45<00:00,  1.91s/it, _timer/_fps=116.360, _timer/batch_time=1.306, _timer/data_time=0.060, _timer/model_time=1.246, loss=7.321, lr=1.000e-03, map10=0.116, momentum=0.900, ndcg20=0.070]


valid (4/100) loss: 7.368003326693907 | lr: 0.001 | map10: 0.14277619045499146 | map10/std: 0.01827907416129996 | momentum: 0.9 | ndcg20: 0.0808225085009013 | ndcg20/std: 0.007908791033970772
* Epoch (4/100) 


5/100 * Epoch (train): 100%|██████████| 24/24 [01:04<00:00,  2.69s/it, _timer/_fps=62.362, _timer/batch_time=2.437, _timer/data_time=1.358, _timer/model_time=1.080, loss=7.334, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (5/100) loss: 7.366109337711965 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


5/100 * Epoch (valid): 100%|██████████| 24/24 [00:46<00:00,  1.96s/it, _timer/_fps=125.542, _timer/batch_time=1.211, _timer/data_time=0.048, _timer/model_time=1.163, loss=7.324, lr=1.000e-03, map10=0.111, momentum=0.900, ndcg20=0.067]


valid (5/100) loss: 7.338384742610502 | lr: 0.001 | map10: 0.14101563135519724 | map10/std: 0.016730697705583625 | momentum: 0.9 | ndcg20: 0.08011497274929326 | ndcg20/std: 0.007748707340547681
* Epoch (5/100) 


6/100 * Epoch (train): 100%|██████████| 24/24 [01:30<00:00,  3.75s/it, _timer/_fps=54.067, _timer/batch_time=2.811, _timer/data_time=1.809, _timer/model_time=1.003, loss=7.303, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (6/100) loss: 7.3454561751409875 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


6/100 * Epoch (valid): 100%|██████████| 24/24 [00:53<00:00,  2.22s/it, _timer/_fps=148.102, _timer/batch_time=1.026, _timer/data_time=0.041, _timer/model_time=0.985, loss=7.292, lr=1.000e-03, map10=0.117, momentum=0.900, ndcg20=0.067]


valid (6/100) loss: 7.3162160481838185 | lr: 0.001 | map10: 0.1402249941960076 | map10/std: 0.014694305850898981 | momentum: 0.9 | ndcg20: 0.08016644199162919 | ndcg20/std: 0.007482254828173791
* Epoch (6/100) 


7/100 * Epoch (train): 100%|██████████| 24/24 [00:57<00:00,  2.41s/it, _timer/_fps=90.924, _timer/batch_time=1.672, _timer/data_time=1.007, _timer/model_time=0.665, loss=7.290, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (7/100) loss: 7.326318434532115 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


7/100 * Epoch (valid): 100%|██████████| 24/24 [00:39<00:00,  1.66s/it, _timer/_fps=133.120, _timer/batch_time=1.142, _timer/data_time=0.058, _timer/model_time=1.084, loss=7.278, lr=1.000e-03, map10=0.119, momentum=0.900, ndcg20=0.069]


valid (7/100) loss: 7.296065888183795 | lr: 0.001 | map10: 0.1433888631841994 | map10/std: 0.016075402040686665 | momentum: 0.9 | ndcg20: 0.08125064933536859 | ndcg20/std: 0.007795165217683189
* Epoch (7/100) 


8/100 * Epoch (train): 100%|██████████| 24/24 [00:54<00:00,  2.27s/it, _timer/_fps=90.485, _timer/batch_time=1.680, _timer/data_time=1.031, _timer/model_time=0.649, loss=7.248, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (8/100) loss: 7.298139457828952 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


8/100 * Epoch (valid): 100%|██████████| 24/24 [00:38<00:00,  1.59s/it, _timer/_fps=159.936, _timer/batch_time=0.950, _timer/data_time=0.045, _timer/model_time=0.905, loss=7.228, lr=1.000e-03, map10=0.118, momentum=0.900, ndcg20=0.067]


valid (8/100) loss: 7.266701885880224 | lr: 0.001 | map10: 0.14374309970645713 | map10/std: 0.01802734557606383 | momentum: 0.9 | ndcg20: 0.08154814519629575 | ndcg20/std: 0.007931098961901728
* Epoch (8/100) 


9/100 * Epoch (train): 100%|██████████| 24/24 [00:54<00:00,  2.27s/it, _timer/_fps=91.718, _timer/batch_time=1.657, _timer/data_time=0.998, _timer/model_time=0.659, loss=7.225, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (9/100) loss: 7.287213707601788 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


9/100 * Epoch (valid): 100%|██████████| 24/24 [00:37<00:00,  1.58s/it, _timer/_fps=142.010, _timer/batch_time=1.070, _timer/data_time=0.042, _timer/model_time=1.029, loss=7.209, lr=1.000e-03, map10=0.112, momentum=0.900, ndcg20=0.067]


valid (9/100) loss: 7.255560097473346 | lr: 0.001 | map10: 0.14757276207998102 | map10/std: 0.01708933798534066 | momentum: 0.9 | ndcg20: 0.08324888626865995 | ndcg20/std: 0.00825387728796972
* Epoch (9/100) 


10/100 * Epoch (train): 100%|██████████| 24/24 [00:53<00:00,  2.23s/it, _timer/_fps=87.284, _timer/batch_time=1.741, _timer/data_time=0.960, _timer/model_time=0.781, loss=7.205, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (10/100) loss: 7.265993701859026 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


10/100 * Epoch (valid): 100%|██████████| 24/24 [00:38<00:00,  1.60s/it, _timer/_fps=131.755, _timer/batch_time=1.154, _timer/data_time=0.052, _timer/model_time=1.101, loss=7.193, lr=1.000e-03, map10=0.142, momentum=0.900, ndcg20=0.069]


valid (10/100) loss: 7.233913134265419 | lr: 0.001 | map10: 0.1529968572965521 | map10/std: 0.01683672716436295 | momentum: 0.9 | ndcg20: 0.08406947889667472 | ndcg20/std: 0.008478863738568506
* Epoch (10/100) 


11/100 * Epoch (train): 100%|██████████| 24/24 [00:56<00:00,  2.34s/it, _timer/_fps=89.945, _timer/batch_time=1.690, _timer/data_time=0.940, _timer/model_time=0.750, loss=7.167, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (11/100) loss: 7.245073545847508 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


11/100 * Epoch (valid): 100%|██████████| 24/24 [00:40<00:00,  1.69s/it, _timer/_fps=158.023, _timer/batch_time=0.962, _timer/data_time=0.034, _timer/model_time=0.928, loss=7.151, lr=1.000e-03, map10=0.147, momentum=0.900, ndcg20=0.070]


valid (11/100) loss: 7.213057524321095 | lr: 0.001 | map10: 0.15369764584184484 | map10/std: 0.01565346827855212 | momentum: 0.9 | ndcg20: 0.08401238052852897 | ndcg20/std: 0.007259509702008643
* Epoch (11/100) 


12/100 * Epoch (train): 100%|██████████| 24/24 [00:53<00:00,  2.23s/it, _timer/_fps=87.103, _timer/batch_time=1.745, _timer/data_time=1.095, _timer/model_time=0.650, loss=7.147, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (12/100) loss: 7.229004225825632 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


12/100 * Epoch (valid): 100%|██████████| 24/24 [00:39<00:00,  1.66s/it, _timer/_fps=119.932, _timer/batch_time=1.267, _timer/data_time=0.056, _timer/model_time=1.212, loss=7.133, lr=1.000e-03, map10=0.140, momentum=0.900, ndcg20=0.075]


valid (12/100) loss: 7.193159949068992 | lr: 0.001 | map10: 0.15492686051011872 | map10/std: 0.01691386646441662 | momentum: 0.9 | ndcg20: 0.08549870604908227 | ndcg20/std: 0.007956864186221638
* Epoch (12/100) 


13/100 * Epoch (train): 100%|██████████| 24/24 [00:59<00:00,  2.46s/it, _timer/_fps=83.747, _timer/batch_time=1.815, _timer/data_time=1.039, _timer/model_time=0.776, loss=7.157, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (13/100) loss: 7.210018116275206 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


13/100 * Epoch (valid): 100%|██████████| 24/24 [00:49<00:00,  2.06s/it, _timer/_fps=118.977, _timer/batch_time=1.278, _timer/data_time=0.042, _timer/model_time=1.236, loss=7.133, lr=1.000e-03, map10=0.108, momentum=0.900, ndcg20=0.068]


valid (13/100) loss: 7.166984260003298 | lr: 0.001 | map10: 0.15693908942653645 | map10/std: 0.01942349379738357 | momentum: 0.9 | ndcg20: 0.08575387442151443 | ndcg20/std: 0.007786602204776577
* Epoch (13/100) 


14/100 * Epoch (train): 100%|██████████| 24/24 [00:59<00:00,  2.47s/it, _timer/_fps=86.525, _timer/batch_time=1.757, _timer/data_time=1.072, _timer/model_time=0.684, loss=7.110, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (14/100) loss: 7.183091582999324 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


14/100 * Epoch (valid): 100%|██████████| 24/24 [00:52<00:00,  2.17s/it, _timer/_fps=110.745, _timer/batch_time=1.373, _timer/data_time=0.059, _timer/model_time=1.314, loss=7.084, lr=1.000e-03, map10=0.103, momentum=0.900, ndcg20=0.062]


valid (14/100) loss: 7.13479264745649 | lr: 0.001 | map10: 0.15654042466784157 | map10/std: 0.01937536501034222 | momentum: 0.9 | ndcg20: 0.08493426416488672 | ndcg20/std: 0.008237564436113964
* Epoch (14/100) 


15/100 * Epoch (train): 100%|██████████| 24/24 [01:02<00:00,  2.62s/it, _timer/_fps=89.710, _timer/batch_time=1.694, _timer/data_time=1.029, _timer/model_time=0.665, loss=7.103, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (15/100) loss: 7.159664724994178 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


15/100 * Epoch (valid): 100%|██████████| 24/24 [00:54<00:00,  2.29s/it, _timer/_fps=138.791, _timer/batch_time=1.095, _timer/data_time=0.053, _timer/model_time=1.042, loss=7.080, lr=1.000e-03, map10=0.105, momentum=0.900, ndcg20=0.063]


valid (15/100) loss: 7.109327103444284 | lr: 0.001 | map10: 0.15493882645834361 | map10/std: 0.018787622634696917 | momentum: 0.9 | ndcg20: 0.08545321798482479 | ndcg20/std: 0.00812456605316575
* Epoch (15/100) 


16/100 * Epoch (train): 100%|██████████| 24/24 [01:10<00:00,  2.92s/it, _timer/_fps=63.884, _timer/batch_time=2.379, _timer/data_time=1.495, _timer/model_time=0.884, loss=7.042, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (16/100) loss: 7.119824802322894 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


16/100 * Epoch (valid): 100%|██████████| 24/24 [00:55<00:00,  2.30s/it, _timer/_fps=115.419, _timer/batch_time=1.317, _timer/data_time=0.038, _timer/model_time=1.279, loss=7.016, lr=1.000e-03, map10=0.115, momentum=0.900, ndcg20=0.070]


valid (16/100) loss: 7.0683137792625175 | lr: 0.001 | map10: 0.1571009019550109 | map10/std: 0.019131392679905223 | momentum: 0.9 | ndcg20: 0.08872806360784745 | ndcg20/std: 0.008303695716012116
* Epoch (16/100) 


17/100 * Epoch (train): 100%|██████████| 24/24 [01:13<00:00,  3.05s/it, _timer/_fps=38.387, _timer/batch_time=3.960, _timer/data_time=2.839, _timer/model_time=1.120, loss=7.011, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (17/100) loss: 7.095871355044132 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


17/100 * Epoch (valid): 100%|██████████| 24/24 [00:52<00:00,  2.20s/it, _timer/_fps=99.719, _timer/batch_time=1.524, _timer/data_time=0.046, _timer/model_time=1.479, loss=6.985, lr=1.000e-03, map10=0.131, momentum=0.900, ndcg20=0.078] 


valid (17/100) loss: 7.0419729788571805 | lr: 0.001 | map10: 0.16369047784647406 | map10/std: 0.020902738227127217 | momentum: 0.9 | ndcg20: 0.092124769180421 | ndcg20/std: 0.009101143750758016
* Epoch (17/100) 


18/100 * Epoch (train): 100%|██████████| 24/24 [01:13<00:00,  3.08s/it, _timer/_fps=39.743, _timer/batch_time=3.825, _timer/data_time=2.788, _timer/model_time=1.037, loss=6.971, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (18/100) loss: 7.0615571975708 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


18/100 * Epoch (valid): 100%|██████████| 24/24 [00:54<00:00,  2.25s/it, _timer/_fps=98.925, _timer/batch_time=1.537, _timer/data_time=0.055, _timer/model_time=1.482, loss=6.935, lr=1.000e-03, map10=0.123, momentum=0.900, ndcg20=0.082] 


valid (18/100) loss: 7.000935981447334 | lr: 0.001 | map10: 0.16581868049719478 | map10/std: 0.021246749755743793 | momentum: 0.9 | ndcg20: 0.0937369241718425 | ndcg20/std: 0.00953617516723914
* Epoch (18/100) 


19/100 * Epoch (train): 100%|██████████| 24/24 [01:11<00:00,  2.99s/it, _timer/_fps=49.680, _timer/batch_time=3.060, _timer/data_time=1.402, _timer/model_time=1.658, loss=6.935, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (19/100) loss: 7.0227322035277915 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


19/100 * Epoch (valid): 100%|██████████| 24/24 [00:58<00:00,  2.44s/it, _timer/_fps=94.623, _timer/batch_time=1.606, _timer/data_time=0.043, _timer/model_time=1.563, loss=6.899, lr=1.000e-03, map10=0.142, momentum=0.900, ndcg20=0.087] 


valid (19/100) loss: 6.961496850196889 | lr: 0.001 | map10: 0.16946319029820675 | map10/std: 0.018224540854058766 | momentum: 0.9 | ndcg20: 0.09704142508917298 | ndcg20/std: 0.008182885760203304
* Epoch (19/100) 


20/100 * Epoch (train): 100%|██████████| 24/24 [01:15<00:00,  3.15s/it, _timer/_fps=45.630, _timer/batch_time=3.331, _timer/data_time=2.031, _timer/model_time=1.300, loss=6.891, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (20/100) loss: 6.987526513409141 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


20/100 * Epoch (valid): 100%|██████████| 24/24 [00:59<00:00,  2.48s/it, _timer/_fps=85.505, _timer/batch_time=1.778, _timer/data_time=0.042, _timer/model_time=1.735, loss=6.855, lr=1.000e-03, map10=0.139, momentum=0.900, ndcg20=0.086] 


valid (20/100) loss: 6.916438131142925 | lr: 0.001 | map10: 0.1667211383975894 | map10/std: 0.01930881050044727 | momentum: 0.9 | ndcg20: 0.0967261532185883 | ndcg20/std: 0.008701087772268912
* Epoch (20/100) 


21/100 * Epoch (train): 100%|██████████| 24/24 [01:08<00:00,  2.87s/it, _timer/_fps=43.838, _timer/batch_time=3.467, _timer/data_time=2.564, _timer/model_time=0.903, loss=6.827, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (21/100) loss: 6.9529601873940985 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


21/100 * Epoch (valid): 100%|██████████| 24/24 [00:48<00:00,  2.04s/it, _timer/_fps=113.668, _timer/batch_time=1.337, _timer/data_time=0.058, _timer/model_time=1.279, loss=6.783, lr=1.000e-03, map10=0.147, momentum=0.900, ndcg20=0.092]


valid (21/100) loss: 6.8901759791847885 | lr: 0.001 | map10: 0.17052794424508577 | map10/std: 0.018914324618814878 | momentum: 0.9 | ndcg20: 0.0996166905326559 | ndcg20/std: 0.00906277480312917
* Epoch (21/100) 


22/100 * Epoch (train): 100%|██████████| 24/24 [01:36<00:00,  4.01s/it, _timer/_fps=25.528, _timer/batch_time=5.954, _timer/data_time=4.267, _timer/model_time=1.687, loss=6.809, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]  


train (22/100) loss: 6.910055972724561 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


22/100 * Epoch (valid): 100%|██████████| 24/24 [01:00<00:00,  2.52s/it, _timer/_fps=74.522, _timer/batch_time=2.040, _timer/data_time=0.024, _timer/model_time=2.016, loss=6.760, lr=1.000e-03, map10=0.138, momentum=0.900, ndcg20=0.092] 


valid (22/100) loss: 6.838480527511496 | lr: 0.001 | map10: 0.17248369686256176 | map10/std: 0.019704656230571212 | momentum: 0.9 | ndcg20: 0.101501167126444 | ndcg20/std: 0.009196907976312748
* Epoch (22/100) 


23/100 * Epoch (train): 100%|██████████| 24/24 [02:11<00:00,  5.48s/it, _timer/_fps=21.468, _timer/batch_time=7.080, _timer/data_time=6.021, _timer/model_time=1.059, loss=6.792, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (23/100) loss: 6.879636921787894 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


23/100 * Epoch (valid): 100%|██████████| 24/24 [01:18<00:00,  3.28s/it, _timer/_fps=70.101, _timer/batch_time=2.168, _timer/data_time=0.030, _timer/model_time=2.139, loss=6.748, lr=1.000e-03, map10=0.139, momentum=0.900, ndcg20=0.090] 


valid (23/100) loss: 6.80198800295394 | lr: 0.001 | map10: 0.17701271014497766 | map10/std: 0.0195989743860648 | momentum: 0.9 | ndcg20: 0.10415454781213344 | ndcg20/std: 0.010252375447423884
* Epoch (23/100) 


24/100 * Epoch (train): 100%|██████████| 24/24 [01:18<00:00,  3.26s/it, _timer/_fps=57.794, _timer/batch_time=2.630, _timer/data_time=1.979, _timer/model_time=0.652, loss=6.766, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (24/100) loss: 6.848636629723555 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


24/100 * Epoch (valid): 100%|██████████| 24/24 [00:54<00:00,  2.25s/it, _timer/_fps=135.100, _timer/batch_time=1.125, _timer/data_time=0.049, _timer/model_time=1.076, loss=6.709, lr=1.000e-03, map10=0.151, momentum=0.900, ndcg20=0.093]


valid (24/100) loss: 6.765014719173607 | lr: 0.001 | map10: 0.17938289500230192 | map10/std: 0.01787269543769873 | momentum: 0.9 | ndcg20: 0.10777963879487373 | ndcg20/std: 0.00955463244128265
* Epoch (24/100) 


25/100 * Epoch (train): 100%|██████████| 24/24 [01:15<00:00,  3.15s/it, _timer/_fps=45.140, _timer/batch_time=3.367, _timer/data_time=2.691, _timer/model_time=0.677, loss=6.756, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (25/100) loss: 6.822920659046299 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


25/100 * Epoch (valid): 100%|██████████| 24/24 [00:51<00:00,  2.14s/it, _timer/_fps=109.294, _timer/batch_time=1.391, _timer/data_time=0.055, _timer/model_time=1.336, loss=6.705, lr=1.000e-03, map10=0.169, momentum=0.900, ndcg20=0.097]


valid (25/100) loss: 6.741386746412871 | lr: 0.001 | map10: 0.18198803323783622 | map10/std: 0.020491316711446942 | momentum: 0.9 | ndcg20: 0.10949627155105011 | ndcg20/std: 0.009485303134466378
* Epoch (25/100) 


26/100 * Epoch (train): 100%|██████████| 24/24 [01:18<00:00,  3.26s/it, _timer/_fps=42.137, _timer/batch_time=3.607, _timer/data_time=2.644, _timer/model_time=0.964, loss=6.715, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (26/100) loss: 6.79128344643195 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


26/100 * Epoch (valid): 100%|██████████| 24/24 [00:50<00:00,  2.10s/it, _timer/_fps=119.815, _timer/batch_time=1.269, _timer/data_time=0.025, _timer/model_time=1.244, loss=6.655, lr=1.000e-03, map10=0.154, momentum=0.900, ndcg20=0.096]


valid (26/100) loss: 6.7062338589043025 | lr: 0.001 | map10: 0.1858302708493164 | map10/std: 0.02253778451302508 | momentum: 0.9 | ndcg20: 0.11113376607563322 | ndcg20/std: 0.010497024460727805
* Epoch (26/100) 


27/100 * Epoch (train): 100%|██████████| 24/24 [01:02<00:00,  2.60s/it, _timer/_fps=59.268, _timer/batch_time=2.565, _timer/data_time=1.801, _timer/model_time=0.764, loss=6.705, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (27/100) loss: 6.7727412173290125 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


27/100 * Epoch (valid): 100%|██████████| 24/24 [00:48<00:00,  2.04s/it, _timer/_fps=142.955, _timer/batch_time=1.063, _timer/data_time=0.058, _timer/model_time=1.005, loss=6.634, lr=1.000e-03, map10=0.184, momentum=0.900, ndcg20=0.104]


valid (27/100) loss: 6.680901508457613 | lr: 0.001 | map10: 0.1924939314063811 | map10/std: 0.021898588864856818 | momentum: 0.9 | ndcg20: 0.11510710730063205 | ndcg20/std: 0.011241456130007444
* Epoch (27/100) 


28/100 * Epoch (train): 100%|██████████| 24/24 [01:00<00:00,  2.51s/it, _timer/_fps=89.143, _timer/batch_time=1.705, _timer/data_time=1.017, _timer/model_time=0.688, loss=6.669, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (28/100) loss: 6.748376546948163 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


28/100 * Epoch (valid): 100%|██████████| 24/24 [00:47<00:00,  1.99s/it, _timer/_fps=126.286, _timer/batch_time=1.204, _timer/data_time=0.035, _timer/model_time=1.169, loss=6.600, lr=1.000e-03, map10=0.191, momentum=0.900, ndcg20=0.107]


valid (28/100) loss: 6.657979586266523 | lr: 0.001 | map10: 0.189731111017284 | map10/std: 0.02265235869574167 | momentum: 0.9 | ndcg20: 0.11626115729477232 | ndcg20/std: 0.011332912087982917
* Epoch (28/100) 


29/100 * Epoch (train): 100%|██████████| 24/24 [00:59<00:00,  2.47s/it, _timer/_fps=67.182, _timer/batch_time=2.263, _timer/data_time=1.246, _timer/model_time=1.017, loss=6.626, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (29/100) loss: 6.729793691319344 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


29/100 * Epoch (valid): 100%|██████████| 24/24 [00:58<00:00,  2.45s/it, _timer/_fps=128.858, _timer/batch_time=1.180, _timer/data_time=0.054, _timer/model_time=1.126, loss=6.562, lr=1.000e-03, map10=0.172, momentum=0.900, ndcg20=0.103]


valid (29/100) loss: 6.634550567020645 | lr: 0.001 | map10: 0.1883321942872559 | map10/std: 0.020797366075634267 | momentum: 0.9 | ndcg20: 0.11488855081480859 | ndcg20/std: 0.011094841368910015
* Epoch (29/100) 


30/100 * Epoch (train): 100%|██████████| 24/24 [01:05<00:00,  2.71s/it, _timer/_fps=78.066, _timer/batch_time=1.947, _timer/data_time=1.189, _timer/model_time=0.758, loss=6.622, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (30/100) loss: 6.697151129766806 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


30/100 * Epoch (valid): 100%|██████████| 24/24 [00:49<00:00,  2.06s/it, _timer/_fps=153.597, _timer/batch_time=0.990, _timer/data_time=0.045, _timer/model_time=0.944, loss=6.565, lr=1.000e-03, map10=0.181, momentum=0.900, ndcg20=0.106]


valid (30/100) loss: 6.590060978693678 | lr: 0.001 | map10: 0.19366304825078573 | map10/std: 0.02483866694320072 | momentum: 0.9 | ndcg20: 0.11837045946065952 | ndcg20/std: 0.011592027519667118
* Epoch (30/100) 


31/100 * Epoch (train): 100%|██████████| 24/24 [01:06<00:00,  2.79s/it, _timer/_fps=81.198, _timer/batch_time=1.872, _timer/data_time=1.154, _timer/model_time=0.718, loss=6.581, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (31/100) loss: 6.6834424978849905 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


31/100 * Epoch (valid): 100%|██████████| 24/24 [00:51<00:00,  2.13s/it, _timer/_fps=107.932, _timer/batch_time=1.408, _timer/data_time=0.064, _timer/model_time=1.344, loss=6.504, lr=1.000e-03, map10=0.172, momentum=0.900, ndcg20=0.109]


valid (31/100) loss: 6.577102447661343 | lr: 0.001 | map10: 0.1949030445111508 | map10/std: 0.022845030236548474 | momentum: 0.9 | ndcg20: 0.11987244711806443 | ndcg20/std: 0.011931849134742118
* Epoch (31/100) 


32/100 * Epoch (train): 100%|██████████| 24/24 [01:04<00:00,  2.71s/it, _timer/_fps=54.748, _timer/batch_time=2.776, _timer/data_time=1.978, _timer/model_time=0.798, loss=6.544, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (32/100) loss: 6.653389913672643 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


32/100 * Epoch (valid): 100%|██████████| 24/24 [00:45<00:00,  1.90s/it, _timer/_fps=75.255, _timer/batch_time=2.020, _timer/data_time=0.049, _timer/model_time=1.971, loss=6.463, lr=1.000e-03, map10=0.171, momentum=0.900, ndcg20=0.108] 


valid (32/100) loss: 6.542886693588158 | lr: 0.001 | map10: 0.19771962852667496 | map10/std: 0.023799018106523635 | momentum: 0.9 | ndcg20: 0.12001581475237348 | ndcg20/std: 0.01276453574942295
* Epoch (32/100) 


33/100 * Epoch (train): 100%|██████████| 24/24 [01:02<00:00,  2.59s/it, _timer/_fps=74.506, _timer/batch_time=2.040, _timer/data_time=1.396, _timer/model_time=0.644, loss=6.557, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (33/100) loss: 6.638781085235394 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


33/100 * Epoch (valid): 100%|██████████| 24/24 [00:45<00:00,  1.89s/it, _timer/_fps=119.880, _timer/batch_time=1.268, _timer/data_time=0.032, _timer/model_time=1.235, loss=6.468, lr=1.000e-03, map10=0.184, momentum=0.900, ndcg20=0.110]


valid (33/100) loss: 6.517829044923088 | lr: 0.001 | map10: 0.19736822446845223 | map10/std: 0.02360102965918906 | momentum: 0.9 | ndcg20: 0.12104550612880692 | ndcg20/std: 0.013439396381597912
* Epoch (33/100) 


34/100 * Epoch (train): 100%|██████████| 24/24 [01:04<00:00,  2.69s/it, _timer/_fps=70.560, _timer/batch_time=2.154, _timer/data_time=1.454, _timer/model_time=0.700, loss=6.564, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (34/100) loss: 6.619838951439257 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


34/100 * Epoch (valid): 100%|██████████| 24/24 [00:47<00:00,  1.99s/it, _timer/_fps=129.338, _timer/batch_time=1.175, _timer/data_time=0.046, _timer/model_time=1.129, loss=6.469, lr=1.000e-03, map10=0.184, momentum=0.900, ndcg20=0.114]


valid (34/100) loss: 6.502798514966143 | lr: 0.001 | map10: 0.20269575912431373 | map10/std: 0.02672678608552474 | momentum: 0.9 | ndcg20: 0.12453843421296568 | ndcg20/std: 0.013742270528290976
* Epoch (34/100) 


35/100 * Epoch (train): 100%|██████████| 24/24 [01:06<00:00,  2.79s/it, _timer/_fps=75.661, _timer/batch_time=2.009, _timer/data_time=1.322, _timer/model_time=0.687, loss=6.541, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (35/100) loss: 6.601888338303724 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


35/100 * Epoch (valid): 100%|██████████| 24/24 [00:42<00:00,  1.76s/it, _timer/_fps=120.972, _timer/batch_time=1.256, _timer/data_time=0.045, _timer/model_time=1.212, loss=6.437, lr=1.000e-03, map10=0.201, momentum=0.900, ndcg20=0.121]


valid (35/100) loss: 6.478465913936792 | lr: 0.001 | map10: 0.20745413524425582 | map10/std: 0.025684418557995924 | momentum: 0.9 | ndcg20: 0.12764143558922195 | ndcg20/std: 0.013151161101910655
* Epoch (35/100) 


36/100 * Epoch (train): 100%|██████████| 24/24 [01:05<00:00,  2.73s/it, _timer/_fps=72.689, _timer/batch_time=2.091, _timer/data_time=1.395, _timer/model_time=0.696, loss=6.472, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (36/100) loss: 6.578578669819612 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


36/100 * Epoch (valid): 100%|██████████| 24/24 [00:48<00:00,  2.01s/it, _timer/_fps=150.960, _timer/batch_time=1.007, _timer/data_time=0.037, _timer/model_time=0.970, loss=6.368, lr=1.000e-03, map10=0.205, momentum=0.900, ndcg20=0.120]


valid (36/100) loss: 6.445095677407371 | lr: 0.001 | map10: 0.2066680617087724 | map10/std: 0.024758004975583316 | momentum: 0.9 | ndcg20: 0.12850222621137733 | ndcg20/std: 0.013823015861837257
* Epoch (36/100) 


37/100 * Epoch (train): 100%|██████████| 24/24 [01:04<00:00,  2.69s/it, _timer/_fps=77.592, _timer/batch_time=1.959, _timer/data_time=1.244, _timer/model_time=0.715, loss=6.504, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (37/100) loss: 6.561814404165508 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


37/100 * Epoch (valid): 100%|██████████| 24/24 [00:48<00:00,  2.04s/it, _timer/_fps=122.685, _timer/batch_time=1.239, _timer/data_time=0.036, _timer/model_time=1.203, loss=6.395, lr=1.000e-03, map10=0.198, momentum=0.900, ndcg20=0.121]


valid (37/100) loss: 6.426561844427854 | lr: 0.001 | map10: 0.20957206495550293 | map10/std: 0.02688632397830897 | momentum: 0.9 | ndcg20: 0.12978295797346445 | ndcg20/std: 0.013782442339010305
* Epoch (37/100) 


38/100 * Epoch (train): 100%|██████████| 24/24 [01:10<00:00,  2.94s/it, _timer/_fps=73.645, _timer/batch_time=2.064, _timer/data_time=1.377, _timer/model_time=0.687, loss=6.507, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (38/100) loss: 6.526572869787152 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


38/100 * Epoch (valid): 100%|██████████| 24/24 [00:49<00:00,  2.08s/it, _timer/_fps=114.709, _timer/batch_time=1.325, _timer/data_time=0.043, _timer/model_time=1.283, loss=6.395, lr=1.000e-03, map10=0.210, momentum=0.900, ndcg20=0.130]


valid (38/100) loss: 6.391361083100174 | lr: 0.001 | map10: 0.2170935104027489 | map10/std: 0.02738687346415108 | momentum: 0.9 | ndcg20: 0.13438393315732086 | ndcg20/std: 0.01417458399045985
* Epoch (38/100) 


39/100 * Epoch (train): 100%|██████████| 24/24 [01:11<00:00,  2.99s/it, _timer/_fps=74.618, _timer/batch_time=2.037, _timer/data_time=1.263, _timer/model_time=0.774, loss=6.444, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (39/100) loss: 6.519992444373124 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


39/100 * Epoch (valid): 100%|██████████| 24/24 [00:51<00:00,  2.14s/it, _timer/_fps=102.035, _timer/batch_time=1.490, _timer/data_time=0.035, _timer/model_time=1.454, loss=6.330, lr=1.000e-03, map10=0.208, momentum=0.900, ndcg20=0.127]


valid (39/100) loss: 6.378594615759439 | lr: 0.001 | map10: 0.21917219874479915 | map10/std: 0.025155079787516858 | momentum: 0.9 | ndcg20: 0.13514056405089545 | ndcg20/std: 0.013776267459495891
* Epoch (39/100) 


40/100 * Epoch (train): 100%|██████████| 24/24 [01:09<00:00,  2.90s/it, _timer/_fps=88.190, _timer/batch_time=1.724, _timer/data_time=1.042, _timer/model_time=0.681, loss=6.470, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (40/100) loss: 6.488068341893077 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


40/100 * Epoch (valid): 100%|██████████| 24/24 [00:55<00:00,  2.32s/it, _timer/_fps=95.626, _timer/batch_time=1.590, _timer/data_time=0.047, _timer/model_time=1.543, loss=6.345, lr=1.000e-03, map10=0.209, momentum=0.900, ndcg20=0.131] 


valid (40/100) loss: 6.345786936551529 | lr: 0.001 | map10: 0.22132438160725776 | map10/std: 0.024749441893409704 | momentum: 0.9 | ndcg20: 0.13672125596084342 | ndcg20/std: 0.013870128647222583
* Epoch (40/100) 


41/100 * Epoch (train): 100%|██████████| 24/24 [01:20<00:00,  3.37s/it, _timer/_fps=75.832, _timer/batch_time=2.004, _timer/data_time=1.305, _timer/model_time=0.700, loss=6.432, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (41/100) loss: 6.48143572144161 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


41/100 * Epoch (valid): 100%|██████████| 24/24 [00:59<00:00,  2.46s/it, _timer/_fps=96.539, _timer/batch_time=1.574, _timer/data_time=0.043, _timer/model_time=1.531, loss=6.317, lr=1.000e-03, map10=0.225, momentum=0.900, ndcg20=0.135] 


valid (41/100) loss: 6.333691952560121 | lr: 0.001 | map10: 0.23081112367822637 | map10/std: 0.02564751003293549 | momentum: 0.9 | ndcg20: 0.14232137369004305 | ndcg20/std: 0.014690524379859854
* Epoch (41/100) 


42/100 * Epoch (train): 100%|██████████| 24/24 [01:21<00:00,  3.38s/it, _timer/_fps=32.847, _timer/batch_time=4.627, _timer/data_time=3.410, _timer/model_time=1.217, loss=6.400, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (42/100) loss: 6.464614858058904 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


42/100 * Epoch (valid): 100%|██████████| 24/24 [00:48<00:00,  2.00s/it, _timer/_fps=108.669, _timer/batch_time=1.399, _timer/data_time=0.041, _timer/model_time=1.358, loss=6.254, lr=1.000e-03, map10=0.231, momentum=0.900, ndcg20=0.132]


valid (42/100) loss: 6.307091792529783 | lr: 0.001 | map10: 0.22901872925410996 | map10/std: 0.023618360093543095 | momentum: 0.9 | ndcg20: 0.14178586016032874 | ndcg20/std: 0.014603108893265578
* Epoch (42/100) 


43/100 * Epoch (train): 100%|██████████| 24/24 [01:30<00:00,  3.76s/it, _timer/_fps=43.599, _timer/batch_time=3.486, _timer/data_time=2.679, _timer/model_time=0.808, loss=6.362, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]  


train (43/100) loss: 6.453959694919207 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


43/100 * Epoch (valid): 100%|██████████| 24/24 [00:59<00:00,  2.47s/it, _timer/_fps=73.934, _timer/batch_time=2.056, _timer/data_time=0.048, _timer/model_time=2.008, loss=6.235, lr=1.000e-03, map10=0.244, momentum=0.900, ndcg20=0.140] 


valid (43/100) loss: 6.300668001490713 | lr: 0.001 | map10: 0.23334016219669623 | map10/std: 0.02674241708715889 | momentum: 0.9 | ndcg20: 0.14672488011666485 | ndcg20/std: 0.01579835716214185
* Epoch (43/100) 


44/100 * Epoch (train): 100%|██████████| 24/24 [00:59<00:00,  2.50s/it, _timer/_fps=89.985, _timer/batch_time=1.689, _timer/data_time=0.991, _timer/model_time=0.698, loss=6.330, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (44/100) loss: 6.420362042433379 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


44/100 * Epoch (valid): 100%|██████████| 24/24 [00:44<00:00,  1.83s/it, _timer/_fps=130.028, _timer/batch_time=1.169, _timer/data_time=0.023, _timer/model_time=1.146, loss=6.202, lr=1.000e-03, map10=0.229, momentum=0.900, ndcg20=0.137]


valid (44/100) loss: 6.2633563692206575 | lr: 0.001 | map10: 0.22804616257449653 | map10/std: 0.025345127468740725 | momentum: 0.9 | ndcg20: 0.14566025102375363 | ndcg20/std: 0.015303286631652762
* Epoch (44/100) 


45/100 * Epoch (train): 100%|██████████| 24/24 [00:59<00:00,  2.48s/it, _timer/_fps=71.052, _timer/batch_time=2.139, _timer/data_time=1.169, _timer/model_time=0.970, loss=6.336, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (45/100) loss: 6.408497537524496 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


45/100 * Epoch (valid): 100%|██████████| 24/24 [00:40<00:00,  1.68s/it, _timer/_fps=139.776, _timer/batch_time=1.087, _timer/data_time=0.054, _timer/model_time=1.033, loss=6.191, lr=1.000e-03, map10=0.218, momentum=0.900, ndcg20=0.133]


valid (45/100) loss: 6.243302079383901 | lr: 0.001 | map10: 0.23042306450029088 | map10/std: 0.02399405454463128 | momentum: 0.9 | ndcg20: 0.1476819961473642 | ndcg20/std: 0.01613933879992255
* Epoch (45/100) 


46/100 * Epoch (train): 100%|██████████| 24/24 [01:05<00:00,  2.71s/it, _timer/_fps=79.232, _timer/batch_time=1.918, _timer/data_time=1.130, _timer/model_time=0.789, loss=6.303, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (46/100) loss: 6.39986747905908 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


46/100 * Epoch (valid): 100%|██████████| 24/24 [00:42<00:00,  1.78s/it, _timer/_fps=115.671, _timer/batch_time=1.314, _timer/data_time=0.056, _timer/model_time=1.258, loss=6.176, lr=1.000e-03, map10=0.223, momentum=0.900, ndcg20=0.137]


valid (46/100) loss: 6.236429213214393 | lr: 0.001 | map10: 0.23344276948085682 | map10/std: 0.026314575729814313 | momentum: 0.9 | ndcg20: 0.14933022986974145 | ndcg20/std: 0.016506836362336207
* Epoch (46/100) 


47/100 * Epoch (train): 100%|██████████| 24/24 [01:15<00:00,  3.13s/it, _timer/_fps=71.284, _timer/batch_time=2.132, _timer/data_time=1.202, _timer/model_time=0.930, loss=6.324, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (47/100) loss: 6.384315533669579 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


47/100 * Epoch (valid): 100%|██████████| 24/24 [00:52<00:00,  2.19s/it, _timer/_fps=145.954, _timer/batch_time=1.041, _timer/data_time=0.042, _timer/model_time=0.999, loss=6.185, lr=1.000e-03, map10=0.218, momentum=0.900, ndcg20=0.137]


valid (47/100) loss: 6.2120098240328145 | lr: 0.001 | map10: 0.2353779412855376 | map10/std: 0.022933576119037277 | momentum: 0.9 | ndcg20: 0.15118353372772797 | ndcg20/std: 0.016290186580393926
* Epoch (47/100) 


48/100 * Epoch (train): 100%|██████████| 24/24 [00:58<00:00,  2.43s/it, _timer/_fps=94.604, _timer/batch_time=1.607, _timer/data_time=0.945, _timer/model_time=0.662, loss=6.285, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (48/100) loss: 6.370455470937767 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


48/100 * Epoch (valid): 100%|██████████| 24/24 [00:38<00:00,  1.60s/it, _timer/_fps=162.614, _timer/batch_time=0.935, _timer/data_time=0.037, _timer/model_time=0.898, loss=6.144, lr=1.000e-03, map10=0.230, momentum=0.900, ndcg20=0.144]


valid (48/100) loss: 6.197527838069082 | lr: 0.001 | map10: 0.23884614419858186 | map10/std: 0.023812270492349474 | momentum: 0.9 | ndcg20: 0.15430565807993044 | ndcg20/std: 0.016419111704202703
* Epoch (48/100) 


49/100 * Epoch (train): 100%|██████████| 24/24 [00:55<00:00,  2.32s/it, _timer/_fps=89.239, _timer/batch_time=1.703, _timer/data_time=0.929, _timer/model_time=0.774, loss=6.297, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (49/100) loss: 6.362612932723089 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


49/100 * Epoch (valid): 100%|██████████| 24/24 [00:37<00:00,  1.57s/it, _timer/_fps=151.919, _timer/batch_time=1.001, _timer/data_time=0.054, _timer/model_time=0.947, loss=6.158, lr=1.000e-03, map10=0.229, momentum=0.900, ndcg20=0.144]


valid (49/100) loss: 6.189583620488248 | lr: 0.001 | map10: 0.2382363862943965 | map10/std: 0.027189937791198604 | momentum: 0.9 | ndcg20: 0.15358493724800892 | ndcg20/std: 0.016949029023290454
* Epoch (49/100) 


50/100 * Epoch (train): 100%|██████████| 24/24 [00:55<00:00,  2.31s/it, _timer/_fps=78.989, _timer/batch_time=1.924, _timer/data_time=1.209, _timer/model_time=0.715, loss=6.312, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (50/100) loss: 6.346968823868708 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


50/100 * Epoch (valid): 100%|██████████| 24/24 [00:46<00:00,  1.93s/it, _timer/_fps=133.759, _timer/batch_time=1.136, _timer/data_time=0.051, _timer/model_time=1.085, loss=6.168, lr=1.000e-03, map10=0.236, momentum=0.900, ndcg20=0.144]


valid (50/100) loss: 6.173709265601556 | lr: 0.001 | map10: 0.23719917707490604 | map10/std: 0.023094777693779652 | momentum: 0.9 | ndcg20: 0.15342786916044374 | ndcg20/std: 0.01702856308944005
* Epoch (50/100) 


51/100 * Epoch (train): 100%|██████████| 24/24 [01:11<00:00,  2.96s/it, _timer/_fps=73.395, _timer/batch_time=2.071, _timer/data_time=1.320, _timer/model_time=0.751, loss=6.324, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (51/100) loss: 6.326021058669943 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


51/100 * Epoch (valid): 100%|██████████| 24/24 [00:47<00:00,  1.99s/it, _timer/_fps=137.002, _timer/batch_time=1.109, _timer/data_time=0.035, _timer/model_time=1.074, loss=6.188, lr=1.000e-03, map10=0.231, momentum=0.900, ndcg20=0.145]


valid (51/100) loss: 6.148965594942206 | lr: 0.001 | map10: 0.23819328959019773 | map10/std: 0.027236450454308607 | momentum: 0.9 | ndcg20: 0.1554729727127694 | ndcg20/std: 0.016874802941349118
* Epoch (51/100) 


52/100 * Epoch (train): 100%|██████████| 24/24 [01:10<00:00,  2.92s/it, _timer/_fps=80.260, _timer/batch_time=1.894, _timer/data_time=1.223, _timer/model_time=0.671, loss=6.295, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (52/100) loss: 6.327545840534943 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


52/100 * Epoch (valid): 100%|██████████| 24/24 [00:44<00:00,  1.84s/it, _timer/_fps=145.104, _timer/batch_time=1.048, _timer/data_time=0.050, _timer/model_time=0.998, loss=6.142, lr=1.000e-03, map10=0.232, momentum=0.900, ndcg20=0.147]


valid (52/100) loss: 6.14289610401684 | lr: 0.001 | map10: 0.23808181645064952 | map10/std: 0.027471384619788652 | momentum: 0.9 | ndcg20: 0.15548159144572074 | ndcg20/std: 0.0175061221480792
* Epoch (52/100) 


53/100 * Epoch (train): 100%|██████████| 24/24 [00:54<00:00,  2.27s/it, _timer/_fps=100.255, _timer/batch_time=1.516, _timer/data_time=0.885, _timer/model_time=0.631, loss=6.244, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (53/100) loss: 6.3128465058787775 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


53/100 * Epoch (valid): 100%|██████████| 24/24 [00:40<00:00,  1.68s/it, _timer/_fps=147.732, _timer/batch_time=1.029, _timer/data_time=0.044, _timer/model_time=0.985, loss=6.092, lr=1.000e-03, map10=0.250, momentum=0.900, ndcg20=0.152]


valid (53/100) loss: 6.133731497202488 | lr: 0.001 | map10: 0.23719367870431862 | map10/std: 0.02642576108992254 | momentum: 0.9 | ndcg20: 0.15404126809922272 | ndcg20/std: 0.01712050588648872
* Epoch (53/100) 


54/100 * Epoch (train): 100%|██████████| 24/24 [00:54<00:00,  2.25s/it, _timer/_fps=87.068, _timer/batch_time=1.746, _timer/data_time=1.068, _timer/model_time=0.677, loss=6.264, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (54/100) loss: 6.300816230900241 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


54/100 * Epoch (valid): 100%|██████████| 24/24 [00:42<00:00,  1.76s/it, _timer/_fps=136.202, _timer/batch_time=1.116, _timer/data_time=0.040, _timer/model_time=1.076, loss=6.089, lr=1.000e-03, map10=0.240, momentum=0.900, ndcg20=0.147]


valid (54/100) loss: 6.115810972807423 | lr: 0.001 | map10: 0.23737514068354046 | map10/std: 0.023480422168979775 | momentum: 0.9 | ndcg20: 0.15402954491163723 | ndcg20/std: 0.017437280896189143
* Epoch (54/100) 


55/100 * Epoch (train): 100%|██████████| 24/24 [00:53<00:00,  2.25s/it, _timer/_fps=93.128, _timer/batch_time=1.632, _timer/data_time=0.913, _timer/model_time=0.720, loss=6.220, lr=1.000e-03, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00] 


train (55/100) loss: 6.2903915954741425 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


55/100 * Epoch (valid): 100%|██████████| 24/24 [00:50<00:00,  2.11s/it, _timer/_fps=157.895, _timer/batch_time=0.963, _timer/data_time=0.037, _timer/model_time=0.926, loss=6.047, lr=1.000e-03, map10=0.239, momentum=0.900, ndcg20=0.152]

valid (55/100) loss: 6.105169727628594 | lr: 0.001 | map10: 0.23468514505206356 | map10/std: 0.021876374812646388 | momentum: 0.9 | ndcg20: 0.15346711534538016 | ndcg20/std: 0.016831209036916248
* Epoch (55/100) 





In [None]:
test_runner = RecSysRunner(model=model)

In [None]:
test_dataset = MyDataset(ds=joined, num_items=n_items, phase='test',item2idx=item2idx)


inference_loader = DataLoader(test_dataset, 
                              batch_size=joined.shape[0]//100, 
                              collate_fn=collate_fn_train,)

preds = []

for prediction in tqdm(runner.predict_loader(loader=inference_loader)):
    preds.extend(prediction.detach().cpu().numpy().tolist())
    
print(len(preds))
assert len(preds) == joined.shape[0]

joined['preds_bert4rec'] = preds
joined['recs_bert4rec_10'] = joined['preds_bert4rec'].apply(lambda x: np.argsort(-np.array(x))[:10])
joined['recs_bert4rec_10'] = joined['recs_bert4rec_10'].apply(lambda x: [idx2item[t-1] for t in x])
joined['recs_bert4rec_5'] = joined['preds_bert4rec'].apply(lambda x: np.argsort(-np.array(x))[:5])
joined['recs_bert4rec_5'] = joined['recs_bert4rec_5'].apply(lambda x: [idx2item[t-1] for t in x])
joined.drop(['preds_bert4rec'],axis=1, inplace=True)
joined.head()

101it [00:31,  3.19it/s]


6040


Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions,recs_bert4rec,recs_bert4rec_10,recs_bert4rec_5
0,1,"[(3186, 2000-12-31 22:00:19, 6, 22), (1270, 20...","[(2791, 2000-12-31 22:36:28, 6, 22), (2321, 20...","[(2687, 2001-01-06 23:37:48, 5, 23), (745, 200...","[2396, 593, 527, 1265, 34, 1617, 2858, 318, 1,...","[2396, 593, 527, 1265, 34, 1617, 2858, 318, 1,...","[2396, 593, 527, 1265, 34]"
1,2,"[(1198, 2000-12-31 21:28:44, 6, 21), (1210, 20...","[(2028, 2000-12-31 21:56:13, 6, 21), (2571, 20...","[(1372, 2000-12-31 21:59:01, 6, 21), (1552, 20...","[2289, 1641, 1394, 2918, 34, 2028, 1079, 1197,...","[2289, 1641, 1394, 2918, 34, 2028, 1079, 1197,...","[2289, 1641, 1394, 2918, 34]"
2,3,"[(593, 2000-12-31 21:10:18, 6, 21), (2858, 200...","[(648, 2000-12-31 21:24:27, 6, 21), (2735, 200...","[(1270, 2000-12-31 21:30:31, 6, 21), (1079, 20...","[10, 316, 329, 380, 780, 1527, 349, 1552, 165,...","[10, 316, 329, 380, 780, 1527, 349, 1552, 165,...","[10, 316, 329, 380, 780]"
3,4,"[(1210, 2000-12-31 20:18:44, 6, 20), (1097, 20...","[(2947, 2000-12-31 20:23:50, 6, 20), (1214, 20...","[(1240, 2000-12-31 20:24:20, 6, 20), (2951, 20...","[589, 1240, 2571, 1214, 1200, 541, 110, 3703, ...","[589, 1240, 2571, 1214, 1200, 541, 110, 3703, ...","[589, 1240, 2571, 1214, 1200]"
4,5,"[(2717, 2000-12-31 05:37:52, 6, 5), (908, 2000...","[(2323, 2000-12-31 06:50:45, 6, 6), (272, 2000...","[(1715, 2000-12-31 06:58:11, 6, 6), (1653, 200...","[25, 337, 1358, 778, 300, 1885, 1682, 529, 117...","[25, 337, 1358, 778, 300, 1885, 1682, 529, 117...","[25, 337, 1358, 778, 300]"


In [None]:
evaluate_recommender(joined, model_preds='recs_bert4rec_10')

{'ndcg': 0.14581013099020618, 'recall': 0.0337442523686209}

In [None]:
evaluate_recommender(joined, model_preds='recs_bert4rec_5')

{'ndcg': 0.08740025855788641, 'recall': 0.01756498438783866}

### ML-20m

In [11]:
df = pd.read_csv('ML_20m.csv')
df = df.rename(columns={'userId': 'user_id', 'movieId': 'item_id'})
df['timestamp'] = pd.to_datetime(df['timestamp'],unit='s')
df['weekday'] = pd.to_datetime(df.timestamp).dt.weekday
df['hour'] = pd.to_datetime(df.timestamp).dt.hour
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp,weekday,hour
0,1,2,3.5,2005-04-02 23:53:47,5,23
1,1,29,3.5,2005-04-02 23:31:16,5,23
2,1,32,3.5,2005-04-02 23:33:39,5,23
3,1,47,3.5,2005-04-02 23:32:07,5,23
4,1,50,3.5,2005-04-02 23:29:40,5,23


In [12]:
splitter = RandomSplit(test_fraction=0.2)
train_df, valid_df, test_df = splitter(df)

In [13]:
train_grouped = train_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id, 
                                                                 x.timestamp,
                                                                 x.weekday,
                                                                 x.hour), key=lambda x: x[1])]
).reset_index()
train_grouped.rename({0:'train_interactions'}, axis=1, inplace=True)

valid_grouped = valid_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
valid_grouped.rename({0:'valid_interactions'}, axis=1, inplace=True)

test_grouped = test_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
test_grouped.rename({0:'test_interactions'}, axis=1, inplace=True)


train_grouped.head()

Unnamed: 0,user_id,train_interactions
0,1,"[(924, 2004-09-10 03:06:38, 4, 3), (919, 2004-..."
1,2,"[(62, 2000-11-21 15:29:58, 1, 15), (469, 2000-..."
2,3,"[(589, 1999-12-11 07:25:08, 5, 7), (1188, 1999..."
3,4,"[(380, 1996-08-24 09:27:05, 5, 9), (165, 1996-..."
4,5,"[(17, 1996-12-25 15:15:35, 2, 15), (62, 1996-1..."


In [14]:
joined = train_grouped.merge(valid_grouped).merge(test_grouped)
joined.head()

Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions
0,1,"[(924, 2004-09-10 03:06:38, 4, 3), (919, 2004-...","[(2947, 2005-04-02 23:46:20, 5, 23), (8961, 20...","[(7164, 2005-04-02 23:52:03, 5, 23), (2021, 20..."
1,2,"[(62, 2000-11-21 15:29:58, 1, 15), (469, 2000-...","[(3926, 2000-11-21 15:34:49, 1, 15), (1973, 20...","[(1969, 2000-11-21 15:36:09, 1, 15), (1970, 20..."
2,3,"[(589, 1999-12-11 07:25:08, 5, 7), (1188, 1999...","[(3070, 1999-12-11 13:42:09, 5, 13), (2872, 19...","[(1356, 1999-12-14 12:51:10, 1, 12), (1603, 19..."
3,4,"[(380, 1996-08-24 09:27:05, 5, 9), (165, 1996-...","[(370, 1996-08-24 09:34:03, 5, 9), (594, 1996-...","[(596, 1996-08-24 09:37:04, 5, 9), (531, 1996-..."
4,5,"[(17, 1996-12-25 15:15:35, 2, 15), (62, 1996-1...","[(720, 1996-12-25 15:27:30, 2, 15), (350, 1996...","[(376, 1996-12-26 16:24:34, 3, 16), (1079, 199..."


In [15]:
our_items = set()
for idx, row in tqdm(joined.iterrows()):
    for el in row.train_interactions:
        our_items.add(el[0])
        
len(our_items)

35378it [00:02, 15430.55it/s]


15736

In [16]:
item2idx = {k: i for i, k in enumerate(our_items)}
idx2item = {i: k for k, i in item2idx.items()}

In [19]:
n_items = len(item2idx)

train = MyDataset(ds=joined,
                  num_items=n_items, 
                  item2idx=item2idx,
                  phase='train')

valid = MyDataset(ds=joined,
                  num_items=n_items,
                  item2idx=item2idx,
                  phase='valid')

print(len(train),len(valid))

35378 35378


In [22]:
loaders = {
        "train": DataLoader(train, batch_size=256, collate_fn=collate_fn_train),
        "valid": DataLoader(valid, batch_size=256, collate_fn=collate_fn_valid),
}

In [23]:

model = BERT4Rec(n_items=len(item2idx)+1, mask_ratio=0.2, hidden_size=32)

optimizer = optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
engine = dl.DeviceEngine('cpu')
hparams = {
    "anneal_cap": 0.2,
    "total_anneal_steps": 6000,
}


callbacks = [
    dl.NDCGCallback("logits", "targets", [20]),
    dl.MAPCallback("logits", "targets", [10]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.EarlyStoppingCallback(
        patience=7, loader_key="valid", metric_key="map10", minimize=False
    )
]


runner = RecSysRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    engine=engine,
    hparams=hparams,
    scheduler=lr_scheduler,
    loaders=loaders,
    num_epochs=100,
    verbose=True,
    timeit=True,
    callbacks=callbacks,
)


1/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (1/100) loss: 8.199406338800133 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


1/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (1/100) loss: 7.682598683312328 | lr: 0.001 | map10: 0.14191976870354386 | map10/std: 0.01575361833242672 | momentum: 0.9 | ndcg20: 0.08572437047985326 | ndcg20/std: 0.008670231579484424
* Epoch (1/100) 


2/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (2/100) loss: 7.629647422075419 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


2/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (2/100) loss: 7.5258317482708 | lr: 0.001 | map10: 0.12778205406014462 | map10/std: 0.01550662268094716 | momentum: 0.9 | ndcg20: 0.08467163797810172 | ndcg20/std: 0.00945349549002737
* Epoch (2/100) 


3/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (3/100) loss: 7.492806892706639 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


3/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (3/100) loss: 7.416311937928761 | lr: 0.001 | map10: 0.14081029171482448 | map10/std: 0.014939018859476922 | momentum: 0.9 | ndcg20: 0.09622599805570307 | ndcg20/std: 0.009073696049901124
* Epoch (3/100) 


4/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (4/100) loss: 7.396398810562903 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


4/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (4/100) loss: 7.3113045745279015 | lr: 0.001 | map10: 0.14700832447484938 | map10/std: 0.01578819750247036 | momentum: 0.9 | ndcg20: 0.09979267183903338 | ndcg20/std: 0.01030950953672326
* Epoch (4/100) 


5/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (5/100) loss: 7.286322449147114 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


5/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (5/100) loss: 7.154952674138753 | lr: 0.001 | map10: 0.1380325132555848 | map10/std: 0.015224603869856326 | momentum: 0.9 | ndcg20: 0.09854976731947508 | ndcg20/std: 0.010322254561977752
* Epoch (5/100) 


6/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (6/100) loss: 7.143988294163124 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


6/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (6/100) loss: 7.030342947413057 | lr: 0.001 | map10: 0.14092700181791015 | map10/std: 0.013648841192299477 | momentum: 0.9 | ndcg20: 0.10049508191740757 | ndcg20/std: 0.009824306507228328
* Epoch (6/100) 


7/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (7/100) loss: 7.072385969490319 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


7/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (7/100) loss: 6.978189536880774 | lr: 0.001 | map10: 0.14885271578038198 | map10/std: 0.014480676765747196 | momentum: 0.9 | ndcg20: 0.10505831232467659 | ndcg20/std: 0.010592648819460902
* Epoch (7/100) 


8/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (8/100) loss: 7.030814175177433 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


8/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (8/100) loss: 6.936297204726437 | lr: 0.001 | map10: 0.15003726027614894 | map10/std: 0.014352503905423196 | momentum: 0.9 | ndcg20: 0.10761511244039429 | ndcg20/std: 0.010698093552953253
* Epoch (8/100) 


9/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (9/100) loss: 6.993344450651997 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


9/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (9/100) loss: 6.895454564461963 | lr: 0.001 | map10: 0.14907500313855981 | map10/std: 0.014143507157198055 | momentum: 0.9 | ndcg20: 0.10717131264338871 | ndcg20/std: 0.010302677365106294
* Epoch (9/100) 


10/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (10/100) loss: 6.955351309533968 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


10/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (10/100) loss: 6.834163804598518 | lr: 0.001 | map10: 0.160605169878617 | map10/std: 0.014098922022016952 | momentum: 0.9 | ndcg20: 0.11470564899916524 | ndcg20/std: 0.010510206979895591
* Epoch (10/100) 


11/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (11/100) loss: 6.891892326579693 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


11/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (11/100) loss: 6.753638923724165 | lr: 0.001 | map10: 0.16453948007844207 | map10/std: 0.013599582453106739 | momentum: 0.9 | ndcg20: 0.11939710371486473 | ndcg20/std: 0.010366036565939138
* Epoch (11/100) 


12/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (12/100) loss: 6.831540041979648 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


12/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (12/100) loss: 6.698723413802856 | lr: 0.001 | map10: 0.1684666386391069 | map10/std: 0.013577235137896033 | momentum: 0.9 | ndcg20: 0.12465125226814094 | ndcg20/std: 0.011138254141166128
* Epoch (12/100) 


13/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (13/100) loss: 6.791094205251303 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


13/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (13/100) loss: 6.655030999757024 | lr: 0.001 | map10: 0.166687383231849 | map10/std: 0.014881246749916135 | momentum: 0.9 | ndcg20: 0.12402755084217854 | ndcg20/std: 0.011313803733527961
* Epoch (13/100) 


14/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (14/100) loss: 6.75542696583459 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


14/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (14/100) loss: 6.615751414148322 | lr: 0.001 | map10: 0.1681047940703714 | map10/std: 0.015617341003762686 | momentum: 0.9 | ndcg20: 0.12666080993679116 | ndcg20/std: 0.01200231678760594
* Epoch (14/100) 


15/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (15/100) loss: 6.7232299511182845 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


15/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (15/100) loss: 6.5792181730823 | lr: 0.001 | map10: 0.17287743899273536 | map10/std: 0.015737286388133837 | momentum: 0.9 | ndcg20: 0.12997857894900292 | ndcg20/std: 0.011999956417494817
* Epoch (15/100) 


16/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (16/100) loss: 6.692904755207785 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


16/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (16/100) loss: 6.547168281431606 | lr: 0.001 | map10: 0.1775350622519826 | map10/std: 0.018136461321336003 | momentum: 0.9 | ndcg20: 0.1336011232239509 | ndcg20/std: 0.012088150837909322
* Epoch (16/100) 


17/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (17/100) loss: 6.676468235809257 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


17/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (17/100) loss: 6.528582636042659 | lr: 0.001 | map10: 0.17941293700206098 | map10/std: 0.016107599464070366 | momentum: 0.9 | ndcg20: 0.1354509128645241 | ndcg20/std: 0.01204190750504897
* Epoch (17/100) 


18/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (18/100) loss: 6.657720440824611 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


18/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (18/100) loss: 6.503142389355389 | lr: 0.001 | map10: 0.18140379458866346 | map10/std: 0.016814676046959444 | momentum: 0.9 | ndcg20: 0.13699026387025853 | ndcg20/std: 0.012632142353130935
* Epoch (18/100) 


19/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (19/100) loss: 6.638687265418654 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


19/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (19/100) loss: 6.481275431289946 | lr: 0.001 | map10: 0.18372371870630763 | map10/std: 0.01608728854509195 | momentum: 0.9 | ndcg20: 0.13908868034484131 | ndcg20/std: 0.012203032424067763
* Epoch (19/100) 


20/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (20/100) loss: 6.615991835071487 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


20/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (20/100) loss: 6.452781643064988 | lr: 0.001 | map10: 0.1880985784674859 | map10/std: 0.016508094774266267 | momentum: 0.9 | ndcg20: 0.14212646381568858 | ndcg20/std: 0.012273467074750345
* Epoch (20/100) 


21/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (21/100) loss: 6.597632700056247 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


21/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (21/100) loss: 6.428646972635116 | lr: 0.001 | map10: 0.18940146437662542 | map10/std: 0.017023082441166456 | momentum: 0.9 | ndcg20: 0.14356374026416382 | ndcg20/std: 0.012368971053535087
* Epoch (21/100) 


22/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (22/100) loss: 6.5814957380685675 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


22/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (22/100) loss: 6.411451691964947 | lr: 0.001 | map10: 0.1908036282373341 | map10/std: 0.01772879936404066 | momentum: 0.9 | ndcg20: 0.1448422491633 | ndcg20/std: 0.012678857448266757
* Epoch (22/100) 


23/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (23/100) loss: 6.561501935618883 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


23/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (23/100) loss: 6.384669622279689 | lr: 0.001 | map10: 0.19586448312526827 | map10/std: 0.016927001528392552 | momentum: 0.9 | ndcg20: 0.1476551973492545 | ndcg20/std: 0.012264038193015999
* Epoch (23/100) 


24/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (24/100) loss: 6.5416288526003346 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


24/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (24/100) loss: 6.357510402625966 | lr: 0.001 | map10: 0.19877740344596728 | map10/std: 0.01614177022426309 | momentum: 0.9 | ndcg20: 0.15046020362148577 | ndcg20/std: 0.012611460918149554
* Epoch (24/100) 


25/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (25/100) loss: 6.523258812546533 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


25/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (25/100) loss: 6.337076922135312 | lr: 0.001 | map10: 0.20325261678095347 | map10/std: 0.018076633939596094 | momentum: 0.9 | ndcg20: 0.1528915453455178 | ndcg20/std: 0.01260779189660176
* Epoch (25/100) 


26/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (26/100) loss: 6.504245241215732 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


26/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (26/100) loss: 6.314009706662278 | lr: 0.001 | map10: 0.20594151436598537 | map10/std: 0.01676177783104779 | momentum: 0.9 | ndcg20: 0.15512762076390785 | ndcg20/std: 0.012872322359888723
* Epoch (26/100) 


27/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (27/100) loss: 6.478996366432598 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


27/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (27/100) loss: 6.28926095682714 | lr: 0.001 | map10: 0.2090031587994654 | map10/std: 0.01722292488241897 | momentum: 0.9 | ndcg20: 0.15627163965936314 | ndcg20/std: 0.013478921627246402
* Epoch (27/100) 


28/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (28/100) loss: 6.465106754495993 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


28/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (28/100) loss: 6.266905264369479 | lr: 0.001 | map10: 0.20895231625228314 | map10/std: 0.0179182243636931 | momentum: 0.9 | ndcg20: 0.15700302461774107 | ndcg20/std: 0.0135288863224916
* Epoch (28/100) 


29/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (29/100) loss: 6.447295569343789 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


29/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (29/100) loss: 6.247949485114596 | lr: 0.001 | map10: 0.20946994214951806 | map10/std: 0.018253285352907926 | momentum: 0.9 | ndcg20: 0.157781800765356 | ndcg20/std: 0.014001391647530621
* Epoch (29/100) 


30/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (30/100) loss: 6.432610278119066 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


30/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (30/100) loss: 6.232322297065791 | lr: 0.001 | map10: 0.2108993624116467 | map10/std: 0.01738789685083246 | momentum: 0.9 | ndcg20: 0.15765072482923054 | ndcg20/std: 0.013430261575696975
* Epoch (30/100) 


31/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (31/100) loss: 6.420808905628944 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


31/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (31/100) loss: 6.218308254218004 | lr: 0.001 | map10: 0.21340121298420187 | map10/std: 0.01785668013377716 | momentum: 0.9 | ndcg20: 0.1596032227000169 | ndcg20/std: 0.013832265134204944
* Epoch (31/100) 


32/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (32/100) loss: 6.41094511528 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


32/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (32/100) loss: 6.20633843089452 | lr: 0.001 | map10: 0.2164922488592298 | map10/std: 0.01663919736055531 | momentum: 0.9 | ndcg20: 0.1618615909311553 | ndcg20/std: 0.013702077920613419
* Epoch (32/100) 


33/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (33/100) loss: 6.4007411611795195 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


33/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (33/100) loss: 6.192901870085542 | lr: 0.001 | map10: 0.21520549638651823 | map10/std: 0.017208284076705032 | momentum: 0.9 | ndcg20: 0.16137451349179358 | ndcg20/std: 0.013344232532332313
* Epoch (33/100) 


34/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (34/100) loss: 6.388138827666211 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


34/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (34/100) loss: 6.18030996862844 | lr: 0.001 | map10: 0.21873685838719398 | map10/std: 0.017857228984457362 | momentum: 0.9 | ndcg20: 0.1630810588022052 | ndcg20/std: 0.013801447966402668
* Epoch (34/100) 


35/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (35/100) loss: 6.382894980751765 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


35/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (35/100) loss: 6.172174424460372 | lr: 0.001 | map10: 0.21647662953265928 | map10/std: 0.016956689596017184 | momentum: 0.9 | ndcg20: 0.1625237636955956 | ndcg20/std: 0.013460924252847768
* Epoch (35/100) 


36/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (36/100) loss: 6.376808235620901 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


36/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (36/100) loss: 6.164276891932979 | lr: 0.001 | map10: 0.21901204897682264 | map10/std: 0.017738865693879478 | momentum: 0.9 | ndcg20: 0.16440297030908602 | ndcg20/std: 0.013478062286875165
* Epoch (36/100) 


37/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (37/100) loss: 6.364740827407206 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


37/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (37/100) loss: 6.150240577805795 | lr: 0.001 | map10: 0.21945724920835893 | map10/std: 0.01807019153036206 | momentum: 0.9 | ndcg20: 0.16512690140910027 | ndcg20/std: 0.013980289559957792
* Epoch (37/100) 


38/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

train (38/100) loss: 6.360964708017167 | lr: 0.001 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


38/100 * Epoch (valid):   0%|          | 0/139 [00:00<?, ?it/s]

valid (38/100) loss: 6.14646630435702 | lr: 0.001 | map10: 0.22083035571207613 | map10/std: 0.0183051990267605 | momentum: 0.9 | ndcg20: 0.16575575793090702 | ndcg20/std: 0.013995228678057867
* Epoch (38/100) 


39/100 * Epoch (train):   0%|          | 0/139 [00:00<?, ?it/s]

Keyboard Interrupt


KeyboardInterrupt: ignored

In [43]:
test_dataset = MyDataset(ds=joined, num_items=n_items, phase='test',item2idx=item2idx)


inference_loader = DataLoader(test_dataset, 
                              batch_size=joined.shape[0]//100, 
                              collate_fn=collate_fn_train,)

preds = []

for prediction in tqdm(runner.predict_loader(loader=inference_loader, engine=dl.DeviceEngine('cpu'))):
    preds.extend(prediction.detach().cpu().numpy().tolist())
    
print(len(preds))
assert len(preds) == joined.shape[0]

joined['preds_bert4rec'] = preds
joined['recs_bert4rec_10'] = joined['preds_bert4rec'].apply(lambda x: np.argsort(-np.array(x))[:10])
joined['recs_bert4rec_10'] = joined['recs_bert4rec_10'].apply(lambda x: [idx2item[t-1] for t in x])
joined['recs_bert4rec_5'] = joined['preds_bert4rec'].apply(lambda x: np.argsort(-np.array(x))[:5])
joined['recs_bert4rec_5'] = joined['recs_bert4rec_5'].apply(lambda x: [idx2item[t-1] for t in x])
joined.drop(['preds_bert4rec'],axis=1, inplace=True)
joined.head()

101it [00:54,  1.86it/s]


35378


Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions,recs_bert4rec_10,recs_bert4rec_5
0,1,"[(924, 2004-09-10 03:06:38, 4, 3), (919, 2004-...","[(2947, 2005-04-02 23:46:20, 5, 23), (8961, 20...","[(7164, 2005-04-02 23:52:03, 5, 23), (2021, 20...","[1197, 1961, 1097, 1270, 1265, 1394, 1220, 608...","[1197, 1961, 1097, 1270, 1265]"
1,2,"[(62, 2000-11-21 15:29:58, 1, 15), (469, 2000-...","[(3926, 2000-11-21 15:34:49, 1, 15), (1973, 20...","[(1969, 2000-11-21 15:36:09, 1, 15), (1970, 20...","[2369, 2428, 2719, 2694, 2606, 3826, 2605, 299...","[2369, 2428, 2719, 2694, 2606]"
2,3,"[(589, 1999-12-11 07:25:08, 5, 7), (1188, 1999...","[(3070, 1999-12-11 13:42:09, 5, 13), (2872, 19...","[(1356, 1999-12-14 12:51:10, 1, 12), (1603, 19...","[1129, 2640, 1375, 3702, 2001, 2406, 2021, 264...","[1129, 2640, 1375, 3702, 2001]"
3,4,"[(380, 1996-08-24 09:27:05, 5, 9), (165, 1996-...","[(370, 1996-08-24 09:34:03, 5, 9), (594, 1996-...","[(596, 1996-08-24 09:37:04, 5, 9), (531, 1996-...","[95, 553, 442, 474, 2, 141, 62, 235, 11, 17]","[95, 553, 442, 474, 2]"
4,5,"[(17, 1996-12-25 15:15:35, 2, 15), (62, 1996-1...","[(720, 1996-12-25 15:27:30, 2, 15), (350, 1996...","[(376, 1996-12-26 16:24:34, 3, 16), (1079, 199...","[474, 348, 515, 508, 555, 357, 265, 272, 377, ...","[474, 348, 515, 508, 555]"


In [44]:
evaluate_recommender(joined, model_preds='recs_bert4rec_10')

{'ndcg': 0.16356511797620038, 'recall': 0.04399354663059405}

In [45]:
evaluate_recommender(joined, model_preds='recs_bert4rec_5')

{'ndcg': 0.09440482277692609, 'recall': 0.021120154268991623}