In [1]:
import os
import sys
sys.path.append("..")
from nbr.preparation import Preprocess, save_split, Corpus
from nbr.trainer import NBRTrainer
from nbr.model import BPR, SLRC
import torch
import random
import numpy as np
import optuna
import warnings
warnings.filterwarnings("ignore")

# TaFeng

Fix seed:

In [4]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Read interactions data (filter users with less than 5 transactions, high purchase frequency and one-day users and items with less than 10 transactions). Train dataset - all baskets except the last two, validation dataset - the last but one basket, test dataset - the last basket:

In [5]:
corpus_path = "./data/"
dataset_name = "ta_feng"

preprocessor = Preprocess(corpus_path, dataset_name)
preprocessor.load_data(5, 10, filt=True)
save_split(corpus_path, dataset_name, preprocessor)

Before preprocessing: #users = 32266, #items = 23812, #clicks = 817741 (#illegal records = 0)
After preprocessing: #users = 7358, #items = 11202, #clicks = 368951
Saving dataset in ./data//data_ta_feng/...


In [6]:
corpus = Corpus(corpus_path, dataset_name)
corpus.load_data()

Tune hyperparams on validation dataset:

In [7]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=5,
    topk=10,
    early_stop_num=None
)

train dataset preparing...


100%|██████████| 7358/7358 [00:12<00:00, 598.50it/s]


dev dataset preparing...


100%|██████████| 7357/7357 [00:01<00:00, 4508.39it/s]


test dataset preparing...


100%|██████████| 7357/7357 [00:01<00:00, 3953.99it/s]


In [8]:
def objective(trial):
    params = {
        "model": SLRC(
            base_model_class=BPR,
            base_model_config={
                "emb_size": trial.suggest_categorical("emb_size", [32, 64, 128]),
                "user_num": corpus.n_users,
                "item_num": corpus.n_items,
                "click_num": corpus.n_clicks
            },
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128, 256]),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "l2_reg_coef": trial.suggest_loguniform("l2_reg_coef", 1e-4, 1e-1)
    }

    trainer.init_hyperparams(**params)
    trainer.train(evaluation_flg=False)
    metrics = trainer.evaluate(mode="dev")
    score = metrics["ndcg"]
    return score

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=25)

[32m[I 2023-04-12 09:29:35,501][0m A new study created in memory with name: no-name-d95d3175-61c7-422f-9fd9-bea34b30e576[0m


Epoch 1:


Batch loss = 0.615568: 100%|██████████| 8889/8889 [01:37<00:00, 91.11it/s] 



Epoch 2:


Batch loss = 0.885531: 100%|██████████| 8889/8889 [01:35<00:00, 93.01it/s]



Epoch 3:


Batch loss = 0.515037: 100%|██████████| 8889/8889 [01:36<00:00, 92.19it/s]



Epoch 4:


Batch loss = 0.444421: 100%|██████████| 8889/8889 [01:36<00:00, 91.71it/s]



Epoch 5:


Batch loss = 0.718573: 100%|██████████| 8889/8889 [01:36<00:00, 92.50it/s] 





100%|██████████| 7357/7357 [01:26<00:00, 84.93it/s]
[32m[I 2023-04-12 09:39:10,554][0m Trial 0 finished with value: 0.09845775866498242 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.011018509458263562, 'l2_reg_coef': 0.00032161219616553475}. Best is trial 0 with value: 0.09845775866498242.[0m


Epoch 1:


Batch loss = 0.709831: 100%|██████████| 2223/2223 [00:33<00:00, 66.48it/s]



Epoch 2:


Batch loss = 0.584393: 100%|██████████| 2223/2223 [00:31<00:00, 69.88it/s]



Epoch 3:


Batch loss = 0.201437: 100%|██████████| 2223/2223 [00:32<00:00, 68.39it/s]



Epoch 4:


Batch loss = 0.227203: 100%|██████████| 2223/2223 [00:33<00:00, 66.97it/s]



Epoch 5:


Batch loss = 0.170475: 100%|██████████| 2223/2223 [00:31<00:00, 70.49it/s]





100%|██████████| 7357/7357 [01:25<00:00, 85.56it/s] 
[32m[I 2023-04-12 09:43:19,287][0m Trial 1 finished with value: 0.09991268960072652 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.007709412252614239, 'l2_reg_coef': 0.0007509797119626302}. Best is trial 1 with value: 0.09991268960072652.[0m


Epoch 1:


Batch loss = 0.675095: 100%|██████████| 2223/2223 [00:31<00:00, 70.18it/s]



Epoch 2:


Batch loss = 0.666087: 100%|██████████| 2223/2223 [00:33<00:00, 66.98it/s]



Epoch 3:


Batch loss = 0.621916: 100%|██████████| 2223/2223 [00:32<00:00, 68.41it/s]



Epoch 4:


Batch loss = 0.549419: 100%|██████████| 2223/2223 [00:32<00:00, 68.91it/s]



Epoch 5:


Batch loss = 0.464818: 100%|██████████| 2223/2223 [00:33<00:00, 66.91it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.23it/s] 
[32m[I 2023-04-12 09:47:28,580][0m Trial 2 finished with value: 0.10879014472525873 and parameters: {'emb_size': 32, 'batch_size': 128, 'lr': 0.0005445728346977897, 'l2_reg_coef': 0.0071334715811591475}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.682206: 100%|██████████| 2223/2223 [00:31<00:00, 70.62it/s]



Epoch 2:


Batch loss = 0.68173: 100%|██████████| 2223/2223 [00:32<00:00, 67.40it/s]



Epoch 3:


Batch loss = 0.6812: 100%|██████████| 2223/2223 [00:32<00:00, 69.22it/s]



Epoch 4:


Batch loss = 0.680727: 100%|██████████| 2223/2223 [00:32<00:00, 68.56it/s]



Epoch 5:


Batch loss = 0.68024: 100%|██████████| 2223/2223 [00:33<00:00, 67.28it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.38it/s] 
[32m[I 2023-04-12 09:51:36,949][0m Trial 3 finished with value: 0.10321574200866855 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 2.3005803026525478e-05, 'l2_reg_coef': 0.0007981787657617986}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.833505: 100%|██████████| 2223/2223 [00:32<00:00, 68.90it/s]



Epoch 2:


Batch loss = 2.664169: 100%|██████████| 2223/2223 [00:31<00:00, 70.18it/s]



Epoch 3:


Batch loss = 3.135139: 100%|██████████| 2223/2223 [00:33<00:00, 67.36it/s]



Epoch 4:


Batch loss = 1.573329: 100%|██████████| 2223/2223 [00:31<00:00, 69.51it/s]



Epoch 5:


Batch loss = 1.743166: 100%|██████████| 2223/2223 [00:32<00:00, 68.87it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.29it/s]
[32m[I 2023-04-12 09:55:44,605][0m Trial 4 finished with value: 0.07475734171869129 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 0.02675476925825261, 'l2_reg_coef': 0.0011349008421937211}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.857236: 100%|██████████| 2223/2223 [00:32<00:00, 67.89it/s]



Epoch 2:


Batch loss = 2.568643: 100%|██████████| 2223/2223 [00:31<00:00, 70.04it/s]



Epoch 3:


Batch loss = 1.462181: 100%|██████████| 2223/2223 [00:32<00:00, 67.47it/s]



Epoch 4:


Batch loss = 0.631801: 100%|██████████| 2223/2223 [00:31<00:00, 69.69it/s]



Epoch 5:


Batch loss = 1.10624: 100%|██████████| 2223/2223 [00:31<00:00, 70.44it/s]





100%|██████████| 7357/7357 [01:25<00:00, 86.22it/s] 
[32m[I 2023-04-12 09:59:51,091][0m Trial 5 finished with value: 0.08499414741183456 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.01924964086209924, 'l2_reg_coef': 0.0002840900733622218}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 1.031451: 100%|██████████| 4445/4445 [00:55<00:00, 80.80it/s]



Epoch 2:


Batch loss = 11.290731: 100%|██████████| 4445/4445 [00:53<00:00, 82.61it/s]



Epoch 3:


Batch loss = 7.671795: 100%|██████████| 4445/4445 [00:54<00:00, 81.88it/s]



Epoch 4:


Batch loss = 11.68211: 100%|██████████| 4445/4445 [00:54<00:00, 81.98it/s]



Epoch 5:


Batch loss = 10.735609: 100%|██████████| 4445/4445 [00:54<00:00, 82.02it/s]





100%|██████████| 7357/7357 [01:28<00:00, 82.74it/s]
[32m[I 2023-04-12 10:05:51,784][0m Trial 6 finished with value: 0.04436801289454063 and parameters: {'emb_size': 128, 'batch_size': 64, 'lr': 0.040862698316257884, 'l2_reg_coef': 0.00401489180670258}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.676434: 100%|██████████| 1112/1112 [00:22<00:00, 50.22it/s]



Epoch 2:


Batch loss = 0.676185: 100%|██████████| 1112/1112 [00:21<00:00, 51.52it/s]



Epoch 3:


Batch loss = 0.675931: 100%|██████████| 1112/1112 [00:22<00:00, 49.61it/s]


Epoch 4:



Batch loss = 0.675676: 100%|██████████| 1112/1112 [00:22<00:00, 48.99it/s]



Epoch 5:


Batch loss = 0.675415: 100%|██████████| 1112/1112 [00:21<00:00, 51.09it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.10it/s] 
[32m[I 2023-04-12 10:09:06,017][0m Trial 7 finished with value: 0.10209699035910465 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.4448968183511732e-05, 'l2_reg_coef': 0.0019430167088096818}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.651664: 100%|██████████| 4445/4445 [00:52<00:00, 84.35it/s]



Epoch 2:


Batch loss = 0.634136: 100%|██████████| 4445/4445 [00:53<00:00, 82.37it/s]



Epoch 3:


Batch loss = 0.572745: 100%|██████████| 4445/4445 [00:52<00:00, 84.00it/s]



Epoch 4:


Batch loss = 0.491763: 100%|██████████| 4445/4445 [00:53<00:00, 82.73it/s]



Epoch 5:


Batch loss = 0.417729: 100%|██████████| 4445/4445 [00:52<00:00, 84.05it/s]





100%|██████████| 7357/7357 [01:23<00:00, 87.84it/s] 
[32m[I 2023-04-12 10:14:56,125][0m Trial 8 finished with value: 0.10864957962452208 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.00046777453460422386, 'l2_reg_coef': 0.006557415353196482}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.669242: 100%|██████████| 2223/2223 [00:31<00:00, 69.68it/s]



Epoch 2:


Batch loss = 0.522703: 100%|██████████| 2223/2223 [00:32<00:00, 69.42it/s]



Epoch 3:


Batch loss = 0.360478: 100%|██████████| 2223/2223 [00:32<00:00, 69.30it/s]



Epoch 4:


Batch loss = 0.388069: 100%|██████████| 2223/2223 [00:32<00:00, 69.38it/s]



Epoch 5:


Batch loss = 0.259352: 100%|██████████| 2223/2223 [00:32<00:00, 68.09it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.23it/s] 
[32m[I 2023-04-12 10:19:00,365][0m Trial 9 finished with value: 0.1084243556132671 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 0.0016994670732385824, 'l2_reg_coef': 0.0049718848264030035}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.61905: 100%|██████████| 8889/8889 [01:38<00:00, 90.34it/s]



Epoch 2:


Batch loss = 0.606768: 100%|██████████| 8889/8889 [01:38<00:00, 89.94it/s]



Epoch 3:


Batch loss = 0.596979: 100%|██████████| 8889/8889 [01:38<00:00, 90.07it/s]



Epoch 4:


Batch loss = 0.563625: 100%|██████████| 8889/8889 [01:38<00:00, 89.88it/s]



Epoch 5:


Batch loss = 0.542626: 100%|██████████| 8889/8889 [01:38<00:00, 90.13it/s]





100%|██████████| 7357/7357 [01:24<00:00, 87.33it/s]
[32m[I 2023-04-12 10:28:38,206][0m Trial 10 finished with value: 0.10657291800219418 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.0003497063710293824, 'l2_reg_coef': 0.03295582896915084}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.658695: 100%|██████████| 4445/4445 [00:54<00:00, 82.24it/s]


Epoch 2:



Batch loss = 0.647506: 100%|██████████| 4445/4445 [00:55<00:00, 80.54it/s]



Epoch 3:


Batch loss = 0.630427: 100%|██████████| 4445/4445 [00:55<00:00, 80.39it/s]



Epoch 4:


Batch loss = 0.604138: 100%|██████████| 4445/4445 [00:54<00:00, 81.04it/s]



Epoch 5:


Batch loss = 0.558073: 100%|██████████| 4445/4445 [00:55<00:00, 80.30it/s]





100%|██████████| 7357/7357 [01:25<00:00, 85.98it/s] 
[32m[I 2023-04-12 10:34:38,705][0m Trial 11 finished with value: 0.10732535706930774 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.0002532791620038323, 'l2_reg_coef': 0.01313447029023291}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.642512: 100%|██████████| 4445/4445 [00:55<00:00, 79.81it/s]



Epoch 2:


Batch loss = 0.598326: 100%|██████████| 4445/4445 [00:55<00:00, 79.68it/s]



Epoch 3:


Batch loss = 0.506142: 100%|██████████| 4445/4445 [00:56<00:00, 79.20it/s]



Epoch 4:


Batch loss = 0.412063: 100%|██████████| 4445/4445 [00:55<00:00, 80.17it/s]


Epoch 5:



Batch loss = 0.37653: 100%|██████████| 4445/4445 [00:54<00:00, 81.18it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.47it/s]
[32m[I 2023-04-12 10:40:42,722][0m Trial 12 finished with value: 0.10764018620532262 and parameters: {'emb_size': 32, 'batch_size': 64, 'lr': 0.0011382008633000968, 'l2_reg_coef': 0.011537405259754371}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.675127: 100%|██████████| 1112/1112 [00:22<00:00, 48.98it/s]



Epoch 2:


Batch loss = 0.672954: 100%|██████████| 1112/1112 [00:22<00:00, 50.27it/s]



Epoch 3:


Batch loss = 0.670797: 100%|██████████| 1112/1112 [00:23<00:00, 48.28it/s]



Epoch 4:


Batch loss = 0.66842: 100%|██████████| 1112/1112 [00:21<00:00, 50.84it/s]


Epoch 5:



Batch loss = 0.665001: 100%|██████████| 1112/1112 [00:22<00:00, 48.96it/s]





100%|██████████| 7357/7357 [01:23<00:00, 87.69it/s] 
[32m[I 2023-04-12 10:43:59,257][0m Trial 13 finished with value: 0.10558491517497041 and parameters: {'emb_size': 64, 'batch_size': 256, 'lr': 0.00013084027307985277, 'l2_reg_coef': 0.057496589988963454}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.645548: 100%|██████████| 4445/4445 [00:54<00:00, 82.23it/s]



Epoch 2:


Batch loss = 0.539046: 100%|██████████| 4445/4445 [00:54<00:00, 81.92it/s]



Epoch 3:


Batch loss = 0.437771: 100%|██████████| 4445/4445 [00:54<00:00, 82.25it/s]


Epoch 4:



Batch loss = 0.340084: 100%|██████████| 4445/4445 [00:54<00:00, 82.24it/s]



Epoch 5:


Batch loss = 0.32388: 100%|██████████| 4445/4445 [00:54<00:00, 81.15it/s]





100%|██████████| 7357/7357 [01:22<00:00, 88.77it/s]
[32m[I 2023-04-12 10:49:53,447][0m Trial 14 finished with value: 0.10373790612427843 and parameters: {'emb_size': 32, 'batch_size': 64, 'lr': 0.0033353274115153057, 'l2_reg_coef': 0.008196335630364664}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.768393: 100%|██████████| 4445/4445 [00:54<00:00, 81.57it/s]



Epoch 2:


Batch loss = 22.522455: 100%|██████████| 4445/4445 [00:54<00:00, 82.29it/s]



Epoch 3:


Batch loss = 27.614145: 100%|██████████| 4445/4445 [00:54<00:00, 81.55it/s]



Epoch 4:


Batch loss = 38.399331: 100%|██████████| 4445/4445 [00:54<00:00, 82.25it/s]



Epoch 5:


Batch loss = 12.170757: 100%|██████████| 4445/4445 [00:54<00:00, 81.83it/s]





100%|██████████| 7357/7357 [01:23<00:00, 87.96it/s] 
[32m[I 2023-04-12 10:55:48,664][0m Trial 15 finished with value: 0.03820782438102939 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.09195886437027284, 'l2_reg_coef': 0.00010314684086731772}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.616208: 100%|██████████| 8889/8889 [01:38<00:00, 90.64it/s]



Epoch 2:


Batch loss = 0.606625: 100%|██████████| 8889/8889 [01:37<00:00, 91.09it/s] 



Epoch 3:


Batch loss = 0.590604: 100%|██████████| 8889/8889 [01:38<00:00, 90.62it/s]



Epoch 4:


Batch loss = 0.548231: 100%|██████████| 8889/8889 [01:37<00:00, 91.24it/s]



Epoch 5:


Batch loss = 0.566369: 100%|██████████| 8889/8889 [01:37<00:00, 90.88it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.05it/s] 
[32m[I 2023-04-12 11:05:21,328][0m Trial 16 finished with value: 0.10614009419415746 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.0004159632358936015, 'l2_reg_coef': 0.019782709294638965}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.675782: 100%|██████████| 1112/1112 [00:22<00:00, 49.42it/s]


Epoch 2:



Batch loss = 0.674435: 100%|██████████| 1112/1112 [00:21<00:00, 51.33it/s]


Epoch 3:



Batch loss = 0.673201: 100%|██████████| 1112/1112 [00:23<00:00, 48.23it/s]


Epoch 4:



Batch loss = 0.671986: 100%|██████████| 1112/1112 [00:22<00:00, 49.68it/s]


Epoch 5:



Batch loss = 0.670762: 100%|██████████| 1112/1112 [00:23<00:00, 48.16it/s]





100%|██████████| 7357/7357 [01:25<00:00, 86.31it/s]
[32m[I 2023-04-12 11:08:39,386][0m Trial 17 finished with value: 0.10384149178501334 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 7.924239443135598e-05, 'l2_reg_coef': 0.09431328085393048}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.644798: 100%|██████████| 4445/4445 [00:53<00:00, 82.34it/s]



Epoch 2:


Batch loss = 0.582986: 100%|██████████| 4445/4445 [00:55<00:00, 79.82it/s]



Epoch 3:


Batch loss = 0.439004: 100%|██████████| 4445/4445 [00:58<00:00, 75.69it/s]



Epoch 4:


Batch loss = 0.39313: 100%|██████████| 4445/4445 [00:56<00:00, 78.05it/s]


Epoch 5:



Batch loss = 0.370454: 100%|██████████| 4445/4445 [00:57<00:00, 77.95it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.22it/s] 
[32m[I 2023-04-12 11:14:48,272][0m Trial 18 finished with value: 0.10846560648863778 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.0008831680921026232, 'l2_reg_coef': 0.004540684678764227}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.671504: 100%|██████████| 2223/2223 [00:34<00:00, 65.22it/s]



Epoch 2:


Batch loss = 0.306908: 100%|██████████| 2223/2223 [00:33<00:00, 66.28it/s]



Epoch 3:


Batch loss = 0.264469: 100%|██████████| 2223/2223 [00:34<00:00, 65.30it/s]



Epoch 4:


Batch loss = 0.153553: 100%|██████████| 2223/2223 [00:33<00:00, 66.64it/s]



Epoch 5:


Batch loss = 0.117878: 100%|██████████| 2223/2223 [00:33<00:00, 66.39it/s]





100%|██████████| 7357/7357 [01:25<00:00, 85.68it/s]
[32m[I 2023-04-12 11:19:02,919][0m Trial 19 finished with value: 0.10802901909704897 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.00345980916161105, 'l2_reg_coef': 0.0024704198299284556}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.669094: 100%|██████████| 4445/4445 [00:56<00:00, 78.90it/s]



Epoch 2:


Batch loss = 0.665725: 100%|██████████| 4445/4445 [00:55<00:00, 80.59it/s]



Epoch 3:


Batch loss = 0.662558: 100%|██████████| 4445/4445 [00:56<00:00, 79.01it/s]



Epoch 4:


Batch loss = 0.659532: 100%|██████████| 4445/4445 [00:56<00:00, 79.14it/s]



Epoch 5:


Batch loss = 0.656721: 100%|██████████| 4445/4445 [00:56<00:00, 79.20it/s]





100%|██████████| 7357/7357 [01:24<00:00, 87.53it/s] 
[32m[I 2023-04-12 11:25:07,170][0m Trial 20 finished with value: 0.10482662578760765 and parameters: {'emb_size': 32, 'batch_size': 64, 'lr': 5.162371108918804e-05, 'l2_reg_coef': 0.007497163703650436}. Best is trial 2 with value: 0.10879014472525873.[0m


Epoch 1:


Batch loss = 0.648384: 100%|██████████| 4445/4445 [00:56<00:00, 78.38it/s]



Epoch 2:


Batch loss = 0.60499: 100%|██████████| 4445/4445 [00:55<00:00, 80.55it/s]



Epoch 3:


Batch loss = 0.504455: 100%|██████████| 4445/4445 [00:55<00:00, 79.83it/s]



Epoch 4:


Batch loss = 0.462011: 100%|██████████| 4445/4445 [00:56<00:00, 77.98it/s]



Epoch 5:


Batch loss = 0.407935: 100%|██████████| 4445/4445 [00:56<00:00, 79.07it/s]





100%|██████████| 7357/7357 [01:24<00:00, 86.79it/s] 
[32m[I 2023-04-12 11:31:12,910][0m Trial 21 finished with value: 0.1093370103339033 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.0006142297613045982, 'l2_reg_coef': 0.0047331742711911855}. Best is trial 21 with value: 0.1093370103339033.[0m


Epoch 1:


Batch loss = 0.649938: 100%|██████████| 4445/4445 [00:56<00:00, 78.84it/s]



Epoch 2:


Batch loss = 0.638338: 100%|██████████| 4445/4445 [00:56<00:00, 78.11it/s]



Epoch 3:


Batch loss = 0.583207: 100%|██████████| 4445/4445 [00:56<00:00, 78.91it/s]


Epoch 4:



Batch loss = 0.515884: 100%|██████████| 4445/4445 [00:55<00:00, 80.22it/s]



Epoch 5:


Batch loss = 0.469672: 100%|██████████| 4445/4445 [00:56<00:00, 78.27it/s]





100%|██████████| 7357/7357 [01:25<00:00, 85.63it/s]
[32m[I 2023-04-12 11:37:20,806][0m Trial 22 finished with value: 0.10709916061030145 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.0005380601448932815, 'l2_reg_coef': 0.020736461016104896}. Best is trial 21 with value: 0.1093370103339033.[0m


Epoch 1:


Batch loss = 0.6611: 100%|██████████| 4445/4445 [00:57<00:00, 77.08it/s]



Epoch 2:


Batch loss = 0.650369: 100%|██████████| 4445/4445 [00:57<00:00, 76.98it/s]



Epoch 3:


Batch loss = 0.636753: 100%|██████████| 4445/4445 [00:57<00:00, 77.26it/s]



Epoch 4:


Batch loss = 0.61219: 100%|██████████| 4445/4445 [00:57<00:00, 77.59it/s]



Epoch 5:


Batch loss = 0.581824: 100%|██████████| 4445/4445 [00:57<00:00, 77.90it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.45it/s]
[32m[I 2023-04-12 11:43:34,369][0m Trial 23 finished with value: 0.1076340302967333 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.00020289243236411598, 'l2_reg_coef': 0.00650703755714042}. Best is trial 21 with value: 0.1093370103339033.[0m


Epoch 1:


Batch loss = 0.647944: 100%|██████████| 4445/4445 [00:58<00:00, 76.32it/s]



Epoch 2:


Batch loss = 0.604464: 100%|██████████| 4445/4445 [00:58<00:00, 76.13it/s]



Epoch 3:


Batch loss = 0.482702: 100%|██████████| 4445/4445 [00:58<00:00, 76.57it/s]



Epoch 4:


Batch loss = 0.38924: 100%|██████████| 4445/4445 [00:57<00:00, 77.34it/s]



Epoch 5:


Batch loss = 0.363346: 100%|██████████| 4445/4445 [00:57<00:00, 77.60it/s]





100%|██████████| 7357/7357 [01:25<00:00, 85.93it/s]
[32m[I 2023-04-12 11:49:49,619][0m Trial 24 finished with value: 0.10904024662631921 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.0006250653189707479, 'l2_reg_coef': 0.0028470312329004627}. Best is trial 21 with value: 0.1093370103339033.[0m


Test SLRC (calculate scores for different seeds):

In [11]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=20,
    topk=10,
    early_stop_num=3
)

train dataset preparing...


100%|██████████| 7358/7358 [00:11<00:00, 641.34it/s]


dev dataset preparing...


100%|██████████| 7357/7357 [00:02<00:00, 3669.10it/s]


test dataset preparing...


100%|██████████| 7357/7357 [00:01<00:00, 4187.58it/s]


In [12]:
test_metrics = {
    "precision": [],
    "recall": [],
    "ndcg": []
}

In [None]:
for seed in range(5):
    print(f"\n___SEED___{seed}")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    params = {
        "model": SLRC(
            base_model_class=BPR,
            base_model_config={
                "emb_size": study.best_params["emb_size"],
                "user_num": corpus.n_users,
                "item_num": corpus.n_items,
                "click_num": corpus.n_clicks
            },
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": study.best_params["batch_size"],
        "lr": study.best_params["lr"],
        "l2_reg_coef": study.best_params["l2_reg_coef"]
    }

    trainer.init_hyperparams(**params)
    trainer.train()

    metrics = trainer.evaluate(mode="test")

    test_metrics["precision"].append(metrics["precision"])
    test_metrics["recall"].append(metrics["recall"])
    test_metrics["ndcg"].append(metrics["ndcg"])


___SEED___0
Epoch 1:


Batch loss = 0.647844: 100%|██████████| 4445/4445 [00:54<00:00, 81.41it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.86it/s]


 {'precision': 0.05306510806035069, 'recall': 0.11760657494409904, 'ndcg': 0.10338726165353773}
Epoch 2:



Batch loss = 0.613529: 100%|██████████| 4445/4445 [00:51<00:00, 86.26it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:23<00:00, 87.77it/s] 


 {'precision': 0.055280685061845865, 'recall': 0.12787036008393035, 'ndcg': 0.10752063171562534}
Epoch 3:



Batch loss = 0.484176: 100%|██████████| 4445/4445 [00:52<00:00, 84.15it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.27it/s]


 {'precision': 0.05630012233247248, 'recall': 0.13254850400181878, 'ndcg': 0.1095571150683061}
Epoch 4:



Batch loss = 0.412783: 100%|██████████| 4445/4445 [00:52<00:00, 84.40it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.03it/s]


 {'precision': 0.056069049884463776, 'recall': 0.13089943342672652, 'ndcg': 0.10860200062750996}
Epoch 5:



Batch loss = 0.40947: 100%|██████████| 4445/4445 [00:53<00:00, 83.44it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.79it/s]


 {'precision': 0.055579719994563, 'recall': 0.12869468613279975, 'ndcg': 0.1081829218646605}
Epoch 6:



Batch loss = 0.405149: 100%|██████████| 4445/4445 [00:52<00:00, 84.53it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:24<00:00, 86.66it/s] 


 {'precision': 0.05598749490281364, 'recall': 0.13105695275419954, 'ndcg': 0.1091629061070286}



100%|██████████| 7357/7357 [01:25<00:00, 86.47it/s] 


___SEED___1
Epoch 1:



Batch loss = 0.648115: 100%|██████████| 4445/4445 [00:52<00:00, 84.91it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:23<00:00, 88.15it/s]


 {'precision': 0.052983553078700556, 'recall': 0.11730884025233931, 'ndcg': 0.10326253450365228}





Epoch 2:


Batch loss = 0.618138: 100%|██████████| 4445/4445 [00:53<00:00, 83.14it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.00it/s]


 {'precision': 0.05518553758325404, 'recall': 0.12778208290447507, 'ndcg': 0.10736117864801158}
Epoch 3:



Batch loss = 0.49879: 100%|██████████| 4445/4445 [00:53<00:00, 82.73it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.69it/s]


 {'precision': 0.05597390240587196, 'recall': 0.13212535969929048, 'ndcg': 0.10932099089009956}
Epoch 4:



Batch loss = 0.498291: 100%|██████████| 4445/4445 [00:53<00:00, 82.76it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.75it/s] 


 {'precision': 0.056286529835530785, 'recall': 0.1320072938810533, 'ndcg': 0.10893477085005902}
Epoch 5:



Batch loss = 0.393214: 100%|██████████| 4445/4445 [00:52<00:00, 84.38it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.58it/s]


 {'precision': 0.05566127497621313, 'recall': 0.128561368615953, 'ndcg': 0.10802543395474196}
Epoch 6:



Batch loss = 0.372129: 100%|██████████| 4445/4445 [00:52<00:00, 84.28it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.52it/s]


 {'precision': 0.05598749490281364, 'recall': 0.12938509283272898, 'ndcg': 0.1089098140710593}



100%|██████████| 7357/7357 [01:24<00:00, 86.56it/s]


___SEED___2
Epoch 1:



Batch loss = 0.64851: 100%|██████████| 4445/4445 [00:53<00:00, 83.76it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:23<00:00, 87.59it/s]


 {'precision': 0.052942775587875496, 'recall': 0.11726026876172108, 'ndcg': 0.10313874782031339}
Epoch 2:



Batch loss = 0.613944: 100%|██████████| 4445/4445 [00:54<00:00, 82.21it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.49it/s] 


 {'precision': 0.0552263150740791, 'recall': 0.12784914315486998, 'ndcg': 0.10745646263383438}





Epoch 3:


Batch loss = 0.515254: 100%|██████████| 4445/4445 [00:53<00:00, 83.67it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.10it/s]


 {'precision': 0.05615060486611391, 'recall': 0.1321609177451479, 'ndcg': 0.10964041700697584}





Epoch 4:


Batch loss = 0.40877: 100%|██████████| 4445/4445 [00:53<00:00, 82.51it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.06it/s]


 {'precision': 0.05593312491504689, 'recall': 0.12980271658558626, 'ndcg': 0.10834286228903398}
Epoch 5:



Batch loss = 0.396547: 100%|██████████| 4445/4445 [00:53<00:00, 83.17it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.50it/s]



 {'precision': 0.056082642381405465, 'recall': 0.13033290741866962, 'ndcg': 0.10892434725283648}
Epoch 6:


Batch loss = 0.359622: 100%|██████████| 4445/4445 [00:54<00:00, 82.07it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.11it/s] 


 {'precision': 0.056259344841647414, 'recall': 0.13130733438514197, 'ndcg': 0.1093480871326376}



100%|██████████| 7357/7357 [01:24<00:00, 87.12it/s]



___SEED___3
Epoch 1:


Batch loss = 0.648213: 100%|██████████| 4445/4445 [00:53<00:00, 82.79it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.14it/s]


 {'precision': 0.05292918309093381, 'recall': 0.11722428625684035, 'ndcg': 0.10334206215700821}





Epoch 2:


Batch loss = 0.613215: 100%|██████████| 4445/4445 [00:54<00:00, 81.34it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.90it/s] 


 {'precision': 0.055307870055729236, 'recall': 0.12817648204870014, 'ndcg': 0.10749804528714656}
Epoch 3:



Batch loss = 0.522081: 100%|██████████| 4445/4445 [00:53<00:00, 82.88it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.45it/s]


 {'precision': 0.056096234878347154, 'recall': 0.1320079945827257, 'ndcg': 0.10941848721221557}
Epoch 4:



Batch loss = 0.455008: 100%|██████████| 4445/4445 [00:54<00:00, 80.85it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.48it/s]


 {'precision': 0.05605545738752209, 'recall': 0.13052547489044342, 'ndcg': 0.10846417290285777}





Epoch 5:


Batch loss = 0.462366: 100%|██████████| 4445/4445 [00:54<00:00, 81.78it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:26<00:00, 85.41it/s]


 {'precision': 0.05559331249150469, 'recall': 0.12858402374791061, 'ndcg': 0.10827293938903636}
Epoch 6:



Batch loss = 0.455538: 100%|██████████| 4445/4445 [00:55<00:00, 80.72it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 85.93it/s] 


 {'precision': 0.056449639798831046, 'recall': 0.13171414304460102, 'ndcg': 0.10979318641736421}
Epoch 7:



Batch loss = 0.354771: 100%|██████████| 4445/4445 [00:54<00:00, 81.28it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.44it/s]


 {'precision': 0.056123419872230525, 'recall': 0.13106735390259025, 'ndcg': 0.10969199779528051}
Epoch 8:



Batch loss = 0.431846: 100%|██████████| 4445/4445 [00:54<00:00, 81.07it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:24<00:00, 86.98it/s]


 {'precision': 0.056272937338589096, 'recall': 0.13061539721897414, 'ndcg': 0.10898434919210259}
Epoch 9:



Batch loss = 0.355546: 100%|██████████| 4445/4445 [00:54<00:00, 81.27it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.20it/s]


 {'precision': 0.05598749490281364, 'recall': 0.1302838275905165, 'ndcg': 0.10907176493139453}



100%|██████████| 7357/7357 [01:24<00:00, 86.75it/s]


___SEED___4
Epoch 1:



Batch loss = 0.648428: 100%|██████████| 4445/4445 [00:54<00:00, 81.55it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:24<00:00, 86.83it/s] 


 {'precision': 0.05283403561234199, 'recall': 0.11688602227934589, 'ndcg': 0.10315941097580893}
Epoch 2:



Batch loss = 0.617539: 100%|██████████| 4445/4445 [00:54<00:00, 81.41it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.69it/s]


 {'precision': 0.055280685061845865, 'recall': 0.12808553041197648, 'ndcg': 0.1075739078236309}
Epoch 3:



Batch loss = 0.526964: 100%|██████████| 4445/4445 [00:54<00:00, 81.00it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.82it/s] 


 {'precision': 0.05610982737528884, 'recall': 0.1318116758528083, 'ndcg': 0.10937556278952634}
Epoch 4:



Batch loss = 0.467776: 100%|██████████| 4445/4445 [00:55<00:00, 80.65it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 87.33it/s]


 {'precision': 0.05589234742422183, 'recall': 0.13093346527700872, 'ndcg': 0.10873333472870504}
Epoch 5:



Batch loss = 0.472686: 100%|██████████| 4445/4445 [00:55<00:00, 80.46it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.18it/s]


 {'precision': 0.05564768247927144, 'recall': 0.12844266576522984, 'ndcg': 0.10794656347332765}
Epoch 6:



Batch loss = 0.3681: 100%|██████████| 4445/4445 [00:56<00:00, 79.26it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:26<00:00, 85.00it/s]


 {'precision': 0.05593312491504689, 'recall': 0.12913710209741236, 'ndcg': 0.10902443637702317}



100%|██████████| 7357/7357 [01:25<00:00, 85.88it/s]


In [None]:
{
    "precision": np.array(test_metrics["precision"]).mean(),
    "recall": np.array(test_metrics["recall"]).mean(),
    "ndcg": np.array(test_metrics["ndcg"]).mean(),
}

{'precision': 0.06205790403697159,
 'recall': 0.14575089650712814,
 'ndcg': 0.11939254118955867}

# TaoBao

Fix seed:

In [14]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Read interactions data (filter users with less than 10 transactions, high purchase frequency and one-day users and items with less than 10 transactions). Train dataset - all baskets except the last two, validation dataset - the last but one basket, test dataset - the last basket:

In [15]:
corpus_path = "./data/"
dataset_name = "taobao"

preprocessor = Preprocess(corpus_path, dataset_name)
preprocessor.load_data(10, 10, filt=True)
save_split(corpus_path, dataset_name, preprocessor)

Before preprocessing: #users = 672404, #items = 638962, #clicks = 2015807 (#illegal records = 0)
After preprocessing: #users = 10092, #items = 22286, #clicks = 67991
Saving dataset in ./data//data_taobao/...


In [16]:
corpus = Corpus(corpus_path, dataset_name)
corpus.load_data()

Tune hyperparams on validation dataset:

In [17]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=2,
    topk=10,
    early_stop_num=None
)

train dataset preparing...


100%|██████████| 10092/10092 [00:32<00:00, 305.85it/s]


dev dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 28939.10it/s]


test dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 16245.89it/s]


In [18]:
def objective(trial):
    params = {
        "model": SLRC(
            base_model_class=BPR,
            base_model_config={
                "emb_size": trial.suggest_categorical("emb_size", [32, 64, 128]),
                "user_num": corpus.n_users,
                "item_num": corpus.n_items,
                "click_num": corpus.n_clicks
            },
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128, 256]),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "l2_reg_coef": trial.suggest_loguniform("l2_reg_coef", 1e-4, 1e-1)
    }

    trainer.init_hyperparams(**params)
    trainer.train(evaluation_flg=False)
    metrics = trainer.evaluate(mode="dev")
    score = metrics["ndcg"]
    return score

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=25)

[32m[I 2023-04-21 04:44:33,608][0m A new study created in memory with name: no-name-db970ecf-aabd-4706-bcd7-62daebe89b2b[0m


Epoch 1:


Batch loss = 0.65354: 100%|██████████| 1522/1522 [00:14<00:00, 106.38it/s]



Epoch 2:


Batch loss = 0.394244: 100%|██████████| 1522/1522 [00:16<00:00, 94.53it/s] 





100%|██████████| 9307/9307 [02:59<00:00, 51.94it/s]
[32m[I 2023-04-21 04:48:08,754][0m Trial 0 finished with value: 0.06793688331488305 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.011018509458263562, 'l2_reg_coef': 0.00032161219616553475}. Best is trial 0 with value: 0.06793688331488305.[0m


Epoch 1:


Batch loss = 0.652984: 100%|██████████| 381/381 [00:04<00:00, 78.62it/s]



Epoch 2:


Batch loss = 0.337408: 100%|██████████| 381/381 [00:04<00:00, 86.01it/s]





100%|██████████| 9307/9307 [02:58<00:00, 52.00it/s]
[32m[I 2023-04-21 04:51:17,216][0m Trial 1 finished with value: 0.07072060675888366 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.007709412252614239, 'l2_reg_coef': 0.0007509797119626302}. Best is trial 1 with value: 0.07072060675888366.[0m


Epoch 1:


Batch loss = 0.664989: 100%|██████████| 381/381 [00:04<00:00, 84.90it/s]



Epoch 2:


Batch loss = 0.661288: 100%|██████████| 381/381 [00:04<00:00, 80.71it/s]





100%|██████████| 9307/9307 [02:57<00:00, 52.54it/s]
[32m[I 2023-04-21 04:54:23,645][0m Trial 2 finished with value: 0.07247172830451426 and parameters: {'emb_size': 32, 'batch_size': 128, 'lr': 0.0005445728346977897, 'l2_reg_coef': 0.0071334715811591475}. Best is trial 2 with value: 0.07247172830451426.[0m


Epoch 1:


Batch loss = 0.667229: 100%|██████████| 381/381 [00:04<00:00, 87.34it/s]



Epoch 2:


Batch loss = 0.666966: 100%|██████████| 381/381 [00:05<00:00, 71.55it/s]





100%|██████████| 9307/9307 [02:55<00:00, 52.93it/s]
[32m[I 2023-04-21 04:57:29,277][0m Trial 3 finished with value: 0.07246088640925542 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 2.3005803026525478e-05, 'l2_reg_coef': 0.0007981787657617986}. Best is trial 2 with value: 0.07247172830451426.[0m


Epoch 1:


Batch loss = 0.661801: 100%|██████████| 381/381 [00:04<00:00, 78.50it/s]



Epoch 2:


Batch loss = 0.571593: 100%|██████████| 381/381 [00:04<00:00, 76.66it/s]





100%|██████████| 9307/9307 [02:55<00:00, 53.01it/s]
[32m[I 2023-04-21 05:00:34,802][0m Trial 4 finished with value: 0.05515138070966771 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 0.02675476925825261, 'l2_reg_coef': 0.0011349008421937211}. Best is trial 2 with value: 0.07247172830451426.[0m


Epoch 1:


Batch loss = 0.655495: 100%|██████████| 381/381 [00:05<00:00, 71.74it/s]



Epoch 2:


Batch loss = 0.277625: 100%|██████████| 381/381 [00:04<00:00, 82.74it/s]





100%|██████████| 9307/9307 [02:59<00:00, 51.91it/s]
[32m[I 2023-04-21 05:03:44,205][0m Trial 5 finished with value: 0.06261106175533525 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.01924964086209924, 'l2_reg_coef': 0.0002840900733622218}. Best is trial 2 with value: 0.07247172830451426.[0m


Epoch 1:


Batch loss = 0.785018: 100%|██████████| 761/761 [00:07<00:00, 101.40it/s]



Epoch 2:


Batch loss = 2.272465: 100%|██████████| 761/761 [00:08<00:00, 90.77it/s] 





100%|██████████| 9307/9307 [02:57<00:00, 52.46it/s]
[32m[I 2023-04-21 05:06:57,679][0m Trial 6 finished with value: 0.04484123712660121 and parameters: {'emb_size': 128, 'batch_size': 64, 'lr': 0.040862698316257884, 'l2_reg_coef': 0.00401489180670258}. Best is trial 2 with value: 0.07247172830451426.[0m


Epoch 1:


Batch loss = 0.671875: 100%|██████████| 191/191 [00:03<00:00, 51.86it/s]



Epoch 2:


Batch loss = 0.671842: 100%|██████████| 191/191 [00:03<00:00, 56.56it/s]





100%|██████████| 9307/9307 [02:56<00:00, 52.70it/s]
[32m[I 2023-04-21 05:10:01,424][0m Trial 7 finished with value: 0.07267752272163056 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.4448968183511732e-05, 'l2_reg_coef': 0.0019430167088096818}. Best is trial 7 with value: 0.07267752272163056.[0m


Epoch 1:


Batch loss = 0.648111: 100%|██████████| 761/761 [00:08<00:00, 88.63it/s] 



Epoch 2:


Batch loss = 0.640805: 100%|██████████| 761/761 [00:07<00:00, 99.13it/s]





100%|██████████| 9307/9307 [02:58<00:00, 52.05it/s]
[32m[I 2023-04-21 05:13:16,629][0m Trial 8 finished with value: 0.07200619138735719 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.00046777453460422386, 'l2_reg_coef': 0.006557415353196482}. Best is trial 7 with value: 0.07267752272163056.[0m


Epoch 1:


Batch loss = 0.661261: 100%|██████████| 381/381 [00:04<00:00, 84.19it/s]



Epoch 2:


Batch loss = 0.641369: 100%|██████████| 381/381 [00:05<00:00, 70.44it/s]





100%|██████████| 9307/9307 [02:56<00:00, 52.74it/s]
[32m[I 2023-04-21 05:16:23,160][0m Trial 9 finished with value: 0.0714778532484395 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 0.0016994670732385824, 'l2_reg_coef': 0.0049718848264030035}. Best is trial 7 with value: 0.07267752272163056.[0m


Epoch 1:


Batch loss = 0.67189: 100%|██████████| 191/191 [00:03<00:00, 59.09it/s]



Epoch 2:


Batch loss = 0.671843: 100%|██████████| 191/191 [00:03<00:00, 50.41it/s]





100%|██████████| 9307/9307 [02:56<00:00, 52.61it/s]
[32m[I 2023-04-21 05:19:27,204][0m Trial 10 finished with value: 0.07391601599813025 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.0851391597925009e-05, 'l2_reg_coef': 0.03239377807560215}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.671948: 100%|██████████| 191/191 [00:03<00:00, 58.81it/s]



Epoch 2:


Batch loss = 0.671879: 100%|██████████| 191/191 [00:03<00:00, 50.57it/s]





100%|██████████| 9307/9307 [02:58<00:00, 52.23it/s]
[32m[I 2023-04-21 05:22:32,517][0m Trial 11 finished with value: 0.07291035712102076 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.2124719828634975e-05, 'l2_reg_coef': 0.050550116545435016}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.671902: 100%|██████████| 191/191 [00:03<00:00, 52.19it/s]



Epoch 2:


Batch loss = 0.671658: 100%|██████████| 191/191 [00:03<00:00, 54.56it/s]





100%|██████████| 9307/9307 [02:57<00:00, 52.49it/s]
[32m[I 2023-04-21 05:25:37,086][0m Trial 12 finished with value: 0.07245967463114009 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 5.325451324067815e-05, 'l2_reg_coef': 0.06646300167576691}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.671838: 100%|██████████| 191/191 [00:03<00:00, 53.37it/s]



Epoch 2:


Batch loss = 0.671473: 100%|██████████| 191/191 [00:03<00:00, 54.03it/s]





100%|██████████| 9307/9307 [02:56<00:00, 52.70it/s]
[32m[I 2023-04-21 05:28:40,897][0m Trial 13 finished with value: 0.07246362312614425 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 8.780163174287375e-05, 'l2_reg_coef': 0.07614832519319778}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.671843: 100%|██████████| 191/191 [00:03<00:00, 55.45it/s]



Epoch 2:


Batch loss = 0.671746: 100%|██████████| 191/191 [00:03<00:00, 52.93it/s]





100%|██████████| 9307/9307 [02:57<00:00, 52.42it/s]
[32m[I 2023-04-21 05:31:45,592][0m Trial 14 finished with value: 0.07258960218343756 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.6884474114797857e-05, 'l2_reg_coef': 0.024592541230861575}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.671738: 100%|██████████| 191/191 [00:03<00:00, 54.16it/s]



Epoch 2:


Batch loss = 0.671693: 100%|██████████| 191/191 [00:03<00:00, 53.40it/s]





100%|██████████| 9307/9307 [02:58<00:00, 52.15it/s]
[32m[I 2023-04-21 05:34:51,258][0m Trial 15 finished with value: 0.07281167301012749 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.1271125134170693e-05, 'l2_reg_coef': 0.023931408225856793}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.673716: 100%|██████████| 1522/1522 [00:14<00:00, 102.37it/s]



Epoch 2:


Batch loss = 0.671843: 100%|██████████| 1522/1522 [00:14<00:00, 102.50it/s]





100%|██████████| 9307/9307 [03:00<00:00, 51.59it/s]
[32m[I 2023-04-21 05:38:21,465][0m Trial 16 finished with value: 0.0733891628690314 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 9.644746321799824e-05, 'l2_reg_coef': 0.02063921219236069}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.673535: 100%|██████████| 1522/1522 [00:14<00:00, 102.96it/s]



Epoch 2:


Batch loss = 0.671449: 100%|██████████| 1522/1522 [00:15<00:00, 96.77it/s]





100%|██████████| 9307/9307 [02:57<00:00, 52.46it/s]
[32m[I 2023-04-21 05:41:49,470][0m Trial 17 finished with value: 0.07326286168724744 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.00010823366385418135, 'l2_reg_coef': 0.0138648409537114}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.674508: 100%|██████████| 1522/1522 [00:15<00:00, 98.05it/s] 



Epoch 2:


Batch loss = 0.673692: 100%|██████████| 1522/1522 [00:15<00:00, 100.01it/s]





100%|██████████| 9307/9307 [02:54<00:00, 53.33it/s]
[32m[I 2023-04-21 05:45:14,813][0m Trial 18 finished with value: 0.07256706517163526 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 4.119139118883863e-05, 'l2_reg_coef': 0.013884321021793676}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.673925: 100%|██████████| 1522/1522 [00:15<00:00, 97.10it/s]



Epoch 2:


Batch loss = 0.669578: 100%|██████████| 1522/1522 [00:15<00:00, 97.29it/s]





100%|██████████| 9307/9307 [02:55<00:00, 53.13it/s]
[32m[I 2023-04-21 05:48:41,501][0m Trial 19 finished with value: 0.07250096515202317 and parameters: {'emb_size': 128, 'batch_size': 32, 'lr': 0.00018694272291701352, 'l2_reg_coef': 0.09903906675123257}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.674533: 100%|██████████| 1522/1522 [00:15<00:00, 98.93it/s] 



Epoch 2:


Batch loss = 0.673871: 100%|██████████| 1522/1522 [00:15<00:00, 98.02it/s] 





100%|██████████| 9307/9307 [02:56<00:00, 52.85it/s]
[32m[I 2023-04-21 05:52:08,625][0m Trial 20 finished with value: 0.07284971813337461 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 3.5747956999090266e-05, 'l2_reg_coef': 0.00010894004329170929}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.673048: 100%|██████████| 1522/1522 [00:15<00:00, 97.83it/s]



Epoch 2:


Batch loss = 0.670431: 100%|██████████| 1522/1522 [00:15<00:00, 97.62it/s]





100%|██████████| 9307/9307 [02:55<00:00, 53.13it/s]
[32m[I 2023-04-21 05:55:35,049][0m Trial 21 finished with value: 0.07292137743284234 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.00013761879742568386, 'l2_reg_coef': 0.01331661311312066}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.673802: 100%|██████████| 1522/1522 [00:15<00:00, 99.51it/s]



Epoch 2:


Batch loss = 0.67262: 100%|██████████| 1522/1522 [00:15<00:00, 98.43it/s]





100%|██████████| 9307/9307 [02:54<00:00, 53.19it/s]
[32m[I 2023-04-21 05:59:00,861][0m Trial 22 finished with value: 0.07298298408525486 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 7.256754909172245e-05, 'l2_reg_coef': 0.039715085149017705}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.672652: 100%|██████████| 1522/1522 [00:15<00:00, 96.68it/s] 



Epoch 2:


Batch loss = 0.66935: 100%|██████████| 1522/1522 [00:15<00:00, 97.02it/s]





100%|██████████| 9307/9307 [02:55<00:00, 53.16it/s]
[32m[I 2023-04-21 06:02:27,478][0m Trial 23 finished with value: 0.0729921412743812 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.00018519570382444518, 'l2_reg_coef': 0.03159018570610155}. Best is trial 10 with value: 0.07391601599813025.[0m


Epoch 1:


Batch loss = 0.674537: 100%|██████████| 1522/1522 [00:15<00:00, 98.01it/s] 



Epoch 2:


Batch loss = 0.674023: 100%|██████████| 1522/1522 [00:15<00:00, 98.32it/s] 





100%|██████████| 9307/9307 [02:55<00:00, 52.97it/s]
[32m[I 2023-04-21 06:05:54,297][0m Trial 24 finished with value: 0.07317208512053477 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 3.0351973315936807e-05, 'l2_reg_coef': 0.014726211366418478}. Best is trial 10 with value: 0.07391601599813025.[0m


Test SLRC (calculate scores for different seeds):

In [21]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=20,
    topk=10,
    early_stop_num=3
)

train dataset preparing...


100%|██████████| 10092/10092 [00:37<00:00, 271.23it/s]


dev dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 31350.15it/s]


test dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 27222.16it/s]


In [22]:
test_metrics = {
    "precision": [],
    "recall": [],
    "ndcg": []
}

In [None]:
for seed in range(5):
    print(f"\n___SEED___{seed}")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    params = {
        "model": SLRC(
            base_model_class=BPR,
            base_model_config={
                "emb_size": study.best_params["emb_size"],
                "user_num": corpus.n_users,
                "item_num": corpus.n_items,
                "click_num": corpus.n_clicks
            },
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": study.best_params["batch_size"],
        "lr": study.best_params["lr"],
        "l2_reg_coef": study.best_params["l2_reg_coef"]
    }

    trainer.init_hyperparams(**params)
    trainer.train()

    metrics = trainer.evaluate(mode="test")

    test_metrics["precision"].append(metrics["precision"])
    test_metrics["recall"].append(metrics["recall"])
    test_metrics["ndcg"].append(metrics["ndcg"])


___SEED___0
Epoch 1:


Batch loss = 0.671844: 100%|██████████| 191/191 [00:03<00:00, 54.47it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:54<00:00, 53.33it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.07300048930102894}
Epoch 2:



Batch loss = 0.671798: 100%|██████████| 191/191 [00:03<00:00, 61.59it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:52<00:00, 54.03it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.07296455313152723}
Epoch 3:



Batch loss = 0.671731: 100%|██████████| 191/191 [00:03<00:00, 61.95it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:53<00:00, 53.55it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.07291945580809885}
Epoch 4:



Batch loss = 0.671712: 100%|██████████| 191/191 [00:03<00:00, 52.72it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.64it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.07289013117963543}



100%|██████████| 9307/9307 [02:50<00:00, 54.50it/s]



___SEED___1
Epoch 1:


Batch loss = 0.671883: 100%|██████████| 191/191 [00:03<00:00, 56.83it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.73it/s]


 {'precision': 0.010465241216288815, 'recall': 0.09920310877117582, 'ndcg': 0.0734044488405234}
Epoch 2:



Batch loss = 0.671844: 100%|██████████| 191/191 [00:03<00:00, 53.63it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.83it/s]


 {'precision': 0.010465241216288815, 'recall': 0.09920310877117582, 'ndcg': 0.07334945301938596}
Epoch 3:



Batch loss = 0.671808: 100%|██████████| 191/191 [00:03<00:00, 55.80it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.88it/s]


 {'precision': 0.010465241216288815, 'recall': 0.09920310877117582, 'ndcg': 0.07344369861600566}
Epoch 4:



Batch loss = 0.671758: 100%|██████████| 191/191 [00:03<00:00, 55.67it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.90it/s]


 {'precision': 0.010465241216288815, 'recall': 0.09920310877117582, 'ndcg': 0.07339878776500793}
Epoch 5:



Batch loss = 0.671729: 100%|██████████| 191/191 [00:03<00:00, 55.14it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:48<00:00, 55.15it/s]


 {'precision': 0.010454496615450738, 'recall': 0.09909566276279502, 'ndcg': 0.07330910588535189}
Epoch 6:



Batch loss = 0.671676: 100%|██████████| 191/191 [00:03<00:00, 55.97it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.86it/s]


 {'precision': 0.010454496615450738, 'recall': 0.09909566276279502, 'ndcg': 0.07326640614399639}



100%|██████████| 9307/9307 [02:50<00:00, 54.59it/s]


___SEED___2
Epoch 1:



Batch loss = 0.671854: 100%|██████████| 191/191 [00:03<00:00, 60.35it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.59it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07298222158697909}
Epoch 2:



Batch loss = 0.671842: 100%|██████████| 191/191 [00:03<00:00, 62.80it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.58it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07303087620925015}





Epoch 3:


Batch loss = 0.671784: 100%|██████████| 191/191 [00:03<00:00, 57.40it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.62it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07306332070346207}
Epoch 4:



Batch loss = 0.671717: 100%|██████████| 191/191 [00:03<00:00, 55.29it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.88it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07309419630228346}
Epoch 5:



Batch loss = 0.671687: 100%|██████████| 191/191 [00:03<00:00, 57.33it/s]



Evaluation (dev):


100%|██████████| 9307/9307 [02:49<00:00, 54.84it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07313553776834017}
Epoch 6:



Batch loss = 0.671648: 100%|██████████| 191/191 [00:03<00:00, 55.04it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.92it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07316701077021953}
Epoch 7:



Batch loss = 0.671609: 100%|██████████| 191/191 [00:03<00:00, 55.90it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:47<00:00, 55.53it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07310306323887776}
Epoch 8:



Batch loss = 0.671517: 100%|██████████| 191/191 [00:02<00:00, 67.91it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.67it/s]


 {'precision': 0.010400773611260341, 'recall': 0.09861215572508147, 'ndcg': 0.07308100387460682}
Epoch 9:



Batch loss = 0.671517: 100%|██████████| 191/191 [00:03<00:00, 60.84it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:51<00:00, 54.34it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09871960173346227, 'ndcg': 0.07311995343904591}



100%|██████████| 9307/9307 [02:51<00:00, 54.22it/s]


___SEED___3
Epoch 1:



Batch loss = 0.671816: 100%|██████████| 191/191 [00:03<00:00, 56.47it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.89it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09898821675441424, 'ndcg': 0.07250471029556521}
Epoch 2:



Batch loss = 0.671804: 100%|██████████| 191/191 [00:03<00:00, 57.34it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.81it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.07245213502729231}
Epoch 3:



Batch loss = 0.671757: 100%|██████████| 191/191 [00:03<00:00, 58.51it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.89it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.07246064326169149}
Epoch 4:



Batch loss = 0.671701: 100%|██████████| 191/191 [00:03<00:00, 55.10it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.47it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.07247970196392889}



100%|██████████| 9307/9307 [02:52<00:00, 53.93it/s]


___SEED___4
Epoch 1:



Batch loss = 0.671825: 100%|██████████| 191/191 [00:03<00:00, 52.62it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:50<00:00, 54.59it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09882704774184306, 'ndcg': 0.07251693470541855}
Epoch 2:



Batch loss = 0.671778: 100%|██████████| 191/191 [00:03<00:00, 50.48it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.75it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09882704774184306, 'ndcg': 0.0725007522274374}
Epoch 3:



Batch loss = 0.671715: 100%|██████████| 191/191 [00:03<00:00, 50.53it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:48<00:00, 55.12it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09882704774184306, 'ndcg': 0.07254511603932819}
Epoch 4:



Batch loss = 0.67168: 100%|██████████| 191/191 [00:03<00:00, 53.63it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.95it/s]



 {'precision': 0.010433007413774578, 'recall': 0.09888077074603345, 'ndcg': 0.07259953111472374}
Epoch 5:


Batch loss = 0.671653: 100%|██████████| 191/191 [00:03<00:00, 54.40it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.89it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09888077074603345, 'ndcg': 0.072574361746052}
Epoch 6:



Batch loss = 0.671547: 100%|██████████| 191/191 [00:03<00:00, 54.79it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.85it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09882704774184306, 'ndcg': 0.0726587889249773}
Epoch 7:



Batch loss = 0.671561: 100%|██████████| 191/191 [00:03<00:00, 53.85it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.85it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09893449375022384, 'ndcg': 0.07263141609543815}
Epoch 8:



Batch loss = 0.671459: 100%|██████████| 191/191 [00:03<00:00, 52.31it/s]



Evaluation (dev):


100%|██████████| 9307/9307 [02:49<00:00, 54.76it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09893449375022384, 'ndcg': 0.07260209661604644}
Epoch 9:



Batch loss = 0.671451: 100%|██████████| 191/191 [00:03<00:00, 49.65it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:49<00:00, 54.90it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09893449375022384, 'ndcg': 0.07252946776558236}



100%|██████████| 9307/9307 [02:51<00:00, 54.34it/s]


In [None]:
{
    "precision": np.array(test_metrics["precision"]).mean(),
    "recall": np.array(test_metrics["recall"]).mean(),
    "ndcg": np.array(test_metrics["ndcg"]).mean(),
}

{'precision': 0.011578381863113787,
 'recall': 0.11183338705633752,
 'ndcg': 0.08009273036056093}

# Dunnhumby

Fix seed:

In [24]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Read interactions data (filter users with less than 5 transactions, high purchase frequency and one-day users and items with less than 10 transactions). Train dataset - all baskets except the last two, validation dataset - the last but one basket, test dataset - the last basket:

In [25]:
corpus_path = "./data/"
dataset_name = "dunnhumby"

preprocessor = Preprocess(corpus_path, dataset_name)
preprocessor.load_data(5, 10, filt=True)
save_split(corpus_path, dataset_name, preprocessor)

Before preprocessing: #users = 2500, #items = 92339, #clicks = 2595370 (#illegal records = 0)
After preprocessing: #users = 2358, #items = 26756, #clicks = 1976796
Saving dataset in ./data//data_dunnhumby/...


In [26]:
corpus = Corpus(corpus_path, dataset_name)
corpus.load_data()

Tune hyperparams on validation dataset:

In [27]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=1,
    topk=10,
    early_stop_num=None
)

train dataset preparing...


100%|██████████| 2358/2358 [00:08<00:00, 289.10it/s]


dev dataset preparing...


100%|██████████| 2357/2357 [00:12<00:00, 190.26it/s]


test dataset preparing...


100%|██████████| 2357/2357 [00:11<00:00, 197.05it/s]


In [28]:
def objective(trial):
    params = {
        "model": SLRC(
            base_model_class=BPR,
            base_model_config={
                "emb_size": trial.suggest_categorical("emb_size", [32, 64, 128]),
                "user_num": corpus.n_users,
                "item_num": corpus.n_items,
                "click_num": corpus.n_clicks
            },
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128, 256]),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "l2_reg_coef": trial.suggest_loguniform("l2_reg_coef", 1e-4, 1e-1)
    }

    trainer.init_hyperparams(**params)
    trainer.train(evaluation_flg=False)
    metrics = trainer.evaluate(mode="dev")
    score = metrics["ndcg"]
    return score

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=25)

[32m[I 2023-04-22 23:34:33,097][0m A new study created in memory with name: no-name-d4585873-f78e-43eb-af1f-f37c35d96eb3[0m


Epoch 1:


Batch loss = 0.447194: 100%|██████████| 60237/60237 [29:10<00:00, 34.41it/s]





100%|██████████| 2357/2357 [03:43<00:00, 10.56it/s]
[32m[I 2023-04-23 00:07:26,805][0m Trial 0 finished with value: 0.15148375345370244 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.011018509458263562, 'l2_reg_coef': 0.00032161219616553475}. Best is trial 0 with value: 0.15148375345370244.[0m


Epoch 1:


Batch loss = 0.405301: 100%|██████████| 15060/15060 [13:51<00:00, 18.11it/s]





100%|██████████| 2357/2357 [04:00<00:00,  9.78it/s]
[32m[I 2023-04-23 00:25:19,450][0m Trial 1 finished with value: 0.1558336896053822 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.007709412252614239, 'l2_reg_coef': 0.0007509797119626302}. Best is trial 1 with value: 0.1558336896053822.[0m


Epoch 1:


Batch loss = 0.462205: 100%|██████████| 15060/15060 [06:53<00:00, 36.42it/s]





100%|██████████| 2357/2357 [03:41<00:00, 10.64it/s]
[32m[I 2023-04-23 00:35:54,530][0m Trial 2 finished with value: 0.16878730766837877 and parameters: {'emb_size': 32, 'batch_size': 128, 'lr': 0.0005445728346977897, 'l2_reg_coef': 0.0071334715811591475}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.554042: 100%|██████████| 15060/15060 [09:17<00:00, 27.02it/s]





100%|██████████| 2357/2357 [03:43<00:00, 10.55it/s]
[32m[I 2023-04-23 00:48:55,414][0m Trial 3 finished with value: 0.16731909583052043 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 2.3005803026525478e-05, 'l2_reg_coef': 0.0007981787657617986}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.584676: 100%|██████████| 15060/15060 [09:23<00:00, 26.73it/s]





100%|██████████| 2357/2357 [03:42<00:00, 10.57it/s]
[32m[I 2023-04-23 01:02:01,856][0m Trial 4 finished with value: 0.14338914693825852 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 0.02675476925825261, 'l2_reg_coef': 0.0011349008421937211}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.389378: 100%|██████████| 15060/15060 [14:04<00:00, 17.83it/s]





100%|██████████| 2357/2357 [03:53<00:00, 10.11it/s]
[32m[I 2023-04-23 01:19:59,580][0m Trial 5 finished with value: 0.14527801081160163 and parameters: {'emb_size': 128, 'batch_size': 128, 'lr': 0.01924964086209924, 'l2_reg_coef': 0.0002840900733622218}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 4.967166: 100%|██████████| 30119/30119 [32:20<00:00, 15.52it/s]





100%|██████████| 2357/2357 [04:05<00:00,  9.62it/s]
[32m[I 2023-04-23 01:56:25,059][0m Trial 6 finished with value: 0.14191140944288214 and parameters: {'emb_size': 128, 'batch_size': 64, 'lr': 0.040862698316257884, 'l2_reg_coef': 0.00401489180670258}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.643648: 100%|██████████| 7530/7530 [04:21<00:00, 28.80it/s]





100%|██████████| 2357/2357 [03:38<00:00, 10.81it/s]
[32m[I 2023-04-23 02:04:24,731][0m Trial 7 finished with value: 0.16546047357211222 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 1.4448968183511732e-05, 'l2_reg_coef': 0.0019430167088096818}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.469989: 100%|██████████| 30119/30119 [18:16<00:00, 27.48it/s]





100%|██████████| 2357/2357 [03:49<00:00, 10.29it/s]
[32m[I 2023-04-23 02:26:30,058][0m Trial 8 finished with value: 0.16719920643189984 and parameters: {'emb_size': 64, 'batch_size': 64, 'lr': 0.00046777453460422386, 'l2_reg_coef': 0.006557415353196482}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.430545: 100%|██████████| 15060/15060 [09:16<00:00, 27.08it/s]





100%|██████████| 2357/2357 [03:46<00:00, 10.43it/s]
[32m[I 2023-04-23 02:39:32,329][0m Trial 9 finished with value: 0.16323411407951588 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 0.0016994670732385824, 'l2_reg_coef': 0.0049718848264030035}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.375644: 100%|██████████| 60237/60237 [27:37<00:00, 36.34it/s]





100%|██████████| 2357/2357 [03:42<00:00, 10.61it/s]
[32m[I 2023-04-23 03:10:52,373][0m Trial 10 finished with value: 0.16608992741008688 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.0003497063710293824, 'l2_reg_coef': 0.03295582896915084}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.540239: 100%|██████████| 15060/15060 [09:16<00:00, 27.05it/s]





100%|██████████| 2357/2357 [03:42<00:00, 10.60it/s]
[32m[I 2023-04-23 03:23:51,498][0m Trial 11 finished with value: 0.16695798274744186 and parameters: {'emb_size': 64, 'batch_size': 128, 'lr': 3.998493073453524e-05, 'l2_reg_coef': 0.017614887606723105}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.625493: 100%|██████████| 7530/7530 [04:10<00:00, 30.03it/s]





100%|██████████| 2357/2357 [03:39<00:00, 10.73it/s]
[32m[I 2023-04-23 03:31:42,033][0m Trial 12 finished with value: 0.16735885449804452 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 8.841602602084763e-05, 'l2_reg_coef': 0.00013387911588679}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.617606: 100%|██████████| 7530/7530 [04:13<00:00, 29.72it/s]





100%|██████████| 2357/2357 [03:36<00:00, 10.90it/s]
[32m[I 2023-04-23 03:39:31,826][0m Trial 13 finished with value: 0.16737866189009154 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.00013084027307985277, 'l2_reg_coef': 0.00010440483218503582}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.616269: 100%|██████████| 7530/7530 [04:09<00:00, 30.17it/s]





100%|██████████| 2357/2357 [03:38<00:00, 10.80it/s]
[32m[I 2023-04-23 03:47:19,697][0m Trial 14 finished with value: 0.1674056462220308 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.00014092997336747555, 'l2_reg_coef': 0.06753264773209731}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.562451: 100%|██████████| 7530/7530 [04:11<00:00, 29.90it/s]





100%|██████████| 2357/2357 [03:39<00:00, 10.73it/s]
[32m[I 2023-04-23 03:55:11,363][0m Trial 15 finished with value: 0.16833791091480246 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.0014371182317804566, 'l2_reg_coef': 0.07090946879725188}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.532959: 100%|██████████| 7530/7530 [04:13<00:00, 29.72it/s]





100%|██████████| 2357/2357 [03:41<00:00, 10.65it/s]
[32m[I 2023-04-23 04:03:06,014][0m Trial 16 finished with value: 0.1632992402609579 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.002805035462037658, 'l2_reg_coef': 0.01129236775747917}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.370375: 100%|██████████| 60237/60237 [27:22<00:00, 36.67it/s]





100%|██████████| 2357/2357 [03:44<00:00, 10.48it/s]
[32m[I 2023-04-23 04:34:13,456][0m Trial 17 finished with value: 0.1638806210410126 and parameters: {'emb_size': 32, 'batch_size': 32, 'lr': 0.0006770623196702809, 'l2_reg_coef': 0.09431328085393048}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 6.244652: 100%|██████████| 30119/30119 [14:08<00:00, 35.49it/s]





100%|██████████| 2357/2357 [03:41<00:00, 10.64it/s]
[32m[I 2023-04-23 04:52:03,890][0m Trial 18 finished with value: 0.13428536236601596 and parameters: {'emb_size': 32, 'batch_size': 64, 'lr': 0.09964990820951232, 'l2_reg_coef': 0.021085749388527922}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.54951: 100%|██████████| 7530/7530 [04:12<00:00, 29.86it/s]





100%|██████████| 2357/2357 [03:40<00:00, 10.70it/s]
[32m[I 2023-04-23 04:59:56,463][0m Trial 19 finished with value: 0.16426691437088156 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.00345980916161105, 'l2_reg_coef': 0.03730405760602197}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.552496: 100%|██████████| 7530/7530 [07:38<00:00, 16.42it/s]





100%|██████████| 2357/2357 [03:57<00:00,  9.93it/s]
[32m[I 2023-04-23 05:11:32,733][0m Trial 20 finished with value: 0.16321656711458246 and parameters: {'emb_size': 128, 'batch_size': 256, 'lr': 0.001169238449452281, 'l2_reg_coef': 0.007619546957676205}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.601636: 100%|██████████| 7530/7530 [04:13<00:00, 29.73it/s]





100%|██████████| 2357/2357 [03:37<00:00, 10.82it/s]
[32m[I 2023-04-23 05:19:23,871][0m Trial 21 finished with value: 0.16877969040597376 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.0002461103677396194, 'l2_reg_coef': 0.09676463125684026}. Best is trial 2 with value: 0.16878730766837877.[0m


Epoch 1:


Batch loss = 0.586116: 100%|██████████| 7530/7530 [04:13<00:00, 29.69it/s]





100%|██████████| 2357/2357 [03:39<00:00, 10.75it/s]
[32m[I 2023-04-23 05:27:16,795][0m Trial 22 finished with value: 0.16899702698754962 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.0004278422616204465, 'l2_reg_coef': 0.05650593239008871}. Best is trial 22 with value: 0.16899702698754962.[0m


Epoch 1:


Batch loss = 0.601167: 100%|██████████| 7530/7530 [04:11<00:00, 29.90it/s]





100%|██████████| 2357/2357 [03:39<00:00, 10.72it/s]
[32m[I 2023-04-23 05:35:08,729][0m Trial 23 finished with value: 0.16857128960145432 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.0002504075415812634, 'l2_reg_coef': 0.04256733701472282}. Best is trial 22 with value: 0.16899702698754962.[0m


Epoch 1:


Batch loss = 0.575684: 100%|██████████| 7530/7530 [04:21<00:00, 28.83it/s]





100%|██████████| 2357/2357 [03:43<00:00, 10.56it/s]
[32m[I 2023-04-23 05:43:13,211][0m Trial 24 finished with value: 0.1695119713150707 and parameters: {'emb_size': 32, 'batch_size': 256, 'lr': 0.0006366285017414498, 'l2_reg_coef': 0.09063752099202302}. Best is trial 24 with value: 0.1695119713150707.[0m


Test SLRC (calculate scores for different seeds):

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=20,
    topk=10,
    early_stop_num=3
)

train dataset preparing...


100%|██████████| 2358/2358 [00:14<00:00, 163.95it/s]


dev dataset preparing...


100%|██████████| 2357/2357 [00:15<00:00, 153.46it/s]


test dataset preparing...


100%|██████████| 2357/2357 [00:15<00:00, 152.40it/s]


In [None]:
test_metrics = {
    "precision": [],
    "recall": [],
    "ndcg": []
}

In [None]:
for seed in range(5):
    print(f"\n___SEED___{seed}")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    params = {
        "model": SLRC(
            base_model_class=BPR,
            base_model_config={
                "emb_size": study.best_params["emb_size"],
                "user_num": corpus.n_users,
                "item_num": corpus.n_items,
                "click_num": corpus.n_clicks
            },
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": study.best_params["batch_size"],
        "lr": study.best_params["lr"],
        "l2_reg_coef": study.best_params["l2_reg_coef"]
    }

    trainer.init_hyperparams(**params)
    trainer.train()

    metrics = trainer.evaluate(mode="test")

    test_metrics["precision"].append(metrics["precision"])
    test_metrics["recall"].append(metrics["recall"])
    test_metrics["ndcg"].append(metrics["ndcg"])


___SEED___0
Epoch 1:


Batch loss = 0.57551: 100%|██████████| 7530/7530 [04:19<00:00, 29.06it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:19<00:00, 11.79it/s]


 {'precision': 0.118879932117098, 'recall': 0.18208650031932486, 'ndcg': 0.16937464958948809}
Epoch 2:



Batch loss = 0.532531: 100%|██████████| 7530/7530 [04:16<00:00, 29.31it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:19<00:00, 11.80it/s]


 {'precision': 0.11726771319473907, 'recall': 0.182605322260548, 'ndcg': 0.16941483380978967}
Epoch 3:



Batch loss = 0.482401: 100%|██████████| 7530/7530 [04:18<00:00, 29.17it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.71it/s]


 {'precision': 0.11697072549851507, 'recall': 0.18104452768462156, 'ndcg': 0.1678905745748544}
Epoch 4:



Batch loss = 0.478594: 100%|██████████| 7530/7530 [04:20<00:00, 28.95it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:19<00:00, 11.80it/s]


 {'precision': 0.11646160373355961, 'recall': 0.18022523522667563, 'ndcg': 0.16613224587069464}
Epoch 5:



Batch loss = 0.449429: 100%|██████████| 7530/7530 [04:15<00:00, 29.42it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:20<00:00, 11.74it/s]



 {'precision': 0.11641917691981331, 'recall': 0.17985273614361144, 'ndcg': 0.16548402706566653}


100%|██████████| 2357/2357 [03:21<00:00, 11.67it/s]



___SEED___1
Epoch 1:


Batch loss = 0.574728: 100%|██████████| 7530/7530 [04:14<00:00, 29.63it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.69it/s]


 {'precision': 0.11909206618582946, 'recall': 0.1820022107620056, 'ndcg': 0.16958347360965687}
Epoch 2:



Batch loss = 0.52777: 100%|██████████| 7530/7530 [04:21<00:00, 28.82it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.72it/s]


 {'precision': 0.11739499363597795, 'recall': 0.1820814317295862, 'ndcg': 0.1693973975085176}
Epoch 3:



Batch loss = 0.487147: 100%|██████████| 7530/7530 [04:17<00:00, 29.23it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.67it/s]


 {'precision': 0.11692829868476878, 'recall': 0.1802864774089352, 'ndcg': 0.167748478666449}
Epoch 4:



Batch loss = 0.467395: 100%|██████████| 7530/7530 [04:19<00:00, 28.97it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.70it/s]



 {'precision': 0.11629189647857446, 'recall': 0.17995990701576256, 'ndcg': 0.16594597091322003}


100%|██████████| 2357/2357 [03:22<00:00, 11.64it/s]



___SEED___2
Epoch 1:


Batch loss = 0.575429: 100%|██████████| 7530/7530 [04:12<00:00, 29.76it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:19<00:00, 11.84it/s]


 {'precision': 0.1189223589308443, 'recall': 0.18175430419669467, 'ndcg': 0.16961052771533622}





Epoch 2:


Batch loss = 0.535303: 100%|██████████| 7530/7530 [04:17<00:00, 29.23it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:20<00:00, 11.74it/s]


 {'precision': 0.11714043275350022, 'recall': 0.18195169099661504, 'ndcg': 0.16933952001306773}
Epoch 3:



Batch loss = 0.484216: 100%|██████████| 7530/7530 [04:19<00:00, 29.04it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:20<00:00, 11.77it/s]


 {'precision': 0.11667373780229105, 'recall': 0.18048165869995278, 'ndcg': 0.1673737234889153}
Epoch 4:



Batch loss = 0.46151: 100%|██████████| 7530/7530 [04:23<00:00, 28.57it/s]



Evaluation (dev):


100%|██████████| 2357/2357 [03:20<00:00, 11.74it/s]


 {'precision': 0.11646160373355961, 'recall': 0.18014316209181905, 'ndcg': 0.16597490172741558}



100%|██████████| 2357/2357 [03:23<00:00, 11.61it/s]


___SEED___3
Epoch 1:



Batch loss = 0.575868: 100%|██████████| 7530/7530 [04:10<00:00, 30.12it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:20<00:00, 11.77it/s]


 {'precision': 0.1189647857445906, 'recall': 0.1820888415025857, 'ndcg': 0.1694849567757149}
Epoch 2:



Batch loss = 0.53325: 100%|██████████| 7530/7530 [04:16<00:00, 29.30it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:20<00:00, 11.74it/s]



 {'precision': 0.11726771319473907, 'recall': 0.18168297205200976, 'ndcg': 0.16921977732694268}
Epoch 3:


Batch loss = 0.489312: 100%|██████████| 7530/7530 [04:16<00:00, 29.37it/s]



Evaluation (dev):


100%|██████████| 2357/2357 [03:18<00:00, 11.88it/s]


 {'precision': 0.11697072549851506, 'recall': 0.18080963720724855, 'ndcg': 0.16766592007451925}
Epoch 4:



Batch loss = 0.47149: 100%|██████████| 7530/7530 [04:15<00:00, 29.46it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:18<00:00, 11.86it/s]


 {'precision': 0.11633432329232077, 'recall': 0.18054860062257394, 'ndcg': 0.1663662974681971}



100%|██████████| 2357/2357 [03:22<00:00, 11.62it/s]


___SEED___4
Epoch 1:



Batch loss = 0.575093: 100%|██████████| 7530/7530 [04:14<00:00, 29.57it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:19<00:00, 11.81it/s]


 {'precision': 0.11904963937208317, 'recall': 0.1818893610416039, 'ndcg': 0.1694487064721814}
Epoch 2:



Batch loss = 0.52849: 100%|██████████| 7530/7530 [04:20<00:00, 28.93it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.70it/s]


 {'precision': 0.11743742044972423, 'recall': 0.1816663917449361, 'ndcg': 0.16937594774304177}
Epoch 3:



Batch loss = 0.492156: 100%|██████████| 7530/7530 [04:18<00:00, 29.11it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:20<00:00, 11.75it/s]


 {'precision': 0.11701315231226135, 'recall': 0.18080669690286144, 'ndcg': 0.16788002873202354}
Epoch 4:



Batch loss = 0.449101: 100%|██████████| 7530/7530 [04:21<00:00, 28.81it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [03:21<00:00, 11.68it/s]


 {'precision': 0.11658888417479846, 'recall': 0.1798352159912159, 'ndcg': 0.16606146931135962}



100%|██████████| 2357/2357 [03:23<00:00, 11.61it/s]


In [None]:
{
    "precision": np.array(test_metrics["precision"]).mean(),
    "recall": np.array(test_metrics["recall"]).mean(),
    "ndcg": np.array(test_metrics["ndcg"]).mean(),
}

{'precision': 0.11917691981332204,
 'recall': 0.1727385625823469,
 'ndcg': 0.16745920832459504}