In [1]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter
from replay.utils import get_spark_session
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback, 
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import (
    SequenceTokenizer,
    SequentialDataset,
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo
)
from replay.models.nn.sequential import Bert4Rec
from replay.models.nn.sequential.bert4rec import (
    Bert4RecPredictionDataset,
    Bert4RecTrainingDataset,
    Bert4RecValidationDataset,
    Bert4RecPredictionBatch,
    Bert4RecModel
)

import pandas as pd

In [2]:
# Загрузка данных
user_features = pd.read_csv('data/user_features.csv')
item_features = pd.read_csv('data/item_features.csv')
interactions  = pd.read_csv('data/events.csv')

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,rating,timestamp
0,0,1505,4,0
511292,3433,3022,3,0
511349,3434,3673,2,0
511721,3435,358,1,0
511769,3436,1279,4,0
...,...,...,...,...
336760,2257,3452,3,2051
336761,2257,2749,3,2052
336762,2257,2623,3,2053
336763,2257,1175,4,2054


In [4]:
splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

In [5]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

In [6]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

In [7]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

In [8]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

In [9]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
        embedding_dim=300,
    )
)

In [10]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

In [11]:
test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)

In [13]:
MAX_SEQ_LEN = 100
BATCH_SIZE = 1024
NUM_WORKERS = 4

model = Bert4Rec(
    tensor_schema,
    block_count=2,
    head_count=4,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)
checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints_bert4rec",
    save_top_k=1,
    verbose=True,
    # if you use multiple dataloaders, then add the serial number of the dataloader to the suffix of the metric name.
    # For example,"recall@10/dataloader_idx_0"
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)]
)

csv_logger = CSVLogger(save_dir=".logs/train", name="Bert4Rec_example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=Bert4RecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=Bert4RecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | _model | Bert4RecModel    | 4.4 M  | train
1 | _loss  | CrossEntropyLoss | 0      | train
----------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.702    Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\misha\Documents\PyProj\Recsys_2_new\.venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
C:\Users\misha\Documents\PyProj\Recsys_2_new\.venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
C:\Users\misha\Documents\PyProj\Recsys_2_new\.venv\lib\site-packages\lightning\pytorch\loops\fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 6: 'recall@10' reached 0.02434 (best 0.02434), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=0-step=6.ckpt' as top 1


k              1        10        20         5
map     0.004305  0.008905  0.010396  0.007470
ndcg    0.004305  0.012454  0.017975  0.008975
recall  0.004305  0.024338  0.046358  0.013576



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 12: 'recall@10' reached 0.03659 (best 0.03659), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=1-step=12.ckpt' as top 1


k              1        10        20         5
map     0.003311  0.010563  0.012520  0.008107
ndcg    0.003311  0.016503  0.023789  0.010432
recall  0.003311  0.036589  0.065728  0.017550



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 18: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00298  0.010117  0.012074  0.007779
ndcg    0.00298  0.016022  0.023362  0.010288
recall  0.00298  0.035927  0.065397  0.018046



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 24: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00298  0.009979  0.012100  0.007552
ndcg    0.00298  0.015617  0.023519  0.009731
recall  0.00298  0.034603  0.066225  0.016391



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 30: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003974  0.010799  0.012928  0.008557
ndcg    0.003974  0.016476  0.024416  0.010965
recall  0.003974  0.035596  0.067384  0.018377



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 36: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003642  0.011246  0.013208  0.009150
ndcg    0.003642  0.016952  0.024216  0.011712
recall  0.003642  0.036093  0.065066  0.019536



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 42: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003146  0.010260  0.012406  0.008389
ndcg    0.003146  0.015670  0.023660  0.010968
recall  0.003146  0.033775  0.065728  0.018874



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 48: 'recall@10' was not in top 1


k              1        10        20         5
map     0.004305  0.010278  0.012390  0.008648
ndcg    0.004305  0.015183  0.023022  0.011101
recall  0.004305  0.031623  0.062914  0.018709



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 54: 'recall@10' was not in top 1


k              1        10        20         5
map     0.004305  0.010715  0.012774  0.008802
ndcg    0.004305  0.015707  0.023286  0.010998
recall  0.004305  0.032450  0.062583  0.017715



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 60: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003642  0.010083  0.012055  0.007688
ndcg    0.003642  0.015738  0.023099  0.009854
recall  0.003642  0.034934  0.064404  0.016556



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 66: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00298  0.009333  0.011464  0.007597
ndcg    0.00298  0.014436  0.022464  0.010185
recall  0.00298  0.031457  0.063742  0.018212



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 72: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003642  0.009476  0.011567  0.007569
ndcg    0.003642  0.014399  0.022184  0.009723
recall  0.003642  0.030960  0.062086  0.016391



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 78: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003146  0.009487  0.011454  0.007260
ndcg    0.003146  0.014890  0.022160  0.009299
recall  0.003146  0.033278  0.062252  0.015563



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 84: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003642  0.010755  0.012732  0.008449
ndcg    0.003642  0.016447  0.023702  0.010777
recall  0.003642  0.035596  0.064404  0.017881



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 90: 'recall@10' reached 0.03725 (best 0.03725), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=14-step=90.ckpt' as top 1


k              1        10        20         5
map     0.003974  0.011566  0.013708  0.009434
ndcg    0.003974  0.017472  0.025428  0.012196
recall  0.003974  0.037252  0.069040  0.020695



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 96: 'recall@10' reached 0.04040 (best 0.04040), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=15-step=96.ckpt' as top 1


k              1        10        20         5
map     0.004636  0.012935  0.014892  0.010618
ndcg    0.004636  0.019259  0.026574  0.013546
recall  0.004636  0.040397  0.069702  0.022517



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 102: 'recall@10' reached 0.04288 (best 0.04288), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=16-step=102.ckpt' as top 1


k              1        10        20         5
map     0.003642  0.013085  0.015283  0.010477
ndcg    0.003642  0.019949  0.028086  0.013503
recall  0.003642  0.042881  0.075331  0.022682



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 108: 'recall@10' reached 0.04619 (best 0.04619), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=17-step=108.ckpt' as top 1


k              1        10        20         5
map     0.005132  0.014246  0.016482  0.011322
ndcg    0.005132  0.021559  0.029826  0.014391
recall  0.005132  0.046192  0.079139  0.023841



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 114: 'recall@10' was not in top 1


k              1        10        20         5
map     0.005795  0.015448  0.018170  0.013044
ndcg    0.005795  0.022526  0.032581  0.016663
recall  0.005795  0.046026  0.086093  0.027815



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 120: 'recall@10' reached 0.04719 (best 0.04719), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=19-step=120.ckpt' as top 1


k              1        10        20         5
map     0.006126  0.015701  0.018369  0.012881
ndcg    0.006126  0.022946  0.032963  0.016069
recall  0.006126  0.047185  0.087417  0.025828



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 126: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006623  0.016012  0.018751  0.013289
ndcg    0.006623  0.023185  0.033318  0.016500
recall  0.006623  0.047185  0.087583  0.026325



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 132: 'recall@10' was not in top 1


k              1        10        20         5
map     0.005464  0.014830  0.017353  0.011984
ndcg    0.005464  0.022097  0.031409  0.015069
recall  0.005464  0.046523  0.083609  0.024503



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 138: 'recall@10' reached 0.05017 (best 0.05017), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=22-step=138.ckpt' as top 1


k              1        10        20         5
map     0.006126  0.016084  0.018789  0.012930
ndcg    0.006126  0.023902  0.033883  0.016189
recall  0.006126  0.050166  0.089901  0.026159



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 144: 'recall@10' reached 0.05149 (best 0.05149), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=23-step=144.ckpt' as top 1


k              1        10        20         5
map     0.007781  0.016950  0.019704  0.013629
ndcg    0.007781  0.024818  0.035161  0.016654
recall  0.007781  0.051490  0.093046  0.025993



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 150: 'recall@10' reached 0.05298 (best 0.05298), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=24-step=150.ckpt' as top 1


k              1        10        20         5
map     0.008609  0.018289  0.021136  0.014959
ndcg    0.008609  0.026238  0.036700  0.018192
recall  0.008609  0.052980  0.094536  0.028146



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 156: 'recall@10' reached 0.05397 (best 0.05397), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=25-step=156.ckpt' as top 1


k              1        10        20         5
map     0.006623  0.017788  0.020636  0.014432
ndcg    0.006623  0.026121  0.036637  0.017984
recall  0.006623  0.053974  0.095861  0.028808



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 162: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006788  0.016897  0.020032  0.013982
ndcg    0.006788  0.024711  0.036219  0.017424
recall  0.006788  0.050993  0.096689  0.027980



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 168: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00596  0.015574  0.018875  0.012875
ndcg    0.00596  0.023361  0.035542  0.016607
recall  0.00596  0.049503  0.098013  0.028146



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 174: 'recall@10' reached 0.05513 (best 0.05513), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=28-step=174.ckpt' as top 1


k              1        10        20         5
map     0.006788  0.017255  0.020293  0.013750
ndcg    0.006788  0.025927  0.037422  0.017391
recall  0.006788  0.055132  0.101490  0.028642



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 180: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006291  0.017685  0.020664  0.014503
ndcg    0.006291  0.026245  0.037367  0.018475
recall  0.006291  0.054801  0.099338  0.030629



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30, global step 186: 'recall@10' reached 0.05762 (best 0.05762), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=30-step=186.ckpt' as top 1


k              1        10        20         5
map     0.006291  0.018796  0.021848  0.015668
ndcg    0.006291  0.027795  0.039337  0.020203
recall  0.006291  0.057616  0.104139  0.034106



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31, global step 192: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006457  0.018133  0.021360  0.015039
ndcg    0.006457  0.026977  0.038903  0.019356
recall  0.006457  0.056457  0.103974  0.032616



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 198: 'recall@10' reached 0.06192 (best 0.06192), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=32-step=198.ckpt' as top 1


k             1        10        20         5
map     0.00745  0.019391  0.022497  0.015613
ndcg    0.00745  0.029128  0.040734  0.019816
recall  0.00745  0.061921  0.108444  0.032781



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33, global step 204: 'recall@10' was not in top 1


k              1        10        20         5
map     0.007616  0.019783  0.023130  0.016256
ndcg    0.007616  0.029247  0.041654  0.020596
recall  0.007616  0.060927  0.110430  0.033940



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34, global step 210: 'recall@10' reached 0.06540 (best 0.06540), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=34-step=210.ckpt' as top 1


k              1        10        20         5
map     0.008113  0.021258  0.024673  0.017500
ndcg    0.008113  0.031422  0.044074  0.022193
recall  0.008113  0.065397  0.115894  0.036589



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 216: 'recall@10' reached 0.06805 (best 0.06805), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=35-step=216.ckpt' as top 1


k              1        10        20         5
map     0.008609  0.022261  0.026144  0.018579
ndcg    0.008609  0.032804  0.047058  0.023682
recall  0.008609  0.068046  0.124669  0.039404



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36, global step 222: 'recall@10' reached 0.08096 (best 0.08096), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=36-step=222.ckpt' as top 1


k              1        10        20         5
map     0.010265  0.026303  0.030216  0.021617
ndcg    0.010265  0.038849  0.053170  0.027217
recall  0.010265  0.080960  0.137748  0.044371



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37, global step 228: 'recall@10' reached 0.09685 (best 0.09685), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=37-step=228.ckpt' as top 1


k              1        10        20         5
map     0.015232  0.033170  0.037912  0.027210
ndcg    0.015232  0.047751  0.065253  0.033175
recall  0.015232  0.096854  0.166556  0.051490



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 234: 'recall@10' reached 0.10728 (best 0.10728), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=38-step=234.ckpt' as top 1


k              1        10        20         5
map     0.015066  0.036427  0.041790  0.030781
ndcg    0.015066  0.052742  0.072892  0.038670
recall  0.015066  0.107285  0.188245  0.062914



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39, global step 240: 'recall@10' reached 0.11639 (best 0.11639), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=39-step=240.ckpt' as top 1


k              1        10        20         5
map     0.016225  0.040223  0.045989  0.034186
ndcg    0.016225  0.057838  0.079345  0.042999
recall  0.016225  0.116391  0.202483  0.070033



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40, global step 246: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016391  0.038912  0.044809  0.032861
ndcg    0.016391  0.056225  0.077922  0.041216
recall  0.016391  0.114073  0.200331  0.066887



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 252: 'recall@10' reached 0.12434 (best 0.12434), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=41-step=252.ckpt' as top 1


k              1        10        20         5
map     0.015728  0.041135  0.047059  0.034051
ndcg    0.015728  0.060314  0.082345  0.042918
recall  0.015728  0.124338  0.212417  0.070033



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42, global step 258: 'recall@10' reached 0.13825 (best 0.13825), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=42-step=258.ckpt' as top 1


k              1        10        20         5
map     0.017715  0.046001  0.051389  0.038411
ndcg    0.017715  0.067284  0.087356  0.048625
recall  0.017715  0.138245  0.218543  0.079967



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43, global step 264: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019702  0.047108  0.053037  0.039790
ndcg    0.019702  0.067580  0.089372  0.049665
recall  0.019702  0.135762  0.222351  0.079967



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 270: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021358  0.048162  0.053806  0.040510
ndcg    0.021358  0.068777  0.089607  0.049965
recall  0.021358  0.137748  0.220695  0.078974



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45, global step 276: 'recall@10' reached 0.14040 (best 0.14040), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=45-step=276.ckpt' as top 1


k              1        10        20         5
map     0.020695  0.048036  0.054157  0.040254
ndcg    0.020695  0.069264  0.091680  0.049974
recall  0.020695  0.140397  0.229305  0.079801



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46, global step 282: 'recall@10' reached 0.14967 (best 0.14967), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=46-step=282.ckpt' as top 1


k              1        10        20         5
map     0.022682  0.051436  0.057506  0.043275
ndcg    0.022682  0.074033  0.096415  0.053926
recall  0.022682  0.149669  0.238742  0.086755



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 288: 'recall@10' reached 0.15166 (best 0.15166), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=47-step=288.ckpt' as top 1


k              1        10        20         5
map     0.022351  0.053749  0.060242  0.046228
ndcg    0.022351  0.076437  0.100362  0.057911
recall  0.022351  0.151656  0.246854  0.093709



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48, global step 294: 'recall@10' reached 0.15215 (best 0.15215), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=48-step=294.ckpt' as top 1


k              1        10        20         5
map     0.024007  0.054021  0.060701  0.046115
ndcg    0.024007  0.076698  0.101241  0.057313
recall  0.024007  0.152152  0.249669  0.091722



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49, global step 300: 'recall@10' reached 0.15315 (best 0.15315), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=49-step=300.ckpt' as top 1


k              1        10        20         5
map     0.024503  0.054111  0.060615  0.045723
ndcg    0.024503  0.076953  0.100895  0.056489
recall  0.024503  0.153146  0.248344  0.089570



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 306: 'recall@10' reached 0.15944 (best 0.15944), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=50-step=306.ckpt' as top 1


k              1        10        20         5
map     0.024834  0.055542  0.061982  0.046396
ndcg    0.024834  0.079427  0.103152  0.057012
recall  0.024834  0.159437  0.253808  0.089570



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51, global step 312: 'recall@10' was not in top 1


k              1        10        20         5
map     0.025993  0.057348  0.063897  0.049034
ndcg    0.025993  0.080787  0.104814  0.060464
recall  0.025993  0.158775  0.254139  0.095530



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52, global step 318: 'recall@10' reached 0.16407 (best 0.16407), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=52-step=318.ckpt' as top 1


k              1        10        20         5
map     0.025497  0.059364  0.066055  0.051366
ndcg    0.025497  0.083641  0.108009  0.063947
recall  0.025497  0.164073  0.260430  0.102483



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 324: 'recall@10' reached 0.16474 (best 0.16474), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=53-step=324.ckpt' as top 1


k              1        10        20         5
map     0.023841  0.057713  0.064089  0.049258
ndcg    0.023841  0.082520  0.106081  0.061899
recall  0.023841  0.164735  0.258609  0.100662



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54, global step 330: 'recall@10' was not in top 1


k              1        10        20         5
map     0.025993  0.059407  0.065785  0.051471
ndcg    0.025993  0.083426  0.106824  0.064047
recall  0.025993  0.162914  0.255795  0.102649



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55, global step 336: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027152  0.060120  0.066824  0.051976
ndcg    0.027152  0.084184  0.108790  0.064185
recall  0.027152  0.164073  0.261755  0.101656



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 342: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027318  0.060767  0.067586  0.052586
ndcg    0.027318  0.084438  0.109374  0.064496
recall  0.027318  0.162748  0.261589  0.100828



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57, global step 348: 'recall@10' reached 0.16523 (best 0.16523), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=57-step=348.ckpt' as top 1


k             1        10        20         5
map     0.02798  0.061098  0.067964  0.052712
ndcg    0.02798  0.085198  0.110528  0.064643
recall  0.02798  0.165232  0.266060  0.101159



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58, global step 354: 'recall@10' was not in top 1


k              1        10        20         5
map     0.026987  0.060084  0.066753  0.051863
ndcg    0.026987  0.084080  0.108667  0.063880
recall  0.026987  0.163742  0.261589  0.100662



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 360: 'recall@10' reached 0.17235 (best 0.17235), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=59-step=360.ckpt' as top 1


k              1        10        20         5
map     0.028974  0.063883  0.070578  0.055353
ndcg    0.028974  0.088986  0.113551  0.067912
recall  0.028974  0.172351  0.269868  0.106291



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60, global step 366: 'recall@10' was not in top 1


k              1        10        20         5
map     0.024172  0.059570  0.065944  0.050281
ndcg    0.024172  0.085197  0.108669  0.062482
recall  0.024172  0.170364  0.263742  0.099669



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61, global step 372: 'recall@10' was not in top 1


k              1        10        20         5
map     0.024669  0.059727  0.066585  0.050698
ndcg    0.024669  0.085591  0.110747  0.063388
recall  0.024669  0.171689  0.271523  0.102318



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 378: 'recall@10' reached 0.17533 (best 0.17533), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=62-step=378.ckpt' as top 1


k              1        10        20         5
map     0.028146  0.063550  0.070299  0.054219
ndcg    0.028146  0.089356  0.114288  0.066410
recall  0.028146  0.175331  0.274669  0.103642



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63, global step 384: 'recall@10' reached 0.17715 (best 0.17715), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=63-step=384.ckpt' as top 1


k              1        10        20         5
map     0.028146  0.063720  0.070396  0.054741
ndcg    0.028146  0.089951  0.114437  0.067767
recall  0.028146  0.177152  0.274338  0.107616



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64, global step 390: 'recall@10' was not in top 1


k              1        10        20         5
map     0.025331  0.061949  0.068663  0.053372
ndcg    0.025331  0.088129  0.113018  0.067007
recall  0.025331  0.174834  0.274172  0.108775



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 396: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028146  0.062255  0.068913  0.053259
ndcg    0.028146  0.086920  0.111414  0.064908
recall  0.028146  0.168874  0.266225  0.100331



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66, global step 402: 'recall@10' reached 0.17765 (best 0.17765), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=66-step=402.ckpt' as top 1


k              1        10        20         5
map     0.032285  0.066850  0.073912  0.058033
ndcg    0.032285  0.092463  0.118533  0.070753
recall  0.032285  0.177649  0.281457  0.109768



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67, global step 408: 'recall@10' reached 0.18609 (best 0.18609), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=67-step=408.ckpt' as top 1


k              1        10        20         5
map     0.029139  0.066649  0.073614  0.056962
ndcg    0.029139  0.094285  0.120009  0.070575
recall  0.029139  0.186093  0.288576  0.112252



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 414: 'recall@10' was not in top 1


k              1        10        20         5
map     0.026656  0.064389  0.071739  0.054774
ndcg    0.026656  0.091404  0.118409  0.067828
recall  0.026656  0.181126  0.288411  0.107616



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69, global step 420: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028311  0.065839  0.073199  0.056841
ndcg    0.028311  0.092702  0.119752  0.070654
recall  0.028311  0.181623  0.289073  0.112914



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70, global step 426: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02947  0.065629  0.072540  0.056115
ndcg    0.02947  0.092125  0.117542  0.068810
recall  0.02947  0.180298  0.281291  0.107616



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 432: 'recall@10' reached 0.18891 (best 0.18891), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=71-step=432.ckpt' as top 1


k              1        10        20         5
map     0.030132  0.068408  0.075255  0.058640
ndcg    0.030132  0.096272  0.121412  0.072293
recall  0.030132  0.188907  0.288742  0.114073



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72, global step 438: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031126  0.069304  0.076239  0.059616
ndcg    0.031126  0.096870  0.122424  0.073314
recall  0.031126  0.188245  0.289901  0.115232



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73, global step 444: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028311  0.066492  0.073100  0.057224
ndcg    0.028311  0.094802  0.119341  0.072047
recall  0.028311  0.188742  0.286755  0.117715



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 450: 'recall@10' was not in top 1


k              1        10        20         5
map     0.029139  0.067546  0.074672  0.058306
ndcg    0.029139  0.095345  0.121575  0.072867
recall  0.029139  0.187252  0.291556  0.117550



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75, global step 456: 'recall@10' reached 0.19454 (best 0.19454), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=75-step=456.ckpt' as top 1


k              1        10        20         5
map     0.028974  0.070595  0.077624  0.061156
ndcg    0.028974  0.099410  0.125218  0.076264
recall  0.028974  0.194536  0.297020  0.122351



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76, global step 462: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032781  0.070597  0.077536  0.061076
ndcg    0.032781  0.098173  0.123565  0.074663
recall  0.032781  0.189901  0.290563  0.116225



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 468: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032616  0.070643  0.078044  0.061639
ndcg    0.032616  0.097881  0.125256  0.075769
recall  0.032616  0.188079  0.297185  0.119040



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78, global step 474: 'recall@10' reached 0.19735 (best 0.19735), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=78-step=474.ckpt' as top 1


k              1        10        20         5
map     0.032285  0.072277  0.079208  0.062017
ndcg    0.032285  0.101196  0.126744  0.075937
recall  0.032285  0.197351  0.299007  0.118377



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79, global step 480: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031623  0.071474  0.078454  0.061595
ndcg    0.031623  0.100287  0.125898  0.076070
recall  0.031623  0.195861  0.297517  0.120364



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 486: 'recall@10' was not in top 1


k              1        10        20         5
map     0.029801  0.068567  0.075963  0.059859
ndcg    0.029801  0.095838  0.122986  0.074458
recall  0.029801  0.185927  0.293709  0.119205



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81, global step 492: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031291  0.069558  0.076835  0.059705
ndcg    0.031291  0.098028  0.124832  0.073682
recall  0.031291  0.192881  0.299503  0.116556



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82, global step 498: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030464  0.070406  0.077787  0.060626
ndcg    0.030464  0.099022  0.126043  0.075248
recall  0.030464  0.193709  0.300828  0.120033



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 504: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030298  0.069588  0.077309  0.059765
ndcg    0.030298  0.098151  0.126363  0.074022
recall  0.030298  0.193046  0.304801  0.117715



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84, global step 510: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028146  0.067260  0.074807  0.057848
ndcg    0.028146  0.095401  0.123042  0.072376
recall  0.028146  0.188576  0.298179  0.116887



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85, global step 516: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032119  0.071243  0.078563  0.061874
ndcg    0.032119  0.099449  0.126500  0.076497
recall  0.032119  0.192881  0.300662  0.121358



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 522: 'recall@10' reached 0.20000 (best 0.20000), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=86-step=522.ckpt' as top 1


k             1        10        20         5
map     0.03394  0.075467  0.083025  0.066090
ndcg    0.03394  0.104417  0.132018  0.081457
recall  0.03394  0.200000  0.309272  0.128477



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87, global step 528: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031457  0.072520  0.079884  0.062950
ndcg    0.031457  0.101208  0.128233  0.077772
recall  0.031457  0.196026  0.303311  0.123013



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88, global step 534: 'recall@10' reached 0.20298 (best 0.20298), saving model to 'C:\\Users\\misha\\Documents\\PyProj\\Recsys_2_new\\.checkpoints_bert4rec\\epoch=88-step=534.ckpt' as top 1


k              1        10        20         5
map     0.034603  0.074709  0.081936  0.064556
ndcg    0.034603  0.104384  0.130821  0.079480
recall  0.034603  0.202980  0.307781  0.125331



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 540: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032119  0.071515  0.079195  0.061992
ndcg    0.032119  0.099659  0.127877  0.076303
recall  0.032119  0.192881  0.304967  0.120033



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90, global step 546: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03394  0.073260  0.081169  0.063551
ndcg    0.03394  0.101697  0.130749  0.077874
recall  0.03394  0.196026  0.311424  0.121689



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91, global step 552: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031954  0.072249  0.080008  0.062472
ndcg    0.031954  0.101079  0.129617  0.077215
recall  0.031954  0.196523  0.309934  0.122351



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 558: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.076712  0.084342  0.067100
ndcg    0.035265  0.105768  0.133622  0.082164
recall  0.035265  0.201821  0.312086  0.128146



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93, global step 564: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033609  0.075026  0.082626  0.065687
ndcg    0.033609  0.103925  0.131933  0.080879
recall  0.033609  0.199503  0.310927  0.127318



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94, global step 570: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033444  0.074483  0.082191  0.064688
ndcg    0.033444  0.103530  0.131729  0.079650
recall  0.033444  0.199503  0.311258  0.125331



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 576: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031126  0.071931  0.079539  0.062139
ndcg    0.031126  0.101044  0.129115  0.077100
recall  0.031126  0.197351  0.309106  0.122848



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96, global step 582: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031623  0.070336  0.077616  0.060938
ndcg    0.031623  0.098203  0.125038  0.075162
recall  0.031623  0.190563  0.297351  0.118709



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97, global step 588: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034106  0.072583  0.080374  0.062969
ndcg    0.034106  0.100707  0.129476  0.076944
recall  0.034106  0.194205  0.308775  0.119702



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 594: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032285  0.073943  0.081693  0.064291
ndcg    0.032285  0.103156  0.131817  0.079385
recall  0.032285  0.199834  0.314073  0.125497



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99, global step 600: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.033444  0.074046  0.081561  0.063689
ndcg    0.033444  0.103725  0.131597  0.078373
recall  0.033444  0.202318  0.313576  0.123344



In [14]:
best_model = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

In [15]:
prediction_dataloader = DataLoader(
    dataset=Bert4RecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

csv_logger = CSVLogger(save_dir=".logs/test", name="Bert4Rec_example")

In [26]:
TOPK = [10]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)


query_embeddings_callback = QueryEmbeddingsPredictionCallback()

trainer = L.Trainer(
    callbacks=[
        pandas_prediction_callback,
        query_embeddings_callback,
    ], 
    logger=csv_logger, 
    inference_mode=True
)
trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()
user_embeddings = query_embeddings_callback.get_result()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
C:\Users\misha\Documents\PyProj\Recsys_2_new\.venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'predict_dataloader' to speed up the dataloader worker initialization.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [27]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(pandas_res)

In [28]:
init_args = {"query_column": "user_id", "rating_column": "score"}

In [29]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK), NDCG(TOPK), MRR(TOPK), HitRate(TOPK)], **init_args)(recommendations, raw_test_gt)

In [30]:
result_metrics

{'Recall@10': 0.1956953642384106,
 'Precision@10': 0.01956953642384106,
 'MAP@10': 0.07520885892988542,
 'NDCG@10': 0.1032067555833944,
 'MRR@10': 0.07520885892988542,
 'HitRate@10': 0.1956953642384106}

In [31]:
submission_df = recommendations.groupby('user_id')['item_id'].apply(lambda x: ' '.join(x.astype(str))).reset_index()
submission_df
submission_df.to_csv('test_sub_4.csv', index=False)