In [None]:
import os
import sys
sys.path.append("..")
from nbr.preparation import Preprocess, save_split, Corpus
from nbr.trainer import NBRTrainer
from nbr.model import RepurchaseModule
import torch
import random
import numpy as np
import optuna
import warnings
warnings.filterwarnings("ignore")

# TaFeng

Fix seed:

In [None]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Read interactions data (filter users with less than 5 transactions, high purchase frequency and one-day users and items with less than 10 transactions). Train dataset - all baskets except the last two, validation dataset - the last but one basket, test dataset - the last basket:

In [None]:
corpus_path = "./data/"
dataset_name = "ta_feng"

preprocessor = Preprocess(corpus_path, dataset_name)
preprocessor.load_data(5, 10, filt=True)
save_split(corpus_path, dataset_name, preprocessor)

Before preprocessing: #users = 32266, #items = 23812, #clicks = 817741 (#illegal records = 0)
After preprocessing: #users = 7358, #items = 11202, #clicks = 368951
Saving dataset in ./data//data_ta_feng/...


In [None]:
corpus = Corpus(corpus_path, dataset_name)
corpus.load_data()

Tune hyperparams on validation dataset:

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=3,
    topk=10,
    early_stop_num=None
)

train dataset preparing...


100%|██████████| 7358/7358 [00:12<00:00, 608.62it/s]


dev dataset preparing...


100%|██████████| 7357/7357 [00:01<00:00, 4691.81it/s]


test dataset preparing...


100%|██████████| 7357/7357 [00:01<00:00, 4149.20it/s]


In [None]:
def objective(trial):
    params = {
        "model": RepurchaseModule(
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128, 256]),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "l2_reg_coef": trial.suggest_loguniform("l2_reg_coef", 1e-4, 1e-1)
    }

    trainer.init_hyperparams(**params)
    trainer.train(evaluation_flg=False)
    metrics = trainer.evaluate(mode="dev")
    score = metrics["ndcg"]
    return score

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=25)

[32m[I 2023-05-02 04:47:18,891][0m A new study created in memory with name: no-name-c685093d-8d59-4fd3-b0c4-a2a7c6442905[0m


Epoch 1:


Batch loss = 0.647869: 100%|██████████| 8889/8889 [01:12<00:00, 122.22it/s]



Epoch 2:


Batch loss = 0.635774: 100%|██████████| 8889/8889 [01:09<00:00, 127.03it/s]



Epoch 3:


Batch loss = 0.630067: 100%|██████████| 8889/8889 [01:11<00:00, 125.13it/s]





100%|██████████| 7357/7357 [01:23<00:00, 87.84it/s] 
[32m[I 2023-05-02 04:52:22,060][0m Trial 0 finished with value: 0.09700288178218114 and parameters: {'batch_size': 32, 'lr': 0.0009863431872330064, 'l2_reg_coef': 0.0004724870791526793}. Best is trial 0 with value: 0.09700288178218114.[0m


Epoch 1:


Batch loss = 0.651307: 100%|██████████| 4445/4445 [00:40<00:00, 110.59it/s]


Epoch 2:



Batch loss = 0.648101: 100%|██████████| 4445/4445 [00:40<00:00, 110.88it/s]


Epoch 3:



Batch loss = 0.649857: 100%|██████████| 4445/4445 [00:40<00:00, 110.19it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.21it/s] 
[32m[I 2023-05-02 04:55:46,136][0m Trial 1 finished with value: 0.08918967584270882 and parameters: {'batch_size': 64, 'lr': 0.005513651007120869, 'l2_reg_coef': 0.07247363402746428}. Best is trial 0 with value: 0.09700288178218114.[0m


Epoch 1:


Batch loss = 0.671587: 100%|██████████| 2223/2223 [00:24<00:00, 89.45it/s]


Epoch 2:



Batch loss = 0.670729: 100%|██████████| 2223/2223 [00:24<00:00, 89.73it/s]


Epoch 3:



Batch loss = 0.669983: 100%|██████████| 2223/2223 [00:24<00:00, 91.08it/s]





100%|██████████| 7357/7357 [01:25<00:00, 86.12it/s] 
[32m[I 2023-05-02 04:58:25,639][0m Trial 2 finished with value: 0.08992751763914275 and parameters: {'batch_size': 128, 'lr': 0.007709412252614239, 'l2_reg_coef': 0.0007509797119626302}. Best is trial 0 with value: 0.09700288178218114.[0m


Epoch 1:


Batch loss = 0.65615: 100%|██████████| 8889/8889 [01:10<00:00, 125.57it/s]


Epoch 2:



Batch loss = 0.652119: 100%|██████████| 8889/8889 [01:10<00:00, 125.27it/s]


Epoch 3:



Batch loss = 0.647918: 100%|██████████| 8889/8889 [01:10<00:00, 125.36it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.27it/s] 
[32m[I 2023-05-02 05:03:21,690][0m Trial 3 finished with value: 0.10049843133741301 and parameters: {'batch_size': 32, 'lr': 0.0003114318604656744, 'l2_reg_coef': 0.010529332619998283}. Best is trial 3 with value: 0.10049843133741301.[0m


Epoch 1:


Batch loss = 0.675157: 100%|██████████| 2223/2223 [00:24<00:00, 89.39it/s] 


Epoch 2:



Batch loss = 0.671454: 100%|██████████| 2223/2223 [00:25<00:00, 88.75it/s]


Epoch 3:



Batch loss = 0.671184: 100%|██████████| 2223/2223 [00:24<00:00, 92.10it/s]





100%|██████████| 7357/7357 [01:25<00:00, 86.33it/s] 
[32m[I 2023-05-02 05:06:01,020][0m Trial 4 finished with value: 0.09235173391966142 and parameters: {'batch_size': 128, 'lr': 0.003995661855958764, 'l2_reg_coef': 0.006355019100735405}. Best is trial 3 with value: 0.10049843133741301.[0m


Epoch 1:


Batch loss = 0.682588: 100%|██████████| 2223/2223 [00:25<00:00, 88.34it/s]


Epoch 2:



Batch loss = 0.682546: 100%|██████████| 2223/2223 [00:24<00:00, 89.89it/s]


Epoch 3:



Batch loss = 0.682504: 100%|██████████| 2223/2223 [00:24<00:00, 91.17it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.22it/s] 
[32m[I 2023-05-02 05:08:38,746][0m Trial 5 finished with value: 0.10154978807087144 and parameters: {'batch_size': 128, 'lr': 2.3005803026525478e-05, 'l2_reg_coef': 0.0007981787657617986}. Best is trial 5 with value: 0.10154978807087144.[0m


Epoch 1:


Batch loss = 0.66428: 100%|██████████| 4445/4445 [00:40<00:00, 109.13it/s]



Epoch 2:


Batch loss = 0.656066: 100%|██████████| 4445/4445 [00:41<00:00, 106.59it/s]


Epoch 3:



Batch loss = 0.652434: 100%|██████████| 4445/4445 [00:40<00:00, 108.79it/s]





100%|██████████| 7357/7357 [01:23<00:00, 87.64it/s] 
[32m[I 2023-05-02 05:12:06,031][0m Trial 6 finished with value: 0.09602259832546668 and parameters: {'batch_size': 64, 'lr': 0.0015500461319089488, 'l2_reg_coef': 0.028698618231059004}. Best is trial 5 with value: 0.10154978807087144.[0m


Epoch 1:


Batch loss = 0.67152: 100%|██████████| 4445/4445 [00:41<00:00, 107.54it/s]


Epoch 2:



Batch loss = 0.670705: 100%|██████████| 4445/4445 [00:41<00:00, 107.99it/s]


Epoch 3:



Batch loss = 0.669868: 100%|██████████| 4445/4445 [00:41<00:00, 105.97it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.34it/s] 
[32m[I 2023-05-02 05:15:33,799][0m Trial 7 finished with value: 0.10259593364601936 and parameters: {'batch_size': 64, 'lr': 0.00015270273918806623, 'l2_reg_coef': 0.04485485309869719}. Best is trial 7 with value: 0.10259593364601936.[0m


Epoch 1:


Batch loss = 0.679854: 100%|██████████| 2223/2223 [00:25<00:00, 87.10it/s]


Epoch 2:



Batch loss = 0.685095: 100%|██████████| 2223/2223 [00:25<00:00, 87.30it/s]


Epoch 3:



Batch loss = 0.684987: 100%|██████████| 2223/2223 [00:24<00:00, 90.52it/s]





100%|██████████| 7357/7357 [01:27<00:00, 84.56it/s]
[32m[I 2023-05-02 05:18:16,398][0m Trial 8 finished with value: 0.08046695467579876 and parameters: {'batch_size': 128, 'lr': 0.01924964086209924, 'l2_reg_coef': 0.0002840900733622218}. Best is trial 7 with value: 0.10259593364601936.[0m


Epoch 1:


Batch loss = 0.681195: 100%|██████████| 2223/2223 [00:24<00:00, 89.29it/s]


Epoch 2:



Batch loss = 0.68091: 100%|██████████| 2223/2223 [00:25<00:00, 86.83it/s]


Epoch 3:



Batch loss = 0.679899: 100%|██████████| 2223/2223 [00:25<00:00, 86.97it/s]





100%|██████████| 7357/7357 [01:22<00:00, 88.94it/s]
[32m[I 2023-05-02 05:20:55,226][0m Trial 9 finished with value: 0.0798890797295511 and parameters: {'batch_size': 128, 'lr': 0.02016003934561091, 'l2_reg_coef': 0.0005677045861147308}. Best is trial 7 with value: 0.10259593364601936.[0m


Epoch 1:


Batch loss = 0.686229: 100%|██████████| 1112/1112 [00:17<00:00, 65.24it/s]


Epoch 2:



Batch loss = 0.692095: 100%|██████████| 1112/1112 [00:18<00:00, 60.12it/s]


Epoch 3:



Batch loss = 0.691905: 100%|██████████| 1112/1112 [00:17<00:00, 64.89it/s]





100%|██████████| 7357/7357 [01:23<00:00, 87.74it/s] 
[32m[I 2023-05-02 05:23:11,820][0m Trial 10 finished with value: 0.03936337581593457 and parameters: {'batch_size': 256, 'lr': 0.08097836606986637, 'l2_reg_coef': 0.09111479007859723}. Best is trial 7 with value: 0.10259593364601936.[0m


Epoch 1:


Batch loss = 0.67207: 100%|██████████| 4445/4445 [00:41<00:00, 105.98it/s]


Epoch 2:



Batch loss = 0.671985: 100%|██████████| 4445/4445 [00:41<00:00, 106.54it/s]



Epoch 3:


Batch loss = 0.6719: 100%|██████████| 4445/4445 [00:41<00:00, 106.85it/s]





100%|██████████| 7357/7357 [01:26<00:00, 84.70it/s]
[32m[I 2023-05-02 05:26:44,005][0m Trial 11 finished with value: 0.10142749646515184 and parameters: {'batch_size': 64, 'lr': 1.6462341078166405e-05, 'l2_reg_coef': 0.00213318999253824}. Best is trial 7 with value: 0.10259593364601936.[0m


Epoch 1:


Batch loss = 0.676577: 100%|██████████| 1112/1112 [00:17<00:00, 61.96it/s]


Epoch 2:



Batch loss = 0.676512: 100%|██████████| 1112/1112 [00:17<00:00, 63.85it/s]


Epoch 3:



Batch loss = 0.676447: 100%|██████████| 1112/1112 [00:17<00:00, 63.22it/s]





100%|██████████| 7357/7357 [01:25<00:00, 86.23it/s] 
[32m[I 2023-05-02 05:29:02,345][0m Trial 12 finished with value: 0.10156086183289299 and parameters: {'batch_size': 256, 'lr': 3.5825920825139016e-05, 'l2_reg_coef': 0.0018995211552374786}. Best is trial 7 with value: 0.10259593364601936.[0m


Epoch 1:


Batch loss = 0.676458: 100%|██████████| 1112/1112 [00:18<00:00, 58.78it/s]


Epoch 2:



Batch loss = 0.676298: 100%|██████████| 1112/1112 [00:18<00:00, 58.81it/s]


Epoch 3:



Batch loss = 0.676135: 100%|██████████| 1112/1112 [00:18<00:00, 61.16it/s]





100%|██████████| 7357/7357 [01:24<00:00, 86.59it/s] 
[32m[I 2023-05-02 05:31:23,380][0m Trial 13 finished with value: 0.10271194069806305 and parameters: {'batch_size': 256, 'lr': 8.841602602084763e-05, 'l2_reg_coef': 0.0001245715018977151}. Best is trial 13 with value: 0.10271194069806305.[0m


Epoch 1:


Batch loss = 0.676456: 100%|██████████| 1112/1112 [00:17<00:00, 63.03it/s]


Epoch 2:



Batch loss = 0.676252: 100%|██████████| 1112/1112 [00:17<00:00, 64.33it/s]


Epoch 3:



Batch loss = 0.676044: 100%|██████████| 1112/1112 [00:17<00:00, 62.03it/s]





100%|██████████| 7357/7357 [01:27<00:00, 84.12it/s]
[32m[I 2023-05-02 05:33:43,746][0m Trial 14 finished with value: 0.10280218642963897 and parameters: {'batch_size': 256, 'lr': 0.00011201144001505824, 'l2_reg_coef': 0.00011498224071460201}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.67651: 100%|██████████| 1112/1112 [00:17<00:00, 63.58it/s]


Epoch 2:



Batch loss = 0.676393: 100%|██████████| 1112/1112 [00:17<00:00, 63.16it/s]


Epoch 3:



Batch loss = 0.676275: 100%|██████████| 1112/1112 [00:17<00:00, 64.01it/s]





100%|██████████| 7357/7357 [01:23<00:00, 88.03it/s] 
[32m[I 2023-05-02 05:35:59,843][0m Trial 15 finished with value: 0.10227424018202808 and parameters: {'batch_size': 256, 'lr': 6.453819918936324e-05, 'l2_reg_coef': 0.00010487598503315062}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676439: 100%|██████████| 1112/1112 [00:17<00:00, 62.49it/s]



Epoch 2:


Batch loss = 0.676234: 100%|██████████| 1112/1112 [00:17<00:00, 62.19it/s]


Epoch 3:



Batch loss = 0.676027: 100%|██████████| 1112/1112 [00:17<00:00, 64.07it/s]





100%|██████████| 7357/7357 [01:25<00:00, 85.82it/s] 
[32m[I 2023-05-02 05:38:18,666][0m Trial 16 finished with value: 0.10267258514405206 and parameters: {'batch_size': 256, 'lr': 0.00011222088443149928, 'l2_reg_coef': 0.00012231553421043192}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.67657: 100%|██████████| 1112/1112 [00:17<00:00, 64.17it/s]


Epoch 2:



Batch loss = 0.676546: 100%|██████████| 1112/1112 [00:17<00:00, 63.98it/s]


Epoch 3:



Batch loss = 0.676522: 100%|██████████| 1112/1112 [00:18<00:00, 60.42it/s]





100%|██████████| 7357/7357 [01:22<00:00, 89.30it/s]
[32m[I 2023-05-02 05:40:34,215][0m Trial 17 finished with value: 0.10012122366554616 and parameters: {'batch_size': 256, 'lr': 1.329297135898151e-05, 'l2_reg_coef': 0.00027464696195951804}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676253: 100%|██████████| 1112/1112 [00:17<00:00, 62.36it/s]



Epoch 2:


Batch loss = 0.675721: 100%|██████████| 1112/1112 [00:17<00:00, 64.58it/s]


Epoch 3:



Batch loss = 0.67517: 100%|██████████| 1112/1112 [00:18<00:00, 59.73it/s]





100%|██████████| 7357/7357 [01:22<00:00, 89.21it/s] 
[32m[I 2023-05-02 05:42:50,420][0m Trial 18 finished with value: 0.10269078638784204 and parameters: {'batch_size': 256, 'lr': 0.00028465473268808, 'l2_reg_coef': 0.00017574392246954507}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676527: 100%|██████████| 1112/1112 [00:17<00:00, 64.41it/s]


Epoch 2:



Batch loss = 0.676427: 100%|██████████| 1112/1112 [00:17<00:00, 64.01it/s]


Epoch 3:



Batch loss = 0.676326: 100%|██████████| 1112/1112 [00:17<00:00, 63.20it/s]





100%|██████████| 7357/7357 [01:22<00:00, 88.96it/s] 
[32m[I 2023-05-02 05:45:05,412][0m Trial 19 finished with value: 0.10260000019982794 and parameters: {'batch_size': 256, 'lr': 5.5293550069357055e-05, 'l2_reg_coef': 0.00024650879640410303}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676005: 100%|██████████| 1112/1112 [00:17<00:00, 62.04it/s]


Epoch 2:



Batch loss = 0.675182: 100%|██████████| 1112/1112 [00:18<00:00, 58.84it/s]


Epoch 3:



Batch loss = 0.674318: 100%|██████████| 1112/1112 [00:17<00:00, 64.52it/s]





100%|██████████| 7357/7357 [01:22<00:00, 89.08it/s] 
[32m[I 2023-05-02 05:47:22,110][0m Trial 20 finished with value: 0.10197867365582104 and parameters: {'batch_size': 256, 'lr': 0.00043038130914649624, 'l2_reg_coef': 0.00011660553035887586}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676375: 100%|██████████| 1112/1112 [00:17<00:00, 64.87it/s]


Epoch 2:



Batch loss = 0.676086: 100%|██████████| 1112/1112 [00:17<00:00, 64.18it/s]



Epoch 3:


Batch loss = 0.67579: 100%|██████████| 1112/1112 [00:17<00:00, 63.25it/s]





100%|██████████| 7357/7357 [01:26<00:00, 85.38it/s]
[32m[I 2023-05-02 05:49:40,388][0m Trial 21 finished with value: 0.10253347231295941 and parameters: {'batch_size': 256, 'lr': 0.00015736290299939808, 'l2_reg_coef': 0.00017305673515704938}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676084: 100%|██████████| 1112/1112 [00:19<00:00, 56.69it/s]


Epoch 2:



Batch loss = 0.67543: 100%|██████████| 1112/1112 [00:18<00:00, 60.01it/s]


Epoch 3:



Batch loss = 0.674749: 100%|██████████| 1112/1112 [00:18<00:00, 60.20it/s]





100%|██████████| 7357/7357 [01:22<00:00, 89.47it/s] 
[32m[I 2023-05-02 05:51:59,314][0m Trial 22 finished with value: 0.10220487834396935 and parameters: {'batch_size': 256, 'lr': 0.000345926242363446, 'l2_reg_coef': 0.00018890834290077584}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.676501: 100%|██████████| 1112/1112 [00:18<00:00, 61.70it/s]


Epoch 2:



Batch loss = 0.676377: 100%|██████████| 1112/1112 [00:18<00:00, 60.28it/s]



Epoch 3:


Batch loss = 0.676252: 100%|██████████| 1112/1112 [00:17<00:00, 62.08it/s]





100%|██████████| 7357/7357 [01:21<00:00, 90.40it/s] 
[32m[I 2023-05-02 05:54:15,146][0m Trial 23 finished with value: 0.10249558938077973 and parameters: {'batch_size': 256, 'lr': 6.819170153367563e-05, 'l2_reg_coef': 0.00010329670512797045}. Best is trial 14 with value: 0.10280218642963897.[0m


Epoch 1:


Batch loss = 0.659678: 100%|██████████| 8889/8889 [01:17<00:00, 114.83it/s]



Epoch 2:


Batch loss = 0.659555: 100%|██████████| 8889/8889 [01:17<00:00, 114.76it/s]



Epoch 3:


Batch loss = 0.659433: 100%|██████████| 8889/8889 [01:17<00:00, 114.38it/s]





100%|██████████| 7357/7357 [01:21<00:00, 90.05it/s]
[32m[I 2023-05-02 05:59:29,496][0m Trial 24 finished with value: 0.10161126104345573 and parameters: {'batch_size': 32, 'lr': 1.0410651356272892e-05, 'l2_reg_coef': 0.00038389638862750913}. Best is trial 14 with value: 0.10280218642963897.[0m


Test Repurchase Module (calculate scores for different seeds):

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=20,
    topk=10,
    early_stop_num=3
)

train dataset preparing...


100%|██████████| 7358/7358 [00:13<00:00, 547.96it/s]


dev dataset preparing...


100%|██████████| 7357/7357 [00:01<00:00, 3894.40it/s]


test dataset preparing...


100%|██████████| 7357/7357 [00:02<00:00, 3509.95it/s]


In [None]:
test_metrics = {
    "precision": [],
    "recall": [],
    "ndcg": []
}

In [None]:
for seed in range(5):
    print(f"\n___SEED___{seed}")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    params = {
        "model": RepurchaseModule(
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": study.best_params["batch_size"],
        "lr": study.best_params["lr"],
        "l2_reg_coef": study.best_params["l2_reg_coef"]
    }

    trainer.init_hyperparams(**params)
    trainer.train()

    metrics = trainer.evaluate(mode="test")

    test_metrics["precision"].append(metrics["precision"])
    test_metrics["recall"].append(metrics["recall"])
    test_metrics["ndcg"].append(metrics["ndcg"])
    print(test_metrics)


___SEED___0
Epoch 1:


Batch loss = 0.676445: 100%|██████████| 1112/1112 [00:18<00:00, 59.75it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:25<00:00, 86.03it/s] 


 {'precision': 0.051610710887590054, 'recall': 0.11495819342791812, 'ndcg': 0.10141630054004878}
Epoch 2:



Batch loss = 0.676241: 100%|██████████| 1112/1112 [00:20<00:00, 53.25it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:24<00:00, 86.95it/s]


 {'precision': 0.05243985320103303, 'recall': 0.11605516319617802, 'ndcg': 0.102341271886976}
Epoch 3:



Batch loss = 0.676033: 100%|██████████| 1112/1112 [00:20<00:00, 53.35it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.52it/s]


 {'precision': 0.05252140818268316, 'recall': 0.11632803205640391, 'ndcg': 0.10265317857092217}
Epoch 4:



Batch loss = 0.675823: 100%|██████████| 1112/1112 [00:18<00:00, 60.94it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.31it/s]


 {'precision': 0.05256218567350822, 'recall': 0.11577157034670713, 'ndcg': 0.10254185335913882}
Epoch 5:



Batch loss = 0.67561: 100%|██████████| 1112/1112 [00:18<00:00, 61.25it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.44it/s] 


 {'precision': 0.05261655566127498, 'recall': 0.11596437299237809, 'ndcg': 0.10267426685960523}
Epoch 6:



Batch loss = 0.675395: 100%|██████████| 1112/1112 [00:17<00:00, 61.79it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.09it/s] 


 {'precision': 0.05272529563680848, 'recall': 0.11611096883113818, 'ndcg': 0.10271010502070328}
Epoch 7:



Batch loss = 0.675177: 100%|██████████| 1112/1112 [00:18<00:00, 60.71it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.47it/s] 


 {'precision': 0.05273888813375017, 'recall': 0.11614950932017105, 'ndcg': 0.10272138781619343}
Epoch 8:



Batch loss = 0.674958: 100%|██████████| 1112/1112 [00:18<00:00, 61.13it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.38it/s] 


 {'precision': 0.05256218567350822, 'recall': 0.11565081512439765, 'ndcg': 0.10236100914297794}
Epoch 9:



Batch loss = 0.674736: 100%|██████████| 1112/1112 [00:18<00:00, 60.81it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.53it/s] 


 {'precision': 0.052575778170449915, 'recall': 0.11579728890064883, 'ndcg': 0.10237323254874256}
Epoch 10:



Batch loss = 0.674512: 100%|██████████| 1112/1112 [00:18<00:00, 61.16it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.46it/s] 


 {'precision': 0.05249422318879978, 'recall': 0.11540889575383728, 'ndcg': 0.10213035573380333}



100%|██████████| 7357/7357 [01:21<00:00, 90.40it/s] 

{'precision': [0.05829821938290063], 'recall': [0.12708883872286808], 'ndcg': [0.11081432366201466]}

___SEED___1
Epoch 1:



Batch loss = 0.676552: 100%|██████████| 1112/1112 [00:18<00:00, 60.23it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 89.85it/s] 


 {'precision': 0.05174663585700693, 'recall': 0.11486148358629557, 'ndcg': 0.10122152532073887}
Epoch 2:



Batch loss = 0.676348: 100%|██████████| 1112/1112 [00:18<00:00, 60.62it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:21<00:00, 90.67it/s] 


 {'precision': 0.052412668207149654, 'recall': 0.11575275542472294, 'ndcg': 0.10233579119314426}
Epoch 3:



Batch loss = 0.67614: 100%|██████████| 1112/1112 [00:18<00:00, 60.80it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:20<00:00, 91.87it/s] 


 {'precision': 0.0524670381949164, 'recall': 0.1157950923853339, 'ndcg': 0.10262691596884006}
Epoch 4:



Batch loss = 0.67593: 100%|██████████| 1112/1112 [00:17<00:00, 64.06it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.00it/s] 


 {'precision': 0.05261655566127498, 'recall': 0.11585518587443003, 'ndcg': 0.10271719699274086}
Epoch 5:



Batch loss = 0.675717: 100%|██████████| 1112/1112 [00:16<00:00, 66.58it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.49it/s] 


 {'precision': 0.052602963164333286, 'recall': 0.11580461945859506, 'ndcg': 0.1027230979178443}
Epoch 6:



Batch loss = 0.675502: 100%|██████████| 1112/1112 [00:16<00:00, 67.57it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:18<00:00, 93.57it/s]


 {'precision': 0.052575778170449915, 'recall': 0.11595689583341764, 'ndcg': 0.10266700421329292}
Epoch 7:



Batch loss = 0.675284: 100%|██████████| 1112/1112 [00:17<00:00, 63.35it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:17<00:00, 94.92it/s] 


 {'precision': 0.05265733315210005, 'recall': 0.11595759984041017, 'ndcg': 0.10264473753530909}
Epoch 8:



Batch loss = 0.675065: 100%|██████████| 1112/1112 [00:17<00:00, 64.81it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:18<00:00, 93.52it/s] 


 {'precision': 0.052385483213266276, 'recall': 0.11527131088706567, 'ndcg': 0.10213824401120289}



100%|██████████| 7357/7357 [01:19<00:00, 92.03it/s] 

{'precision': [0.05829821938290063, 0.058107924425717], 'recall': [0.12708883872286808, 0.1267917249271807], 'ndcg': [0.11081432366201466, 0.11051284735987134]}

___SEED___2
Epoch 1:



Batch loss = 0.6765: 100%|██████████| 1112/1112 [00:16<00:00, 67.09it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.16it/s] 


 {'precision': 0.051801005844773686, 'recall': 0.11493381932394407, 'ndcg': 0.10165236537821108}
Epoch 2:



Batch loss = 0.676296: 100%|██████████| 1112/1112 [00:16<00:00, 67.28it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:18<00:00, 94.03it/s]


 {'precision': 0.05231752072855784, 'recall': 0.11594546643701245, 'ndcg': 0.10243461028958202}
Epoch 3:



Batch loss = 0.676089: 100%|██████████| 1112/1112 [00:17<00:00, 65.14it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:17<00:00, 94.45it/s] 


 {'precision': 0.05261655566127498, 'recall': 0.11657188450306172, 'ndcg': 0.10289739433966297}
Epoch 4:



Batch loss = 0.675878: 100%|██████████| 1112/1112 [00:16<00:00, 65.76it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.29it/s] 


 {'precision': 0.05253500067962485, 'recall': 0.1159154463523375, 'ndcg': 0.10278916300484287}
Epoch 5:



Batch loss = 0.675666: 100%|██████████| 1112/1112 [00:16<00:00, 65.82it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.35it/s] 


 {'precision': 0.052657333152100035, 'recall': 0.11628854075726729, 'ndcg': 0.10282694241229194}
Epoch 6:



Batch loss = 0.675451: 100%|██████████| 1112/1112 [00:17<00:00, 65.16it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.37it/s] 


 {'precision': 0.05271170313986679, 'recall': 0.11652222586454218, 'ndcg': 0.10289494852913228}



100%|██████████| 7357/7357 [01:18<00:00, 93.53it/s]

{'precision': [0.05829821938290063, 0.058107924425717, 0.058067146934891935], 'recall': [0.12708883872286808, 0.1267917249271807, 0.12742818952983803], 'ndcg': [0.11081432366201466, 0.11051284735987134, 0.11087908552800879]}

___SEED___3
Epoch 1:



Batch loss = 0.676474: 100%|██████████| 1112/1112 [00:17<00:00, 62.75it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:17<00:00, 94.59it/s] 


 {'precision': 0.05195052331113225, 'recall': 0.1156230928626154, 'ndcg': 0.10197960100791124}
Epoch 2:



Batch loss = 0.67627: 100%|██████████| 1112/1112 [00:16<00:00, 66.30it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.25it/s] 


 {'precision': 0.05231752072855783, 'recall': 0.11603974059049585, 'ndcg': 0.10256831109125367}
Epoch 3:



Batch loss = 0.676062: 100%|██████████| 1112/1112 [00:16<00:00, 66.47it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:19<00:00, 92.30it/s] 


 {'precision': 0.05245344569797472, 'recall': 0.11591048006986918, 'ndcg': 0.1027257458142659}
Epoch 4:



Batch loss = 0.675851: 100%|██████████| 1112/1112 [00:17<00:00, 65.38it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.26it/s] 


 {'precision': 0.05250781568574146, 'recall': 0.11575319397368874, 'ndcg': 0.10271832963370045}
Epoch 5:



Batch loss = 0.675638: 100%|██████████| 1112/1112 [00:16<00:00, 66.21it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.49it/s] 


 {'precision': 0.052643740655158346, 'recall': 0.11614537193572409, 'ndcg': 0.10285922667664568}
Epoch 6:



Batch loss = 0.675423: 100%|██████████| 1112/1112 [00:16<00:00, 65.58it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.89it/s]


 {'precision': 0.05276607312763355, 'recall': 0.11655349943530233, 'ndcg': 0.10290240449312202}
Epoch 7:



Batch loss = 0.675206: 100%|██████████| 1112/1112 [00:17<00:00, 63.62it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:18<00:00, 93.23it/s]


 {'precision': 0.052779665624575235, 'recall': 0.11634911496685268, 'ndcg': 0.10273790106213881}
Epoch 8:



Batch loss = 0.674986: 100%|██████████| 1112/1112 [00:18<00:00, 61.22it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:17<00:00, 94.54it/s]


 {'precision': 0.05256218567350823, 'recall': 0.11578826871917293, 'ndcg': 0.10241205444898473}
Epoch 9:



Batch loss = 0.674764: 100%|██████████| 1112/1112 [00:17<00:00, 65.41it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.20it/s] 



 {'precision': 0.05249422318879978, 'recall': 0.11560075292633441, 'ndcg': 0.10212504761594665}


100%|██████████| 7357/7357 [01:20<00:00, 91.49it/s] 

{'precision': [0.05829821938290063, 0.058107924425717, 0.058067146934891935, 0.058175886910425446], 'recall': [0.12708883872286808, 0.1267917249271807, 0.12742818952983803, 0.12684088059047646], 'ndcg': [0.11081432366201466, 0.11051284735987134, 0.11087908552800879, 0.11067052948997728]}

___SEED___4
Epoch 1:



Batch loss = 0.676412: 100%|██████████| 1112/1112 [00:17<00:00, 64.38it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.24it/s] 


 {'precision': 0.0516786733722985, 'recall': 0.11499818536387797, 'ndcg': 0.10150591490811248}
Epoch 2:



Batch loss = 0.676208: 100%|██████████| 1112/1112 [00:17<00:00, 64.86it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.54it/s]


 {'precision': 0.0521272257713742, 'recall': 0.11563012852444533, 'ndcg': 0.10222992158679849}
Epoch 3:



Batch loss = 0.676: 100%|██████████| 1112/1112 [00:16<00:00, 65.53it/s]



Evaluation (dev):


100%|██████████| 7357/7357 [01:18<00:00, 94.24it/s] 


 {'precision': 0.05257577817044992, 'recall': 0.11628381424230773, 'ndcg': 0.10269160061530277}
Epoch 4:



Batch loss = 0.67579: 100%|██████████| 1112/1112 [00:17<00:00, 63.98it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:18<00:00, 93.21it/s] 


 {'precision': 0.05252140818268316, 'recall': 0.11570567777503674, 'ndcg': 0.10242643734778803}
Epoch 5:



Batch loss = 0.675577: 100%|██████████| 1112/1112 [00:17<00:00, 64.43it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.31it/s] 


 {'precision': 0.052752480630691864, 'recall': 0.1162870672163725, 'ndcg': 0.10264102934094418}
Epoch 6:



Batch loss = 0.675362: 100%|██████████| 1112/1112 [00:17<00:00, 63.17it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:20<00:00, 91.92it/s] 


 {'precision': 0.052752480630691864, 'recall': 0.11651367722218438, 'ndcg': 0.10271090534626744}
Epoch 7:



Batch loss = 0.675144: 100%|██████████| 1112/1112 [00:17<00:00, 63.89it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:20<00:00, 91.91it/s] 


 {'precision': 0.05271170313986679, 'recall': 0.11629664877014166, 'ndcg': 0.10254315420184261}
Epoch 8:



Batch loss = 0.674925: 100%|██████████| 1112/1112 [00:17<00:00, 63.91it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:20<00:00, 91.90it/s] 


 {'precision': 0.052602963164333286, 'recall': 0.11587463724567974, 'ndcg': 0.10225979437727453}
Epoch 9:



Batch loss = 0.674703: 100%|██████████| 1112/1112 [00:16<00:00, 66.08it/s]


Evaluation (dev):



100%|██████████| 7357/7357 [01:19<00:00, 92.05it/s]


 {'precision': 0.05252140818268316, 'recall': 0.11592458742058198, 'ndcg': 0.10216093113759561}



100%|██████████| 7357/7357 [01:18<00:00, 93.88it/s] 

{'precision': [0.05829821938290063, 0.058107924425717, 0.058067146934891935, 0.058175886910425446, 0.05808073943183363], 'recall': [0.12708883872286808, 0.1267917249271807, 0.12742818952983803, 0.12684088059047646, 0.12678311111502852], 'ndcg': [0.11081432366201466, 0.11051284735987134, 0.11087908552800879, 0.11067052948997728, 0.11057756628484917]}





In [None]:
{
    "precision": np.array(test_metrics["precision"]).mean(),
    "recall": np.array(test_metrics["recall"]).mean(),
    "ndcg": np.array(test_metrics["ndcg"]).mean(),
}

{'precision': 0.05814598341715373,
 'recall': 0.12698654897707834,
 'ndcg': 0.11069087046494426}

# TaoBao

Fix seed:

In [None]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Read interactions data (filter users with less than 10 transactions, high purchase frequency and one-day users and items with less than 10 transactions). Train dataset - all baskets except the last two, validation dataset - the last but one basket, test dataset - the last basket:

In [None]:
corpus_path = "./data/"
dataset_name = "taobao"

preprocessor = Preprocess(corpus_path, dataset_name)
preprocessor.load_data(10, 10, filt=True)
save_split(corpus_path, dataset_name, preprocessor)

Before preprocessing: #users = 672404, #items = 638962, #clicks = 2015807 (#illegal records = 0)
After preprocessing: #users = 10092, #items = 22286, #clicks = 67991
Saving dataset in ./data//data_taobao/...


In [None]:
corpus = Corpus(corpus_path, dataset_name)
corpus.load_data()

Tune hyperparams on validation dataset:

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=2,
    topk=10,
    early_stop_num=None
)

train dataset preparing...


100%|██████████| 10092/10092 [00:36<00:00, 274.64it/s]


dev dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 35741.19it/s]


test dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 30135.90it/s]


In [None]:
def objective(trial):
    params = {
        "model": RepurchaseModule(
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128, 256]),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "l2_reg_coef": trial.suggest_loguniform("l2_reg_coef", 1e-4, 1e-1)
    }

    trainer.init_hyperparams(**params)
    trainer.train(evaluation_flg=False)
    metrics = trainer.evaluate(mode="dev")
    score = metrics["ndcg"]
    return score

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=25)

[32m[I 2023-05-04 07:47:45,056][0m A new study created in memory with name: no-name-b27b7ee1-c893-4196-8efc-eee20907a9b3[0m


Epoch 1:


Batch loss = 0.675072: 100%|██████████| 1522/1522 [00:12<00:00, 118.34it/s]


Epoch 2:



Batch loss = 0.674311: 100%|██████████| 1522/1522 [00:12<00:00, 117.82it/s]





100%|██████████| 9307/9307 [02:59<00:00, 51.86it/s]
[32m[I 2023-05-04 07:51:15,526][0m Trial 0 finished with value: 0.07239948067893617 and parameters: {'batch_size': 32, 'lr': 0.0009863431872330064, 'l2_reg_coef': 0.0004724870791526793}. Best is trial 0 with value: 0.07239948067893617.[0m


Epoch 1:


Batch loss = 0.651219: 100%|██████████| 761/761 [00:07<00:00, 100.37it/s]


Epoch 2:



Batch loss = 0.645342: 100%|██████████| 761/761 [00:08<00:00, 93.48it/s]





100%|██████████| 9307/9307 [02:55<00:00, 53.18it/s]
[32m[I 2023-05-04 07:54:26,293][0m Trial 1 finished with value: 0.0711869016490023 and parameters: {'batch_size': 64, 'lr': 0.005513651007120869, 'l2_reg_coef': 0.07247363402746428}. Best is trial 0 with value: 0.07239948067893617.[0m


Epoch 1:


Batch loss = 0.666126: 100%|██████████| 381/381 [00:04<00:00, 87.72it/s]


Epoch 2:



Batch loss = 0.661997: 100%|██████████| 381/381 [00:04<00:00, 87.62it/s]





100%|██████████| 9307/9307 [02:54<00:00, 53.48it/s]
[32m[I 2023-05-04 07:57:29,059][0m Trial 2 finished with value: 0.07168921553828643 and parameters: {'batch_size': 128, 'lr': 0.007709412252614239, 'l2_reg_coef': 0.0007509797119626302}. Best is trial 0 with value: 0.07239948067893617.[0m


Epoch 1:


Batch loss = 0.675156: 100%|██████████| 1522/1522 [00:13<00:00, 116.57it/s]


Epoch 2:



Batch loss = 0.674918: 100%|██████████| 1522/1522 [00:13<00:00, 115.67it/s]





100%|██████████| 9307/9307 [02:54<00:00, 53.44it/s]
[32m[I 2023-05-04 08:00:49,478][0m Trial 3 finished with value: 0.07221848767409397 and parameters: {'batch_size': 32, 'lr': 0.0003114318604656744, 'l2_reg_coef': 0.010529332619998283}. Best is trial 0 with value: 0.07239948067893617.[0m


Epoch 1:


Batch loss = 0.666747: 100%|██████████| 381/381 [00:04<00:00, 84.73it/s]


Epoch 2:



Batch loss = 0.664569: 100%|██████████| 381/381 [00:04<00:00, 87.86it/s]





100%|██████████| 9307/9307 [02:53<00:00, 53.77it/s]
[32m[I 2023-05-04 08:03:51,429][0m Trial 4 finished with value: 0.07157595242065973 and parameters: {'batch_size': 128, 'lr': 0.003995661855958764, 'l2_reg_coef': 0.006355019100735405}. Best is trial 0 with value: 0.07239948067893617.[0m


Epoch 1:


Batch loss = 0.667404: 100%|██████████| 381/381 [00:04<00:00, 79.47it/s]



Epoch 2:


Batch loss = 0.667391: 100%|██████████| 381/381 [00:04<00:00, 90.13it/s]





100%|██████████| 9307/9307 [02:53<00:00, 53.65it/s]
[32m[I 2023-05-04 08:06:53,962][0m Trial 5 finished with value: 0.07271572372008725 and parameters: {'batch_size': 128, 'lr': 2.3005803026525478e-05, 'l2_reg_coef': 0.0007981787657617986}. Best is trial 5 with value: 0.07271572372008725.[0m


Epoch 1:


Batch loss = 0.652869: 100%|██████████| 761/761 [00:08<00:00, 93.64it/s] 


Epoch 2:



Batch loss = 0.651179: 100%|██████████| 761/761 [00:07<00:00, 105.29it/s]





100%|██████████| 9307/9307 [02:54<00:00, 53.34it/s]
[32m[I 2023-05-04 08:10:03,839][0m Trial 6 finished with value: 0.07146811573355069 and parameters: {'batch_size': 64, 'lr': 0.0015500461319089488, 'l2_reg_coef': 0.028698618231059004}. Best is trial 5 with value: 0.07271572372008725.[0m


Epoch 1:


Batch loss = 0.653226: 100%|██████████| 761/761 [00:07<00:00, 100.52it/s]


Epoch 2:



Batch loss = 0.65306: 100%|██████████| 761/761 [00:08<00:00, 94.10it/s] 





100%|██████████| 9307/9307 [02:53<00:00, 53.70it/s]
[32m[I 2023-05-04 08:13:12,848][0m Trial 7 finished with value: 0.07308154846278746 and parameters: {'batch_size': 64, 'lr': 0.00015270273918806623, 'l2_reg_coef': 0.04485485309869719}. Best is trial 7 with value: 0.07308154846278746.[0m


Epoch 1:


Batch loss = 0.664666: 100%|██████████| 381/381 [00:04<00:00, 87.09it/s]


Epoch 2:



Batch loss = 0.655816: 100%|██████████| 381/381 [00:04<00:00, 77.67it/s]





100%|██████████| 9307/9307 [02:52<00:00, 54.03it/s]
[32m[I 2023-05-04 08:16:14,414][0m Trial 8 finished with value: 0.07072601012145072 and parameters: {'batch_size': 128, 'lr': 0.01924964086209924, 'l2_reg_coef': 0.0002840900733622218}. Best is trial 7 with value: 0.07308154846278746.[0m


Epoch 1:


Batch loss = 0.664545: 100%|██████████| 381/381 [00:04<00:00, 89.89it/s]


Epoch 2:



Batch loss = 0.655509: 100%|██████████| 381/381 [00:04<00:00, 93.05it/s]





100%|██████████| 9307/9307 [02:50<00:00, 54.69it/s]
[32m[I 2023-05-04 08:19:12,956][0m Trial 9 finished with value: 0.07090360526327591 and parameters: {'batch_size': 128, 'lr': 0.02016003934561091, 'l2_reg_coef': 0.0005677045861147308}. Best is trial 7 with value: 0.07308154846278746.[0m


Epoch 1:


Batch loss = 0.668759: 100%|██████████| 191/191 [00:03<00:00, 54.18it/s]


Epoch 2:



Batch loss = 0.693147: 100%|██████████| 191/191 [00:02<00:00, 66.89it/s]





100%|██████████| 9307/9307 [02:50<00:00, 54.53it/s]
[32m[I 2023-05-04 08:22:10,073][0m Trial 10 finished with value: 0.051101659849159146 and parameters: {'batch_size': 256, 'lr': 0.08097836606986637, 'l2_reg_coef': 0.09111479007859714}. Best is trial 7 with value: 0.07308154846278746.[0m


Epoch 1:


Batch loss = 0.653439: 100%|██████████| 761/761 [00:06<00:00, 118.10it/s]



Epoch 2:


Batch loss = 0.653421: 100%|██████████| 761/761 [00:07<00:00, 102.30it/s]





100%|██████████| 9307/9307 [02:49<00:00, 54.80it/s]
[32m[I 2023-05-04 08:25:13,838][0m Trial 11 finished with value: 0.07265914127954991 and parameters: {'batch_size': 64, 'lr': 1.6462341078166435e-05, 'l2_reg_coef': 0.00213318999253824}. Best is trial 7 with value: 0.07308154846278746.[0m


Epoch 1:


Batch loss = 0.67179: 100%|██████████| 191/191 [00:02<00:00, 69.85it/s]


Epoch 2:



Batch loss = 0.671777: 100%|██████████| 191/191 [00:02<00:00, 71.00it/s]





100%|██████████| 9307/9307 [02:49<00:00, 54.78it/s]
[32m[I 2023-05-04 08:28:09,217][0m Trial 12 finished with value: 0.07350233825331731 and parameters: {'batch_size': 256, 'lr': 3.5825920825139016e-05, 'l2_reg_coef': 0.0018995211552374786}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.671811: 100%|██████████| 191/191 [00:03<00:00, 59.05it/s]


Epoch 2:



Batch loss = 0.671779: 100%|██████████| 191/191 [00:03<00:00, 61.17it/s]





100%|██████████| 9307/9307 [02:50<00:00, 54.59it/s]
[32m[I 2023-05-04 08:31:06,128][0m Trial 13 finished with value: 0.07259993383814604 and parameters: {'batch_size': 256, 'lr': 8.841602602084763e-05, 'l2_reg_coef': 0.00012457150189771487}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.6718: 100%|██████████| 191/191 [00:03<00:00, 63.55it/s]


Epoch 2:



Batch loss = 0.67177: 100%|██████████| 191/191 [00:03<00:00, 62.57it/s]





100%|██████████| 9307/9307 [02:53<00:00, 53.78it/s]
[32m[I 2023-05-04 08:34:05,313][0m Trial 14 finished with value: 0.07313445659008751 and parameters: {'batch_size': 256, 'lr': 8.287350404318291e-05, 'l2_reg_coef': 0.0030768349797705157}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.671752: 100%|██████████| 191/191 [00:03<00:00, 52.24it/s]


Epoch 2:



Batch loss = 0.671737: 100%|██████████| 191/191 [00:03<00:00, 60.80it/s]





100%|██████████| 9307/9307 [02:53<00:00, 53.79it/s]
[32m[I 2023-05-04 08:37:05,184][0m Trial 15 finished with value: 0.07285763824430949 and parameters: {'batch_size': 256, 'lr': 4.132118354060992e-05, 'l2_reg_coef': 0.0025726384630819538}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.671833: 100%|██████████| 191/191 [00:03<00:00, 55.49it/s]


Epoch 2:



Batch loss = 0.671813: 100%|██████████| 191/191 [00:03<00:00, 51.58it/s]





100%|██████████| 9307/9307 [02:53<00:00, 53.80it/s]
[32m[I 2023-05-04 08:40:05,389][0m Trial 16 finished with value: 0.07307500644780318 and parameters: {'batch_size': 256, 'lr': 5.4191514815566476e-05, 'l2_reg_coef': 0.001649175318104535}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.671842: 100%|██████████| 191/191 [00:03<00:00, 58.69it/s]



Epoch 2:


Batch loss = 0.671838: 100%|██████████| 191/191 [00:03<00:00, 59.65it/s]





100%|██████████| 9307/9307 [02:52<00:00, 53.91it/s]
[32m[I 2023-05-04 08:43:04,522][0m Trial 17 finished with value: 0.07258838178115584 and parameters: {'batch_size': 256, 'lr': 1.207464326417499e-05, 'l2_reg_coef': 0.005414487015641504}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.671804: 100%|██████████| 191/191 [00:03<00:00, 51.41it/s]


Epoch 2:



Batch loss = 0.6717: 100%|██████████| 191/191 [00:03<00:00, 56.79it/s]





100%|██████████| 9307/9307 [02:52<00:00, 53.87it/s]
[32m[I 2023-05-04 08:46:04,422][0m Trial 18 finished with value: 0.07341394431374704 and parameters: {'batch_size': 256, 'lr': 0.00028465473268808, 'l2_reg_coef': 0.014292684391142474}. Best is trial 12 with value: 0.07350233825331731.[0m


Epoch 1:


Batch loss = 0.671828: 100%|██████████| 191/191 [00:03<00:00, 62.56it/s]



Epoch 2:


Batch loss = 0.671748: 100%|██████████| 191/191 [00:03<00:00, 59.83it/s]





100%|██████████| 9307/9307 [02:53<00:00, 53.50it/s]
[32m[I 2023-05-04 08:49:04,692][0m Trial 19 finished with value: 0.07354655117790239 and parameters: {'batch_size': 256, 'lr': 0.00022155020864083442, 'l2_reg_coef': 0.012687180197268989}. Best is trial 19 with value: 0.07354655117790239.[0m


Epoch 1:


Batch loss = 0.671738: 100%|██████████| 191/191 [00:03<00:00, 61.29it/s]


Epoch 2:



Batch loss = 0.671655: 100%|██████████| 191/191 [00:03<00:00, 61.49it/s]





100%|██████████| 9307/9307 [02:52<00:00, 53.82it/s]
[32m[I 2023-05-04 08:52:03,876][0m Trial 20 finished with value: 0.07276074816810305 and parameters: {'batch_size': 256, 'lr': 0.00022854512058308906, 'l2_reg_coef': 0.017787065467734897}. Best is trial 19 with value: 0.07354655117790239.[0m


Epoch 1:


Batch loss = 0.671766: 100%|██████████| 191/191 [00:03<00:00, 52.89it/s]


Epoch 2:



Batch loss = 0.671598: 100%|██████████| 191/191 [00:03<00:00, 59.81it/s]





100%|██████████| 9307/9307 [02:52<00:00, 53.83it/s]
[32m[I 2023-05-04 08:55:03,615][0m Trial 21 finished with value: 0.07259068432412971 and parameters: {'batch_size': 256, 'lr': 0.00046228514780085966, 'l2_reg_coef': 0.01457690241539881}. Best is trial 19 with value: 0.07354655117790239.[0m


Epoch 1:


Batch loss = 0.671789: 100%|██████████| 191/191 [00:03<00:00, 62.58it/s]


Epoch 2:



Batch loss = 0.671739: 100%|██████████| 191/191 [00:03<00:00, 62.47it/s]





100%|██████████| 9307/9307 [02:54<00:00, 53.30it/s]
[32m[I 2023-05-04 08:58:04,401][0m Trial 22 finished with value: 0.07332113749475758 and parameters: {'batch_size': 256, 'lr': 0.00013666957221426842, 'l2_reg_coef': 0.0071154069006925166}. Best is trial 19 with value: 0.07354655117790239.[0m


Epoch 1:


Batch loss = 0.67185: 100%|██████████| 191/191 [00:03<00:00, 63.45it/s]


Epoch 2:



Batch loss = 0.671837: 100%|██████████| 191/191 [00:03<00:00, 62.84it/s]





100%|██████████| 9307/9307 [02:52<00:00, 53.85it/s]
[32m[I 2023-05-04 09:01:03,338][0m Trial 23 finished with value: 0.07303027959852247 and parameters: {'batch_size': 256, 'lr': 3.549365225067545e-05, 'l2_reg_coef': 0.019760126678231054}. Best is trial 19 with value: 0.07354655117790239.[0m


Epoch 1:


Batch loss = 0.675092: 100%|██████████| 1522/1522 [00:12<00:00, 120.22it/s]



Epoch 2:


Batch loss = 0.674593: 100%|██████████| 1522/1522 [00:12<00:00, 118.20it/s]





100%|██████████| 9307/9307 [02:50<00:00, 54.46it/s]
[32m[I 2023-05-04 09:04:19,830][0m Trial 24 finished with value: 0.07196471216806259 and parameters: {'batch_size': 32, 'lr': 0.0006484226085149746, 'l2_reg_coef': 0.005136300153259163}. Best is trial 19 with value: 0.07354655117790239.[0m


Test Repurchase Module (calculate scores for different seeds):

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=20,
    topk=10,
    early_stop_num=3
)

train dataset preparing...


100%|██████████| 10092/10092 [00:32<00:00, 310.95it/s]


dev dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 32143.90it/s]


test dataset preparing...


100%|██████████| 9307/9307 [00:00<00:00, 24237.43it/s]


In [None]:
test_metrics = {
    "precision": [],
    "recall": [],
    "ndcg": []
}

In [None]:
for seed in range(3):
    print(f"\n___SEED___{seed}")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    params = {
        "model": RepurchaseModule(
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": study.best_params["batch_size"],
        "lr": study.best_params["lr"],
        "l2_reg_coef": study.best_params["l2_reg_coef"]
    }

    trainer.init_hyperparams(**params)
    trainer.train()

    metrics = trainer.evaluate(mode="test")

    test_metrics["precision"].append(metrics["precision"])
    test_metrics["recall"].append(metrics["recall"])
    test_metrics["ndcg"].append(metrics["ndcg"])
    print(test_metrics)


___SEED___0
Epoch 1:


Batch loss = 0.671768: 100%|██████████| 191/191 [00:03<00:00, 53.70it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:56<00:00, 52.63it/s]


 {'precision': 0.010454496615450738, 'recall': 0.09900612442247769, 'ndcg': 0.07304581542906191}
Epoch 2:



Batch loss = 0.671688: 100%|██████████| 191/191 [00:02<00:00, 65.33it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [03:00<00:00, 51.47it/s]


 {'precision': 0.010454496615450738, 'recall': 0.09900612442247769, 'ndcg': 0.07321221011898418}
Epoch 3:



Batch loss = 0.671608: 100%|██████████| 191/191 [00:03<00:00, 59.83it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:57<00:00, 52.57it/s]



 {'precision': 0.010454496615450738, 'recall': 0.09900612442247769, 'ndcg': 0.07311565832229666}
Epoch 4:


Batch loss = 0.67153: 100%|██████████| 191/191 [00:02<00:00, 66.92it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:58<00:00, 52.20it/s]


 {'precision': 0.010465241216288815, 'recall': 0.09911357043085849, 'ndcg': 0.07309558911493251}
Epoch 5:



Batch loss = 0.671451: 100%|██████████| 191/191 [00:02<00:00, 71.36it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:58<00:00, 52.14it/s]


 {'precision': 0.010465241216288815, 'recall': 0.09911357043085849, 'ndcg': 0.07295615659914542}



100%|██████████| 9307/9307 [02:58<00:00, 52.01it/s]


{'precision': [0.011582679703449016], 'recall': [0.11188711006052791], 'ndcg': [0.08028130504394505]}

___SEED___1
Epoch 1:


Batch loss = 0.671841: 100%|██████████| 191/191 [00:02<00:00, 67.82it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:58<00:00, 52.20it/s]


 {'precision': 0.010400773611260343, 'recall': 0.09866587872927188, 'ndcg': 0.07284651733953297}
Epoch 2:



Batch loss = 0.67176: 100%|██████████| 191/191 [00:02<00:00, 70.52it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:58<00:00, 52.26it/s]


 {'precision': 0.010400773611260343, 'recall': 0.09866587872927188, 'ndcg': 0.07289885588589323}
Epoch 3:



Batch loss = 0.671681: 100%|██████████| 191/191 [00:02<00:00, 70.38it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:57<00:00, 52.36it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.07300031661548595}
Epoch 4:



Batch loss = 0.671602: 100%|██████████| 191/191 [00:02<00:00, 70.64it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:57<00:00, 52.30it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.07286057881440128}
Epoch 5:



Batch loss = 0.671523: 100%|██████████| 191/191 [00:02<00:00, 70.85it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:57<00:00, 52.50it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.0728958180684239}
Epoch 6:



Batch loss = 0.671443: 100%|██████████| 191/191 [00:03<00:00, 62.27it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:56<00:00, 52.86it/s]


 {'precision': 0.010411518212098422, 'recall': 0.09877332473765266, 'ndcg': 0.07284848237489217}



100%|██████████| 9307/9307 [02:58<00:00, 52.15it/s]

{'precision': [0.011582679703449016, 0.011571935102610939], 'recall': [0.11188711006052791, 0.11177966405214712], 'ndcg': [0.08028130504394505, 0.07965814355041348]}

___SEED___2
Epoch 1:



Batch loss = 0.671782: 100%|██████████| 191/191 [00:02<00:00, 71.83it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:57<00:00, 52.36it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.0728799049192976}
Epoch 2:



Batch loss = 0.671702: 100%|██████████| 191/191 [00:02<00:00, 70.81it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:57<00:00, 52.37it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.07296206662664706}





Epoch 3:


Batch loss = 0.671623: 100%|██████████| 191/191 [00:02<00:00, 71.73it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:58<00:00, 52.23it/s]


 {'precision': 0.0104222628129365, 'recall': 0.09888077074603345, 'ndcg': 0.07264976347857466}
Epoch 4:



Batch loss = 0.671544: 100%|██████████| 191/191 [00:02<00:00, 70.06it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:59<00:00, 51.96it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09898821675441424, 'ndcg': 0.07257201686937935}
Epoch 5:



Batch loss = 0.671464: 100%|██████████| 191/191 [00:02<00:00, 70.53it/s]


Evaluation (dev):



100%|██████████| 9307/9307 [02:58<00:00, 52.03it/s]


 {'precision': 0.010433007413774578, 'recall': 0.09898821675441424, 'ndcg': 0.07240678834666571}



100%|██████████| 9307/9307 [02:58<00:00, 52.27it/s]

{'precision': [0.011582679703449016, 0.011571935102610939, 0.011550445900934781], 'recall': [0.11188711006052791, 0.11177966405214712, 0.11156477203538555], 'ndcg': [0.08028130504394505, 0.07965814355041348, 0.08039730034099186]}





In [None]:
{
    "precision": np.array(test_metrics["precision"]).mean(),
    "recall": np.array(test_metrics["recall"]).mean(),
    "ndcg": np.array(test_metrics["ndcg"]).mean(),
}

{'precision': 0.011568353568998246,
 'recall': 0.11174384871602021,
 'ndcg': 0.0801122496451168}

# Dunnhumby

Fix seed:

In [None]:
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

Read interactions data (filter users with less than 5 transactions, high purchase frequency and one-day users and items with less than 10 transactions). Train dataset - all baskets except the last two, validation dataset - the last but one basket, test dataset - the last basket:

In [None]:
corpus_path = "./data/"
dataset_name = "dunnhumby"

preprocessor = Preprocess(corpus_path, dataset_name)
preprocessor.load_data(5, 10, filt=True)
save_split(corpus_path, dataset_name, preprocessor)

Before preprocessing: #users = 2500, #items = 92339, #clicks = 2595370 (#illegal records = 0)
After preprocessing: #users = 2358, #items = 26756, #clicks = 1976796
Saving dataset in ./data//data_dunnhumby/...


In [None]:
corpus = Corpus(corpus_path, dataset_name)
corpus.load_data()

Tune hyperparams on validation dataset:

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=1,
    topk=10,
    early_stop_num=None
)

train dataset preparing...


100%|██████████| 2358/2358 [00:12<00:00, 189.36it/s]


dev dataset preparing...


100%|██████████| 2357/2357 [00:11<00:00, 202.53it/s]


test dataset preparing...


100%|██████████| 2357/2357 [00:10<00:00, 229.09it/s]


In [None]:
def objective(trial):
    params = {
        "model": RepurchaseModule(
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": trial.suggest_categorical("batch_size", [32, 64, 128, 256]),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "l2_reg_coef": trial.suggest_loguniform("l2_reg_coef", 1e-4, 1e-1)
    }

    trainer.init_hyperparams(**params)
    trainer.train(evaluation_flg=False)
    metrics = trainer.evaluate(mode="dev")
    score = metrics["ndcg"]
    return score

In [None]:
sampler = optuna.samplers.TPESampler(seed=seed)
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective, n_trials=25)

[32m[I 2023-05-02 07:42:54,375][0m A new study created in memory with name: no-name-c4aa80c3-e785-4650-b1b1-69099d384f7b[0m


Epoch 1:


Batch loss = 0.493123: 100%|██████████| 60237/60237 [08:09<00:00, 122.95it/s]





100%|██████████| 2357/2357 [01:14<00:00, 31.48it/s]
[32m[I 2023-05-02 07:52:24,161][0m Trial 0 finished with value: 0.1571827790007825 and parameters: {'batch_size': 32, 'lr': 0.0009863431872330064, 'l2_reg_coef': 0.0004724870791526793}. Best is trial 0 with value: 0.1571827790007825.[0m


Epoch 1:


Batch loss = 0.563522: 100%|██████████| 30119/30119 [04:40<00:00, 107.49it/s]





100%|██████████| 2357/2357 [01:16<00:00, 30.62it/s]
[32m[I 2023-05-02 07:58:21,364][0m Trial 1 finished with value: 0.14852841631118674 and parameters: {'batch_size': 64, 'lr': 0.005513651007120869, 'l2_reg_coef': 0.07247363402746428}. Best is trial 0 with value: 0.1571827790007825.[0m


Epoch 1:


Batch loss = 0.551826: 100%|██████████| 15060/15060 [02:51<00:00, 87.60it/s]





100%|██████████| 2357/2357 [01:13<00:00, 31.97it/s]
[32m[I 2023-05-02 08:02:27,034][0m Trial 2 finished with value: 0.14749004733340734 and parameters: {'batch_size': 128, 'lr': 0.007709412252614239, 'l2_reg_coef': 0.0007509797119626302}. Best is trial 0 with value: 0.1571827790007825.[0m


Epoch 1:


Batch loss = 0.502574: 100%|██████████| 60237/60237 [08:11<00:00, 122.58it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.19it/s]
[32m[I 2023-05-02 08:11:51,712][0m Trial 3 finished with value: 0.1624108396401722 and parameters: {'batch_size': 32, 'lr': 0.0003114318604656744, 'l2_reg_coef': 0.010529332619998283}. Best is trial 3 with value: 0.1624108396401722.[0m


Epoch 1:


Batch loss = 0.556344: 100%|██████████| 15060/15060 [02:51<00:00, 87.79it/s]





100%|██████████| 2357/2357 [01:14<00:00, 31.67it/s]
[32m[I 2023-05-02 08:15:57,709][0m Trial 4 finished with value: 0.15157858379978897 and parameters: {'batch_size': 128, 'lr': 0.003995661855958764, 'l2_reg_coef': 0.006355019100735405}. Best is trial 3 with value: 0.1624108396401722.[0m


Epoch 1:


Batch loss = 0.577382: 100%|██████████| 15060/15060 [02:51<00:00, 87.60it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.22it/s]
[32m[I 2023-05-02 08:20:02,817][0m Trial 5 finished with value: 0.1652210588839391 and parameters: {'batch_size': 128, 'lr': 2.3005803026525478e-05, 'l2_reg_coef': 0.0007981787657617986}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.560697: 100%|██████████| 30119/30119 [04:39<00:00, 107.82it/s]





100%|██████████| 2357/2357 [01:14<00:00, 31.49it/s]
[32m[I 2023-05-02 08:25:57,029][0m Trial 6 finished with value: 0.15577705769402608 and parameters: {'batch_size': 64, 'lr': 0.0015500461319089488, 'l2_reg_coef': 0.028698618231059004}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.57136: 100%|██████████| 30119/30119 [04:39<00:00, 107.70it/s]





100%|██████████| 2357/2357 [01:15<00:00, 31.37it/s]
[32m[I 2023-05-02 08:31:51,852][0m Trial 7 finished with value: 0.16293236117151633 and parameters: {'batch_size': 64, 'lr': 0.00015270273918806623, 'l2_reg_coef': 0.04485485309869719}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.602771: 100%|██████████| 15060/15060 [02:51<00:00, 87.94it/s]





100%|██████████| 2357/2357 [01:15<00:00, 31.34it/s]
[32m[I 2023-05-02 08:35:58,349][0m Trial 8 finished with value: 0.1352432273865811 and parameters: {'batch_size': 128, 'lr': 0.01924964086209924, 'l2_reg_coef': 0.0002840900733622218}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.611003: 100%|██████████| 15060/15060 [02:50<00:00, 88.10it/s]





100%|██████████| 2357/2357 [01:14<00:00, 31.62it/s]
[32m[I 2023-05-02 08:40:03,850][0m Trial 9 finished with value: 0.13467150296900154 and parameters: {'batch_size': 128, 'lr': 0.02016003934561091, 'l2_reg_coef': 0.0005677045861147308}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.647677: 100%|██████████| 7530/7530 [02:00<00:00, 62.61it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.04it/s]
[32m[I 2023-05-02 08:43:17,709][0m Trial 10 finished with value: 0.16399381305646155 and parameters: {'batch_size': 256, 'lr': 1.2690369436118601e-05, 'l2_reg_coef': 0.00011061586013725058}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.647681: 100%|██████████| 7530/7530 [01:59<00:00, 63.01it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.07it/s]
[32m[I 2023-05-02 08:46:30,734][0m Trial 11 finished with value: 0.16395721389824364 and parameters: {'batch_size': 256, 'lr': 1.079615624829323e-05, 'l2_reg_coef': 0.00010682158568187776}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.647661: 100%|██████████| 7530/7530 [01:59<00:00, 63.23it/s]





100%|██████████| 2357/2357 [01:15<00:00, 31.02it/s]
[32m[I 2023-05-02 08:49:45,837][0m Trial 12 finished with value: 0.16378767121430896 and parameters: {'batch_size': 256, 'lr': 1.229353134319667e-05, 'l2_reg_coef': 0.0013255706037010157}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.647337: 100%|██████████| 7530/7530 [01:59<00:00, 63.21it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.02it/s]
[32m[I 2023-05-02 08:52:58,618][0m Trial 13 finished with value: 0.16453337018372205 and parameters: {'batch_size': 256, 'lr': 4.801803589897407e-05, 'l2_reg_coef': 0.00010309814875559583}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.647284: 100%|██████████| 7530/7530 [02:00<00:00, 62.54it/s]





100%|██████████| 2357/2357 [01:12<00:00, 32.34it/s]
[32m[I 2023-05-02 08:56:11,950][0m Trial 14 finished with value: 0.164345240415356 and parameters: {'batch_size': 256, 'lr': 5.049228075508465e-05, 'l2_reg_coef': 0.0020814166171324814}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.681508: 100%|██████████| 7530/7530 [01:58<00:00, 63.74it/s]





100%|██████████| 2357/2357 [01:14<00:00, 31.63it/s]
[32m[I 2023-05-02 08:59:24,642][0m Trial 15 finished with value: 0.1044473268039806 and parameters: {'batch_size': 256, 'lr': 0.08529119723988914, 'l2_reg_coef': 0.00022599279633306386}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.575697: 100%|██████████| 15060/15060 [02:52<00:00, 87.39it/s]





100%|██████████| 2357/2357 [01:15<00:00, 31.07it/s]
[32m[I 2023-05-02 09:03:32,893][0m Trial 16 finished with value: 0.16497547918637034 and parameters: {'batch_size': 128, 'lr': 4.997613384219856e-05, 'l2_reg_coef': 0.0015714777859241755}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.575897: 100%|██████████| 15060/15060 [02:51<00:00, 87.60it/s]





100%|██████████| 2357/2357 [01:15<00:00, 31.30it/s]
[32m[I 2023-05-02 09:07:40,131][0m Trial 17 finished with value: 0.1644626432291221 and parameters: {'batch_size': 128, 'lr': 4.586629275612112e-05, 'l2_reg_coef': 0.0026871126228970215}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.569742: 100%|██████████| 15060/15060 [02:54<00:00, 86.54it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.15it/s]
[32m[I 2023-05-02 09:11:47,493][0m Trial 18 finished with value: 0.16309799557330557 and parameters: {'batch_size': 128, 'lr': 0.0002358952963635556, 'l2_reg_coef': 0.0012102329170348828}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.572629: 100%|██████████| 15060/15060 [02:55<00:00, 85.99it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.01it/s]
[32m[I 2023-05-02 09:15:56,299][0m Trial 19 finished with value: 0.1643147058263584 and parameters: {'batch_size': 128, 'lr': 0.00011925142452157299, 'l2_reg_coef': 0.004316101255062704}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.564856: 100%|██████████| 15060/15060 [02:55<00:00, 85.67it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.24it/s]
[32m[I 2023-05-02 09:20:05,215][0m Trial 20 finished with value: 0.1620976156998307 and parameters: {'batch_size': 128, 'lr': 0.0005669431099343077, 'l2_reg_coef': 0.0013992853424236254}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.51474: 100%|██████████| 60237/60237 [08:28<00:00, 118.50it/s]





100%|██████████| 2357/2357 [01:14<00:00, 31.81it/s]
[32m[I 2023-05-02 09:29:47,695][0m Trial 21 finished with value: 0.16406008571514372 and parameters: {'batch_size': 32, 'lr': 6.427474216450577e-05, 'l2_reg_coef': 0.00027675831786225333}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.647444: 100%|██████████| 7530/7530 [01:59<00:00, 62.93it/s]





100%|██████████| 2357/2357 [01:13<00:00, 32.15it/s]
[32m[I 2023-05-02 09:33:00,697][0m Trial 22 finished with value: 0.16486789126871104 and parameters: {'batch_size': 256, 'lr': 3.346140912096039e-05, 'l2_reg_coef': 0.0007564494144274886}. Best is trial 5 with value: 0.1652210588839391.[0m


Epoch 1:


Batch loss = 0.577396: 100%|██████████| 15060/15060 [02:54<00:00, 86.42it/s]





100%|██████████| 2357/2357 [01:12<00:00, 32.30it/s]
[32m[I 2023-05-02 09:37:07,975][0m Trial 23 finished with value: 0.16530139755253817 and parameters: {'batch_size': 128, 'lr': 2.0869566476632644e-05, 'l2_reg_coef': 0.0008476179290251597}. Best is trial 23 with value: 0.16530139755253817.[0m


Epoch 1:


Batch loss = 0.577329: 100%|██████████| 15060/15060 [02:53<00:00, 86.90it/s]





100%|██████████| 2357/2357 [01:12<00:00, 32.42it/s]
[32m[I 2023-05-02 09:41:14,003][0m Trial 24 finished with value: 0.16433026738946085 and parameters: {'batch_size': 128, 'lr': 2.3355591416389358e-05, 'l2_reg_coef': 0.002473103274964016}. Best is trial 23 with value: 0.16530139755253817.[0m


Test Repurchase Module (calculate scores for different seeds):

In [None]:
trainer = NBRTrainer(
    corpus=corpus,
    max_epochs=20,
    topk=10,
    early_stop_num=3
)

train dataset preparing...


100%|██████████| 2358/2358 [00:09<00:00, 241.97it/s]


dev dataset preparing...


100%|██████████| 2357/2357 [00:11<00:00, 212.88it/s]


test dataset preparing...


100%|██████████| 2357/2357 [00:10<00:00, 218.85it/s]


In [None]:
test_metrics = {
    "precision": [],
    "recall": [],
    "ndcg": []
}

In [None]:
for seed in range(3):
    print(f"\n___SEED___{seed}")
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    params = {
        "model": RepurchaseModule(
            item_num=corpus.n_items,
            avg_repeat_interval=corpus.total_avg_interval
        ),
        "batch_size": study.best_params["batch_size"],
        "lr": study.best_params["lr"],
        "l2_reg_coef": study.best_params["l2_reg_coef"]
    }

    trainer.init_hyperparams(**params)
    trainer.train()

    metrics = trainer.evaluate(mode="test")

    test_metrics["precision"].append(metrics["precision"])
    test_metrics["recall"].append(metrics["recall"])
    test_metrics["ndcg"].append(metrics["ndcg"])


___SEED___0
Epoch 1:


Batch loss = 0.577507: 100%|██████████| 15060/15060 [02:46<00:00, 90.62it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.57it/s]


 {'precision': 0.11535850657615612, 'recall': 0.17654723061333896, 'ndcg': 0.16435737794573518}
Epoch 2:



Batch loss = 0.576102: 100%|██████████| 15060/15060 [02:43<00:00, 91.83it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.46it/s]


 {'precision': 0.11591005515485789, 'recall': 0.1762658233669282, 'ndcg': 0.16437461229120745}
Epoch 3:



Batch loss = 0.574934: 100%|██████████| 15060/15060 [02:44<00:00, 91.39it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.53it/s]


 {'precision': 0.11629189647857445, 'recall': 0.1767296561883852, 'ndcg': 0.16486385888226146}
Epoch 4:



Batch loss = 0.57395: 100%|██████████| 15060/15060 [02:46<00:00, 90.55it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.60it/s]


 {'precision': 0.11633432329232074, 'recall': 0.17644111433967338, 'ndcg': 0.1648813214212546}
Epoch 5:



Batch loss = 0.573112: 100%|██████████| 15060/15060 [02:46<00:00, 90.63it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.61it/s]


 {'precision': 0.11612218922358933, 'recall': 0.1760895660491336, 'ndcg': 0.16441869769084638}
Epoch 6:



Batch loss = 0.572391: 100%|██████████| 15060/15060 [02:46<00:00, 90.42it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.74it/s]


 {'precision': 0.1158676283411116, 'recall': 0.1751983259060505, 'ndcg': 0.16430355773481312}
Epoch 7:



Batch loss = 0.571763: 100%|██████████| 15060/15060 [02:44<00:00, 91.32it/s] 


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.80it/s]


 {'precision': 0.11582520152736529, 'recall': 0.17432394715513913, 'ndcg': 0.16400502673968226}



100%|██████████| 2357/2357 [01:10<00:00, 33.50it/s]


___SEED___1
Epoch 1:



Batch loss = 0.577462: 100%|██████████| 15060/15060 [02:45<00:00, 91.08it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.72it/s]


 {'precision': 0.11506151887993214, 'recall': 0.1765820492392197, 'ndcg': 0.16449431960363373}
Epoch 2:



Batch loss = 0.576074: 100%|██████████| 15060/15060 [02:46<00:00, 90.48it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.89it/s]


 {'precision': 0.11540093338990241, 'recall': 0.17548989428823922, 'ndcg': 0.16405846192863924}
Epoch 3:



Batch loss = 0.574922: 100%|██████████| 15060/15060 [02:46<00:00, 90.42it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.86it/s]


 {'precision': 0.11591005515485786, 'recall': 0.17585670966119427, 'ndcg': 0.1642078415204301}
Epoch 4:



Batch loss = 0.573952: 100%|██████████| 15060/15060 [02:45<00:00, 91.03it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.57it/s]


 {'precision': 0.11595248196860417, 'recall': 0.1756660622423153, 'ndcg': 0.16442445542687847}



100%|██████████| 2357/2357 [01:10<00:00, 33.55it/s]


___SEED___2
Epoch 1:



Batch loss = 0.577384: 100%|██████████| 15060/15060 [02:46<00:00, 90.51it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:10<00:00, 33.52it/s]


 {'precision': 0.11523122613491728, 'recall': 0.17673341506202614, 'ndcg': 0.16472421538662033}
Epoch 2:



Batch loss = 0.575992: 100%|██████████| 15060/15060 [02:46<00:00, 90.43it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.70it/s]


 {'precision': 0.11565549427238016, 'recall': 0.17654839324982205, 'ndcg': 0.16454621682819276}
Epoch 3:



Batch loss = 0.574835: 100%|██████████| 15060/15060 [02:48<00:00, 89.32it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.69it/s]


 {'precision': 0.11582520152736529, 'recall': 0.17622632417097026, 'ndcg': 0.16452099030358078}
Epoch 4:



Batch loss = 0.573862: 100%|██████████| 15060/15060 [02:48<00:00, 89.52it/s]


Evaluation (dev):



100%|██████████| 2357/2357 [01:09<00:00, 33.70it/s]


 {'precision': 0.11591005515485789, 'recall': 0.17607234421162468, 'ndcg': 0.16460485467650643}



100%|██████████| 2357/2357 [01:10<00:00, 33.62it/s]


In [None]:
{
    "precision": np.array(test_metrics["precision"]).mean(),
    "recall": np.array(test_metrics["recall"]).mean(),
    "ndcg": np.array(test_metrics["ndcg"]).mean(),
}

{'precision': 0.11712629048225144,
 'recall': 0.16806432819513914,
 'ndcg': 0.163914761181492}