In [1]:
import torch.nn as nn
from torchmetrics.functional import retrieval_normalized_dcg
from sklearn.metrics import ndcg_score
import numpy as np
import torch
import requests
from dataclasses import dataclass
from collections import defaultdict
import random
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

import sys
sys.path.append('BERT4Rec-VAE-Pytorch')

from models.bert import BERTModel
from models.bert_modules.bert import BERT

import warnings
warnings.filterwarnings("ignore")

### Service functions

In [23]:
def equalize_seq(seq, max_length, end_code):
  
  if len(seq) >= max_length:
    return seq[:max_length]
  else:
    return (max_length - len(seq)) * [end_code] + seq


def recall_k(y_pred, y_true, k=10):
    
    top_k = torch.topk(y_pred, k).indices.tolist()
    result = [el[0] in el[1] for el in zip(y_true.tolist(), top_k)]
    
    return np.mean(result)


def ndcg_k(y_pred, y_true, k=10):
    
    top_k = torch.topk(y_pred, k).indices.tolist()
    results = []
    for true_label, preds in zip(y_true.tolist(), top_k):
        tl = [[el == true_label for el in preds]]
        preds = [[1 for _ in preds]]
        results.append(ndcg_score(tl, preds))
    
    return np.mean(results)


def ndcg_k_seq(y_pred_seq, y_true_seq, k=10):
    recs = []
    for y_pred, y_true in zip(y_pred_seq, y_true_seq):
        recs.append(ndcg_k(y_pred, y_true, k))
    return np.mean(recs)


def recall_k_seq(y_pred_seq, y_true_seq, k=10):
    recs = []
    for y_pred, y_true in zip(y_pred_seq, y_true_seq):
        recs.append(recall_k(y_pred, y_true, k))
    return np.mean(recs)




def recall_k_seq_last(y_pred_seq, y_true_seq, k=10):
    recs = []
    for y_pred, y_true in zip(y_pred_seq, y_true_seq):
        recs.append(recall_k(y_pred, y_true, k))
    return recs[-1]


def ndcg_k_seq_last(y_pred_seq, y_true_seq, k=10):
    recs = []
    for y_pred, y_true in zip(y_pred_seq, y_true_seq):
        recs.append(ndcg_k(y_pred, y_true, k))
    return recs[-1]

In [3]:
device = "cuda:0"

### Load data

In [4]:
data = (pd.read_csv('ml-1m/ratings.dat', 
                    sep='::',
                    header=None,
                    index_col=0,
                    names=['user_id', 'movie_id', 'rating', 'ts'], 
                    engine='python')
          .reset_index(drop=False)
       )

sequences = data.sort_values(by=['user_id', 'ts']).groupby('user_id')['movie_id'].agg(lambda x: list(x)).to_dict()
sequences = {u: s for u, s in sequences.items() if len(s) > 0}

mask_code = 0
max_length = 100
end_code = data['movie_id'].max() + 1

### Train / Valid / Test Split

In [5]:
train = defaultdict(list)
val = {}
test = {}

for user, sequence in sequences.items():

    train[user] = equalize_seq(sequence[:-1], max_length=max_length, end_code=end_code)

    if np.random.choice([0, 1]):
        val[user] = equalize_seq(sequence, max_length=max_length, end_code=end_code)
      
    else:
        test[user] = equalize_seq(sequence, max_length=max_length, end_code=end_code)

test_indexes = np.array(test.keys())
val_indexes = np.array(val.keys())

### Define Model

In [6]:
@dataclass
class BertConf:
    bert_max_len: int
    num_items: int
    bert_num_blocks: int
    bert_num_heads: int
    bert_hidden_units: int
    bert_dropout: float = 0.1
    model_init_seed: int = 42

conf = BertConf(bert_max_len=max_length, 
                num_items=end_code + 1, 
                bert_num_blocks=2, 
                bert_num_heads=2, 
                bert_hidden_units=100)

model = BERTModel(conf).to(device)

In [7]:
import sys
sys.path.append('BERT4Rec-VAE-Pytorch')

from models.bert import BERTModel
from models.bert_modules.bert import BERT

from torch import nn as nn

from models.bert_modules.embedding import BERTEmbedding
from models.bert_modules.transformer import TransformerBlock
from utils import fix_random_seed_as
from models.base import BaseModel

import torch.nn as nn


class SasRec(nn.Module):
    def __init__(self, args):
        super().__init__()

        fix_random_seed_as(args.model_init_seed)

        max_len = args.bert_max_len
        num_items = args.num_items
        n_layers = args.bert_num_blocks
        heads = args.bert_num_heads
        vocab_size = num_items + 2
        hidden = args.bert_hidden_units
        self.hidden = hidden
        dropout = args.bert_dropout

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)

        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x):
      
        # mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        mask = torch.ones(x.size()[1], x.size()[1])
        mask = torch.tril(mask).unsqueeze(0).unsqueeze(1).repeat(x.size()[0], 1, 1, 1).to(device)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x.to(device))

        # running over multiple transformer blocks
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        return x

    def init_weights(self):
        pass


class SasRecModel(BaseModel):
    def __init__(self, args):
        super().__init__(args)
        self.sasrec = SasRec(args)
        self.out = nn.Linear(self.sasrec.hidden, args.num_items + 1)

    @classmethod
    def code(cls):
        return 'sasrec'

    def forward(self, x):
        x = self.sasrec(x)
        return self.out(x)

In [8]:
@dataclass
class SasRecConf:
    bert_max_len: int
    num_items: int
    bert_num_blocks: int
    bert_num_heads: int
    bert_hidden_units: int
    bert_dropout: float = 0.1
    model_init_seed: int = 42

conf = SasRecConf(bert_max_len=max_length, 
                  num_items=end_code + 1, 
                  bert_num_blocks=2, 
                  bert_num_heads=2, 
                  bert_hidden_units=100)

model = SasRecModel(conf).to(device)

### Define Dataloader

In [9]:
import torch
from torch.utils.data import Dataset, DataLoader

class TrainShiftedDataset(Dataset):
    def __init__(self, data, max_length, end_code=-1):
        self.data = data
        self.max_length = max_length
        self.end_code = end_code

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        sample = self.data[index]

        # Select random end of sequence
        idx = np.random.randint(len(sample))

        # Equalize sequence
        seq = equalize_seq(sample[:idx], self.max_length, self.end_code)
        seq = torch.tensor(seq)

        # Create X and y
        X = seq[:-1]
        y = seq[1:]
#         print(X.shape)
#         print(y.shape)

        return X, y
    

class ValidShiftedDataset(Dataset):
    def __init__(self, data, max_length, end_code=-1):
        self.data = data
        self.max_length = max_length
        self.end_code = end_code

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        sample = self.data[index]
        seq = equalize_seq(sample, self.max_length, self.end_code)
        seq = torch.tensor(seq)

        X = seq[:-1]
        y = seq[1:]

        return X, y

In [12]:
from tqdm import tqdm_notebook

### Train Model

In [14]:
dataset = TrainShiftedDataset(list(train.values()), max_length + 1, end_code)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

val_dataset = ValidShiftedDataset(list(val.values()), max_length + 1, end_code)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)

# Using an Adam Optimizer with lr = 0.1
loss_function = torch.nn.CrossEntropyLoss(ignore_index=end_code, reduction='mean')
optimizer = torch.optim.Adam(model.parameters())

epochs = 5000
outputs = []
losses = []

counter = 0
recall_list = []
for epoch in range(epochs):
    for X, y in train_loader:

        optimizer.zero_grad()
        
        X = X.to(device)
        y = y.to(device)

        loss = loss_function(model.forward(X).view(-1, 3955), y.view(-1))

        losses.append(loss.item())
        loss.backward()
         
        optimizer.step()
        
        counter += 1
        if counter % 500 == 0:
            val_losses = []
            recall_batch = []
            ndcg_batch = []
            for X_val, y_val in tqdm_notebook(val_loader):
                
                X_val = X_val.to(device)
                y_val = y_val.to(device)
                
                y_pred = model.forward(X_val.long())
#                 print(y_pred.shape)
#                 print(y_val.shape)
                
                val_loss = loss_function(y_pred.view(-1, 3955), y_val.long().view(-1))
                val_losses.append(val_loss.item())
                
                recall_batch.append(recall_k_seq(y_pred, y_val, k=10))
#                 print("done")
#                 print(y_pred.shape)
#                 print(y_val.shape)
                ndcg_batch.append(ndcg_k_seq(y_pred, y_val))

#             print(counter, np.mean(recall_batch).round(3), np.round(loss.item(), 4), np.mean(val_losses).round(4))
            print("Epoch: ", counter)
            print("Recall@10: ", np.mean(recall_batch).round(3))
            print("NDCG: ", np.mean(ndcg_batch).round(3))
            print("Train Loss: ", np.round(loss.item(), 4))
            print("Val Loss: ", np.round(val_losses, 4))

  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  500
Recall@10:  0.169
NDCG:  0.077
Train Loss:  5.4897
Val Loss:  [5.6823 5.6948 5.7649 5.7182 5.7164 5.788  5.6832 5.6989 5.7924 5.724
 5.7121 5.6932]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  1000
Recall@10:  0.189
NDCG:  0.086
Train Loss:  5.3343
Val Loss:  [5.4845 5.4495 5.5367 5.587  5.5898 5.5735 5.4799 5.5981 5.5637 5.4716
 5.5179 5.6268]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  1500
Recall@10:  0.203
NDCG:  0.092
Train Loss:  5.1366
Val Loss:  [5.512  5.4752 5.4091 5.4089 5.3941 5.4028 5.4111 5.4688 5.4429 5.3817
 5.4344 5.3489]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  2000
Recall@10:  0.212
NDCG:  0.096
Train Loss:  5.0152
Val Loss:  [5.2607 5.3569 5.3261 5.4046 5.4737 5.3316 5.3685 5.3659 5.4077 5.3047
 5.2601 5.3829]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  2500
Recall@10:  0.22
NDCG:  0.1
Train Loss:  5.0842
Val Loss:  [5.3008 5.3149 5.2988 5.2708 5.2848 5.3755 5.2716 5.2223 5.3279 5.1836
 5.3938 5.2579]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  3000
Recall@10:  0.226
NDCG:  0.103
Train Loss:  4.9496
Val Loss:  [5.2408 5.2459 5.2908 5.2852 5.1338 5.2913 5.2179 5.3101 5.202  5.3004
 5.1814 5.2224]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  3500
Recall@10:  0.232
NDCG:  0.105
Train Loss:  4.888
Val Loss:  [5.3201 5.1869 5.2258 5.2111 5.2262 5.1649 5.1241 5.1376 5.204  5.1697
 5.2864 5.1586]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  4000
Recall@10:  0.237
NDCG:  0.108
Train Loss:  4.7493
Val Loss:  [5.1213 5.2627 5.1333 5.1804 5.0734 5.1795 5.1596 5.1817 5.1267 5.1925
 5.2073 5.189 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  4500
Recall@10:  0.241
NDCG:  0.109
Train Loss:  4.8426
Val Loss:  [5.1544 5.1796 5.1183 5.179  5.1604 5.1799 5.1906 5.1814 5.1438 5.0755
 5.1286 4.9739]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  5000
Recall@10:  0.243
NDCG:  0.111
Train Loss:  4.6569
Val Loss:  [5.1409 5.1137 5.158  5.0675 5.0806 5.0617 5.1467 5.1542 5.0832 5.1057
 5.1465 5.1131]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  5500
Recall@10:  0.248
NDCG:  0.112
Train Loss:  4.6456
Val Loss:  [5.0725 5.0854 5.0128 5.0967 5.1632 5.086  5.0786 5.0895 5.1276 5.1572
 5.0971 5.0124]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  6000
Recall@10:  0.25
NDCG:  0.114
Train Loss:  4.7972
Val Loss:  [5.1041 5.0858 5.0036 5.0781 5.0364 5.0265 5.005  5.1779 5.0946 4.9863
 5.1776 5.0589]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  6500
Recall@10:  0.254
NDCG:  0.115
Train Loss:  4.6798
Val Loss:  [5.0778 4.9811 5.0987 5.0022 5.0354 5.0832 5.0325 5.0613 5.1196 5.0767
 5.0219 4.9223]
Epoch:  7000
Recall@10:  0.256
NDCG:  0.116
Train Loss:  4.6516
Val Loss:  [4.9877 5.0043 5.0512 5.0795 5.0276 5.0339 5.025  5.0585 5.0433 5.0033
 5.0449 4.9358]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  7500
Recall@10:  0.258
NDCG:  0.117
Train Loss:  4.6424
Val Loss:  [5.0076 5.0696 5.0139 4.8889 5.0829 5.1198 5.0863 4.9648 4.9296 5.0021
 4.9413 5.1374]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  8000
Recall@10:  0.26
NDCG:  0.118
Train Loss:  4.5278
Val Loss:  [4.9772 5.0092 4.9122 4.9753 5.0041 5.0355 4.9444 5.0606 5.0003 5.034
 4.9797 5.0428]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  8500
Recall@10:  0.263
NDCG:  0.119
Train Loss:  4.5764
Val Loss:  [5.038  4.9812 4.9731 4.9635 5.0284 5.     5.0036 4.9697 4.9238 4.9574
 5.0102 4.8958]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  9000
Recall@10:  0.265
NDCG:  0.12
Train Loss:  4.5412
Val Loss:  [4.9232 5.0342 4.9433 4.9493 4.9867 4.9932 4.8834 5.0046 4.9755 4.9659
 4.9821 5.016 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  9500
Recall@10:  0.267
NDCG:  0.121
Train Loss:  4.526
Val Loss:  [4.9624 4.9493 5.0116 5.021  4.9804 4.9484 4.9151 4.9992 4.9517 4.9282
 4.9157 4.9031]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  10000
Recall@10:  0.267
NDCG:  0.121
Train Loss:  4.6133
Val Loss:  [4.9702 4.9902 4.9069 4.9655 4.8899 4.9152 4.9775 4.9818 4.9138 4.9794
 4.946  5.001 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  10500
Recall@10:  0.269
NDCG:  0.122
Train Loss:  4.5834
Val Loss:  [4.971  4.9943 4.9499 4.8538 4.852  4.956  4.9349 4.9964 4.9312 4.9485
 4.9397 4.9528]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  11000
Recall@10:  0.271
NDCG:  0.123
Train Loss:  4.5052
Val Loss:  [4.9278 4.8876 5.0306 4.9316 4.9275 4.9166 4.9151 4.9046 4.9164 5.0188
 4.9069 4.8825]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  11500
Recall@10:  0.272
NDCG:  0.124
Train Loss:  4.4948
Val Loss:  [4.9496 4.8921 4.9024 4.9637 4.9109 4.8556 4.8885 5.0195 4.89   4.9092
 5.013  4.8539]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  12000
Recall@10:  0.274
NDCG:  0.124
Train Loss:  4.4464
Val Loss:  [5.0429 4.7943 4.9051 4.9489 4.9199 4.8395 4.9377 4.9103 4.9325 4.9352
 4.8928 4.8845]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  12500
Recall@10:  0.274
NDCG:  0.124
Train Loss:  4.3509
Val Loss:  [4.9344 4.7507 4.8631 4.8895 4.905  4.9606 4.9065 5.0008 4.8944 4.8667
 5.0386 4.9192]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  13000
Recall@10:  0.275
NDCG:  0.125
Train Loss:  4.4072
Val Loss:  [4.9413 4.9111 4.9335 4.8235 4.9233 4.9037 4.8661 4.9068 4.8925 4.9584
 4.8307 4.8892]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  13500
Recall@10:  0.277
NDCG:  0.126
Train Loss:  4.4485
Val Loss:  [4.9063 4.8802 4.8807 4.897  4.8784 4.9296 4.9553 4.8858 4.8444 4.9335
 4.8807 4.8559]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  14000
Recall@10:  0.277
NDCG:  0.126
Train Loss:  4.388
Val Loss:  [4.9066 4.8936 4.831  4.8931 4.8709 4.8714 4.8593 4.9197 4.9235 4.8949
 4.8246 4.9474]
Epoch:  14500
Recall@10:  0.279
NDCG:  0.127
Train Loss:  4.3016
Val Loss:  [4.8866 4.9787 4.8553 4.8679 4.9499 4.8383 4.8104 4.922  4.8352 4.9496
 4.8792 4.74  ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  15000
Recall@10:  0.278
NDCG:  0.126
Train Loss:  4.321
Val Loss:  [4.8904 4.886  4.7859 4.9191 4.8632 4.8724 4.8426 4.8602 4.8368 4.9444
 4.9217 4.9416]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  15500
Recall@10:  0.28
NDCG:  0.127
Train Loss:  4.2738
Val Loss:  [4.9345 4.8818 4.8846 4.8489 4.9127 4.8553 4.8769 4.7927 4.8664 4.9628
 4.7952 4.858 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  16000
Recall@10:  0.281
NDCG:  0.128
Train Loss:  4.2535
Val Loss:  [4.86   4.8725 4.8781 4.8236 4.8594 4.8316 4.9235 4.8613 4.8485 4.8341
 4.8398 4.9236]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  16500
Recall@10:  0.281
NDCG:  0.128
Train Loss:  4.2517
Val Loss:  [4.8318 4.8894 4.8792 4.8015 4.7574 4.8527 4.8609 4.8449 4.9351 4.8414
 4.887  4.94  ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  17000
Recall@10:  0.283
NDCG:  0.129
Train Loss:  4.2327
Val Loss:  [4.7876 4.7998 4.8863 4.8655 4.8721 4.9491 4.8453 4.8466 4.8574 4.7731
 4.8974 4.8543]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  17500
Recall@10:  0.283
NDCG:  0.129
Train Loss:  4.2858
Val Loss:  [4.8791 4.8768 4.8385 4.8589 4.9017 4.8053 4.8118 4.9208 4.8185 4.7827
 4.9095 4.7514]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  18000
Recall@10:  0.284
NDCG:  0.129
Train Loss:  4.2859
Val Loss:  [4.7305 4.7717 4.8436 4.9064 4.9035 4.8401 4.9084 4.8273 4.8481 4.6936
 4.8732 4.9073]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  18500
Recall@10:  0.283
NDCG:  0.129
Train Loss:  4.4265
Val Loss:  [4.8031 4.8571 4.7847 4.7683 4.9014 4.8198 4.8984 4.873  4.7875 4.8515
 4.8655 4.9314]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  19000
Recall@10:  0.284
NDCG:  0.129
Train Loss:  4.3436
Val Loss:  [4.8325 4.8203 4.8389 4.8846 4.8205 4.9173 4.8483 4.7783 4.8763 4.7767
 4.8322 4.8111]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  19500
Recall@10:  0.286
NDCG:  0.13
Train Loss:  4.31
Val Loss:  [4.8719 4.7873 4.8933 4.8181 4.7779 4.8774 4.7624 4.8208 4.869  4.8324
 4.7788 4.9339]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  20000
Recall@10:  0.285
NDCG:  0.13
Train Loss:  4.3256
Val Loss:  [4.7995 4.8332 4.6921 4.837  4.9101 4.859  4.7592 4.8374 4.8336 4.8366
 4.8741 4.8573]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  20500
Recall@10:  0.287
NDCG:  0.13
Train Loss:  4.391
Val Loss:  [4.8169 4.8061 4.7751 4.9537 4.7502 4.9012 4.776  4.7824 4.959  4.7243
 4.8159 4.8673]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  21000
Recall@10:  0.288
NDCG:  0.131
Train Loss:  4.39
Val Loss:  [4.8137 4.8247 4.8491 4.8832 4.7617 4.7636 4.7956 4.8438 4.877  4.7724
 4.7629 4.8371]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  21500
Recall@10:  0.289
NDCG:  0.132
Train Loss:  4.3382
Val Loss:  [4.7664 4.8512 4.8015 4.8794 4.8786 4.9091 4.7186 4.7386 4.7717 4.8543
 4.8177 4.8354]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  22000
Recall@10:  0.289
NDCG:  0.131
Train Loss:  4.3176
Val Loss:  [4.8182 4.8203 4.8165 4.7681 4.8175 4.8276 4.827  4.8014 4.8368 4.7835
 4.8171 4.842 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  22500
Recall@10:  0.289
NDCG:  0.131
Train Loss:  4.2458
Val Loss:  [4.8582 4.8503 4.7592 4.7312 4.772  4.8517 4.8274 4.8528 4.8469 4.8224
 4.84   4.7689]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  23000
Recall@10:  0.288
NDCG:  0.131
Train Loss:  4.2502
Val Loss:  [4.8112 4.7571 4.8057 4.8599 4.8361 4.7924 4.7757 4.8938 4.8297 4.7853
 4.7801 4.8537]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  23500
Recall@10:  0.289
NDCG:  0.131
Train Loss:  4.2272
Val Loss:  [4.7994 4.7416 4.7556 4.8153 4.8759 4.9089 4.7854 4.7223 4.8683 4.7567
 4.8491 4.8918]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  24000
Recall@10:  0.291
NDCG:  0.132
Train Loss:  4.2835
Val Loss:  [4.9227 4.747  4.7492 4.6653 4.7395 4.7988 4.8489 4.7136 4.8365 4.8257
 4.9967 4.7702]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  24500
Recall@10:  0.289
NDCG:  0.131
Train Loss:  4.241
Val Loss:  [4.8016 4.7931 4.7754 4.7785 4.8201 4.8167 4.8338 4.7544 4.7752 4.8747
 4.7662 4.8869]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  25000
Recall@10:  0.291
NDCG:  0.132
Train Loss:  4.1559
Val Loss:  [4.8694 4.8016 4.758  4.8332 4.7804 4.8051 4.8023 4.7832 4.8257 4.792
 4.8159 4.6728]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  25500
Recall@10:  0.292
NDCG:  0.133
Train Loss:  4.3413
Val Loss:  [4.7639 4.7597 4.8691 4.7813 4.8042 4.8205 4.8227 4.7392 4.8272 4.7623
 4.7754 4.8056]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  26000
Recall@10:  0.292
NDCG:  0.132
Train Loss:  4.2318
Val Loss:  [4.8015 4.8136 4.7655 4.8403 4.8701 4.8061 4.7468 4.7868 4.7421 4.7805
 4.7942 4.7507]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  26500
Recall@10:  0.293
NDCG:  0.133
Train Loss:  4.1871
Val Loss:  [4.824  4.7214 4.7366 4.8336 4.7877 4.7684 4.8996 4.7968 4.8062 4.7616
 4.7776 4.7226]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  27000
Recall@10:  0.294
NDCG:  0.133
Train Loss:  4.2043
Val Loss:  [4.8478 4.7381 4.7605 4.8381 4.8239 4.8089 4.7524 4.8256 4.8373 4.7226
 4.7509 4.726 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  28000
Recall@10:  0.292
NDCG:  0.132
Train Loss:  4.3176
Val Loss:  [4.7977 4.8035 4.8059 4.736  4.829  4.8175 4.8274 4.8306 4.6818 4.7172
 4.7831 4.8028]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  28500
Recall@10:  0.294
NDCG:  0.134
Train Loss:  4.2071
Val Loss:  [4.7495 4.8011 4.8138 4.7288 4.7817 4.8432 4.8075 4.8188 4.6879 4.7406
 4.778  4.8178]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  29000
Recall@10:  0.293
NDCG:  0.133
Train Loss:  4.1646
Val Loss:  [4.6886 4.8205 4.8214 4.7003 4.7187 4.8142 4.7696 4.7918 4.8496 4.9192
 4.751  4.7543]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  29500
Recall@10:  0.294
NDCG:  0.134
Train Loss:  4.2417
Val Loss:  [4.8087 4.7756 4.8363 4.7756 4.846  4.7709 4.8111 4.7271 4.8475 4.7259
 4.7148 4.7305]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  30000
Recall@10:  0.294
NDCG:  0.134
Train Loss:  4.3492
Val Loss:  [4.7703 4.7838 4.7399 4.8377 4.7792 4.7678 4.7632 4.6711 4.8212 4.8097
 4.7701 4.8002]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  30500
Recall@10:  0.295
NDCG:  0.134
Train Loss:  4.2425
Val Loss:  [4.8047 4.7567 4.733  4.8161 4.7268 4.8014 4.8703 4.8113 4.7217 4.7412
 4.767  4.7584]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  31000
Recall@10:  0.294
NDCG:  0.134
Train Loss:  4.1187
Val Loss:  [4.7645 4.7981 4.7789 4.8573 4.7331 4.787  4.6494 4.7598 4.8485 4.7421
 4.8331 4.8055]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  31500
Recall@10:  0.295
NDCG:  0.134
Train Loss:  4.2192
Val Loss:  [4.8066 4.7934 4.7674 4.7246 4.6999 4.7911 4.8395 4.7869 4.7876 4.7975
 4.7459 4.752 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  32000
Recall@10:  0.294
NDCG:  0.133
Train Loss:  4.1329
Val Loss:  [4.8337 4.775  4.7487 4.8004 4.8459 4.793  4.7246 4.7292 4.7159 4.7887
 4.7705 4.821 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  32500
Recall@10:  0.296
NDCG:  0.134
Train Loss:  4.0865
Val Loss:  [4.7758 4.7667 4.754  4.7251 4.8252 4.7913 4.7275 4.792  4.793  4.7064
 4.7458 4.8539]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  33000
Recall@10:  0.296
NDCG:  0.135
Train Loss:  4.1638
Val Loss:  [4.7629 4.7915 4.6922 4.7424 4.8349 4.7415 4.8696 4.7496 4.7649 4.7573
 4.7306 4.7787]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  33500
Recall@10:  0.296
NDCG:  0.135
Train Loss:  4.231
Val Loss:  [4.7523 4.7257 4.7701 4.8124 4.8099 4.8193 4.7353 4.707  4.7349 4.779
 4.872  4.6226]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  34000
Recall@10:  0.296
NDCG:  0.134
Train Loss:  4.2214
Val Loss:  [4.7725 4.7599 4.7857 4.768  4.7612 4.7767 4.7436 4.7125 4.7357 4.7543
 4.8355 4.8256]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  34500
Recall@10:  0.296
NDCG:  0.135
Train Loss:  4.1975
Val Loss:  [4.7599 4.7897 4.7697 4.8063 4.7895 4.6977 4.818  4.7059 4.7561 4.7405
 4.7603 4.8491]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  35000
Recall@10:  0.297
NDCG:  0.135
Train Loss:  4.2529
Val Loss:  [4.7739 4.7707 4.7671 4.8758 4.7112 4.7595 4.7197 4.8408 4.8174 4.7617
 4.6552 4.7002]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  35500
Recall@10:  0.297
NDCG:  0.135
Train Loss:  4.2152
Val Loss:  [4.7152 4.8053 4.7019 4.7475 4.8214 4.8489 4.7283 4.7483 4.7562 4.7962
 4.7612 4.7559]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  36000
Recall@10:  0.298
NDCG:  0.135
Train Loss:  4.2155
Val Loss:  [4.7451 4.7575 4.7495 4.7267 4.8217 4.6819 4.7995 4.8141 4.8047 4.6822
 4.7783 4.7708]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  36500
Recall@10:  0.297
NDCG:  0.135
Train Loss:  3.9946
Val Loss:  [4.6658 4.6999 4.7561 4.8205 4.7831 4.7515 4.7236 4.7715 4.7795 4.7875
 4.7832 4.7834]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  37000
Recall@10:  0.297
NDCG:  0.135
Train Loss:  4.1569
Val Loss:  [4.8982 4.8035 4.6577 4.766  4.7707 4.7709 4.7311 4.6811 4.6947 4.7227
 4.8103 4.7636]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  37500
Recall@10:  0.298
NDCG:  0.135
Train Loss:  4.1782
Val Loss:  [4.6866 4.658  4.8329 4.7087 4.7138 4.805  4.7094 4.8052 4.7568 4.7658
 4.8686 4.7741]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  38000
Recall@10:  0.299
NDCG:  0.136
Train Loss:  4.263
Val Loss:  [4.6592 4.7188 4.7257 4.7991 4.8124 4.762  4.7693 4.8093 4.725  4.7633
 4.7847 4.7319]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  38500
Recall@10:  0.298
NDCG:  0.135
Train Loss:  4.1501
Val Loss:  [4.7795 4.7739 4.7767 4.7912 4.739  4.7029 4.6878 4.7161 4.7927 4.8016
 4.7498 4.7969]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  39000
Recall@10:  0.297
NDCG:  0.135
Train Loss:  4.0127
Val Loss:  [4.8203 4.7276 4.7798 4.8187 4.7405 4.7362 4.828  4.765  4.7417 4.7168
 4.7201 4.7152]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  39500
Recall@10:  0.297
NDCG:  0.135
Train Loss:  3.9981
Val Loss:  [4.702  4.7369 4.8355 4.7684 4.7379 4.772  4.7338 4.762  4.7544 4.6965
 4.7802 4.7999]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  40000
Recall@10:  0.298
NDCG:  0.136
Train Loss:  4.0318
Val Loss:  [4.8034 4.822  4.6388 4.6912 4.7454 4.7327 4.7578 4.7687 4.7454 4.6907
 4.8204 4.8603]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  40500
Recall@10:  0.3
NDCG:  0.136
Train Loss:  4.1362
Val Loss:  [4.839  4.8178 4.8135 4.757  4.7606 4.7014 4.7003 4.7281 4.6821 4.7817
 4.7046 4.7008]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  41000
Recall@10:  0.3
NDCG:  0.136
Train Loss:  4.2436
Val Loss:  [4.7681 4.7893 4.6587 4.787  4.7962 4.7499 4.7531 4.7678 4.7139 4.8201
 4.7627 4.6765]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  41500
Recall@10:  0.299
NDCG:  0.136
Train Loss:  4.2263
Val Loss:  [4.7589 4.8312 4.7116 4.649  4.7759 4.7861 4.7512 4.8404 4.757  4.6946
 4.7003 4.7489]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  42000
Recall@10:  0.3
NDCG:  0.136
Train Loss:  4.0612
Val Loss:  [4.7181 4.8222 4.7294 4.8265 4.7327 4.7157 4.7918 4.7207 4.7483 4.7261
 4.6677 4.7644]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  42500
Recall@10:  0.299
NDCG:  0.136
Train Loss:  4.1366
Val Loss:  [4.6472 4.7593 4.8068 4.7515 4.8153 4.7773 4.726  4.7503 4.7536 4.7476
 4.6707 4.7464]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  43000
Recall@10:  0.299
NDCG:  0.136
Train Loss:  4.0618
Val Loss:  [4.7591 4.7771 4.7524 4.6684 4.7318 4.8161 4.6912 4.8318 4.6764 4.6636
 4.8342 4.8365]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  43500
Recall@10:  0.299
NDCG:  0.136
Train Loss:  4.181
Val Loss:  [4.7225 4.7917 4.7642 4.7799 4.761  4.7223 4.7335 4.7439 4.7403 4.7365
 4.7029 4.8293]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  44000
Recall@10:  0.3
NDCG:  0.137
Train Loss:  4.1311
Val Loss:  [4.7365 4.7241 4.7632 4.7275 4.7534 4.7835 4.742  4.7409 4.7499 4.7623
 4.7839 4.6113]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  44500
Recall@10:  0.3
NDCG:  0.136
Train Loss:  4.0103
Val Loss:  [4.7982 4.735  4.8439 4.6722 4.7679 4.7142 4.7789 4.7351 4.708  4.7598
 4.6548 4.722 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  45000
Recall@10:  0.301
NDCG:  0.137
Train Loss:  4.1237
Val Loss:  [4.7308 4.6793 4.7385 4.7063 4.7216 4.7855 4.7581 4.7334 4.7752 4.7178
 4.7679 4.7496]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  45500
Recall@10:  0.302
NDCG:  0.137
Train Loss:  3.9807
Val Loss:  [4.7631 4.7559 4.7479 4.7128 4.7584 4.6484 4.7624 4.7443 4.6797 4.7906
 4.7725 4.6992]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  46000
Recall@10:  0.299
NDCG:  0.136
Train Loss:  4.1599
Val Loss:  [4.7114 4.6202 4.6852 4.8039 4.8287 4.6928 4.7866 4.7176 4.691  4.8204
 4.6999 4.8638]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  46500
Recall@10:  0.301
NDCG:  0.137
Train Loss:  4.1804
Val Loss:  [4.7148 4.7258 4.7362 4.7158 4.7201 4.7052 4.8031 4.7109 4.7669 4.7385
 4.7896 4.7548]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  47000
Recall@10:  0.301
NDCG:  0.137
Train Loss:  4.0551
Val Loss:  [4.7458 4.8187 4.7594 4.7989 4.7365 4.6848 4.7729 4.674  4.7136 4.712
 4.6626 4.8478]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  47500
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.1018
Val Loss:  [4.8003 4.7351 4.6907 4.769  4.662  4.6877 4.7823 4.7289 4.7413 4.7398
 4.6734 4.8381]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  48000
Recall@10:  0.301
NDCG:  0.137
Train Loss:  4.0744
Val Loss:  [4.6746 4.7903 4.7525 4.796  4.7011 4.6462 4.7934 4.7019 4.7304 4.7739
 4.6753 4.8611]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  48500
Recall@10:  0.301
NDCG:  0.137
Train Loss:  3.9531
Val Loss:  [4.7028 4.8152 4.8228 4.7305 4.7626 4.6886 4.7264 4.7674 4.8264 4.7196
 4.6426 4.7222]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  49000
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.0731
Val Loss:  [4.8001 4.7366 4.6822 4.7911 4.7889 4.7607 4.6989 4.7334 4.6433 4.6973
 4.6987 4.7267]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  49500
Recall@10:  0.3
NDCG:  0.136
Train Loss:  4.2283
Val Loss:  [4.8167 4.7808 4.7066 4.7514 4.646  4.7711 4.6929 4.7177 4.6683 4.7243
 4.7559 4.8442]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  50000
Recall@10:  0.301
NDCG:  0.137
Train Loss:  4.0374
Val Loss:  [4.7236 4.6568 4.8516 4.7363 4.7006 4.7615 4.6474 4.7951 4.7864 4.7541
 4.7012 4.7768]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  50500
Recall@10:  0.304
NDCG:  0.138
Train Loss:  3.9753
Val Loss:  [4.725  4.7451 4.6734 4.8133 4.7702 4.7089 4.6923 4.7682 4.6948 4.7219
 4.6917 4.7155]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  51000
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.1373
Val Loss:  [4.7327 4.7158 4.7995 4.6466 4.714  4.8477 4.7485 4.6825 4.7451 4.776
 4.6757 4.6997]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  51500
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.1865
Val Loss:  [4.7129 4.7403 4.7725 4.7642 4.7507 4.6497 4.7147 4.7829 4.7842 4.6442
 4.655  4.8414]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  52000
Recall@10:  0.301
NDCG:  0.137
Train Loss:  4.0964
Val Loss:  [4.8251 4.7484 4.7155 4.6765 4.7543 4.7898 4.679  4.7621 4.6767 4.6902
 4.5987 4.854 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  52500
Recall@10:  0.303
NDCG:  0.138
Train Loss:  4.0743
Val Loss:  [4.8085 4.7154 4.7607 4.7844 4.7841 4.7481 4.6862 4.6231 4.7336 4.6892
 4.7121 4.6862]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  53000
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.1244
Val Loss:  [4.704  4.8051 4.7079 4.7351 4.589  4.6222 4.7946 4.8287 4.6881 4.751
 4.7345 4.7976]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  53500
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.0371
Val Loss:  [4.7302 4.7005 4.7406 4.7716 4.7726 4.8006 4.7268 4.6836 4.6934 4.756
 4.7402 4.6882]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  54000
Recall@10:  0.303
NDCG:  0.138
Train Loss:  4.0585
Val Loss:  [4.7248 4.743  4.6481 4.8115 4.804  4.7158 4.7657 4.6798 4.7816 4.6747
 4.7192 4.6876]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  54500
Recall@10:  0.302
NDCG:  0.137
Train Loss:  4.0916
Val Loss:  [4.6358 4.7042 4.7499 4.7433 4.7158 4.7459 4.6545 4.7668 4.7359 4.7813
 4.7228 4.7886]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  55000
Recall@10:  0.303
NDCG:  0.137
Train Loss:  4.1922
Val Loss:  [4.7544 4.7815 4.7178 4.8298 4.726  4.778  4.767  4.7464 4.5914 4.716
 4.6751 4.6738]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  55500
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.0746
Val Loss:  [4.7271 4.765  4.7581 4.6936 4.787  4.7466 4.7173 4.689  4.6788 4.615
 4.7489 4.7808]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  56000
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.08
Val Loss:  [4.7165 4.6959 4.673  4.6729 4.7189 4.781  4.7585 4.7067 4.7158 4.8154
 4.7288 4.6197]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  56500
Recall@10:  0.303
NDCG:  0.138
Train Loss:  4.0161
Val Loss:  [4.7542 4.7151 4.8019 4.7575 4.7446 4.6328 4.8214 4.5881 4.7933 4.6789
 4.728  4.662 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  57000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  3.9925
Val Loss:  [4.6826 4.6956 4.777  4.7109 4.7277 4.7112 4.8193 4.713  4.8086 4.6947
 4.641  4.6268]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  57500
Recall@10:  0.305
NDCG:  0.138
Train Loss:  4.0323
Val Loss:  [4.7252 4.7483 4.7524 4.7478 4.6734 4.7573 4.7009 4.7433 4.6799 4.762
 4.6654 4.699 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  58000
Recall@10:  0.303
NDCG:  0.138
Train Loss:  4.051
Val Loss:  [4.8213 4.6327 4.7273 4.7252 4.6915 4.7588 4.781  4.7646 4.6538 4.7324
 4.6538 4.7386]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  58500
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.0537
Val Loss:  [4.77   4.6766 4.6454 4.7273 4.7615 4.6862 4.6886 4.7357 4.7047 4.7399
 4.7615 4.7445]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  59000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0773
Val Loss:  [4.7673 4.6903 4.7269 4.689  4.7039 4.7519 4.6772 4.6847 4.7152 4.6924
 4.8066 4.711 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  59500
Recall@10:  0.305
NDCG:  0.138
Train Loss:  4.1158
Val Loss:  [4.6986 4.744  4.6518 4.7641 4.6653 4.7567 4.6981 4.6494 4.7978 4.7186
 4.6782 4.7845]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  60000
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.1519
Val Loss:  [4.761  4.682  4.7434 4.7849 4.7    4.7658 4.6298 4.8694 4.7311 4.6939
 4.6658 4.7283]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  60500
Recall@10:  0.303
NDCG:  0.138
Train Loss:  4.075
Val Loss:  [4.7866 4.7669 4.745  4.7078 4.6328 4.6743 4.7542 4.7936 4.656  4.6759
 4.728  4.8305]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  61000
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.2302
Val Loss:  [4.6812 4.7039 4.6765 4.6947 4.7245 4.6224 4.768  4.7589 4.72   4.762
 4.6821 4.8289]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  62000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.05
Val Loss:  [4.7713 4.7852 4.7141 4.7522 4.6846 4.6681 4.748  4.7866 4.6951 4.7159
 4.7307 4.5752]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  62500
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.131
Val Loss:  [4.6281 4.744  4.7066 4.7724 4.6841 4.6659 4.7438 4.7111 4.7127 4.7206
 4.7388 4.7785]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  63000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.2529
Val Loss:  [4.7605 4.7486 4.7019 4.6693 4.7653 4.7582 4.7067 4.6803 4.6531 4.7192
 4.7046 4.7159]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  63500
Recall@10:  0.304
NDCG:  0.138
Train Loss:  4.0946
Val Loss:  [4.8252 4.7319 4.7328 4.7061 4.6831 4.7744 4.734  4.7119 4.6764 4.7091
 4.6714 4.6906]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  64000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0852
Val Loss:  [4.7427 4.7648 4.7001 4.6615 4.7592 4.6917 4.6652 4.7282 4.7365 4.7586
 4.7708 4.6581]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  64500
Recall@10:  0.305
NDCG:  0.138
Train Loss:  3.996
Val Loss:  [4.6828 4.7527 4.6817 4.7186 4.6583 4.6708 4.7342 4.7417 4.7617 4.7223
 4.7388 4.7324]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  65000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0086
Val Loss:  [4.7421 4.8295 4.7408 4.7673 4.6438 4.6469 4.7006 4.7331 4.6574 4.7397
 4.695  4.7154]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  65500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0932
Val Loss:  [4.6058 4.7557 4.5833 4.7196 4.7181 4.8024 4.7003 4.7246 4.7763 4.6374
 4.7925 4.7341]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  66000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0799
Val Loss:  [4.7706 4.7145 4.7351 4.7367 4.6881 4.7745 4.7391 4.6535 4.6425 4.6968
 4.6347 4.8109]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  66500
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.1252
Val Loss:  [4.6918 4.7135 4.727  4.7904 4.6955 4.7223 4.7079 4.6882 4.7386 4.7691
 4.5687 4.752 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  67000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0798
Val Loss:  [4.6913 4.7936 4.7628 4.6828 4.658  4.693  4.7452 4.7006 4.6999 4.682
 4.7877 4.7237]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  67500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.108
Val Loss:  [4.763  4.6791 4.6807 4.745  4.636  4.7366 4.6787 4.6427 4.8041 4.7032
 4.7208 4.8115]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  68000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0483
Val Loss:  [4.7126 4.6843 4.7075 4.5835 4.7709 4.6366 4.7179 4.7464 4.7528 4.76
 4.7936 4.6395]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  68500
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0455
Val Loss:  [4.8069 4.7569 4.609  4.7566 4.7215 4.7283 4.6481 4.7773 4.6499 4.7847
 4.7143 4.6483]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  69000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0557
Val Loss:  [4.7239 4.7231 4.7074 4.6585 4.7054 4.7854 4.7907 4.7823 4.6062 4.6306
 4.7062 4.7567]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  69500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0304
Val Loss:  [4.6788 4.7037 4.7024 4.6498 4.7128 4.7488 4.7122 4.698  4.726  4.719
 4.6846 4.759 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  70000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.178
Val Loss:  [4.7131 4.6141 4.7533 4.6951 4.6755 4.6608 4.6728 4.8054 4.7254 4.7668
 4.7163 4.7474]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  70500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0395
Val Loss:  [4.6365 4.7137 4.7032 4.7343 4.7821 4.7333 4.6659 4.7489 4.6871 4.7117
 4.7326 4.6828]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  71000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.1127
Val Loss:  [4.6136 4.7508 4.7246 4.7338 4.7797 4.6701 4.6762 4.6701 4.7314 4.6553
 4.7559 4.8023]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  71500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0139
Val Loss:  [4.7286 4.7729 4.7475 4.7375 4.7705 4.7029 4.7247 4.7103 4.6335 4.6884
 4.6901 4.6231]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  72000
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.1111
Val Loss:  [4.6664 4.7717 4.7884 4.6435 4.6498 4.6396 4.8305 4.7211 4.6618 4.6812
 4.7205 4.8198]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  72500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.082
Val Loss:  [4.7365 4.6271 4.7013 4.7833 4.7294 4.77   4.6592 4.6337 4.7616 4.704
 4.7225 4.7221]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  73000
Recall@10:  0.307
NDCG:  0.14
Train Loss:  4.0196
Val Loss:  [4.6703 4.7307 4.699  4.7857 4.7479 4.7074 4.6118 4.7143 4.6665 4.7119
 4.7397 4.6472]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  73500
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0969
Val Loss:  [4.7513 4.7357 4.722  4.7534 4.7322 4.7438 4.5638 4.7149 4.6824 4.7273
 4.709  4.7278]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  74000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  3.9985
Val Loss:  [4.7816 4.7461 4.7055 4.6926 4.6995 4.7143 4.7103 4.6604 4.653  4.7677
 4.6577 4.6809]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  74500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.1217
Val Loss:  [4.6659 4.8496 4.8311 4.7174 4.6743 4.6628 4.7949 4.6394 4.7554 4.6041
 4.6746 4.6834]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  75000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.1132
Val Loss:  [4.6427 4.7162 4.6998 4.7521 4.7162 4.8045 4.7506 4.6598 4.7047 4.6739
 4.7095 4.6714]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  75500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.1016
Val Loss:  [4.684  4.6379 4.7099 4.7339 4.7837 4.6884 4.737  4.7153 4.7168 4.6978
 4.7147 4.6502]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  76000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  3.9856
Val Loss:  [4.619  4.7963 4.6768 4.675  4.6676 4.7036 4.8037 4.733  4.7992 4.7218
 4.6581 4.6988]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  76500
Recall@10:  0.307
NDCG:  0.14
Train Loss:  4.0301
Val Loss:  [4.6868 4.7435 4.6933 4.7253 4.7412 4.6696 4.7235 4.776  4.7528 4.617
 4.7167 4.6997]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  77000
Recall@10:  0.307
NDCG:  0.139
Train Loss:  4.1614
Val Loss:  [4.6633 4.7314 4.7359 4.7062 4.7311 4.7898 4.6882 4.7796 4.6394 4.6898
 4.6776 4.7441]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  77500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.1143
Val Loss:  [4.7884 4.6434 4.6973 4.7337 4.6661 4.6303 4.7184 4.7385 4.7259 4.7737
 4.7278 4.6812]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  78500
Recall@10:  0.307
NDCG:  0.139
Train Loss:  3.9814
Val Loss:  [4.7429 4.6351 4.7065 4.6818 4.7497 4.7018 4.709  4.7027 4.7579 4.6563
 4.6855 4.7592]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  79000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9969
Val Loss:  [4.6961 4.6991 4.7196 4.6691 4.6888 4.694  4.6727 4.7713 4.6918 4.7347
 4.7498 4.7058]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  79500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  3.8535
Val Loss:  [4.7268 4.6205 4.6754 4.7473 4.7478 4.6757 4.7097 4.796  4.6663 4.655
 4.6898 4.8049]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  80000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.128
Val Loss:  [4.7532 4.8037 4.6937 4.6725 4.7271 4.6621 4.7861 4.6385 4.7276 4.6973
 4.6213 4.6827]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  80500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0402
Val Loss:  [4.7324 4.7647 4.6794 4.7336 4.7167 4.5876 4.7095 4.6835 4.7452 4.656
 4.7195 4.7827]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  81000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0047
Val Loss:  [4.6015 4.8257 4.8361 4.6221 4.6912 4.6909 4.6913 4.7969 4.6485 4.7336
 4.6955 4.7257]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  81500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9645
Val Loss:  [4.7342 4.7352 4.6428 4.7336 4.6846 4.6742 4.6975 4.8144 4.7077 4.7178
 4.7258 4.513 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  82000
Recall@10:  0.307
NDCG:  0.14
Train Loss:  3.9709
Val Loss:  [4.7179 4.7642 4.6354 4.7371 4.6884 4.6845 4.7883 4.7035 4.6769 4.7097
 4.6    4.6816]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  82500
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.2136
Val Loss:  [4.7457 4.8034 4.6689 4.7643 4.7198 4.6222 4.7085 4.6949 4.7269 4.6842
 4.7403 4.6666]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  83000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0491
Val Loss:  [4.7375 4.6494 4.8116 4.6921 4.7505 4.7366 4.6499 4.7024 4.5973 4.6828
 4.7442 4.6825]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  83500
Recall@10:  0.307
NDCG:  0.14
Train Loss:  4.0502
Val Loss:  [4.7851 4.7345 4.6882 4.6351 4.6417 4.7365 4.8644 4.746  4.6391 4.6474
 4.6774 4.6758]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  84000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.0542
Val Loss:  [4.6842 4.6489 4.7608 4.7502 4.6802 4.6813 4.6716 4.6429 4.7824 4.7279
 4.7541 4.735 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  84500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9295
Val Loss:  [4.7569 4.6107 4.8002 4.6307 4.7542 4.643  4.7245 4.7255 4.7343 4.6867
 4.7029 4.6611]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  85000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.1577
Val Loss:  [4.7002 4.6994 4.7636 4.698  4.6798 4.7579 4.7198 4.662  4.7032 4.6765
 4.6571 4.6573]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  85500
Recall@10:  0.305
NDCG:  0.139
Train Loss:  4.0901
Val Loss:  [4.7631 4.7015 4.7282 4.717  4.7457 4.6685 4.6713 4.7577 4.7316 4.6834
 4.6571 4.7193]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  86000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9303
Val Loss:  [4.7718 4.6504 4.7392 4.7142 4.6992 4.7716 4.7118 4.7599 4.6451 4.6563
 4.7036 4.6389]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  86500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0447
Val Loss:  [4.7534 4.7582 4.7703 4.7192 4.6686 4.6681 4.684  4.6821 4.626  4.7024
 4.628  4.737 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  87000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9961
Val Loss:  [4.7633 4.618  4.6654 4.7007 4.699  4.6761 4.7334 4.7105 4.6996 4.7767
 4.7434 4.7296]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  87500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.1341
Val Loss:  [4.5956 4.7386 4.6856 4.6741 4.706  4.779  4.7033 4.7753 4.7218 4.6401
 4.6499 4.7729]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  88000
Recall@10:  0.307
NDCG:  0.14
Train Loss:  4.0721
Val Loss:  [4.7451 4.6979 4.7331 4.6402 4.7153 4.7615 4.6833 4.6455 4.7274 4.6849
 4.7037 4.7258]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  88500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.1923
Val Loss:  [4.6816 4.7629 4.7226 4.6918 4.6881 4.7019 4.7059 4.6607 4.7384 4.7071
 4.7777 4.5831]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  89000
Recall@10:  0.309
NDCG:  0.14
Train Loss:  3.9305
Val Loss:  [4.6291 4.7487 4.6267 4.7537 4.7537 4.7449 4.7464 4.6894 4.6873 4.6962
 4.6355 4.7561]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  89500
Recall@10:  0.307
NDCG:  0.139
Train Loss:  4.0738
Val Loss:  [4.7549 4.7204 4.7213 4.7258 4.6039 4.6491 4.6764 4.7731 4.6648 4.6964
 4.7853 4.6829]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  90000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0761
Val Loss:  [4.6931 4.7233 4.6742 4.7686 4.6391 4.6592 4.7338 4.7319 4.7329 4.7534
 4.7011 4.6931]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  90500
Recall@10:  0.307
NDCG:  0.14
Train Loss:  4.1622
Val Loss:  [4.7369 4.612  4.7565 4.7251 4.6707 4.7628 4.6968 4.7389 4.6358 4.6722
 4.7793 4.643 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  91000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  4.1119
Val Loss:  [4.6481 4.7036 4.715  4.6901 4.79   4.705  4.7597 4.69   4.6779 4.6726
 4.6611 4.7841]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  91500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9939
Val Loss:  [4.7245 4.7164 4.6612 4.6701 4.6886 4.7462 4.729  4.673  4.7287 4.7127
 4.6886 4.7171]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  92000
Recall@10:  0.307
NDCG:  0.139
Train Loss:  3.9349
Val Loss:  [4.7547 4.6913 4.6481 4.694  4.7941 4.6351 4.7359 4.6884 4.7179 4.6609
 4.5987 4.8401]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  92500
Recall@10:  0.307
NDCG:  0.14
Train Loss:  3.9731
Val Loss:  [4.6964 4.7572 4.643  4.7548 4.6603 4.6424 4.7798 4.6603 4.718  4.6264
 4.6707 4.791 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  93000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0142
Val Loss:  [4.5793 4.7643 4.6594 4.7921 4.6927 4.6712 4.6134 4.7421 4.7771 4.7047
 4.7741 4.6606]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  93500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0124
Val Loss:  [4.7218 4.6781 4.7029 4.6122 4.7385 4.6584 4.6894 4.8343 4.6786 4.6904
 4.7347 4.6829]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  94000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0916
Val Loss:  [4.6904 4.7086 4.6894 4.7744 4.6372 4.6675 4.721  4.6562 4.713  4.759
 4.7182 4.7078]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  94500
Recall@10:  0.309
NDCG:  0.14
Train Loss:  3.9738
Val Loss:  [4.7411 4.6393 4.7742 4.7248 4.7423 4.7194 4.7198 4.7624 4.6787 4.6276
 4.6948 4.5682]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  95000
Recall@10:  0.306
NDCG:  0.139
Train Loss:  3.9941
Val Loss:  [4.6287 4.7604 4.7207 4.7536 4.7546 4.5917 4.6563 4.6606 4.6637 4.7461
 4.7621 4.7954]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  95500
Recall@10:  0.309
NDCG:  0.14
Train Loss:  4.0204
Val Loss:  [4.6486 4.6743 4.7244 4.7574 4.6894 4.8283 4.6785 4.6661 4.6772 4.6609
 4.7246 4.6615]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  96000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0239
Val Loss:  [4.7046 4.7406 4.65   4.7222 4.5145 4.7099 4.7995 4.7517 4.7015 4.634
 4.7861 4.6779]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  96500
Recall@10:  0.308
NDCG:  0.14
Train Loss:  4.0679
Val Loss:  [4.7199 4.6764 4.6409 4.653  4.7594 4.6506 4.7524 4.6239 4.7505 4.7131
 4.7452 4.775 ]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  97000
Recall@10:  0.31
NDCG:  0.141
Train Loss:  4.1157
Val Loss:  [4.7409 4.7855 4.6118 4.7659 4.7281 4.7068 4.7346 4.7577 4.7226 4.6403
 4.6959 4.5334]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  97500
Recall@10:  0.309
NDCG:  0.14
Train Loss:  3.9493
Val Loss:  [4.6355 4.6021 4.7205 4.7525 4.6455 4.7656 4.6584 4.7665 4.7243 4.7685
 4.6782 4.6413]


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch:  98000
Recall@10:  0.308
NDCG:  0.14
Train Loss:  3.9751
Val Loss:  [4.7877 4.648  4.667  4.6569 4.7485 4.8098 4.735  4.7844 4.6634 4.6263
 4.6864 4.6346]


  0%|          | 0/12 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [25]:
import gc
gc.collect()

8586

In [26]:
test_dataset = ValidShiftedDataset(list(test.values()), max_length + 1, end_code)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

recall_batch = []
ndcg_batch = []
recall_batch_last = []
recall_batch_last_n = []
ndcg_batch_last = []
for X_val, y_val in test_loader:

    X_val = X_val.to(device)
    y_val = y_val.to(device)

    y_pred = model.forward(X_val.long())

    val_loss = loss_function(y_pred.view(-1, 3955), y_val.long().view(-1))
    val_losses.append(val_loss.item())

    recall_batch.append(recall_k_seq(y_pred, y_val, k=1))
#                 print(y_pred.shape)
#                 print(y_val.shape)
    ndcg_batch.append(ndcg_k_seq(y_pred, y_val))
    recall_batch_last.append(recall_k_seq_last(y_pred, y_val, k=1))
    recall_batch_last_n.append(recall_k_seq_last(y_pred, y_val, k=10))
    ndcg_batch_last.append(ndcg_k_seq_last(y_pred, y_val))

#             print(counter, np.mean(recall_batch).round(3), np.round(loss.item(), 4), np.mean(val_losses).round(4))
print("Epoch: ", counter)
print("Recall@10: ", np.mean(recall_batch).round(3))
print("NDCG: ", np.mean(ndcg_batch).round(3))
print("Train Loss: ", np.round(loss.item(), 4))
print("Val Loss: ", np.round(val_losses, 4))

Epoch:  98500
Recall@10:  0.092
NDCG:  0.138
Train Loss:  3.92
Val Loss:  [4.6728 4.6405 4.7094 4.798  4.7703 4.6332 4.7308 4.8353 4.6827 4.8238
 4.2946 4.6055 4.6474 4.632  4.7355 4.7523 4.5849 4.5989 4.8591 4.7677
 4.7396 4.6966 4.6024 4.6339 4.6872 4.5297 4.8695 4.6202 4.7405 4.6792
 4.6016 4.8307 4.9578 4.7885 4.5176 4.5662 4.3075 4.8498 4.6536 4.9235
 4.6923 4.8991 4.8624 4.8197 4.6336 4.7685 4.6627 4.9122 5.0828 4.6217
 4.6903 4.6583 4.5473 4.7504 4.8575 4.9014 4.4955 4.9811 4.8286 4.4444
 4.584  4.7305 4.5577 4.5488 4.589  4.651  4.8433 4.7886 4.5364 5.0141
 4.7882 4.5102 4.6859 4.7054 4.7765 4.6968 4.4849 4.8468 4.666  4.5735
 4.5979 4.7585 4.6788 4.8862 4.6489 4.6163 4.6726 4.7143 4.843  4.8543
 4.5983 4.6586 4.7392 4.787  4.764  4.6393 4.7162 4.7799 4.7643 4.6617
 4.6654 4.5783 4.7894 4.7173 4.1557 4.7219 4.8763 4.6991 4.7657 4.6006
 4.6485 4.4606 4.6346 4.7261 4.5867 4.6432 4.643  4.7125 4.7216 4.8548
 4.743  4.6442 4.5904 4.7911 4.8246 4.7503 4.626  4.5921 4.5727 4.5728
 4.

In [27]:
print("Epoch: ", counter)
print("Recall@10: ", np.mean(recall_batch_last_n).round(3))
print("Recall@1: ", np.mean(recall_batch_last).round(3))
print("NDCG: ", np.mean(ndcg_batch_last).round(3))


Epoch:  98500
Recall@10:  0.294
Recall@1:  0.086
NDCG:  0.134


In [None]:
Epoch:  98500
Recall@10:  0.305
NDCG:  0.138

In [None]:
epochs = 15000
for epoch in range(epochs):
    for X, y in train_loader:

        optimizer.zero_grad()
        
        X = X.to('cuda:0')
        y = y.to('cuda:0')

        loss = loss_function(model.forward(X).view(-1, 3955), y.view(-1))

        losses.append(loss.item())
        loss.backward()
         
        optimizer.step()
        
        counter += 1
        if counter % 500 == 0:
            val_losses = []
            recall_batch = []
            ndcg_batch = []
            for X_val, y_val in val_loader:
                
                X_val = X_val.to('cuda:0')
                y_val = y_val.to('cuda:0')
                
                y_pred = model.forward(X_val.long())[:, -1, :]
                
                val_loss = loss_function(y_pred.view(-1, 3955), y_val.long().view(-1))
                val_losses.append(val_loss.item())
                
                recall_batch.append(recall_k(y_pred, y_val, k=10))
#                 print(y_pred.shape)
#                 print(y_val.shape)
                ndcg_batch.append(ndcg_k(y_pred, y_val))

#             print(counter, np.mean(recall_batch).round(3), np.round(loss.item(), 4), np.mean(val_losses).round(4))
            print("Epoch: ", counter)
            print("Recall@10: ", np.mean(recall_batch).round(3))
            print("NDCG: ", np.mean(ndcg_batch).round(3))
            print("Train Loss: ", np.round(loss.item(), 4))
            print("Val Loss: ", np.round(val_losses, 4))