# Example of SasRec training/inference

In [1]:
from typing import Optional

import lightning as L
import pandas as pd

L.seed_everything(42)

Seed set to 42


42

## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.

---
**NOTE**

Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. 

---

In [2]:
interactions = pd.read_csv("./data/ml1m_ratings.dat", sep="\t", names=["user_id", "item_id","rating","timestamp"])
interactions = interactions.drop(columns=["rating"])

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,timestamp
1000138,6040,858,0
1000153,6040,2384,1
999873,6040,593,2
1000007,6040,1961,3
1000192,6040,2019,4
...,...,...,...
825793,4958,2399,446
825438,4958,1407,447
825724,4958,3264,448
825731,4958,2634,449


### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used.

In [4]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule

encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id", default_value="last"),
        LabelEncodingRule("item_id", default_value="last"),
    ]
)
interactions = interactions.sort_values(by="item_id", ascending=True)
encoded_interactions = encoder.fit_transform(interactions)
encoded_interactions

Unnamed: 0,timestamp,user_id,item_id
0,12,2011,0
1,68,4078,0
2,67,4123,0
3,12,983,0
4,140,2270,0
...,...,...,...
1000204,14,855,3705
1000205,90,1700,3705
1000206,70,936,3705
1000207,25,360,3705


### Split interactions into the train, validation and test datasets using LastNSplitter
We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity.

In [5]:
from replay.splitters import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
    drop_cold_users=True,
    drop_cold_items=True
)

test_events, test_gt = splitter.split(encoded_interactions)
validation_events, validation_gt = splitter.split(test_events)
train_events = validation_events

### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [6]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")
    return grouped_interactions

In [7]:
train_events = bake_data(train_events)

validation_events = bake_data(validation_events)
validation_gt = bake_data(validation_gt)

test_events = bake_data(test_events)
test_gt = bake_data(test_gt)

train_events

Unnamed: 0,user_id,timestamp,item_id
0,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2969, 1574, 1178, 957, 2147, 1658, 3177, 1117..."
1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1108, 1127, 1120, 2512, 1201, 2735, 1135, 110..."
2,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[579, 2651, 3301, 1788, 1781, 1327, 1174, 3429..."
3,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1120, 1025, 466, 3235, 3294, 1106, 253, 1108,..."
4,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2512, 858, 847, 346, 1158, 2007, 2651, 1050, ..."
...,...,...,...
6035,6035,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1574, 1703, 3206, 2183, 2235, 2480, 2375, 250..."
6036,6036,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1702, 672, 1175, 1848, 3275, 2932, 548, 802, ..."
6037,6037,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3165, 859, 1120, 1965, 1288, 346, 1007, 1066,..."
6038,6038,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[107, 275, 1886, 1139, 869, 886, 2872, 2809, 2..."


To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones.  

In [8]:
def add_gt_to_events(events_df, gt_df):
    gt_to_join = gt_df[["user_id", "item_id"]].rename(columns={"item_id": "ground_truth"})

    events_df = events_df.merge(gt_to_join, on="user_id", how="inner")
    return events_df

validation_events = add_gt_to_events(validation_events, validation_gt)
test_events = add_gt_to_events(test_events, test_gt)

In [9]:
from pathlib import Path

data_dir = Path("temp/data/")
data_dir.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = data_dir / "train.parquet"
VAL_PATH = data_dir / "val.parquet"
PREDICT_PATH = data_dir / "test.parquet"

ENCODER_PATH = data_dir / "encoder"

In [10]:
train_events.to_parquet(TRAIN_PATH)
validation_events.to_parquet(VAL_PATH)
test_events.to_parquet(PREDICT_PATH)

encoder.save(ENCODER_PATH)



# Prepare to model training
### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the NN models to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.

Note that **cardinality** is the number of unique values ​in the item catalog (vocabulary). **Padding value** is the next value after the last one.

In [11]:
from replay.data import FeatureHint, FeatureType
from replay.data.nn import TensorFeatureInfo, TensorSchema


EMBEDDING_DIM = 64

encoder = encoder.load(ENCODER_PATH)
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS,
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        )
    ]
)

In [12]:
import torch
import polars as pl

### Configure ParquetModule and transformation pipelines

The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s  "transform pipelines" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.

Internally this function creates the following transforms:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Rename features to match it with expected format by the model during training.
    3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.
    4. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.

**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1.

In [13]:
import copy

import torch

from replay.data.nn import TensorSchema
from replay.nn.transform import *

In [14]:
MAX_SEQ_LEN = 50

def make_sasrec_transforms(
    tensor_schema: TensorSchema, query_column: str = "query_id", num_negative_samples: int = 128,
) -> dict[str, list[torch.nn.Module]]:
    item_column = tensor_schema.item_id_feature_name
    vocab_size = tensor_schema[item_column].cardinality
    train_transforms = [
        FrequencyNegativeSamplingTransform(vocab_size, num_negative_samples),
        NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
        RenameTransform(
            {
                query_column: "query_id",
                f"{item_column}_mask": "padding_mask",
                "positive_labels_mask": "target_padding_mask",
            }
        ),
        UnsqueezeTransform("target_padding_mask", -1),
        UnsqueezeTransform("positive_labels", -1),
        GroupTransform({"feature_tensors": [item_column]}),
    ]

    val_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        GroupTransform({"feature_tensors": [item_column]}),
    ]
    test_transforms = copy.deepcopy(val_transforms)

    predict_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        CopyTransform({"item_id": "seen_ids"}),
        TrimTransform(seq_len=MAX_SEQ_LEN, feature_names=["item_id", "padding_mask"]),
        GroupTransform({"feature_tensors": ["item_id"]})
    ]

    transforms = {
        "train": train_transforms,
        "validate": val_transforms,
        "test": test_transforms,
        "predict": predict_transforms,
    }

    return transforms

In [15]:
transforms = make_sasrec_transforms(tensor_schema, query_column="user_id")

In [16]:
def create_meta(shape: int, gt_shape: Optional[int] = None):
    meta = {
        "user_id": {},
        "item_id": {"shape": shape, "padding": tensor_schema["item_id"].padding_value},
    }
    if gt_shape is not None:
        meta.update({"ground_truth": {"shape": gt_shape, "padding": -1}})

    return meta

train_metadata = {
    "train": create_meta(shape=MAX_SEQ_LEN+1),
    "validate": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),
}

In [17]:
from replay.data.nn import ParquetModule

BATCH_SIZE = 32

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=train_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.

#### Default Configuration

Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum.

In [18]:
from replay.nn.sequential import SasRec
from typing import Literal
def make_sasrec(
    schema: TensorSchema,
    embedding_dim: int = 192,
    num_heads: int = 4,
    num_blocks: int = 2,
    max_sequence_length: int = 50,
    dropout: float = 0.3,
    excluded_features: Optional[list[str]] = None,
    categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
) -> SasRec:
    from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer
    from replay.nn.agg import SumAggregator
    from replay.nn.embedding import SequenceEmbedding
    from replay.nn.loss import CE, CESampled
    from replay.nn.mask import DefaultAttentionMask
    from replay.nn.sequential.sasrec.agg import PositionAwareAggregator
    from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer
    excluded_features = [
        schema.query_id_feature_name,
        schema.timestamp_feature_name,
        *(excluded_features or []),
    ]
    excluded_features = list(set(excluded_features))
    body = SasRecBody(
        embedder=SequenceEmbedding(
            schema=schema,
            categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
            excluded_features=excluded_features,
        ),
        embedding_aggregator=PositionAwareAggregator(
            embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
            max_sequence_length=max_sequence_length,
            dropout=dropout,
        ),
        attn_mask_builder=DefaultAttentionMask(
            reference_feature_name=schema.item_id_feature_name,
            num_heads=num_heads,
        ),
        encoder=SasRecTransformerLayer(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_blocks=num_blocks,
            dropout=dropout,
            activation="relu",
        ),
        output_normalization=torch.nn.LayerNorm(embedding_dim),
    )
    padding_idx = schema.item_id_features.item().padding_value
    return SasRec(
        body=body,
        loss=CESampled(negative_labels_ignore_index=padding_idx),
    )

In [19]:
from replay.nn.sequential import SasRec

NUM_BLOCKS = 2
NUM_HEADS = 2
DROPOUT = 0.3

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
)

A universal PyTorch Lightning module is provided. It can work with any NN model.

In [20]:
from replay.nn.lightning.optimizer import OptimizerFactory
from replay.nn.lightning.scheduler import LRSchedulerFactory
from replay.nn.lightning import LightningModule

model = LightningModule(
    sasrec,
    optimizer_factory=OptimizerFactory(),
    lr_scheduler_factory=LRSchedulerFactory(),
)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.


In [21]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from replay.nn.lightning.callback import ComputeMetricsCallback


checkpoint_callback = ModelCheckpoint(
    dirpath="sasrec/checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
)

csv_logger = CSVLogger(save_dir="sasrec/logs/train", name="SasRec-example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | model | SasRec | 291 K  | train | 0    
-------------------------------------------------
291 K     Trainable params
0         Non-trainable params
291 K     Total params
1.164     Total estimated model

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 24. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.01689 (best 0.01689), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1


k              1        10        20         5
map     0.002318  0.005405  0.006534  0.004410
ndcg    0.002318  0.008041  0.012255  0.005600
recall  0.002318  0.016890  0.033780  0.009273





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.02120 (best 0.02120), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1


k              1        10        20         5
map     0.002153  0.006546  0.007929  0.005200
ndcg    0.002153  0.009909  0.015046  0.006612
recall  0.002153  0.021196  0.041729  0.010929





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.04355 (best 0.04355), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.005464  0.013997  0.016224  0.011288
ndcg    0.005464  0.020772  0.029111  0.014165
recall  0.005464  0.043550  0.077000  0.023017





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.05895 (best 0.05895), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1


k              1        10        20         5
map     0.007948  0.019991  0.023329  0.016827
ndcg    0.007948  0.028989  0.041301  0.021203
recall  0.007948  0.058950  0.107965  0.034608





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.07783 (best 0.07783), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1


k              1        10        20         5
map     0.010432  0.024561  0.028634  0.019697
ndcg    0.010432  0.036718  0.051848  0.024746
recall  0.010432  0.077827  0.138268  0.040404





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 1134: 'recall@10' reached 0.08677 (best 0.08677), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=5-step=1134.ckpt' as top 1


k              1        10        20         5
map     0.008942  0.026988  0.031543  0.021858
ndcg    0.008942  0.040761  0.057457  0.028208
recall  0.008942  0.086769  0.153005  0.047690





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 1323: 'recall@10' reached 0.10432 (best 0.10432), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=6-step=1323.ckpt' as top 1


k              1        10        20         5
map     0.010267  0.030852  0.035828  0.024212
ndcg    0.010267  0.047686  0.066146  0.031392
recall  0.010267  0.104322  0.178010  0.053486





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 1512: 'recall@10' reached 0.10946 (best 0.10946), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=7-step=1512.ckpt' as top 1


k              1        10        20         5
map     0.011426  0.033499  0.038715  0.026793
ndcg    0.011426  0.050930  0.070236  0.034412
recall  0.011426  0.109455  0.186455  0.057791





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1701: 'recall@10' reached 0.12254 (best 0.12254), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=8-step=1701.ckpt' as top 1


k              1        10        20         5
map     0.010929  0.036465  0.042245  0.028962
ndcg    0.010929  0.056240  0.077519  0.037795
recall  0.010929  0.122537  0.207154  0.064911





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1890: 'recall@10' reached 0.12568 (best 0.12568), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=9-step=1890.ckpt' as top 1


k              1        10        20         5
map     0.013578  0.038311  0.044473  0.030303
ndcg    0.013578  0.058303  0.080993  0.038598
recall  0.013578  0.125683  0.215930  0.064083





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 2079: 'recall@10' reached 0.12916 (best 0.12916), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=10-step=2079.ckpt' as top 1


k              1        10        20         5
map     0.014738  0.041131  0.047581  0.033736
ndcg    0.014738  0.061412  0.085261  0.043231
recall  0.014738  0.129160  0.224209  0.072363





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 2268: 'recall@10' reached 0.13761 (best 0.13761), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=11-step=2268.ckpt' as top 1


k              1        10        20         5
map     0.016559  0.044533  0.050794  0.036789
ndcg    0.016559  0.065973  0.089088  0.046908
recall  0.016559  0.137606  0.229674  0.077993





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 2457: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014903  0.042670  0.048923  0.034702
ndcg    0.014903  0.064088  0.087250  0.044486
recall  0.014903  0.135784  0.228183  0.074516





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 2646: 'recall@10' reached 0.14936 (best 0.14936), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=13-step=2646.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.046893  0.052981  0.038058
ndcg    0.01689  0.070470  0.092956  0.048899
recall  0.01689  0.149362  0.238947  0.082298





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 2835: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015234  0.045580  0.052094  0.037258
ndcg    0.015234  0.068618  0.092661  0.048346
recall  0.015234  0.145388  0.241100  0.082464





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 3024: 'recall@10' reached 0.15185 (best 0.15185), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=15-step=3024.ckpt' as top 1


k              1        10        20         5
map     0.017056  0.048151  0.054541  0.039700
ndcg    0.017056  0.072074  0.095709  0.051391
recall  0.017056  0.151846  0.246067  0.087432





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 3213: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.045723  0.052304  0.036598
ndcg    0.014075  0.070079  0.094422  0.047776
recall  0.014075  0.151515  0.248551  0.082133





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 3402: 'recall@10' reached 0.15416 (best 0.15416), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=17-step=3402.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.049262  0.055872  0.041064
ndcg    0.01689  0.073510  0.097806  0.053362
recall  0.01689  0.154165  0.250704  0.091240





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 3591: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01954  0.049558  0.056960  0.040730
ndcg    0.01954  0.073217  0.100567  0.051715
recall  0.01954  0.152343  0.261302  0.085610





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 3780: 'recall@10' was not in top 1


k            1        10        20         5
map     0.0154  0.046710  0.053697  0.038113
ndcg    0.0154  0.070913  0.096729  0.049645
recall  0.0154  0.151846  0.254678  0.085113





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 3969: 'recall@10' reached 0.16013 (best 0.16013), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=20-step=3969.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.050035  0.057518  0.040683
ndcg    0.01689  0.075383  0.102806  0.052415
recall  0.01689  0.160126  0.268919  0.088425





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 4158: 'recall@10' reached 0.16576 (best 0.16576), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=21-step=4158.ckpt' as top 1


k              1        10        20         5
map     0.019043  0.052303  0.059402  0.042731
ndcg    0.019043  0.078404  0.104491  0.054935
recall  0.019043  0.165756  0.269415  0.092565





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 4347: 'recall@10' reached 0.17205 (best 0.17205), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=22-step=4347.ckpt' as top 1


k              1        10        20         5
map     0.016725  0.050426  0.057645  0.039444
ndcg    0.016725  0.078293  0.104894  0.051463
recall  0.016725  0.172048  0.277861  0.088591





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 4536: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.049394  0.057051  0.039582
ndcg    0.016725  0.074930  0.103146  0.050777
recall  0.016725  0.160623  0.272893  0.085113





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 4725: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015234  0.050192  0.057222  0.039877
ndcg    0.015234  0.077541  0.103436  0.052290
recall  0.015234  0.169068  0.272065  0.090412





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 4914: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.051104  0.058872  0.041737
ndcg    0.019705  0.076333  0.104849  0.053286
recall  0.019705  0.160954  0.274218  0.088922





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 5103: 'recall@10' was not in top 1


k              1        10        20         5
map     0.022686  0.057251  0.064570  0.047787
ndcg    0.022686  0.083397  0.110614  0.060192
recall  0.022686  0.170558  0.279351  0.098195





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 5292: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.051860  0.059395  0.042587
ndcg    0.016725  0.078895  0.106353  0.055913
recall  0.016725  0.169233  0.277861  0.097036





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 5481: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.055177  0.062723  0.045352
ndcg    0.019705  0.082132  0.109988  0.058282
recall  0.019705  0.171883  0.282828  0.098029





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 5670: 'recall@10' reached 0.17304 (best 0.17304), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=29-step=5670.ckpt' as top 1


k              1        10        20         5
map     0.017221  0.054327  0.061539  0.044516
ndcg    0.017221  0.081798  0.108197  0.057930
recall  0.017221  0.173042  0.277695  0.099023





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30, global step 5859: 'recall@10' reached 0.18033 (best 0.18033), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=30-step=5859.ckpt' as top 1


k              1        10        20         5
map     0.018712  0.057236  0.064234  0.047196
ndcg    0.018712  0.085694  0.111577  0.061125
recall  0.018712  0.180328  0.283491  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31, global step 6048: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.056004  0.063683  0.045501
ndcg    0.018877  0.084294  0.112758  0.058508
recall  0.018877  0.178838  0.292433  0.098361





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 6237: 'recall@10' reached 0.18281 (best 0.18281), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=32-step=6237.ckpt' as top 1


k              1        10        20         5
map     0.018712  0.056550  0.063924  0.046012
ndcg    0.018712  0.085613  0.112790  0.059607
recall  0.018712  0.182812  0.290942  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33, global step 6426: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019043  0.053122  0.060799  0.041651
ndcg    0.019043  0.080918  0.109027  0.052968
recall  0.019043  0.174532  0.285975  0.087763





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34, global step 6615: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019871  0.054600  0.062228  0.044428
ndcg    0.019871  0.081519  0.109343  0.056693
recall  0.019871  0.171552  0.281669  0.094386





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 6804: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.053350  0.061230  0.043567
ndcg    0.017056  0.081096  0.110064  0.057007
recall  0.017056  0.173704  0.288790  0.098361





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36, global step 6993: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.053178  0.060830  0.043125
ndcg    0.016725  0.081206  0.109333  0.056627
recall  0.016725  0.174698  0.286471  0.098195





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37, global step 7182: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.053892  0.061749  0.043757
ndcg    0.018215  0.081818  0.110890  0.057120
recall  0.018215  0.175029  0.290942  0.098361





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 7371: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.053963  0.062036  0.043832
ndcg    0.018381  0.081539  0.111188  0.056625
recall  0.018381  0.173704  0.291439  0.095877





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39, global step 7560: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.052240  0.060233  0.042283
ndcg    0.015897  0.079727  0.109054  0.055365
recall  0.015897  0.171386  0.287796  0.095546





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40, global step 7749: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020202  0.056008  0.064165  0.045844
ndcg    0.020202  0.084239  0.114400  0.059107
recall  0.020202  0.178838  0.299056  0.100017





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 7938: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.055982  0.063657  0.045609
ndcg    0.018877  0.084580  0.112599  0.059116
recall  0.018877  0.180162  0.291108  0.100679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42, global step 8127: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.055021  0.062857  0.044516
ndcg    0.017056  0.084316  0.113171  0.058592
recall  0.017056  0.182149  0.296903  0.102004





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43, global step 8316: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019374  0.056517  0.064224  0.046263
ndcg    0.019374  0.084975  0.113202  0.060004
recall  0.019374  0.179831  0.291770  0.102335





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 8505: 'recall@10' reached 0.18381 (best 0.18381), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=44-step=8505.ckpt' as top 1


k              1        10        20         5
map     0.020699  0.058128  0.065556  0.047375
ndcg    0.020699  0.087042  0.114454  0.060630
recall  0.020699  0.183805  0.292929  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45, global step 8694: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020699  0.058643  0.065861  0.048535
ndcg    0.020699  0.086859  0.113508  0.062100
recall  0.020699  0.180825  0.286968  0.103660





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46, global step 8883: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020036  0.057529  0.065423  0.047345
ndcg    0.020036  0.086086  0.115398  0.061246
recall  0.020036  0.181156  0.298228  0.103991





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 9072: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.054992  0.062715  0.044754
ndcg    0.016228  0.083818  0.112415  0.058864
recall  0.016228  0.179666  0.293757  0.102169





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48, global step 9261: 'recall@10' reached 0.18414 (best 0.18414), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=48-step=9261.ckpt' as top 1


k              1        10        20         5
map     0.021527  0.059328  0.066937  0.049020
ndcg    0.021527  0.088110  0.116131  0.062837
recall  0.021527  0.184136  0.295579  0.105315





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49, global step 9450: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.053869  0.061966  0.043387
ndcg    0.017221  0.081858  0.111644  0.056252
recall  0.017221  0.175360  0.293757  0.095711





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 9639: 'recall@10' was not in top 1


k              1        10        20         5
map     0.023017  0.058821  0.066449  0.047991
ndcg    0.023017  0.087435  0.115368  0.060816
recall  0.023017  0.183474  0.294254  0.100348





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51, global step 9828: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.053710  0.061701  0.043203
ndcg    0.015897  0.082981  0.112526  0.057285
recall  0.015897  0.180659  0.298394  0.100679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52, global step 10017: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.055277  0.062919  0.043973
ndcg    0.017553  0.084866  0.113038  0.056990
recall  0.017553  0.184136  0.296241  0.096870





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 10206: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.052917  0.060724  0.042811
ndcg    0.016062  0.080923  0.109554  0.056131
recall  0.016062  0.174367  0.287962  0.097036





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54, global step 10395: 'recall@10' reached 0.18513 (best 0.18513), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=54-step=10395.ckpt' as top 1


k              1        10        20         5
map     0.015565  0.053525  0.061330  0.042678
ndcg    0.015565  0.083778  0.112594  0.057021
recall  0.015565  0.185130  0.299884  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55, global step 10584: 'recall@10' reached 0.18546 (best 0.18546), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=55-step=10584.ckpt' as top 1


k              1        10        20         5
map     0.017056  0.055735  0.063609  0.044850
ndcg    0.017056  0.085600  0.114804  0.058863
recall  0.017056  0.185461  0.302037  0.102004





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 10773: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.057733  0.066138  0.048027
ndcg    0.018049  0.086808  0.117857  0.062923
recall  0.018049  0.183308  0.307004  0.108627





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57, global step 10962: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.053924  0.061893  0.043238
ndcg    0.015565  0.083421  0.112880  0.057390
recall  0.015565  0.181818  0.299222  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58, global step 11151: 'recall@10' reached 0.18679 (best 0.18679), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=58-step=11151.ckpt' as top 1


k              1        10        20         5
map     0.016725  0.054989  0.062809  0.043647
ndcg    0.016725  0.085319  0.114043  0.057688
recall  0.016725  0.186786  0.300878  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 11340: 'recall@10' reached 0.18695 (best 0.18695), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=59-step=11340.ckpt' as top 1


k              1        10        20         5
map     0.018381  0.056803  0.064605  0.046073
ndcg    0.018381  0.086787  0.115598  0.060458
recall  0.018381  0.186951  0.301706  0.104819





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60, global step 11529: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.053940  0.062171  0.043570
ndcg    0.015565  0.083125  0.113408  0.057472
recall  0.015565  0.180659  0.301043  0.100182





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61, global step 11718: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021361  0.058044  0.065866  0.046984
ndcg    0.021361  0.087392  0.116117  0.060230
recall  0.021361  0.185792  0.299884  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 11907: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021196  0.056440  0.064452  0.045962
ndcg    0.021196  0.083948  0.113545  0.058342
recall  0.021196  0.176023  0.293923  0.096374





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63, global step 12096: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.056605  0.064777  0.045863
ndcg    0.018381  0.085908  0.116060  0.059571
recall  0.018381  0.183805  0.303858  0.101672





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64, global step 12285: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01954  0.057959  0.065834  0.047221
ndcg    0.01954  0.087203  0.116108  0.060962
recall  0.01954  0.184799  0.299553  0.103163





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 12474: 'recall@10' reached 0.18910 (best 0.18910), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=65-step=12474.ckpt' as top 1


k             1        10        20         5
map     0.02252  0.059894  0.067805  0.048634
ndcg    0.02252  0.089587  0.118476  0.062037
recall  0.02252  0.189104  0.303527  0.103328





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66, global step 12663: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020533  0.058758  0.066800  0.048179
ndcg    0.020533  0.087998  0.117686  0.062044
recall  0.020533  0.185627  0.303858  0.104653





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67, global step 12852: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.055564  0.063544  0.044880
ndcg    0.018381  0.084470  0.113939  0.058286
recall  0.018381  0.181156  0.298559  0.099520





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 13041: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.055300  0.063368  0.044315
ndcg    0.015565  0.085740  0.115490  0.058905
recall  0.015565  0.187283  0.305680  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69, global step 13230: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.056350  0.064507  0.045620
ndcg    0.017221  0.086516  0.116595  0.060251
recall  0.017221  0.187117  0.306839  0.105315





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70, global step 13419: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020036  0.058527  0.066809  0.048104
ndcg    0.020036  0.087695  0.118265  0.062147
recall  0.020036  0.184964  0.306673  0.105315





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 13608: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.055676  0.063948  0.045043
ndcg    0.016559  0.085518  0.115629  0.059197
recall  0.016559  0.185296  0.304355  0.102666





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72, global step 13797: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.053222  0.061590  0.042438
ndcg    0.017221  0.082195  0.112787  0.055720
recall  0.017221  0.179334  0.300546  0.096705





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73, global step 13986: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.055275  0.063481  0.044516
ndcg    0.016228  0.085346  0.115596  0.059018
recall  0.016228  0.185627  0.306011  0.103660





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 14175: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.055404  0.063874  0.044718
ndcg    0.017387  0.084490  0.115429  0.058333
recall  0.017387  0.181653  0.304189  0.100182





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75, global step 14364: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.054881  0.063186  0.044384
ndcg    0.017553  0.083540  0.114051  0.057982
recall  0.017553  0.179169  0.300381  0.099851





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76, global step 14553: 'recall@10' reached 0.19142 (best 0.19142), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=76-step=14553.ckpt' as top 1


k              1        10        20         5
map     0.017056  0.058309  0.066365  0.047610
ndcg    0.017056  0.089063  0.118659  0.062800
recall  0.017056  0.191422  0.308992  0.109455





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 14742: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.058176  0.066286  0.047707
ndcg    0.018546  0.087879  0.117542  0.062003
recall  0.018546  0.186951  0.304521  0.105812





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78, global step 14931: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.056675  0.064859  0.045573
ndcg    0.016559  0.086941  0.116920  0.059677
recall  0.016559  0.187945  0.306839  0.102832





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79, global step 15120: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019374  0.058110  0.066168  0.047000
ndcg    0.019374  0.087994  0.117428  0.060973
recall  0.019374  0.187779  0.304355  0.103991





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 15309: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.057930  0.066746  0.047293
ndcg    0.018215  0.087057  0.119280  0.061293
recall  0.018215  0.183805  0.311475  0.104156





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81, global step 15498: 'recall@10' reached 0.19424 (best 0.19424), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=81-step=15498.ckpt' as top 1


k              1        10        20         5
map     0.018877  0.059787  0.067629  0.048289
ndcg    0.018877  0.090781  0.119724  0.062647
recall  0.018877  0.194237  0.309488  0.106640





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82, global step 15687: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021196  0.059943  0.068264  0.049316
ndcg    0.021196  0.089258  0.119739  0.063349
recall  0.021196  0.186951  0.307832  0.106475





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 15876: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.056730  0.065060  0.045973
ndcg    0.017553  0.086469  0.116985  0.060022
recall  0.017553  0.185792  0.306839  0.103163





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84, global step 16065: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.057324  0.065272  0.046619
ndcg    0.018049  0.087470  0.116848  0.061166
recall  0.018049  0.188111  0.305183  0.105978





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85, global step 16254: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.056925  0.065265  0.046040
ndcg    0.018049  0.086084  0.116780  0.059502
recall  0.018049  0.183308  0.305349  0.100679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 16443: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.056484  0.064352  0.045529
ndcg    0.017221  0.086933  0.115803  0.060068
recall  0.017221  0.188607  0.303196  0.104819





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87, global step 16632: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.055217  0.062926  0.043630
ndcg    0.015897  0.086042  0.114300  0.057662
recall  0.015897  0.189270  0.301374  0.100845





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88, global step 16821: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.056031  0.064152  0.045239
ndcg    0.015897  0.086680  0.116439  0.060292
recall  0.015897  0.188773  0.306839  0.106640





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 17010: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015565  0.055549  0.063511  0.045121
ndcg    0.015565  0.085551  0.115085  0.060090
recall  0.015565  0.185296  0.303196  0.106143





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90, global step 17199: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.055357  0.062889  0.044240
ndcg    0.015069  0.086089  0.113872  0.058869
recall  0.015069  0.188607  0.299222  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91, global step 17388: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014572  0.055767  0.063740  0.044398
ndcg    0.014572  0.087027  0.116359  0.059272
recall  0.014572  0.191257  0.307832  0.104984





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 17577: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.055183  0.063564  0.044133
ndcg    0.016393  0.085214  0.115960  0.058282
recall  0.016393  0.185461  0.307501  0.101838





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93, global step 17766: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.055829  0.064239  0.045104
ndcg    0.015069  0.086437  0.117191  0.060273
recall  0.015069  0.188276  0.310151  0.106971





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94, global step 17955: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.055857  0.063979  0.044569
ndcg    0.015731  0.087068  0.116884  0.059363
recall  0.015731  0.191422  0.309820  0.104984





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 18144: 'recall@10' reached 0.19540 (best 0.19540), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=95-step=18144-v1.ckpt' as top 1


k              1        10        20         5
map     0.014572  0.057496  0.065425  0.046534
ndcg    0.014572  0.089360  0.118578  0.062330
recall  0.014572  0.195397  0.311641  0.110780





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96, global step 18333: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019043  0.059901  0.067771  0.048811
ndcg    0.019043  0.091040  0.119906  0.063627
recall  0.019043  0.195065  0.309654  0.109124





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97, global step 18522: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021196  0.060362  0.068267  0.049520
ndcg    0.021196  0.090188  0.119075  0.063780
recall  0.021196  0.189601  0.304024  0.107634





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 18711: 'recall@10' reached 0.19589 (best 0.19589), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=98-step=18711.ckpt' as top 1


k              1        10        20         5
map     0.018546  0.059986  0.067449  0.048454
ndcg    0.018546  0.091327  0.118587  0.063096
recall  0.018546  0.195893  0.303858  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99, global step 18900: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.016559  0.056867  0.065123  0.046219
ndcg    0.016559  0.086804  0.117006  0.060686
recall  0.016559  0.186455  0.306177  0.104984



Now we can get the best model path stored in the checkpoint callback.

In [22]:
best_model_path = checkpoint_callback.best_model_path
best_model_path

'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=98-step=18711.ckpt'

## Inference

To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance.

In [23]:
import replay

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    excluded_features=None,
)

torch.serialization.add_safe_globals([
    replay.nn.lightning.optimizer.OptimizerFactory,
    replay.nn.lightning.scheduler.LRSchedulerFactory,
])

best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)
best_model.eval();

Configure `ParquetModule` for inference

In [24]:
item_id_len = pl.from_arrow(test_events)["item_id"].list.len().max()

inference_metadata = {"predict": {
        "user_id": {},
        "item_id": {
            "shape": item_id_len,
            "padding": tensor_schema["item_id"].padding_value
        }
    }
}

parquet_module = ParquetModule(
    predict_path=PREDICT_PATH,
    batch_size=BATCH_SIZE,
    metadata=inference_metadata,
    transforms=transforms,
)

inference_metadata

  parquet_module = ParquetModule(


{'predict': {'user_id': {}, 'item_id': {'shape': 2313, 'padding': 3706}}}

During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.


Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`.

In [25]:
from replay.nn.lightning.callback import PandasTopItemsCallback
from replay.nn.lightning.postprocessor import SeenItemsFilter

post_processors = [
    SeenItemsFilter(
        item_count = NUM_UNIQUE_ITEMS,
        seen_items_column = "seen_ids",
    )
]

csv_logger = CSVLogger(save_dir="sasrec/logs/test", name="SasRec-example")

TOPK = [1, 5, 10, 20]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=post_processors
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Predicting: |          | 0/? [00:00<?, ?it/s]



In [26]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,354,10.922463
0,0,1551,10.199773
0,0,1900,9.862277
0,0,1897,9.715552
0,0,1908,9.511134
...,...,...,...
6037,6039,2013,5.697724
6037,6039,1545,5.67787
6037,6039,1384,5.658081
6037,6039,15,5.636842


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [27]:
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.metrics.torch_metrics_builder import metrics_to_df

In [28]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score",
)(pandas_res, test_gt.explode("item_id"))

In [29]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.060782,0.111647,0.119129,0.100179
Precision,0.060782,0.026449,0.018707,0.035343
Recall,0.060782,0.264492,0.374131,0.176714


Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations

In [30]:
encoder.inverse_transform(pandas_res)

Unnamed: 0,user_id,item_id,score
0,1,364,10.922463
0,1,1688,10.199773
0,1,2081,9.862277
0,1,2078,9.715552
0,1,2089,9.511134
...,...,...,...
6037,6040,2194,5.697724
6037,6040,1682,5.67787
6037,6040,1500,5.658081
6037,6040,16,5.636842
