# Example of SasRec training/inference

In [1]:
from typing import Optional

import lightning as L
import pandas as pd

L.seed_everything(42)

Seed set to 42


42

## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.

---
**NOTE**

Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. 

---

In [2]:
interactions = pd.read_csv("./data/ml1m_ratings.dat", sep="\t", names=["user_id", "item_id","rating","timestamp"])
interactions = interactions.drop(columns=["rating"])

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,timestamp
1000138,6040,858,0
1000153,6040,2384,1
999873,6040,593,2
1000007,6040,1961,3
1000192,6040,2019,4
...,...,...,...
825793,4958,2399,446
825438,4958,1407,447
825724,4958,3264,448
825731,4958,2634,449


### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used.

In [4]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule

encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id", default_value="last"),
        LabelEncodingRule("item_id", default_value="last"),
    ]
)
interactions = interactions.sort_values(by="item_id", ascending=True)
encoded_interactions = encoder.fit_transform(interactions)
encoded_interactions

Unnamed: 0,timestamp,user_id,item_id
0,12,2011,0
1,68,4078,0
2,67,4123,0
3,12,983,0
4,140,2270,0
...,...,...,...
1000204,14,855,3705
1000205,90,1700,3705
1000206,70,936,3705
1000207,25,360,3705


### Split interactions into the train, validation and test datasets using LastNSplitter
We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity.

In [5]:
from replay.splitters import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
    drop_cold_users=True,
    drop_cold_items=True
)

test_events, test_gt = splitter.split(encoded_interactions)
validation_events, validation_gt = splitter.split(test_events)
train_events = validation_events

### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [6]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")
    return grouped_interactions

In [7]:
train_events = bake_data(train_events)

validation_events = bake_data(validation_events)
validation_gt = bake_data(validation_gt)

test_events = bake_data(test_events)
test_gt = bake_data(test_gt)

train_events

Unnamed: 0,user_id,timestamp,item_id
0,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2969, 1574, 1178, 957, 2147, 1658, 3177, 1117..."
1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1108, 1127, 1120, 2512, 1201, 2735, 1135, 110..."
2,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[579, 2651, 3301, 1788, 1781, 1327, 1174, 3429..."
3,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1120, 1025, 466, 3235, 3294, 1106, 253, 1108,..."
4,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2512, 858, 847, 346, 1158, 2007, 2651, 1050, ..."
...,...,...,...
6035,6035,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1574, 1703, 3206, 2183, 2235, 2480, 2375, 250..."
6036,6036,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1702, 672, 1175, 1848, 3275, 2932, 548, 802, ..."
6037,6037,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3165, 859, 1120, 1965, 1288, 346, 1007, 1066,..."
6038,6038,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[107, 275, 1886, 1139, 869, 886, 2872, 2809, 2..."


To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones.  

In [8]:
def add_gt_to_events(events_df, gt_df):
    gt_to_join = gt_df[["user_id", "item_id"]].rename(columns={"item_id": "ground_truth"})

    events_df = events_df.merge(gt_to_join, on="user_id", how="inner")
    return events_df

validation_events = add_gt_to_events(validation_events, validation_gt)
test_events = add_gt_to_events(test_events, test_gt)

In [9]:
from pathlib import Path

data_dir = Path("temp/data/")
data_dir.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = data_dir / "train.parquet"
VAL_PATH = data_dir / "val.parquet"
PREDICT_PATH = data_dir / "test.parquet"

ENCODER_PATH = data_dir / "encoder"

In [10]:
train_events.to_parquet(TRAIN_PATH)
validation_events.to_parquet(VAL_PATH)
test_events.to_parquet(PREDICT_PATH)

encoder.save(ENCODER_PATH)



# Prepare to model training
### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the NN models to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.

Note that **cardinality** is the number of unique values ​in the item catalog (vocabulary). **Padding value** is the next value after the last one.

In [11]:
from replay.data import FeatureHint, FeatureType
from replay.data.nn import TensorFeatureInfo, TensorSchema


EMBEDDING_DIM = 64

encoder = encoder.load(ENCODER_PATH)
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS,
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        )
    ]
)

In [12]:
import torch
import polars as pl

### Configure ParquetModule and transformation pipelines

The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s  "transform pipelines" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.

Internally this function creates the following transforms:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Rename features to match it with expected format by the model during training.
    3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.
    4. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.

**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1.

In [13]:
import copy

import torch

from replay.data.nn import TensorSchema
from replay.nn.transform import *

In [14]:
MAX_SEQ_LEN = 50

def make_sasrec_transforms(
    tensor_schema: TensorSchema, query_column: str = "query_id", num_negative_samples: int = 128,
) -> dict[str, list[torch.nn.Module]]:
    item_column = tensor_schema.item_id_feature_name
    vocab_size = tensor_schema[item_column].cardinality
    train_transforms = [
        ThresholdNegativeSamplingTransform(vocab_size, num_negative_samples),
        NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
        RenameTransform(
            {
                query_column: "query_id",
                f"{item_column}_mask": "padding_mask",
                "positive_labels_mask": "target_padding_mask",
            }
        ),
        UnsqueezeTransform("target_padding_mask", -1),
        UnsqueezeTransform("positive_labels", -1),
        GroupTransform({"feature_tensors": [item_column]}),
    ]

    val_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        GroupTransform({"feature_tensors": [item_column]}),
    ]
    test_transforms = copy.deepcopy(val_transforms)

    predict_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        CopyTransform({"item_id": "seen_ids"}),
        TrimTransform(seq_len=MAX_SEQ_LEN, feature_names=["item_id", "padding_mask"]),
        GroupTransform({"feature_tensors": ["item_id"]})
    ]

    transforms = {
        "train": train_transforms,
        "validate": val_transforms,
        "test": test_transforms,
        "predict": predict_transforms,
    }

    return transforms

In [15]:
transforms = make_sasrec_transforms(tensor_schema, query_column="user_id")

In [16]:
def create_meta(shape: int, gt_shape: Optional[int] = None):
    meta = {
        "user_id": {},
        "item_id": {"shape": shape, "padding": tensor_schema["item_id"].padding_value},
    }
    if gt_shape is not None:
        meta.update({"ground_truth": {"shape": gt_shape, "padding": -1}})

    return meta

train_metadata = {
    "train": create_meta(shape=MAX_SEQ_LEN+1),
    "validate": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),
}

In [17]:
from replay.data.nn import ParquetModule

BATCH_SIZE = 32

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=train_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.

#### Default Configuration

Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum.

In [18]:
from replay.nn.sequential import SasRec
from typing import Literal
def make_sasrec(
    schema: TensorSchema,
    embedding_dim: int = 192,
    num_heads: int = 4,
    num_blocks: int = 2,
    max_sequence_length: int = 50,
    dropout: float = 0.3,
    excluded_features: Optional[list[str]] = None,
    categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
) -> SasRec:
    from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer
    from replay.nn.agg import SumAggregator
    from replay.nn.embedding import SequenceEmbedding
    from replay.nn.loss import CE, CESampled
    from replay.nn.mask import DefaultAttentionMask
    from replay.nn.sequential.sasrec.agg import PositionAwareAggregator
    from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer
    excluded_features = [
        schema.query_id_feature_name,
        schema.timestamp_feature_name,
        *(excluded_features or []),
    ]
    excluded_features = list(set(excluded_features))
    body = SasRecBody(
        embedder=SequenceEmbedding(
            schema=schema,
            categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
            excluded_features=excluded_features,
        ),
        embedding_aggregator=PositionAwareAggregator(
            embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
            max_sequence_length=max_sequence_length,
            dropout=dropout,
        ),
        attn_mask_builder=DefaultAttentionMask(
            reference_feature_name=schema.item_id_feature_name,
            num_heads=num_heads,
        ),
        encoder=SasRecTransformerLayer(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_blocks=num_blocks,
            dropout=dropout,
            activation="relu",
        ),
        output_normalization=torch.nn.LayerNorm(embedding_dim),
    )
    padding_idx = schema.item_id_features.item().padding_value
    return SasRec(
        body=body,
        loss=CESampled(negative_labels_ignore_index=padding_idx),
    )

In [19]:
from replay.nn.sequential import SasRec

NUM_BLOCKS = 2
NUM_HEADS = 2
DROPOUT = 0.3

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
)

A universal PyTorch Lightning module is provided. It can work with any NN model.

In [20]:
from replay.nn.lightning.optimizer import OptimizerFactory
from replay.nn.lightning.scheduler import LRSchedulerFactory
from replay.nn.lightning import LightningModule

model = LightningModule(
    sasrec,
    optimizer_factory=OptimizerFactory(),
    lr_scheduler_factory=LRSchedulerFactory(),
)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.


In [21]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from replay.nn.lightning.callback import ComputeMetricsCallback


checkpoint_callback = ModelCheckpoint(
    dirpath="sasrec/checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
)

csv_logger = CSVLogger(save_dir="sasrec/logs/train", name="SasRec-example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | model | SasRec | 291 K  | train | 0    
-------------------------------------------------
291 K     Trainable params
0         Non-trainable params
291 K     Total params
1.164     Total estimated model

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 24. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.02600 (best 0.02600), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1


k              1        10        20         5
map     0.003809  0.008641  0.009838  0.007170
ndcg    0.003809  0.012624  0.017009  0.009015
recall  0.003809  0.025998  0.043385  0.014738





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.04123 (best 0.04123), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1


k              1        10        20         5
map     0.004471  0.013045  0.015481  0.010802
ndcg    0.004471  0.019540  0.028636  0.013965
recall  0.004471  0.041232  0.077662  0.023679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.06160 (best 0.06160), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.005796  0.018991  0.022229  0.015640
ndcg    0.005796  0.028848  0.040790  0.020623
recall  0.005796  0.061600  0.109124  0.035933





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.09306 (best 0.09306), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1


k              1        10        20         5
map     0.011591  0.029450  0.033816  0.023668
ndcg    0.011591  0.044022  0.060207  0.029842
recall  0.011591  0.093062  0.157642  0.048849





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.10432 (best 0.10432), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1


k              1        10        20         5
map     0.014406  0.034102  0.038966  0.027943
ndcg    0.014406  0.050178  0.068167  0.034932
recall  0.014406  0.104322  0.176023  0.056466





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 1134: 'recall@10' reached 0.11856 (best 0.11856), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=5-step=1134.ckpt' as top 1


k              1        10        20         5
map     0.015069  0.039028  0.044421  0.032660
ndcg    0.015069  0.057359  0.077222  0.041590
recall  0.015069  0.118563  0.197549  0.069051





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 1323: 'recall@10' reached 0.12850 (best 0.12850), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=6-step=1323.ckpt' as top 1


k              1        10        20         5
map     0.013578  0.039725  0.045572  0.032113
ndcg    0.013578  0.060170  0.081669  0.041615
recall  0.013578  0.128498  0.213943  0.070873





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 1512: 'recall@10' reached 0.13628 (best 0.13628), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=7-step=1512.ckpt' as top 1


k              1        10        20         5
map     0.013744  0.041970  0.048133  0.034056
ndcg    0.013744  0.063683  0.086377  0.044196
recall  0.013744  0.136281  0.226528  0.075344





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1701: 'recall@10' reached 0.14290 (best 0.14290), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=8-step=1701.ckpt' as top 1


k            1        10        20         5
map     0.0154  0.043551  0.049658  0.034815
ndcg    0.0154  0.066376  0.089052  0.045156
recall  0.0154  0.142904  0.233482  0.077165





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1890: 'recall@10' reached 0.14771 (best 0.14771), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=9-step=1890.ckpt' as top 1


k              1        10        20         5
map     0.015897  0.045990  0.052525  0.037691
ndcg    0.015897  0.069422  0.093374  0.049018
recall  0.015897  0.147707  0.242755  0.083954





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 2079: 'recall@10' reached 0.15665 (best 0.15665), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=10-step=2079.ckpt' as top 1


k              1        10        20         5
map     0.015069  0.047555  0.053994  0.038290
ndcg    0.015069  0.072657  0.096298  0.049866
recall  0.015069  0.156648  0.250538  0.085445





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 2268: 'recall@10' reached 0.15913 (best 0.15913), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=11-step=2268.ckpt' as top 1


k              1        10        20         5
map     0.015069  0.048480  0.055059  0.039286
ndcg    0.015069  0.073992  0.098522  0.051433
recall  0.015069  0.159132  0.257327  0.088756





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 2457: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.049002  0.055715  0.040219
ndcg    0.016062  0.074403  0.099077  0.052980
recall  0.016062  0.158967  0.256996  0.092399





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 2646: 'recall@10' reached 0.16261 (best 0.16261), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=13-step=2646.ckpt' as top 1


k              1        10        20         5
map     0.013413  0.048255  0.055170  0.039085
ndcg    0.013413  0.074641  0.100032  0.052075
recall  0.013413  0.162610  0.263454  0.092068





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 2835: 'recall@10' reached 0.16725 (best 0.16725), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=14-step=2835.ckpt' as top 1


k            1        10        20         5
map     0.0154  0.049686  0.056544  0.039457
ndcg    0.0154  0.076734  0.101941  0.051788
recall  0.0154  0.167246  0.267428  0.089750





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 3024: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.048757  0.056135  0.038823
ndcg    0.014075  0.075707  0.102860  0.051404
recall  0.014075  0.165756  0.273721  0.090081





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 3213: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.049146  0.056984  0.039479
ndcg    0.014738  0.075811  0.104528  0.052089
recall  0.014738  0.164928  0.278854  0.090909





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 3402: 'recall@10' reached 0.16923 (best 0.16923), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=17-step=3402.ckpt' as top 1


k              1        10        20         5
map     0.016393  0.052278  0.059972  0.042968
ndcg    0.016393  0.079308  0.107657  0.056539
recall  0.016393  0.169233  0.282000  0.098361





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 3591: 'recall@10' reached 0.17139 (best 0.17139), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=18-step=3591.ckpt' as top 1


k              1        10        20         5
map     0.014406  0.052097  0.059954  0.042593
ndcg    0.014406  0.079708  0.108702  0.056460
recall  0.014406  0.171386  0.286802  0.099023





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 3780: 'recall@10' reached 0.17453 (best 0.17453), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=19-step=3780.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.052701  0.060324  0.042212
ndcg    0.01689  0.080767  0.108776  0.055247
recall  0.01689  0.174532  0.285809  0.095380





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 3969: 'recall@10' reached 0.17619 (best 0.17619), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=20-step=3969.ckpt' as top 1


k              1        10        20         5
map     0.016228  0.052931  0.060778  0.042471
ndcg    0.016228  0.081350  0.110303  0.055893
recall  0.016228  0.176188  0.291439  0.097202





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 4158: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.053332  0.061150  0.043012
ndcg    0.016725  0.081196  0.109833  0.055971
recall  0.016725  0.174201  0.287796  0.095711





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 4347: 'recall@10' reached 0.17652 (best 0.17652), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=22-step=4347.ckpt' as top 1


k              1        10        20         5
map     0.016228  0.053808  0.061980  0.043641
ndcg    0.016228  0.082110  0.112207  0.057134
recall  0.016228  0.176519  0.296241  0.098526





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 4536: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.052708  0.060814  0.042110
ndcg    0.015731  0.081050  0.110902  0.054980
recall  0.015731  0.175857  0.294585  0.094386





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 4725: 'recall@10' reached 0.18761 (best 0.18761), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=24-step=4725.ckpt' as top 1


k              1        10        20         5
map     0.014738  0.054548  0.062254  0.043492
ndcg    0.014738  0.085244  0.113799  0.058275
recall  0.014738  0.187614  0.301540  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 4914: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.054281  0.062168  0.043440
ndcg    0.017056  0.083965  0.112965  0.057493
recall  0.017056  0.183308  0.298559  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 5103: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.055259  0.063404  0.045032
ndcg    0.016725  0.084200  0.114158  0.059252
recall  0.016725  0.180493  0.299553  0.102997





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 5292: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.056451  0.064289  0.045410
ndcg    0.017056  0.086645  0.115474  0.059782
recall  0.017056  0.187283  0.301871  0.103991





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 5481: 'recall@10' reached 0.18795 (best 0.18795), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=28-step=5481.ckpt' as top 1


k              1        10        20         5
map     0.018215  0.057782  0.065521  0.046923
ndcg    0.018215  0.087802  0.116287  0.061192
recall  0.018215  0.187945  0.301209  0.104984





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 5670: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.057289  0.065069  0.046972
ndcg    0.018049  0.086925  0.115655  0.061922
recall  0.018049  0.185296  0.299718  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30, global step 5859: 'recall@10' reached 0.19059 (best 0.19059), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=30-step=5859.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.057703  0.065550  0.046741
ndcg    0.01689  0.088384  0.117442  0.061577
recall  0.01689  0.190594  0.306508  0.107137





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31, global step 6048: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020864  0.058413  0.066243  0.047638
ndcg    0.020864  0.087791  0.116748  0.061294
recall  0.020864  0.186124  0.301540  0.103328





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 6237: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.055734  0.063912  0.044660
ndcg    0.01689  0.085570  0.115649  0.058616
recall  0.01689  0.185130  0.304686  0.101507





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33, global step 6426: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.056112  0.064240  0.045145
ndcg    0.015069  0.086726  0.116827  0.060128
recall  0.015069  0.188442  0.308495  0.106143





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34, global step 6615: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.056694  0.064550  0.045628
ndcg    0.016559  0.087253  0.116160  0.060150
recall  0.016559  0.189270  0.304189  0.104819





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 6804: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.057012  0.065202  0.045725
ndcg    0.017553  0.086973  0.116986  0.059551
recall  0.017553  0.186951  0.306011  0.102004





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36, global step 6993: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.056379  0.064274  0.044778
ndcg    0.015897  0.087131  0.116005  0.058833
recall  0.015897  0.189767  0.304189  0.101838





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37, global step 7182: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014406  0.055410  0.063578  0.044444
ndcg    0.014406  0.086106  0.115953  0.059166
recall  0.014406  0.188442  0.306673  0.104322





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 7371: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.055842  0.064284  0.044961
ndcg    0.01689  0.086090  0.117241  0.059424
recall  0.01689  0.187117  0.311144  0.103991





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39, global step 7560: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.055978  0.064505  0.044952
ndcg    0.016559  0.086321  0.117774  0.059351
recall  0.016559  0.187614  0.312800  0.103660





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40, global step 7749: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.055956  0.064227  0.044894
ndcg    0.015731  0.086860  0.117139  0.059826
recall  0.015731  0.189932  0.309985  0.105812





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 7938: 'recall@10' reached 0.19424 (best 0.19424), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=41-step=7938.ckpt' as top 1


k              1        10        20         5
map     0.016725  0.058682  0.066594  0.047483
ndcg    0.016725  0.089991  0.119298  0.062600
recall  0.016725  0.194237  0.311144  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42, global step 8127: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.058953  0.066755  0.047949
ndcg    0.018877  0.089925  0.118736  0.063012
recall  0.018877  0.193244  0.307998  0.109455





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43, global step 8316: 'recall@10' reached 0.19540 (best 0.19540), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=43-step=8316.ckpt' as top 1


k              1        10        20         5
map     0.016228  0.057911  0.065847  0.046236
ndcg    0.016228  0.089614  0.118732  0.061085
recall  0.016228  0.195397  0.310979  0.106640





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 8505: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.057262  0.065342  0.045758
ndcg    0.01689  0.088125  0.117919  0.060112
recall  0.01689  0.191091  0.309654  0.104156





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45, global step 8694: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017718  0.056795  0.065017  0.045548
ndcg    0.017718  0.087349  0.117596  0.059593
recall  0.017718  0.189767  0.309985  0.102832





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46, global step 8883: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.058306  0.066419  0.046677
ndcg    0.018877  0.089214  0.118988  0.060662
recall  0.018877  0.192747  0.310979  0.103660





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 9072: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019043  0.059471  0.067703  0.048537
ndcg    0.019043  0.089729  0.120014  0.063102
recall  0.019043  0.190429  0.310813  0.107799





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48, global step 9261: 'recall@10' reached 0.19954 (best 0.19954), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=48-step=9261.ckpt' as top 1


k              1        10        20         5
map     0.018712  0.059989  0.067811  0.047690
ndcg    0.018712  0.092095  0.121105  0.062028
recall  0.018712  0.199536  0.315284  0.105978





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49, global step 9450: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.057403  0.065244  0.046156
ndcg    0.015897  0.088994  0.117689  0.061298
recall  0.015897  0.194403  0.308164  0.107799





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 9639: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.058007  0.066137  0.046639
ndcg    0.016559  0.089610  0.119465  0.061701
recall  0.016559  0.195065  0.313628  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51, global step 9828: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017884  0.060096  0.068180  0.048579
ndcg    0.017884  0.091844  0.121474  0.063693
recall  0.017884  0.197715  0.315284  0.110118





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52, global step 10017: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.060044  0.067995  0.048615
ndcg    0.018546  0.091711  0.120849  0.063863
recall  0.018546  0.197218  0.312800  0.110780





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 10206: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.056941  0.064776  0.045681
ndcg    0.01689  0.087978  0.117187  0.060459
recall  0.01689  0.191588  0.308495  0.105978





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54, global step 10395: 'recall@10' was not in top 1


k              1        10        20         5
map     0.013082  0.055586  0.063704  0.044064
ndcg    0.013082  0.088118  0.117823  0.059769
recall  0.013082  0.196721  0.314456  0.108130





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55, global step 10584: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.058587  0.066732  0.046975
ndcg    0.015731  0.090632  0.120670  0.062228
recall  0.015731  0.197384  0.316940  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 10773: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.060027  0.068222  0.048245
ndcg    0.018049  0.091745  0.121621  0.063010
recall  0.018049  0.197549  0.315781  0.108296





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57, global step 10962: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.058853  0.067070  0.047533
ndcg    0.01689  0.090009  0.120080  0.062335
recall  0.01689  0.193741  0.312966  0.107634





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58, global step 11151: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.058854  0.066957  0.047312
ndcg    0.017056  0.090743  0.120615  0.062533
recall  0.017056  0.197052  0.315946  0.109290





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 11340: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.057655  0.065564  0.045846
ndcg    0.014738  0.090272  0.119410  0.061386
recall  0.014738  0.199040  0.314953  0.109124





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60, global step 11529: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.056750  0.065363  0.045719
ndcg    0.014075  0.088539  0.120289  0.061469
recall  0.014075  0.194237  0.320583  0.109786





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61, global step 11718: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.057128  0.065330  0.045341
ndcg    0.016062  0.088542  0.118705  0.059647
recall  0.016062  0.193575  0.313462  0.103494





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 11907: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.057562  0.065397  0.046625
ndcg    0.015897  0.089010  0.118054  0.061937
recall  0.015897  0.193906  0.309820  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63, global step 12096: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.058018  0.066210  0.046465
ndcg    0.016393  0.089751  0.119847  0.061338
recall  0.016393  0.195728  0.315284  0.106971





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64, global step 12285: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01954  0.060372  0.068813  0.048435
ndcg    0.01954  0.091389  0.122318  0.062281
recall  0.01954  0.195065  0.317768  0.104653





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 12474: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019374  0.060983  0.069356  0.050050
ndcg    0.019374  0.091877  0.122525  0.065187
recall  0.019374  0.194569  0.316112  0.111608





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66, global step 12663: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.059624  0.067524  0.047905
ndcg    0.015565  0.091913  0.121071  0.063243
recall  0.015565  0.199371  0.315450  0.110118





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67, global step 12852: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.060001  0.068111  0.048617
ndcg    0.017221  0.091938  0.121630  0.063923
recall  0.017221  0.198377  0.316112  0.110780





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 13041: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.059211  0.067240  0.047403
ndcg    0.01689  0.091487  0.120996  0.062539
recall  0.01689  0.199205  0.316443  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69, global step 13230: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.057983  0.066004  0.046432
ndcg    0.015069  0.090190  0.119443  0.061845
recall  0.015069  0.197549  0.313297  0.109124





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70, global step 13419: 'recall@10' was not in top 1


k            1        10        20         5
map     0.0154  0.058463  0.066697  0.046241
ndcg    0.0154  0.091022  0.121225  0.061375
recall  0.0154  0.199536  0.319424  0.107799





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 13608: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.058550  0.066770  0.046801
ndcg    0.016228  0.090833  0.121076  0.062164
recall  0.016228  0.198543  0.318761  0.109455





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72, global step 13797: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.059072  0.067356  0.047847
ndcg    0.016725  0.090961  0.121313  0.063478
recall  0.016725  0.197052  0.317437  0.111442





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73, global step 13986: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.057886  0.066128  0.046843
ndcg    0.016062  0.089835  0.120213  0.062757
recall  0.016062  0.196225  0.317105  0.111773





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 14175: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.056452  0.064627  0.045272
ndcg    0.014075  0.088459  0.118374  0.060992
recall  0.014075  0.195065  0.313628  0.109290





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75, global step 14364: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.057871  0.066263  0.046506
ndcg    0.015897  0.089796  0.120637  0.062177
recall  0.015897  0.196059  0.318596  0.110449





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76, global step 14553: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.057843  0.066188  0.046128
ndcg    0.015731  0.089870  0.120403  0.061518
recall  0.015731  0.196556  0.317602  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 14742: 'recall@10' reached 0.20070 (best 0.20070), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=77-step=14742.ckpt' as top 1


k              1        10        20         5
map     0.015565  0.058428  0.066873  0.046520
ndcg    0.015565  0.091225  0.122243  0.062051
recall  0.015565  0.200695  0.323895  0.109786





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78, global step 14931: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.058356  0.066861  0.047097
ndcg    0.015731  0.090471  0.121758  0.062987
recall  0.015731  0.197384  0.321742  0.111939





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79, global step 15120: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015234  0.058665  0.066969  0.047533
ndcg    0.015234  0.090841  0.121459  0.063796
recall  0.015234  0.197549  0.319424  0.113761





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 15309: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014241  0.057717  0.066112  0.046848
ndcg    0.014241  0.089724  0.120434  0.063072
recall  0.014241  0.196059  0.317768  0.112933





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81, global step 15498: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.058715  0.067088  0.047574
ndcg    0.015897  0.090870  0.121661  0.063503
recall  0.015897  0.197880  0.320252  0.112436





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82, global step 15687: 'recall@10' reached 0.20136 (best 0.20136), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=82-step=15687.ckpt' as top 1


k              1        10        20         5
map     0.018215  0.060399  0.068577  0.048294
ndcg    0.018215  0.092890  0.123219  0.063458
recall  0.018215  0.201358  0.322404  0.110118





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 15876: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017884  0.060194  0.068438  0.048913
ndcg    0.017884  0.092251  0.122580  0.064651
recall  0.017884  0.199040  0.319589  0.113098





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84, global step 16065: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014903  0.057568  0.065958  0.045501
ndcg    0.014903  0.090112  0.120842  0.060656
recall  0.014903  0.198708  0.320583  0.107137





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85, global step 16254: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.056964  0.065360  0.045926
ndcg    0.014075  0.089438  0.120420  0.062302
recall  0.014075  0.197549  0.320914  0.112767





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 16443: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01391  0.057816  0.065742  0.046271
ndcg    0.01391  0.090904  0.120013  0.062693
recall  0.01391  0.201027  0.316609  0.113264





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87, global step 16632: 'recall@10' was not in top 1


k              1        10        20         5
map     0.013578  0.057373  0.065635  0.045333
ndcg    0.013578  0.090612  0.121046  0.061214
recall  0.013578  0.201358  0.322404  0.109952





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88, global step 16821: 'recall@10' was not in top 1


k              1        10        20         5
map     0.013413  0.055390  0.064059  0.044022
ndcg    0.013413  0.087193  0.118988  0.059442
recall  0.013413  0.193078  0.319258  0.106806





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 17010: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014406  0.056844  0.065403  0.045808
ndcg    0.014406  0.088346  0.119670  0.061334
recall  0.014406  0.193078  0.317271  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90, global step 17199: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.058196  0.066419  0.046318
ndcg    0.015069  0.091177  0.121450  0.062288
recall  0.015069  0.201027  0.321411  0.111442





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91, global step 17388: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014572  0.056919  0.065539  0.044812
ndcg    0.014572  0.089893  0.121401  0.060241
recall  0.014572  0.200199  0.325054  0.107799





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 17577: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01275  0.056105  0.064367  0.044282
ndcg    0.01275  0.089401  0.119750  0.060408
recall  0.01275  0.200530  0.321080  0.110118





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93, global step 17766: 'recall@10' reached 0.20252 (best 0.20252), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=93-step=17766.ckpt' as top 1


k              1        10        20         5
map     0.015234  0.059347  0.067801  0.047287
ndcg    0.015234  0.092435  0.123298  0.063144
recall  0.015234  0.202517  0.324723  0.111773





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94, global step 17955: 'recall@10' reached 0.20301 (best 0.20301), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=94-step=17955.ckpt' as top 1


k              1        10        20         5
map     0.012585  0.057640  0.065718  0.045631
ndcg    0.012585  0.091231  0.121076  0.061888
recall  0.012585  0.203014  0.321908  0.111773





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 18144: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.060236  0.068600  0.048573
ndcg    0.017221  0.092803  0.123586  0.064344
recall  0.017221  0.201192  0.323564  0.112767





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96, global step 18333: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01391  0.058430  0.067102  0.047061
ndcg    0.01391  0.091398  0.123278  0.063422
recall  0.01391  0.201027  0.327703  0.113595





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97, global step 18522: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017718  0.059870  0.068399  0.048507
ndcg    0.017718  0.092097  0.123440  0.064068
recall  0.017718  0.199702  0.324226  0.111939





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 18711: 'recall@10' reached 0.20351 (best 0.20351), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=98-step=18711-v1.ckpt' as top 1


k              1        10        20         5
map     0.016393  0.060000  0.068277  0.048098
ndcg    0.016393  0.093104  0.123555  0.063841
recall  0.016393  0.203511  0.324557  0.112105





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99, global step 18900: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.015731  0.057897  0.066367  0.046677
ndcg    0.015731  0.089911  0.120960  0.062154
recall  0.015731  0.196721  0.319921  0.109621



Now we can get the best model path stored in the checkpoint callback.

In [22]:
best_model_path = checkpoint_callback.best_model_path
best_model_path

'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=98-step=18711-v1.ckpt'

## Inference

To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance.

In [23]:
import replay

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    excluded_features=None,
)

torch.serialization.add_safe_globals([
    replay.nn.lightning.optimizer.OptimizerFactory,
    replay.nn.lightning.scheduler.LRSchedulerFactory,
])

best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)
best_model.eval();

Configure `ParquetModule` for inference

In [24]:
item_id_len = pl.from_arrow(test_events)["item_id"].list.len().max()

inference_metadata = {"predict": {
        "user_id": {},
        "item_id": {
            "shape": item_id_len,
            "padding": tensor_schema["item_id"].padding_value
        }
    }
}

parquet_module = ParquetModule(
    predict_path=PREDICT_PATH,
    batch_size=BATCH_SIZE,
    metadata=inference_metadata,
    transforms=transforms,
)

inference_metadata

  parquet_module = ParquetModule(


{'predict': {'user_id': {}, 'item_id': {'shape': 2313, 'padding': 3706}}}

During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.


Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`.

In [25]:
from replay.nn.lightning.callback import PandasTopItemsCallback
from replay.nn.lightning.postprocessor import SeenItemsFilter

post_processors = [
    SeenItemsFilter(
        item_count = NUM_UNIQUE_ITEMS,
        seen_items_column = "seen_ids",
    )
]

csv_logger = CSVLogger(save_dir="sasrec/logs/test", name="SasRec-example")

TOPK = [1, 5, 10, 20]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=post_processors
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Predicting: |          | 0/? [00:00<?, ?it/s]



In [26]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,354,12.515966
0,0,1551,11.432517
0,0,2201,11.330788
0,0,1900,11.323733
0,0,1897,10.690818
...,...,...,...
6037,6039,1384,5.390629
6037,6039,346,5.335107
6037,6039,3043,5.333342
6037,6039,15,5.32231


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [27]:
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.metrics.torch_metrics_builder import metrics_to_df

In [28]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score",
)(pandas_res, test_gt.explode("item_id"))

In [29]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.058629,0.113471,0.121447,0.101673
Precision,0.058629,0.027095,0.019328,0.03637
Recall,0.058629,0.270951,0.386552,0.181848


Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations

In [30]:
encoder.inverse_transform(pandas_res)

Unnamed: 0,user_id,item_id,score
0,1,364,12.515966
0,1,1688,11.432517
0,1,2394,11.330788
0,1,2081,11.323733
0,1,2078,10.690818
...,...,...,...
6037,6040,1500,5.390629
6037,6040,356,5.335107
6037,6040,3267,5.333342
6037,6040,16,5.32231
