<a href="https://colab.research.google.com/github/akzholba/RecSysCompetition/blob/bert4rec/bert/Bert4rec_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Example of the Bert4Rec training and inference stages
Note that all the given examples can be run without using PySpark, using only Pandas

In [1]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.5.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.4.0-py3-none-any.whl (810 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m811.0/811.0 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.5.1-py3-none-any.whl (890 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.6/890.6 kB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

In [3]:
from schema import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
)

from optimizer_utils import FatOptimizerFactory

from bert4rec import Bert4Rec, Bert4RecModel

from dataset import (
    Bert4RecPredictionDataset,
    Bert4RecTrainingDataset,
    Bert4RecValidationDataset,
    Bert4RecPredictionBatch,
)

import pandas as pd

## Getting a spark session

In [4]:
from data import Dataset, get_spark_session

spark_session = get_spark_session()



## Prepare data
### Load raw movielens-1M interactions, item features and user features.
In the current implementation, the SASRec does not take into account the features of items or users. They are only used to get a complete list of users and items.

In [5]:
# !pip install rs-datasets

Collecting rs-datasets
  Downloading rs_datasets-0.5.1-py3-none-any.whl.metadata (2.6 kB)
Collecting datatable (from rs-datasets)
  Downloading datatable-1.1.0-cp310-cp310-manylinux_2_35_x86_64.whl.metadata (1.8 kB)
Collecting py7zr (from rs-datasets)
  Downloading py7zr-0.22.0-py3-none-any.whl.metadata (16 kB)
Collecting texttable (from py7zr->rs-datasets)
  Downloading texttable-1.7.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting pycryptodomex>=3.16.0 (from py7zr->rs-datasets)
  Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting pyzstd>=0.15.9 (from py7zr->rs-datasets)
  Downloading pyzstd-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting pyppmd<1.2.0,>=1.1.0 (from py7zr->rs-datasets)
  Downloading pyppmd-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.7 kB)
Collecting pybcj<1.1.0,>=1.0.0 (from py7zr->rs-datasets)
  Downloading pybcj-1.0.2-

In [6]:
# from rs_datasets import MovieLens

In [7]:
# movielens = MovieLens("1m")
# interactions = movielens.ratings
# user_features = movielens.users
# item_features = movielens.items

In [8]:
# interactions.head()

In [9]:
# user_features.head()

In [10]:
# item_features.head()

In [11]:
from sklearn.preprocessing import LabelEncoder
import pandas as pd
le = LabelEncoder()

interactions = pd.read_csv('./data/events.csv')
interactions[interactions['user_id']==0]

item_features = pd.read_csv('./data/item_features.csv')

user_features = pd.read_csv('./data/user_features.csv')
user_features['gender'] = le.fit_transform(user_features['gender'])

Removing duplicates in the timestamp column without changing the original items order where timestamp is the same

In [12]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,rating,timestamp
0,0,1505,4,0
511292,3433,3022,3,0
511349,3434,3673,2,0
511721,3435,358,1,0
511769,3436,1279,4,0
...,...,...,...,...
336760,2257,3452,3,2051
336761,2257,2749,3,2052
336762,2257,2623,3,2053
336763,2257,1175,4,2054


### Split interactions into the train, validation and test datasets using LastNSplitter

In [13]:
from last_n_splitter import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

### Prepare FeatureSchema required to create Dataset

In [14]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

### Create Dataset for the training stage

In [15]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the validation stage

In [16]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the testing stage

In [17]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model

In [18]:
from schema import (
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo)

ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
        embedding_dim=300,
    )
)

### Create sequential datasets using SequenceTokenizer
The SequentialDataset internally store data in the form of sequences of items sorted by increasing interaction time (timestamp). A SequenceTokenizer is used to convert to this format. In addition, the SequenceTokenizer encodes all categorical columns from the source dataset and stores mapping inside itself.
SequentialDataset.keep_common_query_ids is used to leave only sequences from the same users

In [19]:
# !pip install pyspark



In [20]:
from sequence_tokenizer import SequenceTokenizer
from sequential_dataset import SequentialDataset

tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

In [21]:
test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)

You can get the user and item mapping and inverse mapping as follows

In [22]:
print(tokenizer.query_id_encoder.mapping, tokenizer.query_id_encoder.inverse_mapping)
print(tokenizer.item_id_encoder.mapping, tokenizer.item_id_encoder.inverse_mapping)

{'user_id': {4855: 0, 4065: 1, 3331: 2, 5373: 3, 2032: 4, 5875: 5, 3984: 6, 4062: 7, 5117: 8, 5822: 9, 174: 10, 5188: 11, 595: 12, 2538: 13, 5031: 14, 4765: 15, 1819: 16, 3970: 17, 568: 18, 4007: 19, 2641: 20, 2646: 21, 3839: 22, 3263: 23, 281: 24, 2009: 25, 5836: 26, 1581: 27, 679: 28, 3634: 29, 2401: 30, 2184: 31, 5532: 32, 3638: 33, 4159: 34, 1770: 35, 3754: 36, 637: 37, 1452: 38, 5412: 39, 5345: 40, 3078: 41, 4772: 42, 3484: 43, 1064: 44, 2812: 45, 3120: 46, 4295: 47, 491: 48, 3283: 49, 5595: 50, 622: 51, 4428: 52, 1570: 53, 4561: 54, 3927: 55, 127: 56, 1950: 57, 1877: 58, 2285: 59, 656: 60, 462: 61, 4055: 62, 4477: 63, 2148: 64, 1582: 65, 272: 66, 3556: 67, 883: 68, 5295: 69, 3223: 70, 4070: 71, 3: 72, 5314: 73, 4225: 74, 1341: 75, 5909: 76, 1413: 77, 4463: 78, 3900: 79, 4426: 80, 811: 81, 3491: 82, 5118: 83, 2018: 84, 1308: 85, 4379: 86, 4351: 87, 2995: 88, 3680: 89, 1336: 90, 3758: 91, 1286: 92, 5003: 93, 3574: 94, 1703: 95, 1855: 96, 32: 97, 5901: 98, 5207: 99, 1516: 100, 5457:

## Train model
### Create Bert4Rec model instance and run the training stage using lightning
After each epoch validation metrics are shown. You can change the list of validation metrics in ValidationMetricsCallback
The model is determined to be the best and is saved if the metric updates its maximum during validation (see the ModelCheckpoint)

In [57]:
from validation_callback import ValidationMetricsCallback
from postprocessors import RemoveSeenItems

MAX_SEQ_LEN = 100
BATCH_SIZE = 512
NUM_WORKERS = 4

model = Bert4Rec(
    tensor_schema,
    block_count=2,
    head_count=4,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)
checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    # if you use multiple dataloaders, then add the serial number of the dataloader to the suffix of the metric name.
    # For example,"recall@10/dataloader_idx_0"
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)]
)

csv_logger = CSVLogger(save_dir=".logs/train", name="Bert4Rec_example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=Bert4RecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=Bert4RecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /content/.checkpoints exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | _model | Bert4RecModel    | 4.4 M  | train
1 | _loss  | CrossEntropyLoss | 0      | train
----------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (12) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 0, global step 12: 'recall@10' reached 0.03394 (best 0.03394), saving model to '/content/.checkpoints/epoch=0-step=12.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 12: 'recall@10' reached 0.03394 (best 0.03394), saving model to '/content/.checkpoints/epoch=0-step=12.ckpt' as top 1


k              1        10        20         5
map     0.003477  0.010257  0.012229  0.008242
ndcg    0.003477  0.015700  0.022976  0.010729
recall  0.003477  0.033940  0.062914  0.018377





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 1, global step 24: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 24: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003311  0.009583  0.011511  0.007693
ndcg    0.003311  0.014605  0.021827  0.009949
recall  0.003311  0.031457  0.060430  0.016887





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 36: 'recall@10' reached 0.03659 (best 0.03659), saving model to '/content/.checkpoints/epoch=2-step=36-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 36: 'recall@10' reached 0.03659 (best 0.03659), saving model to '/content/.checkpoints/epoch=2-step=36-v1.ckpt' as top 1


k              1        10        20         5
map     0.003642  0.010534  0.012517  0.008162
ndcg    0.003642  0.016487  0.023861  0.010688
recall  0.003642  0.036589  0.066060  0.018543





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 3, global step 48: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 48: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003642  0.009799  0.011883  0.007795
ndcg    0.003642  0.014708  0.022400  0.009716
recall  0.003642  0.031291  0.061921  0.015563





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 60: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 60: 'recall@10' was not in top 1


k              1        10        20         5
map     0.002815  0.009482  0.011357  0.007116
ndcg    0.002815  0.015433  0.022369  0.009608
recall  0.002815  0.035596  0.063245  0.017384





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 5, global step 72: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 5, global step 72: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00298  0.009066  0.011054  0.006995
ndcg    0.00298  0.014297  0.021676  0.009239
recall  0.00298  0.031954  0.061424  0.016225





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 6, global step 84: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 6, global step 84: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003146  0.009909  0.011882  0.007936
ndcg    0.003146  0.015227  0.022421  0.010486
recall  0.003146  0.032947  0.061424  0.018377





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 7, global step 96: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 7, global step 96: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003311  0.010514  0.012373  0.008206
ndcg    0.003311  0.016209  0.023126  0.010694
recall  0.003311  0.035265  0.062914  0.018377





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 8, global step 108: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 8, global step 108: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00298  0.009614  0.011629  0.007127
ndcg    0.00298  0.015572  0.023012  0.009313
recall  0.00298  0.035927  0.065563  0.016060





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 9, global step 120: 'recall@10' reached 0.04089 (best 0.04089), saving model to '/content/.checkpoints/epoch=9-step=120.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 9, global step 120: 'recall@10' reached 0.04089 (best 0.04089), saving model to '/content/.checkpoints/epoch=9-step=120.ckpt' as top 1


k              1        10        20         5
map     0.005298  0.013637  0.015750  0.011498
ndcg    0.005298  0.019926  0.027763  0.014611
recall  0.005298  0.040894  0.072185  0.024172





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 10, global step 132: 'recall@10' reached 0.04652 (best 0.04652), saving model to '/content/.checkpoints/epoch=10-step=132.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 10, global step 132: 'recall@10' reached 0.04652 (best 0.04652), saving model to '/content/.checkpoints/epoch=10-step=132.ckpt' as top 1


k              1        10        20         5
map     0.006954  0.016936  0.019193  0.014605
ndcg    0.006954  0.023820  0.032223  0.018158
recall  0.006954  0.046523  0.080132  0.028974





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 11, global step 144: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 11, global step 144: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006291  0.015650  0.018238  0.013226
ndcg    0.006291  0.022585  0.032172  0.016607
recall  0.006291  0.045695  0.083940  0.026987





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 12, global step 156: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 12, global step 156: 'recall@10' was not in top 1


k              1        10        20         5
map     0.005464  0.014786  0.017163  0.012467
ndcg    0.005464  0.021456  0.030227  0.015803
recall  0.005464  0.043543  0.078477  0.025993





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 13, global step 168: 'recall@10' reached 0.04785 (best 0.04785), saving model to '/content/.checkpoints/epoch=13-step=168.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 13, global step 168: 'recall@10' reached 0.04785 (best 0.04785), saving model to '/content/.checkpoints/epoch=13-step=168.ckpt' as top 1


k              1        10        20         5
map     0.006788  0.016017  0.018415  0.013193
ndcg    0.006788  0.023298  0.032150  0.016231
recall  0.006788  0.047848  0.083113  0.025497





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 14, global step 180: 'recall@10' reached 0.05017 (best 0.05017), saving model to '/content/.checkpoints/epoch=14-step=180.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 14, global step 180: 'recall@10' reached 0.05017 (best 0.05017), saving model to '/content/.checkpoints/epoch=14-step=180.ckpt' as top 1


k              1        10        20         5
map     0.007781  0.016753  0.019296  0.013551
ndcg    0.007781  0.024372  0.033901  0.016515
recall  0.007781  0.050166  0.088411  0.025662





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 15, global step 192: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 15, global step 192: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006291  0.016008  0.018854  0.012873
ndcg    0.006291  0.023662  0.034234  0.016060
recall  0.006291  0.049338  0.091556  0.025828





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 16, global step 204: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 16, global step 204: 'recall@10' was not in top 1


k              1        10        20         5
map     0.005795  0.015428  0.018110  0.012368
ndcg    0.005795  0.023095  0.033185  0.015478
recall  0.005795  0.049007  0.089570  0.025000





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 17, global step 216: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 17, global step 216: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00745  0.016754  0.019930  0.013896
ndcg    0.00745  0.024045  0.035823  0.017036
recall  0.00745  0.048510  0.095530  0.026656





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 18, global step 228: 'recall@10' reached 0.05149 (best 0.05149), saving model to '/content/.checkpoints/epoch=18-step=228.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 18, global step 228: 'recall@10' reached 0.05149 (best 0.05149), saving model to '/content/.checkpoints/epoch=18-step=228.ckpt' as top 1


k              1        10        20         5
map     0.007285  0.016795  0.019716  0.013609
ndcg    0.007285  0.024710  0.035485  0.016809
recall  0.007285  0.051490  0.094371  0.026656





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 19, global step 240: 'recall@10' reached 0.05364 (best 0.05364), saving model to '/content/.checkpoints/epoch=19-step=240.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 19, global step 240: 'recall@10' reached 0.05364 (best 0.05364), saving model to '/content/.checkpoints/epoch=19-step=240.ckpt' as top 1


k              1        10        20         5
map     0.007119  0.017257  0.020485  0.013802
ndcg    0.007119  0.025578  0.037507  0.017124
recall  0.007119  0.053642  0.101159  0.027318





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 20, global step 252: 'recall@10' reached 0.05596 (best 0.05596), saving model to '/content/.checkpoints/epoch=20-step=252.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 20, global step 252: 'recall@10' reached 0.05596 (best 0.05596), saving model to '/content/.checkpoints/epoch=20-step=252.ckpt' as top 1


k              1        10        20         5
map     0.006954  0.018019  0.021404  0.014589
ndcg    0.006954  0.026734  0.039187  0.018337
recall  0.006954  0.055960  0.105464  0.029801





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 21, global step 264: 'recall@10' reached 0.05861 (best 0.05861), saving model to '/content/.checkpoints/epoch=21-step=264.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 21, global step 264: 'recall@10' reached 0.05861 (best 0.05861), saving model to '/content/.checkpoints/epoch=21-step=264.ckpt' as top 1


k              1        10        20         5
map     0.007616  0.018819  0.022054  0.015339
ndcg    0.007616  0.027916  0.039798  0.019218
recall  0.007616  0.058609  0.105795  0.031126





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 22, global step 276: 'recall@10' reached 0.06672 (best 0.06672), saving model to '/content/.checkpoints/epoch=22-step=276.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 22, global step 276: 'recall@10' reached 0.06672 (best 0.06672), saving model to '/content/.checkpoints/epoch=22-step=276.ckpt' as top 1


k              1        10        20         5
map     0.010265  0.023302  0.026749  0.019763
ndcg    0.010265  0.033302  0.046107  0.024545
recall  0.010265  0.066722  0.117881  0.039238





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 23, global step 288: 'recall@10' reached 0.07268 (best 0.07268), saving model to '/content/.checkpoints/epoch=23-step=288.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 23, global step 288: 'recall@10' reached 0.07268 (best 0.07268), saving model to '/content/.checkpoints/epoch=23-step=288.ckpt' as top 1


k              1        10        20         5
map     0.008113  0.022565  0.025876  0.018452
ndcg    0.008113  0.034074  0.046627  0.023860
recall  0.008113  0.072682  0.123344  0.040563





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 24, global step 300: 'recall@10' reached 0.08295 (best 0.08295), saving model to '/content/.checkpoints/epoch=24-step=300.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 24, global step 300: 'recall@10' reached 0.08295 (best 0.08295), saving model to '/content/.checkpoints/epoch=24-step=300.ckpt' as top 1


k              1        10        20         5
map     0.010762  0.027138  0.031657  0.022379
ndcg    0.010762  0.039962  0.056611  0.028225
recall  0.010762  0.082947  0.149172  0.046192





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 25, global step 312: 'recall@10' reached 0.09238 (best 0.09238), saving model to '/content/.checkpoints/epoch=25-step=312.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 25, global step 312: 'recall@10' reached 0.09238 (best 0.09238), saving model to '/content/.checkpoints/epoch=25-step=312.ckpt' as top 1


k              1        10        20         5
map     0.012417  0.030223  0.035188  0.024639
ndcg    0.012417  0.044469  0.062662  0.030739
recall  0.012417  0.092384  0.164570  0.049503





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 26, global step 324: 'recall@10' reached 0.10795 (best 0.10795), saving model to '/content/.checkpoints/epoch=26-step=324.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 26, global step 324: 'recall@10' reached 0.10795 (best 0.10795), saving model to '/content/.checkpoints/epoch=26-step=324.ckpt' as top 1


k              1        10        20         5
map     0.012583  0.033972  0.039283  0.027643
ndcg    0.012583  0.050959  0.070713  0.035356
recall  0.012583  0.107947  0.186921  0.059106





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 27, global step 336: 'recall@10' reached 0.12351 (best 0.12351), saving model to '/content/.checkpoints/epoch=27-step=336.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 27, global step 336: 'recall@10' reached 0.12351 (best 0.12351), saving model to '/content/.checkpoints/epoch=27-step=336.ckpt' as top 1


k              1        10        20         5
map     0.016887  0.040727  0.046278  0.033568
ndcg    0.016887  0.059728  0.080167  0.042070
recall  0.016887  0.123510  0.204801  0.068212





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 28, global step 348: 'recall@10' reached 0.12881 (best 0.12881), saving model to '/content/.checkpoints/epoch=28-step=348.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 28, global step 348: 'recall@10' reached 0.12881 (best 0.12881), saving model to '/content/.checkpoints/epoch=28-step=348.ckpt' as top 1


k              1        10        20         5
map     0.017881  0.042225  0.047943  0.034421
ndcg    0.017881  0.062076  0.083416  0.042986
recall  0.017881  0.128808  0.214238  0.069371





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 29, global step 360: 'recall@10' reached 0.13675 (best 0.13675), saving model to '/content/.checkpoints/epoch=29-step=360.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 29, global step 360: 'recall@10' reached 0.13675 (best 0.13675), saving model to '/content/.checkpoints/epoch=29-step=360.ckpt' as top 1


k              1        10        20         5
map     0.019702  0.046202  0.052138  0.038513
ndcg    0.019702  0.067011  0.088979  0.048097
recall  0.019702  0.136755  0.224338  0.077649





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 30, global step 372: 'recall@10' reached 0.14089 (best 0.14089), saving model to '/content/.checkpoints/epoch=30-step=372.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 30, global step 372: 'recall@10' reached 0.14089 (best 0.14089), saving model to '/content/.checkpoints/epoch=30-step=372.ckpt' as top 1


k              1        10        20         5
map     0.019536  0.048794  0.054651  0.041614
ndcg    0.019536  0.070125  0.091801  0.052507
recall  0.019536  0.140894  0.227318  0.085927





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 31, global step 384: 'recall@10' reached 0.14719 (best 0.14719), saving model to '/content/.checkpoints/epoch=31-step=384.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 31, global step 384: 'recall@10' reached 0.14719 (best 0.14719), saving model to '/content/.checkpoints/epoch=31-step=384.ckpt' as top 1


k             1        10        20         5
map     0.02053  0.050200  0.056117  0.042177
ndcg    0.02053  0.072596  0.094402  0.053005
recall  0.02053  0.147185  0.233940  0.086258





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 32, global step 396: 'recall@10' reached 0.14868 (best 0.14868), saving model to '/content/.checkpoints/epoch=32-step=396.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 32, global step 396: 'recall@10' reached 0.14868 (best 0.14868), saving model to '/content/.checkpoints/epoch=32-step=396.ckpt' as top 1


k              1        10        20         5
map     0.023841  0.053670  0.060232  0.045808
ndcg    0.023841  0.075600  0.099889  0.056200
recall  0.023841  0.148676  0.245530  0.087914





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 33, global step 408: 'recall@10' reached 0.15149 (best 0.15149), saving model to '/content/.checkpoints/epoch=33-step=408.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 33, global step 408: 'recall@10' reached 0.15149 (best 0.15149), saving model to '/content/.checkpoints/epoch=33-step=408.ckpt' as top 1


k              1        10        20         5
map     0.021192  0.052384  0.058750  0.044208
ndcg    0.021192  0.075297  0.098679  0.055236
recall  0.021192  0.151490  0.244371  0.088907





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 34, global step 420: 'recall@10' reached 0.15944 (best 0.15944), saving model to '/content/.checkpoints/epoch=34-step=420.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 34, global step 420: 'recall@10' reached 0.15944 (best 0.15944), saving model to '/content/.checkpoints/epoch=34-step=420.ckpt' as top 1


k             1        10        20         5
map     0.02351  0.055581  0.062038  0.046973
ndcg    0.02351  0.079549  0.103319  0.058398
recall  0.02351  0.159437  0.253974  0.093377





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 35, global step 432: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 35, global step 432: 'recall@10' was not in top 1


k              1        10        20         5
map     0.022517  0.053701  0.060264  0.045306
ndcg    0.022517  0.077091  0.101080  0.056593
recall  0.022517  0.154967  0.250000  0.091225





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 36, global step 444: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 36, global step 444: 'recall@10' was not in top 1


k              1        10        20         5
map     0.025497  0.056972  0.063285  0.048767
ndcg    0.025497  0.080541  0.103828  0.060450
recall  0.025497  0.158940  0.251656  0.096358





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 37, global step 456: 'recall@10' reached 0.16275 (best 0.16275), saving model to '/content/.checkpoints/epoch=37-step=456.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 37, global step 456: 'recall@10' reached 0.16275 (best 0.16275), saving model to '/content/.checkpoints/epoch=37-step=456.ckpt' as top 1


k             1        10        20         5
map     0.02649  0.058761  0.064569  0.050270
ndcg    0.02649  0.082821  0.104352  0.062195
recall  0.02649  0.162748  0.248675  0.098841





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 38, global step 468: 'recall@10' reached 0.16374 (best 0.16374), saving model to '/content/.checkpoints/epoch=38-step=468.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 38, global step 468: 'recall@10' reached 0.16374 (best 0.16374), saving model to '/content/.checkpoints/epoch=38-step=468.ckpt' as top 1


k              1        10        20         5
map     0.027152  0.059842  0.066366  0.050944
ndcg    0.027152  0.083839  0.107723  0.062143
recall  0.027152  0.163742  0.258444  0.096358





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 39, global step 480: 'recall@10' reached 0.16772 (best 0.16772), saving model to '/content/.checkpoints/epoch=39-step=480.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 39, global step 480: 'recall@10' reached 0.16772 (best 0.16772), saving model to '/content/.checkpoints/epoch=39-step=480.ckpt' as top 1


k              1        10        20         5
map     0.026987  0.060661  0.067392  0.051948
ndcg    0.026987  0.085406  0.110287  0.064025
recall  0.026987  0.167715  0.266887  0.100993





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 40, global step 492: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 40, global step 492: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02947  0.061450  0.068016  0.052839
ndcg    0.02947  0.085290  0.109280  0.064135
recall  0.02947  0.164735  0.259768  0.098675





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 41, global step 504: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 41, global step 504: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027483  0.060552  0.066717  0.051791
ndcg    0.027483  0.085100  0.107961  0.063732
recall  0.027483  0.166722  0.257947  0.100331





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 42, global step 516: 'recall@10' reached 0.17384 (best 0.17384), saving model to '/content/.checkpoints/epoch=42-step=516.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 42, global step 516: 'recall@10' reached 0.17384 (best 0.17384), saving model to '/content/.checkpoints/epoch=42-step=516.ckpt' as top 1


k              1        10        20         5
map     0.028311  0.063831  0.070491  0.055414
ndcg    0.028311  0.089373  0.114033  0.068883
recall  0.028311  0.173841  0.272185  0.110265





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 43, global step 528: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 43, global step 528: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032119  0.065550  0.071983  0.056653
ndcg    0.032119  0.090314  0.113918  0.068778
recall  0.032119  0.172517  0.266225  0.105960





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 44, global step 540: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 44, global step 540: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027152  0.062764  0.069803  0.054249
ndcg    0.027152  0.088251  0.114013  0.067381
recall  0.027152  0.172682  0.274834  0.107616





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 45, global step 552: 'recall@10' reached 0.17417 (best 0.17417), saving model to '/content/.checkpoints/epoch=45-step=552.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 45, global step 552: 'recall@10' reached 0.17417 (best 0.17417), saving model to '/content/.checkpoints/epoch=45-step=552.ckpt' as top 1


k              1        10        20         5
map     0.029636  0.064171  0.070960  0.055709
ndcg    0.029636  0.089620  0.114657  0.068740
recall  0.029636  0.174172  0.273841  0.108775





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 46, global step 564: 'recall@10' reached 0.17583 (best 0.17583), saving model to '/content/.checkpoints/epoch=46-step=564.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 46, global step 564: 'recall@10' reached 0.17583 (best 0.17583), saving model to '/content/.checkpoints/epoch=46-step=564.ckpt' as top 1


k              1        10        20         5
map     0.028311  0.064285  0.071103  0.055963
ndcg    0.028311  0.090177  0.115560  0.069783
recall  0.028311  0.175828  0.277318  0.112252





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 47, global step 576: 'recall@10' reached 0.18245 (best 0.18245), saving model to '/content/.checkpoints/epoch=47-step=576.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 47, global step 576: 'recall@10' reached 0.18245 (best 0.18245), saving model to '/content/.checkpoints/epoch=47-step=576.ckpt' as top 1


k              1        10        20         5
map     0.028477  0.065226  0.071609  0.056347
ndcg    0.028477  0.092289  0.115944  0.070049
recall  0.028477  0.182450  0.276821  0.112086





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 48, global step 588: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 48, global step 588: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027318  0.065220  0.072458  0.056029
ndcg    0.027318  0.091667  0.118341  0.069273
recall  0.027318  0.179139  0.285265  0.109603





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 49, global step 600: 'recall@10' reached 0.18526 (best 0.18526), saving model to '/content/.checkpoints/epoch=49-step=600.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 49, global step 600: 'recall@10' reached 0.18526 (best 0.18526), saving model to '/content/.checkpoints/epoch=49-step=600.ckpt' as top 1


k              1        10        20         5
map     0.028311  0.065522  0.072615  0.055844
ndcg    0.028311  0.093166  0.119194  0.069310
recall  0.028311  0.185265  0.288576  0.110596





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 50, global step 612: 'recall@10' reached 0.19652 (best 0.19652), saving model to '/content/.checkpoints/epoch=50-step=612.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 50, global step 612: 'recall@10' reached 0.19652 (best 0.19652), saving model to '/content/.checkpoints/epoch=50-step=612.ckpt' as top 1


k              1        10        20         5
map     0.032947  0.072515  0.079553  0.062740
ndcg    0.032947  0.101197  0.127095  0.077088
recall  0.032947  0.196523  0.299503  0.121026





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 51, global step 624: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 51, global step 624: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030132  0.066134  0.072914  0.056256
ndcg    0.030132  0.092948  0.117952  0.068859
recall  0.030132  0.182285  0.281788  0.107450





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 52, global step 636: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 52, global step 636: 'recall@10' was not in top 1


k              1        10        20         5
map     0.026821  0.063964  0.071335  0.054879
ndcg    0.026821  0.090899  0.118015  0.068515
recall  0.026821  0.180298  0.288079  0.110265





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 53, global step 648: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 53, global step 648: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030629  0.067030  0.073796  0.057820
ndcg    0.030629  0.093221  0.118180  0.070743
recall  0.030629  0.180132  0.279470  0.110265





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 54, global step 660: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 54, global step 660: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027483  0.066825  0.073360  0.058041
ndcg    0.027483  0.094129  0.118219  0.072548
recall  0.027483  0.184272  0.280132  0.116887





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 55, global step 672: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 55, global step 672: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032781  0.068261  0.075051  0.059294
ndcg    0.032781  0.094637  0.119485  0.072654
recall  0.032781  0.182285  0.280795  0.113742





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 56, global step 684: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 56, global step 684: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033444  0.071957  0.079102  0.062116
ndcg    0.033444  0.100068  0.126272  0.076045
recall  0.033444  0.193377  0.297351  0.118709





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 57, global step 696: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 57, global step 696: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030298  0.068415  0.075686  0.059332
ndcg    0.030298  0.095359  0.122240  0.073197
recall  0.030298  0.184437  0.291556  0.115563





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 58, global step 708: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 58, global step 708: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02947  0.069227  0.076382  0.059437
ndcg    0.02947  0.097909  0.124096  0.074286
recall  0.02947  0.192715  0.296523  0.119868





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 59, global step 720: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 59, global step 720: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032285  0.070921  0.078490  0.061220
ndcg    0.032285  0.098664  0.126357  0.074908
recall  0.032285  0.190728  0.300497  0.116722





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 60, global step 732: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 60, global step 732: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027815  0.064878  0.072127  0.056112
ndcg    0.027815  0.091231  0.117644  0.069700
recall  0.027815  0.178477  0.282947  0.111258





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 61, global step 744: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 61, global step 744: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030795  0.069907  0.076928  0.060226
ndcg    0.030795  0.097911  0.123681  0.074274
recall  0.030795  0.190728  0.293046  0.117219





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 62, global step 756: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 62, global step 756: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03394  0.070738  0.077727  0.061129
ndcg    0.03394  0.098046  0.123829  0.074695
recall  0.03394  0.188742  0.291391  0.116391





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 63, global step 768: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 63, global step 768: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02947  0.067225  0.074357  0.057770
ndcg    0.02947  0.094595  0.120692  0.071590
recall  0.02947  0.185265  0.288742  0.113907





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 64, global step 780: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 64, global step 780: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03245  0.073006  0.080364  0.063590
ndcg    0.03245  0.101425  0.128524  0.078364
recall  0.03245  0.195364  0.303146  0.123510





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 65, global step 792: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 65, global step 792: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.074889  0.082380  0.065155
ndcg    0.035265  0.103028  0.130734  0.079305
recall  0.035265  0.196192  0.306623  0.122517





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 66, global step 804: 'recall@10' reached 0.19917 (best 0.19917), saving model to '/content/.checkpoints/epoch=66-step=804.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 66, global step 804: 'recall@10' reached 0.19917 (best 0.19917), saving model to '/content/.checkpoints/epoch=66-step=804.ckpt' as top 1


k              1        10        20         5
map     0.031954  0.072694  0.079668  0.062508
ndcg    0.031954  0.101965  0.127513  0.076969
recall  0.031954  0.199172  0.300497  0.121192





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 67, global step 816: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 67, global step 816: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032947  0.073013  0.080088  0.063413
ndcg    0.032947  0.101713  0.127742  0.078091
recall  0.032947  0.196854  0.300331  0.123013





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 68, global step 828: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 68, global step 828: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031954  0.071385  0.078689  0.061545
ndcg    0.031954  0.100099  0.126880  0.075860
recall  0.031954  0.195530  0.301821  0.119702





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 69, global step 840: 'recall@10' reached 0.20646 (best 0.20646), saving model to '/content/.checkpoints/epoch=69-step=840.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 69, global step 840: 'recall@10' reached 0.20646 (best 0.20646), saving model to '/content/.checkpoints/epoch=69-step=840.ckpt' as top 1


k              1        10        20         5
map     0.034934  0.076867  0.083636  0.066609
ndcg    0.034934  0.106867  0.131798  0.081583
recall  0.034934  0.206457  0.305629  0.127318





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 70, global step 852: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 70, global step 852: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031954  0.072679  0.079507  0.063044
ndcg    0.031954  0.102005  0.127180  0.078315
recall  0.031954  0.199172  0.299338  0.125166





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 71, global step 864: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 71, global step 864: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03096  0.073594  0.080283  0.063714
ndcg    0.03096  0.103789  0.128464  0.079386
recall  0.03096  0.203808  0.301987  0.127318





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 72, global step 876: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 72, global step 876: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030298  0.068850  0.075784  0.059067
ndcg    0.030298  0.096775  0.122297  0.072780
recall  0.030298  0.189570  0.291060  0.114735





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 73, global step 888: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 73, global step 888: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031623  0.070786  0.077836  0.061391
ndcg    0.031623  0.099327  0.125438  0.076161
recall  0.031623  0.194040  0.298179  0.121523





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 74, global step 900: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 74, global step 900: 'recall@10' was not in top 1


k              1        10        20        5
map     0.034437  0.075199  0.082532  0.06489
ndcg    0.034437  0.104751  0.131791  0.07935
recall  0.034437  0.202980  0.310596  0.12351





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 75, global step 912: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 75, global step 912: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031457  0.071862  0.078907  0.062023
ndcg    0.031457  0.100979  0.126804  0.076807
recall  0.031457  0.197517  0.300000  0.122020





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 76, global step 924: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 76, global step 924: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033444  0.073704  0.081060  0.063968
ndcg    0.033444  0.102919  0.129987  0.079132
recall  0.033444  0.199669  0.307285  0.125662





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 77, global step 936: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 77, global step 936: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030795  0.070998  0.078160  0.061393
ndcg    0.030795  0.099778  0.126165  0.076305
recall  0.030795  0.195033  0.300000  0.122020





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 78, global step 948: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 78, global step 948: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03096  0.070940  0.078427  0.061272
ndcg    0.03096  0.100057  0.127618  0.076242
recall  0.03096  0.196689  0.306291  0.122185





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 79, global step 960: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 79, global step 960: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.075406  0.082382  0.065304
ndcg    0.035265  0.105344  0.131003  0.080433
recall  0.035265  0.204967  0.306954  0.126987





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 80, global step 972: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 80, global step 972: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033444  0.073124  0.080464  0.063256
ndcg    0.033444  0.102561  0.129610  0.078324
recall  0.033444  0.200331  0.307947  0.124669





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 81, global step 984: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 81, global step 984: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034272  0.074526  0.082042  0.064514
ndcg    0.034272  0.103525  0.131373  0.079395
recall  0.034272  0.199338  0.310430  0.125000





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 82, global step 996: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 82, global step 996: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033609  0.073946  0.081096  0.063767
ndcg    0.033609  0.103505  0.129713  0.078548
recall  0.033609  0.201656  0.305629  0.123841





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 83, global step 1008: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 83, global step 1008: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030132  0.072919  0.080579  0.063408
ndcg    0.030132  0.102750  0.130774  0.079592
recall  0.030132  0.200993  0.312086  0.129139





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 84, global step 1020: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 84, global step 1020: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031126  0.069984  0.077177  0.060328
ndcg    0.031126  0.098309  0.124736  0.074663
recall  0.031126  0.192219  0.297185  0.118543





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 85, global step 1032: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 85, global step 1032: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034934  0.076412  0.083825  0.066534
ndcg    0.034934  0.106204  0.133427  0.081832
recall  0.034934  0.204967  0.313079  0.128642





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 86, global step 1044: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 86, global step 1044: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035762  0.075687  0.083056  0.065877
ndcg    0.035762  0.104794  0.131766  0.080713
recall  0.035762  0.201325  0.308278  0.126159





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 87, global step 1056: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 87, global step 1056: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032781  0.074128  0.081950  0.064238
ndcg    0.032781  0.103114  0.131892  0.078856
recall  0.032781  0.199007  0.313411  0.123344





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 88, global step 1068: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 88, global step 1068: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034272  0.076411  0.084349  0.066211
ndcg    0.034272  0.106474  0.135560  0.081656
recall  0.034272  0.205960  0.321358  0.128974





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 89, global step 1080: 'recall@10' reached 0.20877 (best 0.20877), saving model to '/content/.checkpoints/epoch=89-step=1080.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 89, global step 1080: 'recall@10' reached 0.20877 (best 0.20877), saving model to '/content/.checkpoints/epoch=89-step=1080.ckpt' as top 1


k              1        10        20         5
map     0.035927  0.077727  0.085557  0.067536
ndcg    0.035927  0.108067  0.136883  0.082942
recall  0.035927  0.208775  0.323344  0.130132





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 90, global step 1092: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 90, global step 1092: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034272  0.076718  0.083935  0.066945
ndcg    0.034272  0.106557  0.133012  0.082637
recall  0.034272  0.205132  0.310099  0.130629





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 91, global step 1104: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 91, global step 1104: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03096  0.073350  0.081247  0.063079
ndcg    0.03096  0.103318  0.132276  0.078271
recall  0.03096  0.202483  0.317384  0.124669





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 92, global step 1116: 'recall@10' reached 0.21490 (best 0.21490), saving model to '/content/.checkpoints/epoch=92-step=1116.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 92, global step 1116: 'recall@10' reached 0.21490 (best 0.21490), saving model to '/content/.checkpoints/epoch=92-step=1116.ckpt' as top 1


k              1        10        20         5
map     0.036258  0.078973  0.086872  0.068469
ndcg    0.036258  0.110405  0.139366  0.084401
recall  0.036258  0.214901  0.329801  0.133278





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 93, global step 1128: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 93, global step 1128: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.076457  0.084103  0.065855
ndcg    0.035265  0.106589  0.134518  0.080391
recall  0.035265  0.206954  0.317550  0.124834





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 94, global step 1140: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 94, global step 1140: 'recall@10' was not in top 1


k              1        10        20         5
map     0.036093  0.078338  0.086325  0.068557
ndcg    0.036093  0.108361  0.137841  0.084401
recall  0.036093  0.207616  0.325000  0.132947





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 95, global step 1152: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 95, global step 1152: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033278  0.073037  0.080835  0.062613
ndcg    0.033278  0.102702  0.131368  0.077125
recall  0.033278  0.201490  0.315397  0.121689





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 96, global step 1164: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 96, global step 1164: 'recall@10' was not in top 1


k              1        10        20         5
map     0.037914  0.079142  0.086639  0.068419
ndcg    0.037914  0.109629  0.137287  0.083563
recall  0.037914  0.210927  0.321027  0.130132





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 97, global step 1176: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 97, global step 1176: 'recall@10' was not in top 1


k              1        10        20         5
map     0.036093  0.075790  0.083469  0.065717
ndcg    0.036093  0.105406  0.133675  0.080524
recall  0.036093  0.203974  0.316391  0.125993





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 98, global step 1188: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 98, global step 1188: 'recall@10' was not in top 1


k              1        10        20         5
map     0.036093  0.079806  0.087114  0.069498
ndcg    0.036093  0.110394  0.137400  0.085089
recall  0.036093  0.211589  0.319205  0.132616





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 99, global step 1200: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 99, global step 1200: 'recall@10' was not in top 1
INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.036589  0.078300  0.085788  0.067373
ndcg    0.036589  0.108879  0.136390  0.082299
recall  0.036589  0.210430  0.319702  0.127980



The path to the best model is saved inside checkpoint_callback

In [58]:
best_model = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

In [88]:
checkpoint_callback.best_model_path

'/content/.checkpoints/epoch=92-step=1116.ckpt'

In [90]:
ls

base_splitter.py          optimizer_utils.py       session_handler.py
Bert4rec_Model.ipynb      output.csv               spark_utils.py
bert4rec.py               postprocessors.py        submission_BERT100epoch.csv
[0m[01;34mdata[0m/                     prediction_callbacks.py  submission.csv
data.py                   [01;34m__pycache__[0m/             top_10_films.csv
dataset_label_encoder.py  recommendations.csv      torch_metrics_builder.py
dataset.py                [01;34msample_data[0m/             torch_sequential_dataset.py
label_encoder.py          schema.py                typesBert.py
last_n_splitter.py        sequence_tokenizer.py    utils.py
model.py                  sequential_dataset.py    validation_callback.py


## Inference stage
### Prepare Dataloader and logger

In [67]:
prediction_dataloader = DataLoader(
    dataset=Bert4RecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

csv_logger = CSVLogger(save_dir=".logs/test", name="Bert4Rec_example")



### Run inference
You can get the recommendations in three formats: PySpark DataFrame, Pandas DataFrame, PyTorch tensors. Each of the types corresponds a callback
You can filter the results using postprocessors strategy. For example the RemoveSeenItems postprocessor is filtering out the items that already have been seen in test dataset
You don't need to use all three callbacks. This is shown only for example

Also, you can get user embeddings, that were used to perform predictions, using `get_query_embedding` method inside Bert4RecModel or `QueryEmbeddingsPredictionCallback` for lightning module.

In [68]:
# from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
# from replay.metrics.torch_metrics_builder import metrics_to_df

from prediction_callbacks import (
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)

TOPK = [1, 10, 20, 100]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

spark_prediction_callback = SparkPredictionCallback(
    spark_session=spark_session,
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

torch_prediction_callback = TorchPredictionCallback(
    top_k=max(TOPK),
    postprocessors=postprocessors,
)

query_embeddings_callback = QueryEmbeddingsPredictionCallback()

trainer = L.Trainer(
    callbacks=[
        spark_prediction_callback,
        pandas_prediction_callback,
        torch_prediction_callback,
        query_embeddings_callback,
    ],
    logger=csv_logger,
    inference_mode=True
)
trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

spark_res = spark_prediction_callback.get_result()
pandas_res = pandas_prediction_callback.get_result()
torch_user_ids, torch_item_ids, torch_scores = torch_prediction_callback.get_result()
user_embeddings = query_embeddings_callback.get_result()

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [69]:
spark_res.show()

+-------+-------+-----------------+
|user_id|item_id|            score|
+-------+-------+-----------------+
|    790|   3421|9.312716484069824|
|    790|   1422|8.847206115722656|
|    790|    434|8.750907897949219|
|    790|    708|8.747827529907227|
|    790|   3025| 8.62032413482666|
|    790|   2543|8.587140083312988|
|    790|   1461|  8.5817232131958|
|    790|   2003|8.497234344482422|
|    790|   1332|8.429722785949707|
|    790|   1287|8.354843139648438|
|    790|    566|8.291472434997559|
|    790|    827|8.272329330444336|
|    790|   2138|8.268227577209473|
|    790|   3460|8.233558654785156|
|    790|   2108|8.218437194824219|
|    790|    980|8.189330101013184|
|    790|   2251|8.142723083496094|
|    790|   1250|8.138209342956543|
|    790|   1951|7.928654193878174|
|    790|   1128|7.888408660888672|
+-------+-------+-----------------+
only showing top 20 rows



In [70]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,790,3421,9.312716
0,790,1422,8.847206
0,790,434,8.750908
0,790,708,8.747828
0,790,3025,8.620324
...,...,...,...
6039,4520,950,4.898004
6039,4520,1906,4.879602
6039,4520,2011,4.870603
6039,4520,960,4.862675


In [71]:
pandas_res.to_csv('output.csv')

In [72]:
print(torch_user_ids[0], torch_item_ids[0], torch_scores[0])

tensor(790) tensor([3421, 1422,  434,  708, 3025, 2543, 1461, 2003, 1332, 1287,  566,  827,
        2138, 3460, 2108,  980, 2251, 1250, 1951, 1128, 1259, 1959, 2143, 1493,
        3699, 3029, 1672, 2146,  874, 2801, 2551, 1329,  904, 1323,  178, 2403,
         516,   74,  446, 3411, 3167, 3393, 1303, 3214,  351, 3481, 1939,  881,
         644, 2456, 1101, 1052, 1014, 3209, 2555, 2593, 2145, 1824, 1905, 2051,
          32, 2397, 3657, 2159, 2687, 1744,  554,  128, 1160, 2604, 1705,  672,
        2702, 1171,  485, 2213,  129, 3103,  994, 3569, 2827,  163, 2925,  411,
        3097, 2275, 1377,  703, 3119, 3007, 3026, 2418, 1920, 1479, 2690, 2339,
        1965,  550, 3056, 1155]) tensor([9.3127, 8.8472, 8.7509, 8.7478, 8.6203, 8.5871, 8.5817, 8.4972, 8.4297,
        8.3548, 8.2915, 8.2723, 8.2682, 8.2336, 8.2184, 8.1893, 8.1427, 8.1382,
        7.9287, 7.8884, 7.8662, 7.8638, 7.8538, 7.8401, 7.7869, 7.7594, 7.7348,
        7.7185, 7.6971, 7.6893, 7.6600, 7.6417, 7.6156, 7.5982, 7.5643, 7.5

Suppose we want to get the recomendations in PySpark format.
Let's get the inverse representation of labels using inverse_transform method.

Note that the reverse representation can only be obtained for PySpark and Pandas formats. When working with PyTorch tensors, the reverse representation must be done manually

In [73]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(spark_res)



In [74]:
recommendations.show()

+-----------------+-------+-------+
|            score|user_id|item_id|
+-----------------+-------+-------+
|8.291472434997559|      0|    566|
|8.138209342956543|      0|   1250|
|7.928654193878174|      0|   1951|
|8.847206115722656|      0|   1422|
|8.272329330444336|      0|    827|
|8.189330101013184|      0|    980|
|8.142723083496094|      0|   2251|
|7.888408660888672|      0|   1128|
|9.312716484069824|      0|   3421|
|8.747827529907227|      0|    708|
|  8.5817232131958|      0|   1461|
|8.354843139648438|      0|   1287|
|8.233558654785156|      0|   3460|
|8.218437194824219|      0|   2108|
| 8.62032413482666|      0|   3025|
|7.866247653961182|      0|   1259|
|8.587140083312988|      0|   2543|
|8.429722785949707|      0|   1332|
|8.268227577209473|      0|   2138|
|8.750907897949219|      0|    434|
+-----------------+-------+-------+
only showing top 20 rows



In [75]:
# Преобразуем Spark DataFrame в Pandas DataFrame
pandas_df2 = recommendations.toPandas()

# Сохраняем результаты в CSV-файл
pandas_df2.to_csv('recommendations.csv', index=False)

### Calculating metrics

In [76]:
init_args = {"query_column": "user_id", "rating_column": "score"}

In [77]:
def get_top_n(user_item_ratings, model_name, n=100):
    '''Функция возвращает топ-n фильмов для каждого пользователя'''

    # Сортируем данные по убыванию предсказанной оценки
    top_n = user_item_ratings.sort_values(model_name, ascending=False)

    # Оставляем только первые n строк для каждого пользователя
    top_n = top_n.groupby('user_id').head(n).reset_index(drop=True)

    return top_n

In [78]:
pandas_df2

Unnamed: 0,score,user_id,item_id
0,4.368713,3,94
1,4.090495,3,467
2,4.018270,3,1371
3,3.834505,3,2947
4,3.299708,3,1304
...,...,...,...
603995,6.219978,6027,793
603996,6.108172,6027,3684
603997,5.933334,6027,3476
603998,5.698685,6027,3396


In [79]:
top_10_films = get_top_n(pandas_df2, 'score', n=10)[['user_id', 'item_id']]

# Экспортируем результат в CSV-файл
top_10_films.to_csv('top_10_films.csv', index=False)

In [80]:
top_10_films.user_id.nunique()

6040

In [81]:
def format_for_submission(df):
    # Группируем строки по user_id и соединяем item_id через пробел
    submission = (
        df
        .groupby('user_id')['item_id']
        .apply(lambda x: ' '.join(x.astype(str)))
        .reset_index()
    )

    return submission

In [82]:
submission = format_for_submission(top_10_films)
submission

Unnamed: 0,user_id,item_id
0,0,3421 1422 434 708 3025 2543 1461 2003 1332 1287
1,1,1246 232 350 1459 2476 2730 3656 3101 1822 452
2,2,2774 382 234 2311 1371 1687 1560 221 2428 1781
3,3,3390 3562 20 1814 361 94 1456 1382 467 1371
4,4,983 1160 394 672 755 743 2205 2338 3030 2159
...,...,...
6035,6035,2800 1439 1296 3665 1403 3668 892 3216 1011 2366
6036,6036,401 1859 3692 1912 829 3105 3418 2470 494 1946
6037,6037,1375 1747 2664 133 1439 12 1379 1102 2502 1296
6038,6038,1355 3343 2231 2632 1514 3278 450 2305 3605 3139


In [86]:
submission.to_csv('./data/submission_BERT100epoch.csv', index=False)

# То что ниже - не трогал

In [42]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK), NDCG(TOPK), MRR(TOPK), HitRate(TOPK)], **init_args
)(recommendations.toPandas(), raw_test_gt)

NameError: name 'OfflineMetrics' is not defined

In [None]:
metrics_to_df(result_metrics)

### User embeddings

Got 6040 x 300 user embeddings, because among all 12 batches:

11 batches contains 512 samples

1 batch contains 408 left samples

11 * 512 + 408 == 6040

In [None]:
user_embeddings

In [None]:
user_embeddings.shape

You can access user embeddings directly with `Bert4RecModel` class

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

core_model = Bert4RecModel(
    tensor_schema,
    num_blocks=2,
    num_heads=4,
    max_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout=0.5
)
core_model.eval()
core_model = core_model.to(device)

# Get first batch of data
data = next(iter(prediction_dataloader))
tensor_map, padding_mask, tokens_mask = data.features, data.padding_mask, data.tokens_mask

# Ensure everything is on the same device
padding_mask = padding_mask.to(device)
tokens_mask = tokens_mask.to(device)
tensor_map["item_id_seq"] = tensor_map["item_id_seq"].to(device)

# Get user embeddings
user_embeddings_batch = core_model.get_query_embeddings(tensor_map, padding_mask, tokens_mask)
user_embeddings_batch

In [None]:
user_embeddings_batch.shape

## Example of launching an inference for a single user without using a trainer (in order to speed up)
An example for the production of an online script

Let's assume that the user's sequence consisted of a sequence of items [1, 2, 3, 4, 5].
Сreate a padding mask and tokens mask corresponding to the sequence of items.

It is important to take only the latest MAX_SEQ_LEN or less items.

In [None]:
item_sequence = torch.arange(1, 5).unsqueeze(0)[:, -MAX_SEQ_LEN:]
padding_mask = torch.ones_like(item_sequence, dtype=torch.bool)
tokens_mask = padding_mask.roll(-1, dims=0)
tokens_mask[-1, ...] = 0
sequence_item_count = item_sequence.shape[1]

### Wrapping created tensors in the Bert4RecPredictionBatch entity

In [None]:
batch = Bert4RecPredictionBatch(
    query_id=torch.arange(0, item_sequence.shape[0], 1).long(),
    padding_mask=padding_mask,
    features={ITEM_FEATURE_NAME: item_sequence.long()},
    tokens_mask=tokens_mask
)

### Run predict step of the Bert4Rec and get scores from the model

In [None]:
with torch.no_grad():
    scores = best_model.predict_step(batch, 0)
scores

### Getting five items with the highest score

In [None]:
torch.topk(scores, k=5).indices