# Example of SasRec training/inference

In [1]:
from typing import Optional

import lightning as L
import pandas as pd

L.seed_everything(42)

Seed set to 42


42

## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.

---
**NOTE**

Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. 

---

In [2]:
interactions = pd.read_csv("./data/ml1m_ratings.dat", sep="\t", names=["user_id", "item_id","rating","timestamp"])
interactions = interactions.drop(columns=["rating"])

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,timestamp
1000138,6040,858,0
1000153,6040,2384,1
999873,6040,593,2
1000007,6040,1961,3
1000192,6040,2019,4
...,...,...,...
825793,4958,2399,446
825438,4958,1407,447
825724,4958,3264,448
825731,4958,2634,449


### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used.

In [4]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule

encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id", default_value="last"),
        LabelEncodingRule("item_id", default_value="last"),
    ]
)
interactions = interactions.sort_values(by="item_id", ascending=True)
encoded_interactions = encoder.fit_transform(interactions)
encoded_interactions

Unnamed: 0,timestamp,user_id,item_id
0,12,2011,0
1,68,4078,0
2,67,4123,0
3,12,983,0
4,140,2270,0
...,...,...,...
1000204,14,855,3705
1000205,90,1700,3705
1000206,70,936,3705
1000207,25,360,3705


### Split interactions into the train, validation and test datasets using LastNSplitter
We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity.

In [5]:
from replay.splitters import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
    drop_cold_users=True,
    drop_cold_items=True
)

test_events, test_gt = splitter.split(encoded_interactions)
validation_events, validation_gt = splitter.split(test_events)
train_events = validation_events

### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [6]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")
    return grouped_interactions

In [7]:
train_events = bake_data(train_events)

validation_events = bake_data(validation_events)
validation_gt = bake_data(validation_gt)

test_events = bake_data(test_events)
test_gt = bake_data(test_gt)

train_events

Unnamed: 0,user_id,timestamp,item_id
0,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2969, 1574, 1178, 957, 2147, 1658, 3177, 1117..."
1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1108, 1127, 1120, 2512, 1201, 2735, 1135, 110..."
2,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[579, 2651, 3301, 1788, 1781, 1327, 1174, 3429..."
3,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1120, 1025, 466, 3235, 3294, 1106, 253, 1108,..."
4,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2512, 858, 847, 346, 1158, 2007, 2651, 1050, ..."
...,...,...,...
6035,6035,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1574, 1703, 3206, 2183, 2235, 2480, 2375, 250..."
6036,6036,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1702, 672, 1175, 1848, 3275, 2932, 548, 802, ..."
6037,6037,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3165, 859, 1120, 1965, 1288, 346, 1007, 1066,..."
6038,6038,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[107, 275, 1886, 1139, 869, 886, 2872, 2809, 2..."


To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones.  

In [8]:
def add_gt_to_events(events_df, gt_df):
    gt_to_join = gt_df[["user_id", "item_id"]].rename(columns={"item_id": "ground_truth"})

    events_df = events_df.merge(gt_to_join, on="user_id", how="inner")
    return events_df

validation_events = add_gt_to_events(validation_events, validation_gt)
test_events = add_gt_to_events(test_events, test_gt)

In [9]:
from pathlib import Path

data_dir = Path("temp/data/")
data_dir.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = data_dir / "train.parquet"
VAL_PATH = data_dir / "val.parquet"
PREDICT_PATH = data_dir / "test.parquet"

ENCODER_PATH = data_dir / "encoder"

In [10]:
train_events.to_parquet(TRAIN_PATH)
validation_events.to_parquet(VAL_PATH)
test_events.to_parquet(PREDICT_PATH)

encoder.save(ENCODER_PATH)



# Prepare to model training
### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the NN models to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.

Note that **cardinality** is the number of unique values ​in the item catalog (vocabulary). **Padding value** is the next value after the last one.

In [11]:
from replay.data import FeatureHint, FeatureType
from replay.data.nn import TensorFeatureInfo, TensorSchema


EMBEDDING_DIM = 64

encoder = encoder.load(ENCODER_PATH)
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS,
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        )
    ]
)

In [12]:
import torch
import polars as pl

### Configure ParquetModule and transformation pipelines

The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s  "transform pipelines" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.

Internally this function creates the following transforms:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Rename features to match it with expected format by the model during training.
    3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.
    4. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.

**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1.

In [13]:
import copy

import torch

from replay.data.nn import TensorSchema
from replay.nn.transform import *

In [16]:
MAX_SEQ_LEN = 50

def make_sasrec_transforms(
    tensor_schema: TensorSchema, query_column: str = "query_id", num_negative_samples: int = 128,
) -> dict[str, list[torch.nn.Module]]:
    item_column = tensor_schema.item_id_feature_name
    vocab_size = tensor_schema[item_column].cardinality
    train_transforms = [
        UniformNegativeSamplingTransform(vocab_size, num_negative_samples),
        NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
        RenameTransform(
            {
                query_column: "query_id",
                f"{item_column}_mask": "padding_mask",
                "positive_labels_mask": "target_padding_mask",
            }
        ),
        UnsqueezeTransform("target_padding_mask", -1),
        UnsqueezeTransform("positive_labels", -1),
        GroupTransform({"feature_tensors": [item_column]}),
    ]

    val_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        GroupTransform({"feature_tensors": [item_column]}),
    ]
    test_transforms = copy.deepcopy(val_transforms)

    predict_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        CopyTransform({"item_id": "seen_ids"}),
        TrimTransform(seq_len=MAX_SEQ_LEN, feature_names=["item_id", "padding_mask"]),
        GroupTransform({"feature_tensors": ["item_id"]})
    ]

    transforms = {
        "train": train_transforms,
        "validate": val_transforms,
        "test": test_transforms,
        "predict": predict_transforms,
    }

    return transforms

In [17]:
transforms = make_sasrec_transforms(tensor_schema, query_column="user_id")

In [18]:
def create_meta(shape: int, gt_shape: Optional[int] = None):
    meta = {
        "user_id": {},
        "item_id": {"shape": shape, "padding": tensor_schema["item_id"].padding_value},
    }
    if gt_shape is not None:
        meta.update({"ground_truth": {"shape": gt_shape, "padding": -1}})

    return meta

train_metadata = {
    "train": create_meta(shape=MAX_SEQ_LEN+1),
    "validate": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),
}

In [19]:
from replay.data.nn import ParquetModule

BATCH_SIZE = 32

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=train_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.

#### Default Configuration

Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum.

In [20]:
from replay.nn.sequential import SasRec
from typing import Literal
def make_sasrec(
    schema: TensorSchema,
    embedding_dim: int = 192,
    num_heads: int = 4,
    num_blocks: int = 2,
    max_sequence_length: int = 50,
    dropout: float = 0.3,
    excluded_features: Optional[list[str]] = None,
    categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
) -> SasRec:
    from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer
    from replay.nn.agg import SumAggregator
    from replay.nn.embedding import SequenceEmbedding
    from replay.nn.loss import CE, CESampled
    from replay.nn.mask import DefaultAttentionMask
    from replay.nn.sequential.sasrec.agg import PositionAwareAggregator
    from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer
    excluded_features = [
        schema.query_id_feature_name,
        schema.timestamp_feature_name,
        *(excluded_features or []),
    ]
    excluded_features = list(set(excluded_features))
    body = SasRecBody(
        embedder=SequenceEmbedding(
            schema=schema,
            categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
            excluded_features=excluded_features,
        ),
        embedding_aggregator=PositionAwareAggregator(
            embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
            max_sequence_length=max_sequence_length,
            dropout=dropout,
        ),
        attn_mask_builder=DefaultAttentionMask(
            reference_feature_name=schema.item_id_feature_name,
            num_heads=num_heads,
        ),
        encoder=SasRecTransformerLayer(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_blocks=num_blocks,
            dropout=dropout,
            activation="relu",
        ),
        output_normalization=torch.nn.LayerNorm(embedding_dim),
    )
    padding_idx = schema.item_id_features.item().padding_value
    return SasRec(
        body=body,
        loss=CESampled(negative_labels_ignore_index=padding_idx),
    )

In [21]:
from replay.nn.sequential import SasRec

NUM_BLOCKS = 2
NUM_HEADS = 2
DROPOUT = 0.3

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
)

A universal PyTorch Lightning module is provided. It can work with any NN model.

In [22]:
from replay.nn.lightning.optimizer import OptimizerFactory
from replay.nn.lightning.scheduler import LRSchedulerFactory
from replay.nn.lightning import LightningModule

model = LightningModule(
    sasrec,
    optimizer_factory=OptimizerFactory(),
    lr_scheduler_factory=LRSchedulerFactory(),
)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.


In [23]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from replay.nn.lightning.callback import ComputeMetricsCallback


checkpoint_callback = ModelCheckpoint(
    dirpath="sasrec/checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
)

csv_logger = CSVLogger(save_dir="sasrec/logs/train", name="SasRec-example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | model | SasRec | 291 K  | train | 0    
-------------------------------------------------
291 K     Trainable params
0         Non-trainable params
291 K     Total params
1.164     Total estimated model params size (MB)
39        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 24. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.01672 (best 0.01672), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1


k              1        10        20         5
map     0.001821  0.004612  0.005541  0.003356
ndcg    0.001821  0.007332  0.010639  0.004158
recall  0.001821  0.016725  0.029641  0.006624





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.02120 (best 0.02120), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1


k              1        10        20         5
map     0.002153  0.006318  0.007965  0.004968
ndcg    0.002153  0.009738  0.015896  0.006501
recall  0.002153  0.021196  0.045869  0.011260





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.03527 (best 0.03527), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.003974  0.011496  0.013657  0.009411
ndcg    0.003974  0.016992  0.025222  0.011983
recall  0.003974  0.035271  0.068554  0.019871





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.05498 (best 0.05498), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1


k              1        10        20         5
map     0.006789  0.018064  0.020656  0.014989
ndcg    0.006789  0.026587  0.036257  0.019078
recall  0.006789  0.054976  0.093724  0.031628





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.06590 (best 0.06590), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1


k             1        10        20         5
map     0.00977  0.021511  0.025053  0.017536
ndcg    0.00977  0.031639  0.044647  0.021818
recall  0.00977  0.065905  0.117569  0.035105





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 1134: 'recall@10' reached 0.07567 (best 0.07567), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=5-step=1134.ckpt' as top 1


k             1        10        20         5
map     0.00828  0.023807  0.027919  0.019051
ndcg    0.00828  0.035731  0.051102  0.024120
recall  0.00828  0.075675  0.137274  0.039576





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 1323: 'recall@10' reached 0.09174 (best 0.09174), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=6-step=1323.ckpt' as top 1


k              1        10        20         5
map     0.010267  0.028863  0.033646  0.023646
ndcg    0.010267  0.043356  0.061162  0.030624
recall  0.010267  0.091737  0.162941  0.052161





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 1512: 'recall@10' reached 0.10267 (best 0.10267), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=7-step=1512.ckpt' as top 1


k              1        10        20         5
map     0.011591  0.032261  0.037579  0.026050
ndcg    0.011591  0.048441  0.068205  0.033239
recall  0.011591  0.102666  0.181653  0.055307





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1701: 'recall@10' reached 0.11542 (best 0.11542), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=8-step=1701.ckpt' as top 1


k              1        10        20         5
map     0.013744  0.036616  0.042138  0.029635
ndcg    0.013744  0.054690  0.074906  0.037490
recall  0.013744  0.115416  0.195562  0.061600





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1890: 'recall@10' reached 0.11724 (best 0.11724), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=9-step=1890.ckpt' as top 1


k              1        10        20         5
map     0.013413  0.037181  0.043070  0.030046
ndcg    0.013413  0.055596  0.077283  0.038208
recall  0.013413  0.117238  0.203511  0.063256





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 2079: 'recall@10' reached 0.13181 (best 0.13181), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=10-step=2079.ckpt' as top 1


k             1        10        20         5
map     0.01275  0.040170  0.045926  0.032458
ndcg    0.01275  0.061269  0.082412  0.042270
recall  0.01275  0.131810  0.215764  0.072363





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 2268: 'recall@10' reached 0.13247 (best 0.13247), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=11-step=2268.ckpt' as top 1


k              1        10        20         5
map     0.014903  0.042531  0.048871  0.035185
ndcg    0.014903  0.063251  0.086801  0.044995
recall  0.014903  0.132472  0.226528  0.075012





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 2457: 'recall@10' reached 0.13264 (best 0.13264), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=12-step=2457.ckpt' as top 1


k              1        10        20         5
map     0.014241  0.041087  0.047694  0.033104
ndcg    0.014241  0.062106  0.086651  0.042425
recall  0.014241  0.132638  0.230667  0.071038





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 2646: 'recall@10' reached 0.14506 (best 0.14506), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=13-step=2646.ckpt' as top 1


k              1        10        20         5
map     0.016062  0.045372  0.051773  0.036979
ndcg    0.016062  0.068309  0.091740  0.047681
recall  0.016062  0.145057  0.237953  0.080642





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 2835: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.045387  0.052134  0.037357
ndcg    0.016062  0.067864  0.092717  0.048324
recall  0.016062  0.142739  0.241596  0.082133





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 3024: 'recall@10' reached 0.14820 (best 0.14820), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=15-step=3024.ckpt' as top 1


k              1        10        20         5
map     0.016725  0.047788  0.053965  0.039670
ndcg    0.016725  0.071002  0.093764  0.051187
recall  0.016725  0.148203  0.238781  0.086604





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 3213: 'recall@10' reached 0.14986 (best 0.14986), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=16-step=3213.ckpt' as top 1


k              1        10        20         5
map     0.014572  0.045006  0.051598  0.036043
ndcg    0.014572  0.069084  0.093440  0.046948
recall  0.014572  0.149859  0.246895  0.080477





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 3402: 'recall@10' reached 0.15565 (best 0.15565), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=17-step=3402.ckpt' as top 1


k              1        10        20         5
map     0.017718  0.050106  0.056686  0.041599
ndcg    0.017718  0.074460  0.098556  0.053471
recall  0.017718  0.155655  0.251201  0.089916





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 3591: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020699  0.051415  0.058355  0.042913
ndcg    0.020699  0.075000  0.100671  0.054110
recall  0.020699  0.153833  0.256168  0.088591





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 3780: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.048864  0.055906  0.040572
ndcg    0.01689  0.072380  0.098151  0.051951
recall  0.01689  0.150687  0.252856  0.086769





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 3969: 'recall@10' reached 0.16079 (best 0.16079), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=20-step=3969.ckpt' as top 1


k              1        10        20         5
map     0.015897  0.048521  0.055597  0.038541
ndcg    0.015897  0.074276  0.100118  0.049739
recall  0.015897  0.160788  0.263123  0.084120





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 4158: 'recall@10' reached 0.16443 (best 0.16443), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=21-step=4158.ckpt' as top 1


k              1        10        20         5
map     0.017387  0.050823  0.057855  0.041370
ndcg    0.017387  0.076978  0.102714  0.053782
recall  0.017387  0.164431  0.266435  0.092068





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 4347: 'recall@10' reached 0.16658 (best 0.16658), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=22-step=4347.ckpt' as top 1


k              1        10        20         5
map     0.017553  0.051560  0.059105  0.041450
ndcg    0.017553  0.078076  0.105823  0.053591
recall  0.017553  0.166584  0.276867  0.090909





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 4536: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.049276  0.056962  0.040059
ndcg    0.016725  0.074350  0.102687  0.051703
recall  0.016725  0.158139  0.270906  0.087432





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 4725: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.049812  0.057295  0.039640
ndcg    0.015731  0.076213  0.103773  0.051387
recall  0.015731  0.164597  0.274218  0.087432





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 4914: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020864  0.052779  0.060587  0.043743
ndcg    0.020864  0.077838  0.106745  0.055498
recall  0.020864  0.161782  0.277033  0.091737





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 5103: 'recall@10' reached 0.16774 (best 0.16774), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=26-step=5103.ckpt' as top 1


k              1        10        20         5
map     0.021196  0.055582  0.063022  0.046120
ndcg    0.021196  0.081459  0.108908  0.058285
recall  0.021196  0.167743  0.277033  0.095546





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 5292: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.052605  0.059901  0.042987
ndcg    0.018049  0.079059  0.105941  0.055416
recall  0.018049  0.167412  0.274383  0.093559





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 5481: 'recall@10' reached 0.17221 (best 0.17221), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=28-step=5481.ckpt' as top 1


k             1        10        20         5
map     0.01954  0.054475  0.061714  0.044651
ndcg    0.01954  0.081605  0.108225  0.057536
recall  0.01954  0.172214  0.278026  0.097202





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 5670: 'recall@10' reached 0.17238 (best 0.17238), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=29-step=5670.ckpt' as top 1


k              1        10        20         5
map     0.019208  0.054595  0.061755  0.044731
ndcg    0.019208  0.081765  0.108175  0.057686
recall  0.019208  0.172380  0.277529  0.097533





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30, global step 5859: 'recall@10' reached 0.17602 (best 0.17602), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=30-step=5859.ckpt' as top 1


k              1        10        20         5
map     0.018712  0.055652  0.063132  0.045593
ndcg    0.018712  0.083439  0.111297  0.058785
recall  0.018712  0.176023  0.287465  0.099189





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31, global step 6048: 'recall@10' reached 0.17735 (best 0.17735), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=31-step=6048.ckpt' as top 1


k              1        10        20         5
map     0.020699  0.055747  0.063169  0.045518
ndcg    0.020699  0.083701  0.111019  0.058567
recall  0.020699  0.177347  0.285975  0.098857





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 6237: 'recall@10' reached 0.17950 (best 0.17950), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=32-step=6237.ckpt' as top 1


k              1        10        20         5
map     0.017884  0.054545  0.062291  0.043498
ndcg    0.017884  0.083230  0.111726  0.056082
recall  0.017884  0.179500  0.292764  0.094718





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33, global step 6426: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017884  0.052739  0.060368  0.042361
ndcg    0.017884  0.080163  0.108373  0.054720
recall  0.017884  0.172048  0.284484  0.092731





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34, global step 6615: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020036  0.054358  0.062097  0.044072
ndcg    0.020036  0.081625  0.110002  0.056245
recall  0.020036  0.173207  0.285809  0.093724





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 6804: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.055443  0.063010  0.045289
ndcg    0.018877  0.083255  0.111167  0.058415
recall  0.018877  0.176023  0.287134  0.098692





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36, global step 6993: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.053940  0.061837  0.044116
ndcg    0.018049  0.081692  0.110537  0.057375
recall  0.018049  0.174532  0.288790  0.098195





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37, global step 7182: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019043  0.055374  0.063095  0.045264
ndcg    0.019043  0.083240  0.111728  0.058378
recall  0.019043  0.176354  0.289783  0.098692





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 7371: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019208  0.055258  0.063082  0.044638
ndcg    0.019208  0.083313  0.112068  0.057127
recall  0.019208  0.177347  0.291605  0.095380





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39, global step 7560: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.052864  0.060610  0.042386
ndcg    0.015565  0.081476  0.109996  0.055713
recall  0.015565  0.177182  0.290611  0.096705





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40, global step 7749: 'recall@10' reached 0.18116 (best 0.18116), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=40-step=7749.ckpt' as top 1


k             1        10        20         5
map     0.02103  0.056563  0.064399  0.045742
ndcg    0.02103  0.085155  0.113897  0.058552
recall  0.02103  0.181156  0.295248  0.098029





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 7938: 'recall@10' reached 0.18314 (best 0.18314), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=41-step=7938.ckpt' as top 1


k              1        10        20         5
map     0.018546  0.056358  0.063802  0.045766
ndcg    0.018546  0.085569  0.113049  0.059571
recall  0.018546  0.183143  0.292598  0.102004





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42, global step 8127: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.055724  0.063385  0.045043
ndcg    0.017387  0.084880  0.112993  0.058761
recall  0.017387  0.182149  0.293757  0.100845





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43, global step 8316: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02103  0.057552  0.065197  0.047334
ndcg    0.02103  0.085784  0.113905  0.060815
recall  0.02103  0.179997  0.291770  0.102335





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 8505: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021527  0.057786  0.065236  0.047086
ndcg    0.021527  0.086560  0.114050  0.060260
recall  0.021527  0.182977  0.292433  0.100845





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45, global step 8694: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02103  0.058859  0.066416  0.048670
ndcg    0.02103  0.086934  0.114670  0.062123
recall  0.02103  0.180328  0.290445  0.103328





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46, global step 8883: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021196  0.058640  0.066437  0.048281
ndcg    0.021196  0.086654  0.115355  0.061452
recall  0.021196  0.179997  0.294088  0.101838





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 9072: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017884  0.055311  0.062937  0.044994
ndcg    0.017884  0.084131  0.112311  0.058815
recall  0.017884  0.180328  0.292598  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48, global step 9261: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02252  0.059221  0.066950  0.048993
ndcg    0.02252  0.087484  0.115818  0.062538
recall  0.02252  0.181818  0.294254  0.104322





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49, global step 9450: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019374  0.054636  0.062777  0.043766
ndcg    0.019374  0.083018  0.112939  0.056363
recall  0.019374  0.178341  0.297235  0.095214





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 9639: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021527  0.057515  0.065345  0.047411
ndcg    0.021527  0.085670  0.114708  0.060848
recall  0.021527  0.179831  0.295744  0.102335





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51, global step 9828: 'recall@10' reached 0.18381 (best 0.18381), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=51-step=9828.ckpt' as top 1


k              1        10        20         5
map     0.016228  0.053971  0.061898  0.042791
ndcg    0.016228  0.083820  0.112816  0.056346
recall  0.016228  0.183805  0.298725  0.098029





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52, global step 10017: 'recall@10' reached 0.18397 (best 0.18397), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=52-step=10017.ckpt' as top 1


k              1        10        20         5
map     0.017221  0.055084  0.062864  0.043782
ndcg    0.017221  0.084683  0.113217  0.056843
recall  0.017221  0.183971  0.297235  0.096870





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 10206: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.051400  0.059755  0.040995
ndcg    0.016228  0.079003  0.109795  0.053598
recall  0.016228  0.171386  0.293923  0.092399





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54, global step 10395: 'recall@10' reached 0.18977 (best 0.18977), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=54-step=10395.ckpt' as top 1


k              1        10        20         5
map     0.016062  0.055016  0.062803  0.043291
ndcg    0.016062  0.085978  0.114660  0.057208
recall  0.016062  0.189767  0.303858  0.100017





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55, global step 10584: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.055603  0.063455  0.044312
ndcg    0.016559  0.086089  0.115264  0.058382
recall  0.016559  0.188111  0.304686  0.101672





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 10773: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.056666  0.065019  0.046120
ndcg    0.015897  0.086568  0.117252  0.060768
recall  0.015897  0.185958  0.307832  0.105647





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57, global step 10962: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.052912  0.060822  0.041563
ndcg    0.014075  0.083098  0.112322  0.055303
recall  0.014075  0.184136  0.300546  0.097533





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58, global step 11151: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.056294  0.064118  0.045352
ndcg    0.017553  0.086382  0.115034  0.059558
recall  0.017553  0.186951  0.300546  0.103328





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 11340: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020202  0.058796  0.066610  0.048250
ndcg    0.020202  0.088402  0.116893  0.062635
recall  0.020202  0.187117  0.299884  0.106971





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60, global step 11529: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.055105  0.062987  0.043956
ndcg    0.015565  0.085279  0.114271  0.057872
recall  0.015565  0.186124  0.301374  0.100513





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61, global step 11718: 'recall@10' reached 0.18993 (best 0.18993), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=61-step=11718.ckpt' as top 1


k              1        10        20         5
map     0.019705  0.058739  0.066515  0.046732
ndcg    0.019705  0.088924  0.117293  0.059843
recall  0.019705  0.189932  0.302202  0.100017





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 11907: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02103  0.057541  0.065188  0.046760
ndcg    0.02103  0.086474  0.114707  0.059940
recall  0.02103  0.183474  0.295910  0.100513





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63, global step 12096: 'recall@10' was not in top 1


k            1        10        20         5
map     0.0154  0.054492  0.062930  0.044019
ndcg    0.0154  0.084395  0.115347  0.058577
recall  0.0154  0.184136  0.307004  0.103328





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64, global step 12285: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.055696  0.063894  0.044972
ndcg    0.017056  0.085546  0.115785  0.059192
recall  0.017056  0.185296  0.305680  0.102997





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 12474: 'recall@10' reached 0.19010 (best 0.19010), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=65-step=12474.ckpt' as top 1


k              1        10        20         5
map     0.020368  0.058863  0.067156  0.047817
ndcg    0.020368  0.089067  0.119288  0.061949
recall  0.020368  0.190098  0.309654  0.105481





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66, global step 12663: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.056398  0.064511  0.045289
ndcg    0.017387  0.086971  0.116779  0.059761
recall  0.017387  0.189104  0.307501  0.104322





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67, global step 12852: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.054063  0.062111  0.043294
ndcg    0.016559  0.083069  0.112713  0.056888
recall  0.016559  0.179831  0.297731  0.098692





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 13041: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.054147  0.062518  0.043084
ndcg    0.014738  0.084522  0.115173  0.057486
recall  0.014738  0.185958  0.307501  0.101838





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69, global step 13230: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014572  0.053841  0.062056  0.042375
ndcg    0.014572  0.084694  0.114874  0.056624
recall  0.014572  0.187945  0.307832  0.100513





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70, global step 13419: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.057374  0.065360  0.045949
ndcg    0.018215  0.087195  0.116514  0.059416
recall  0.018215  0.186786  0.303196  0.100679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 13608: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.053430  0.062299  0.042800
ndcg    0.015069  0.082754  0.115075  0.056689
recall  0.015069  0.180659  0.308495  0.099354





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72, global step 13797: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.051947  0.060668  0.041122
ndcg    0.015069  0.081285  0.113252  0.054642
recall  0.015069  0.179666  0.306508  0.096374





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73, global step 13986: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.055226  0.063822  0.043970
ndcg    0.017056  0.084768  0.116083  0.057631
recall  0.017056  0.183308  0.307170  0.099685





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 14175: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.054868  0.063210  0.043945
ndcg    0.016559  0.084754  0.115396  0.057854
recall  0.016559  0.184799  0.306508  0.100679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75, global step 14364: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.056079  0.064476  0.045388
ndcg    0.018712  0.085060  0.116212  0.058990
recall  0.018712  0.181818  0.306177  0.100845





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76, global step 14553: 'recall@10' reached 0.19208 (best 0.19208), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=76-step=14553.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.058268  0.066477  0.047345
ndcg    0.01689  0.089166  0.119288  0.062269
recall  0.01689  0.192085  0.311641  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 14742: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.056835  0.065016  0.045179
ndcg    0.017221  0.087301  0.117220  0.058947
recall  0.017221  0.189104  0.307667  0.101176





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78, global step 14931: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019043  0.057327  0.065519  0.045703
ndcg    0.019043  0.087922  0.118015  0.059380
recall  0.019043  0.190594  0.310151  0.101507





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79, global step 15120: 'recall@10' was not in top 1


k            1        10        20         5
map     0.0154  0.055800  0.064412  0.044053
ndcg    0.0154  0.086530  0.118361  0.057957
recall  0.0154  0.189104  0.315946  0.100513





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 15309: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.057478  0.066083  0.046705
ndcg    0.017553  0.087633  0.119265  0.061103
recall  0.017553  0.188276  0.313959  0.105315





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81, global step 15498: 'recall@10' reached 0.19242 (best 0.19242), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=81-step=15498.ckpt' as top 1


k              1        10        20         5
map     0.018546  0.058395  0.066565  0.047334
ndcg    0.018546  0.089227  0.119153  0.061846
recall  0.018546  0.192416  0.311144  0.106475





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82, global step 15687: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.057494  0.066217  0.046859
ndcg    0.018877  0.087218  0.119189  0.061111
recall  0.018877  0.186455  0.313297  0.104984





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 15876: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.055004  0.063232  0.043536
ndcg    0.015731  0.085670  0.115784  0.057410
recall  0.015731  0.188442  0.307832  0.100017





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84, global step 16065: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.056402  0.064621  0.045176
ndcg    0.016559  0.087195  0.117405  0.059646
recall  0.016559  0.190098  0.310151  0.104156





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85, global step 16254: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.056227  0.064425  0.045134
ndcg    0.018049  0.085875  0.115869  0.058678
recall  0.018049  0.185130  0.304024  0.100348





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 16443: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.056478  0.064728  0.044767
ndcg    0.017056  0.087622  0.117876  0.058963
recall  0.017056  0.191919  0.311972  0.102666





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87, global step 16632: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.055006  0.063239  0.043501
ndcg    0.016393  0.085265  0.115385  0.057249
recall  0.016393  0.186455  0.305845  0.099520





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88, global step 16821: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.054969  0.063319  0.043793
ndcg    0.015731  0.085875  0.116473  0.058445
recall  0.015731  0.189270  0.310647  0.103660





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 17010: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.056066  0.064136  0.045137
ndcg    0.016228  0.086528  0.116174  0.059815
recall  0.016228  0.188111  0.305845  0.104984





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90, global step 17199: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.056280  0.064094  0.045786
ndcg    0.018546  0.085766  0.114293  0.059990
recall  0.018546  0.184302  0.297235  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91, global step 17388: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.055372  0.063358  0.044185
ndcg    0.014738  0.086260  0.115443  0.058977
recall  0.014738  0.189270  0.304852  0.104488





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 17577: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014241  0.054296  0.062669  0.042915
ndcg    0.014241  0.085070  0.115748  0.057182
recall  0.014241  0.187945  0.309654  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93, global step 17766: 'recall@10' reached 0.19258 (best 0.19258), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=93-step=17766.ckpt' as top 1


k              1        10        20         5
map     0.014738  0.055029  0.063401  0.043363
ndcg    0.014738  0.086642  0.117437  0.057966
recall  0.014738  0.192582  0.314953  0.102997





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94, global step 17955: 'recall@10' reached 0.19424 (best 0.19424), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=94-step=17955.ckpt' as top 1


k              1        10        20         5
map     0.016062  0.057125  0.065547  0.045703
ndcg    0.016062  0.088716  0.119766  0.060754
recall  0.016062  0.194237  0.317768  0.107137





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 18144: 'recall@10' reached 0.19540 (best 0.19540), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=95-step=18144.ckpt' as top 1


k              1        10        20         5
map     0.014903  0.057034  0.065065  0.045899
ndcg    0.014903  0.088995  0.118507  0.061824
recall  0.014903  0.195397  0.312635  0.110946





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96, global step 18333: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.059276  0.067432  0.048427
ndcg    0.018712  0.090034  0.119863  0.063524
recall  0.018712  0.192416  0.310647  0.109952





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97, global step 18522: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.058283  0.066351  0.046426
ndcg    0.017221  0.089517  0.119293  0.060523
recall  0.017221  0.193906  0.312469  0.103660





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 18711: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.059661  0.067876  0.048173
ndcg    0.019705  0.090404  0.120122  0.062469
recall  0.019705  0.193078  0.310151  0.106475





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99, global step 18900: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.016062  0.055792  0.064140  0.044362
ndcg    0.016062  0.086294  0.117054  0.058292
recall  0.016062  0.188276  0.310647  0.101010



Now we can get the best model path stored in the checkpoint callback.

In [24]:
best_model_path = checkpoint_callback.best_model_path
best_model_path

'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=95-step=18144.ckpt'

## Inference

To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance.

In [25]:
import replay

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    excluded_features=None,
)

torch.serialization.add_safe_globals([
    replay.nn.lightning.optimizer.OptimizerFactory,
    replay.nn.lightning.scheduler.LRSchedulerFactory,
])

best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)
best_model.eval();

Configure `ParquetModule` for inference

In [28]:
item_id_len = pl.from_arrow(test_events)["item_id"].list.len().max()

inference_metadata = {"predict": {
        "user_id": {},
        "item_id": {
            "shape": item_id_len,
            "padding": tensor_schema["item_id"].padding_value
        }
    }
}

parquet_module = ParquetModule(
    predict_path=PREDICT_PATH,
    batch_size=BATCH_SIZE,
    metadata=inference_metadata,
    transforms=transforms,
)

inference_metadata

  parquet_module = ParquetModule(


{'predict': {'user_id': {}, 'item_id': {'shape': 2313, 'padding': 3706}}}

During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.


Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`.

In [29]:
from replay.nn.lightning.callback import PandasTopItemsCallback
from replay.nn.lightning.postprocessor import SeenItemsFilter

post_processors = [
    SeenItemsFilter(
        item_count = NUM_UNIQUE_ITEMS,
        seen_items_column = "seen_ids",
    )
]

csv_logger = CSVLogger(save_dir="sasrec/logs/test", name="SasRec-example")

TOPK = [1, 5, 10, 20]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=post_processors
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Predicting: |          | 0/? [00:00<?, ?it/s]



In [30]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,354,10.633862
0,0,1551,9.885081
0,0,1900,9.788022
0,0,1906,9.603766
0,0,1897,9.412257
...,...,...,...
6037,6039,2884,5.024795
6037,6039,284,5.020481
6037,6039,1502,4.998537
6037,6039,1245,4.976641


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [31]:
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.metrics.torch_metrics_builder import metrics_to_df

In [32]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score",
)(pandas_res, test_gt.explode("item_id"))

In [33]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.057469,0.108167,0.115451,0.096022
Precision,0.057469,0.026267,0.018408,0.034283
Recall,0.057469,0.26267,0.368168,0.171414


Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations

In [34]:
encoder.inverse_transform(pandas_res)

Unnamed: 0,user_id,item_id,score
0,1,364,10.633862
0,1,1688,9.885081
0,1,2081,9.788022
0,1,2087,9.603766
0,1,2078,9.412257
...,...,...,...
6037,6040,3100,5.024795
6037,6040,293,5.020481
6037,6040,1639,4.998537
6037,6040,1343,4.976641
