# Example of SasRec training/inference

In [1]:
from typing import Optional

import lightning as L
import pandas as pd

L.seed_everything(42)

Seed set to 42


42

## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.

---
**NOTE**

Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. 

---

In [2]:
interactions = pd.read_csv("./data/ml1m_ratings.dat", sep="\t", names=["user_id", "item_id","rating","timestamp"])
interactions = interactions.drop(columns=["rating"])

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,timestamp
1000138,6040,858,0
1000153,6040,2384,1
999873,6040,593,2
1000007,6040,1961,3
1000192,6040,2019,4
...,...,...,...
825793,4958,2399,446
825438,4958,1407,447
825724,4958,3264,448
825731,4958,2634,449


### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used.

In [4]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule

encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id", default_value="last"),
        LabelEncodingRule("item_id", default_value="last"),
    ]
)
interactions = interactions.sort_values(by="item_id", ascending=True)
encoded_interactions = encoder.fit_transform(interactions)
encoded_interactions

Unnamed: 0,timestamp,user_id,item_id
0,12,2011,0
1,68,4078,0
2,67,4123,0
3,12,983,0
4,140,2270,0
...,...,...,...
1000204,14,855,3705
1000205,90,1700,3705
1000206,70,936,3705
1000207,25,360,3705


### Split interactions into the train, validation and test datasets using LastNSplitter
We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity.

In [5]:
from replay.splitters import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
    drop_cold_users=True,
    drop_cold_items=True
)

test_events, test_gt = splitter.split(encoded_interactions)
validation_events, validation_gt = splitter.split(test_events)
train_events = validation_events

### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [6]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")
    return grouped_interactions

In [7]:
train_events = bake_data(train_events)

validation_events = bake_data(validation_events)
validation_gt = bake_data(validation_gt)

test_events = bake_data(test_events)
test_gt = bake_data(test_gt)

train_events

Unnamed: 0,user_id,timestamp,item_id
0,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2969, 1574, 1178, 957, 2147, 1658, 3177, 1117..."
1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1108, 1127, 1120, 2512, 1201, 2735, 1135, 110..."
2,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[579, 2651, 3301, 1788, 1781, 1327, 1174, 3429..."
3,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1120, 1025, 466, 3235, 3294, 1106, 253, 1108,..."
4,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2512, 858, 847, 346, 1158, 2007, 2651, 1050, ..."
...,...,...,...
6035,6035,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1574, 1703, 3206, 2183, 2235, 2480, 2375, 250..."
6036,6036,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1702, 672, 1175, 1848, 3275, 2932, 548, 802, ..."
6037,6037,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3165, 859, 1120, 1965, 1288, 346, 1007, 1066,..."
6038,6038,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[107, 275, 1886, 1139, 869, 886, 2872, 2809, 2..."


To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones.  

In [8]:
def add_gt_to_events(events_df, gt_df):
    gt_to_join = gt_df[["user_id", "item_id"]].rename(columns={"item_id": "ground_truth"})

    events_df = events_df.merge(gt_to_join, on="user_id", how="inner")
    return events_df

validation_events = add_gt_to_events(validation_events, validation_gt)
test_events = add_gt_to_events(test_events, test_gt)

In [9]:
from pathlib import Path

data_dir = Path("temp/data/")
data_dir.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = data_dir / "train.parquet"
VAL_PATH = data_dir / "val.parquet"
PREDICT_PATH = data_dir / "test.parquet"

ENCODER_PATH = data_dir / "encoder"

In [10]:
train_events.to_parquet(TRAIN_PATH)
validation_events.to_parquet(VAL_PATH)
test_events.to_parquet(PREDICT_PATH)

encoder.save(ENCODER_PATH)



# Prepare to model training
### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the NN models to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.

Note that **cardinality** is the number of unique values ​in the item catalog (vocabulary). **Padding value** is the next value after the last one.

In [11]:
from replay.data import FeatureHint, FeatureType
from replay.data.nn import TensorFeatureInfo, TensorSchema


EMBEDDING_DIM = 64

encoder = encoder.load(ENCODER_PATH)
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS,
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        )
    ]
)

In [12]:
import torch
import polars as pl
freqs = pl.from_pandas(train_events)["item_id"].explode() \
    .value_counts(normalize = True).sort("item_id") \
    .with_columns(
        (pl.col("proportion").log() - pl.col("proportion").log().min()).clip(2, 10).alias("pmi")
    )
item_ids = freqs["item_id"].to_torch()
item_pmi = freqs["pmi"].to_torch()
item_probas = 2 * torch.ones(NUM_UNIQUE_ITEMS)
item_probas[item_ids] = item_pmi.float()
freqs, item_probas

(shape: (3_703, 3)
 ┌─────────┬────────────┬──────────┐
 │ item_id ┆ proportion ┆ pmi      │
 │ ---     ┆ ---        ┆ ---      │
 │ i64     ┆ f64        ┆ f64      │
 ╞═════════╪════════════╪══════════╡
 │ 0       ┆ 0.002082   ┆ 7.629004 │
 │ 1       ┆ 0.000699   ┆ 6.53814  │
 │ 2       ┆ 0.000466   ┆ 6.131226 │
 │ 3       ┆ 0.000172   ┆ 5.135798 │
 │ 4       ┆ 0.000292   ┆ 5.666427 │
 │ …       ┆ …          ┆ …        │
 │ 3701    ┆ 0.000818   ┆ 6.694562 │
 │ 3702    ┆ 0.000286   ┆ 5.645447 │
 │ 3703    ┆ 0.000053   ┆ 3.951244 │
 │ 3704    ┆ 0.000039   ┆ 3.663562 │
 │ 3705    ┆ 0.000373   ┆ 5.910797 │
 └─────────┴────────────┴──────────┘,
 tensor([7.6290, 6.5381, 6.1312,  ..., 3.9512, 3.6636, 5.9108]))

### Configure ParquetModule and transformation pipelines

The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s  "transform pipelines" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.

Internally this function creates the following transforms:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Rename features to match it with expected format by the model during training.
    3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.
    4. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.

**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1.

In [13]:
import copy

import torch

from replay.data.nn import TensorSchema
from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, UniformNegativeSamplingTransform

In [14]:
def make_sasrec_transforms(
    tensor_schema: TensorSchema, query_column: str = "query_id", num_negative_samples: int = 128,
) -> dict[str, list[torch.nn.Module]]:
    item_column = tensor_schema.item_id_feature_name
    vocab_size = tensor_schema[item_column].cardinality
    train_transforms = [
        UniformNegativeSamplingTransform(vocab_size, num_negative_samples, sample_distribution=item_probas),
        NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
        RenameTransform(
            {
                query_column: "query_id",
                f"{item_column}_mask": "padding_mask",
                "positive_labels_mask": "target_padding_mask",
            }
        ),
        UnsqueezeTransform("target_padding_mask", -1),
        UnsqueezeTransform("positive_labels", -1),
        GroupTransform({"feature_tensors": [item_column]}),
    ]

    val_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        GroupTransform({"feature_tensors": [item_column]}),
    ]
    test_transforms = copy.deepcopy(val_transforms)

    predict_transforms = copy.deepcopy(val_transforms)

    transforms = {
        "train": train_transforms,
        "validate": val_transforms,
        "test": test_transforms,
        "predict": predict_transforms,
    }

    return transforms

In [15]:
transforms = make_sasrec_transforms(tensor_schema, query_column="user_id")

In [16]:
MAX_SEQ_LEN = 50

def create_meta(shape: int, gt_shape: Optional[int] = None):
    meta = {
        "user_id": {},
        "item_id": {"shape": shape, "padding": tensor_schema["item_id"].padding_value},
    }
    if gt_shape is not None:
        meta.update({"ground_truth": {"shape": gt_shape, "padding": -1}})

    return meta

train_metadata = {
    "train": create_meta(shape=MAX_SEQ_LEN+1),
    "validate": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),
}

In [17]:
from replay.data.nn import ParquetModule

BATCH_SIZE = 32

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=train_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.

#### Default Configuration

Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum.

In [18]:
from replay.nn.sequential import SasRec
from typing import Literal
def make_sasrec(
    schema: TensorSchema,
    embedding_dim: int = 192,
    num_heads: int = 4,
    num_blocks: int = 2,
    max_sequence_length: int = 50,
    dropout: float = 0.3,
    excluded_features: Optional[list[str]] = None,
    categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
) -> SasRec:
    from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer
    from replay.nn.agg import SumAggregator
    from replay.nn.embedding import SequenceEmbedding
    from replay.nn.loss import CE, CESampled
    from replay.nn.mask import DefaultAttentionMask
    from replay.nn.sequential.sasrec.agg import PositionAwareAggregator
    from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer
    excluded_features = [
        schema.query_id_feature_name,
        schema.timestamp_feature_name,
        *(excluded_features or []),
    ]
    excluded_features = list(set(excluded_features))
    body = SasRecBody(
        embedder=SequenceEmbedding(
            schema=schema,
            categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
            excluded_features=excluded_features,
        ),
        embedding_aggregator=PositionAwareAggregator(
            embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
            max_sequence_length=max_sequence_length,
            dropout=dropout,
        ),
        attn_mask_builder=DefaultAttentionMask(
            reference_feature_name=schema.item_id_feature_name,
            num_heads=num_heads,
        ),
        encoder=SasRecTransformerLayer(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_blocks=num_blocks,
            dropout=dropout,
            activation="relu",
        ),
        output_normalization=torch.nn.LayerNorm(embedding_dim),
    )
    padding_idx = schema.item_id_features.item().padding_value
    return SasRec(
        body=body,
        loss=CESampled(negative_labels_ignore_index=padding_idx),
    )

In [19]:
from replay.nn.sequential import SasRec

NUM_BLOCKS = 2
NUM_HEADS = 2
DROPOUT = 0.3

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
)

A universal PyTorch Lightning module is provided. It can work with any NN model.

In [20]:
from replay.nn.lightning.optimizer import OptimizerFactory
from replay.nn.lightning.scheduler import LRSchedulerFactory
from replay.nn.lightning import LightningModule

model = LightningModule(
    sasrec,
    optimizer_factory=OptimizerFactory(),
    lr_scheduler_factory=LRSchedulerFactory(),
)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.


In [21]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from replay.nn.lightning.callback import ComputeMetricsCallback


checkpoint_callback = ModelCheckpoint(
    dirpath="sasrec/checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
)

csv_logger = CSVLogger(save_dir="sasrec/logs/train", name="SasRec-example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | model | SasRec | 291 K  | train | 0    
-------------------------------------------------
291 K     Trainable params
0         Non-trainable params
291 K     Total params
1.164     Total estimated model params size (MB)
39        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 24. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.01755 (best 0.01755), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1


k             1        10        20         5
map     0.00149  0.004354  0.005648  0.002923
ndcg    0.00149  0.007317  0.012097  0.003811
recall  0.00149  0.017553  0.036595  0.006624





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.03312 (best 0.03312), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1


k              1        10        20         5
map     0.004471  0.011047  0.013398  0.009309
ndcg    0.004471  0.016111  0.024907  0.011701
recall  0.004471  0.033118  0.068389  0.019043





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.06938 (best 0.06938), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.008445  0.022671  0.026696  0.018739
ndcg    0.008445  0.033424  0.048330  0.023739
recall  0.008445  0.069382  0.128829  0.039079





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.08958 (best 0.08958), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1


k              1        10        20         5
map     0.012419  0.030050  0.034391  0.024905
ndcg    0.012419  0.043752  0.059822  0.031154
recall  0.012419  0.089584  0.153668  0.050339





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.10581 (best 0.10581), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1


k              1        10        20         5
map     0.011923  0.032542  0.037953  0.026100
ndcg    0.011923  0.049330  0.068856  0.033423
recall  0.011923  0.105812  0.182646  0.055970





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 1134: 'recall@10' reached 0.10846 (best 0.10846), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=5-step=1134.ckpt' as top 1


k              1        10        20         5
map     0.011923  0.033958  0.039678  0.027869
ndcg    0.011923  0.051101  0.072251  0.036049
recall  0.011923  0.108462  0.192747  0.061268





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 1323: 'recall@10' reached 0.12336 (best 0.12336), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=6-step=1323.ckpt' as top 1


k              1        10        20         5
map     0.009935  0.035941  0.041918  0.028683
ndcg    0.009935  0.056053  0.078129  0.038102
recall  0.009935  0.123365  0.211293  0.067064





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 1512: 'recall@10' reached 0.12502 (best 0.12502), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=7-step=1512.ckpt' as top 1


k             1        10        20        5
map     0.01391  0.038724  0.044828  0.03120
ndcg    0.01391  0.058527  0.081092  0.04003
recall  0.01391  0.125021  0.214936  0.06723





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1701: 'recall@10' reached 0.13727 (best 0.13727), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=8-step=1701.ckpt' as top 1


k              1        10        20         5
map     0.013744  0.042677  0.049120  0.034484
ndcg    0.013744  0.064467  0.088197  0.044322
recall  0.013744  0.137274  0.231661  0.074350





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1890: 'recall@10' reached 0.14208 (best 0.14208), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=9-step=1890.ckpt' as top 1


k              1        10        20         5
map     0.015897  0.044402  0.050721  0.035997
ndcg    0.015897  0.066855  0.090127  0.046183
recall  0.015897  0.142077  0.234641  0.077496





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 2079: 'recall@10' reached 0.14655 (best 0.14655), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=10-step=2079.ckpt' as top 1


k              1        10        20         5
map     0.017387  0.047735  0.054142  0.039355
ndcg    0.017387  0.070563  0.094301  0.050199
recall  0.017387  0.146547  0.241265  0.083458





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 2268: 'recall@10' reached 0.14953 (best 0.14953), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=11-step=2268.ckpt' as top 1


k              1        10        20         5
map     0.014241  0.046795  0.053536  0.038458
ndcg    0.014241  0.070554  0.095317  0.050077
recall  0.014241  0.149528  0.247889  0.085610





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 2457: 'recall@10' reached 0.15069 (best 0.15069), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=12-step=2457.ckpt' as top 1


k              1        10        20         5
map     0.015897  0.047115  0.053834  0.038290
ndcg    0.015897  0.071000  0.095844  0.049467
recall  0.015897  0.150687  0.249710  0.083789





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 2646: 'recall@10' reached 0.15996 (best 0.15996), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=13-step=2646.ckpt' as top 1


k              1        10        20         5
map     0.014406  0.047850  0.054804  0.038086
ndcg    0.014406  0.073647  0.099305  0.049731
recall  0.014406  0.159960  0.262129  0.085445





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 2835: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.048839  0.055924  0.039427
ndcg    0.017387  0.073632  0.099649  0.050667
recall  0.017387  0.156648  0.259977  0.085279





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 3024: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.049977  0.057335  0.041544
ndcg    0.017056  0.074628  0.101700  0.053933
recall  0.017056  0.156648  0.264282  0.092068





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 3213: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.046148  0.053688  0.037034
ndcg    0.014738  0.070409  0.098089  0.048150
recall  0.014738  0.151515  0.261467  0.082298





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 3402: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.051081  0.058615  0.041858
ndcg    0.018049  0.076173  0.103824  0.053743
recall  0.018049  0.159795  0.269581  0.090247





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 3591: 'recall@10' reached 0.16427 (best 0.16427), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=18-step=3591.ckpt' as top 1


k              1        10        20         5
map     0.016559  0.050157  0.057845  0.039913
ndcg    0.016559  0.076399  0.104796  0.051498
recall  0.016559  0.164266  0.277364  0.087101





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 3780: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.047852  0.055595  0.038486
ndcg    0.015069  0.073606  0.101739  0.050562
recall  0.015069  0.159795  0.270906  0.087763





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 3969: 'recall@10' reached 0.16990 (best 0.16990), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=20-step=3969.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.052719  0.060211  0.042562
ndcg    0.01689  0.079738  0.107336  0.055022
recall  0.01689  0.169896  0.279682  0.093227





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 4158: 'recall@10' reached 0.17470 (best 0.17470), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=21-step=4158.ckpt' as top 1


k              1        10        20         5
map     0.014572  0.052667  0.059904  0.042766
ndcg    0.014572  0.080907  0.107559  0.056764
recall  0.014572  0.174698  0.280676  0.099685





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 4347: 'recall@10' reached 0.17851 (best 0.17851), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=22-step=4347.ckpt' as top 1


k              1        10        20         5
map     0.016725  0.053505  0.060784  0.042904
ndcg    0.016725  0.082262  0.108932  0.056255
recall  0.016725  0.178506  0.284319  0.097367





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 4536: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.054584  0.062143  0.043895
ndcg    0.018712  0.082750  0.110570  0.056651
recall  0.018712  0.177016  0.287630  0.095877





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 4725: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.051355  0.058945  0.040123
ndcg    0.015069  0.080240  0.108070  0.052683
recall  0.015069  0.177182  0.287630  0.091240





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 4914: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.054724  0.062662  0.044941
ndcg    0.018049  0.082718  0.111839  0.058530
recall  0.018049  0.176188  0.291770  0.100348





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 5103: 'recall@10' reached 0.17867 (best 0.17867), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=26-step=5103.ckpt' as top 1


k              1        10        20         5
map     0.017718  0.056037  0.063666  0.045982
ndcg    0.017718  0.084315  0.112529  0.059419
recall  0.017718  0.178672  0.291108  0.100513





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 5292: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.054132  0.062130  0.043567
ndcg    0.018381  0.082126  0.111569  0.056214
recall  0.018381  0.175857  0.292929  0.095049





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 5481: 'recall@10' reached 0.18298 (best 0.18298), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=28-step=5481.ckpt' as top 1


k              1        10        20         5
map     0.019871  0.057822  0.065481  0.047328
ndcg    0.019871  0.086660  0.114879  0.060851
recall  0.019871  0.182977  0.295248  0.102335





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 5670: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.056366  0.063684  0.045788
ndcg    0.018546  0.085226  0.112250  0.059427
recall  0.018546  0.181487  0.289121  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30, global step 5859: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.055530  0.063362  0.045043
ndcg    0.018877  0.083873  0.112696  0.058292
recall  0.018877  0.178506  0.293095  0.099023





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31, global step 6048: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020202  0.057919  0.065776  0.047867
ndcg    0.020202  0.085809  0.114771  0.061385
recall  0.020202  0.178506  0.293757  0.102832





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 6237: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.055404  0.063721  0.045631
ndcg    0.018712  0.083331  0.113803  0.059593
recall  0.018712  0.176188  0.297069  0.102666





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33, global step 6426: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.055181  0.063015  0.044491
ndcg    0.018712  0.084073  0.113004  0.057881
recall  0.018712  0.180825  0.296076  0.099189





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34, global step 6615: 'recall@10' reached 0.18546 (best 0.18546), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=34-step=6615.ckpt' as top 1


k              1        10        20         5
map     0.018877  0.056103  0.063740  0.044745
ndcg    0.018877  0.085771  0.113855  0.057781
recall  0.018877  0.185461  0.297069  0.097864





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 6804: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019208  0.056618  0.064280  0.045728
ndcg    0.019208  0.085789  0.113846  0.059080
recall  0.019208  0.183474  0.294751  0.100182





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36, global step 6993: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.054577  0.062269  0.043763
ndcg    0.016393  0.084361  0.112624  0.057794
recall  0.016393  0.183971  0.296241  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37, global step 7182: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.054623  0.062542  0.044008
ndcg    0.016393  0.084262  0.113552  0.058081
recall  0.016393  0.183308  0.300050  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 7371: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.053871  0.061899  0.042778
ndcg    0.017056  0.083060  0.112349  0.056118
recall  0.017056  0.180659  0.296572  0.097202





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39, global step 7560: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.053803  0.061568  0.043247
ndcg    0.016228  0.082360  0.111126  0.056549
recall  0.016228  0.177678  0.292433  0.097367





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40, global step 7749: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.056208  0.064379  0.045742
ndcg    0.018877  0.085443  0.115541  0.059726
recall  0.018877  0.183143  0.302865  0.102832





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 7938: 'recall@10' reached 0.19109 (best 0.19109), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=41-step=7938.ckpt' as top 1


k              1        10        20         5
map     0.017056  0.055815  0.063144  0.044605
ndcg    0.017056  0.086912  0.114060  0.059382
recall  0.017056  0.191091  0.299387  0.105150





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42, global step 8127: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.056343  0.064084  0.045151
ndcg    0.017056  0.086308  0.114824  0.058828
recall  0.017056  0.186455  0.299884  0.100679





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43, global step 8316: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02103  0.059823  0.067389  0.049042
ndcg    0.02103  0.089856  0.117635  0.063325
recall  0.02103  0.190263  0.300546  0.107303





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 8505: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.057635  0.065725  0.046928
ndcg    0.018546  0.086866  0.116505  0.060747
recall  0.018546  0.184302  0.301871  0.103163





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45, global step 8694: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.056776  0.064943  0.046586
ndcg    0.018215  0.085914  0.116093  0.060818
recall  0.018215  0.182977  0.303196  0.104488





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46, global step 8883: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.056014  0.063910  0.045220
ndcg    0.017387  0.085480  0.114599  0.059027
recall  0.017387  0.183805  0.299718  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 9072: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.055991  0.063732  0.044856
ndcg    0.018215  0.085672  0.113888  0.058311
recall  0.018215  0.185130  0.296738  0.099685





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48, global step 9261: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.058238  0.066242  0.047232
ndcg    0.018712  0.088335  0.117842  0.061326
recall  0.018712  0.188773  0.306177  0.104488





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49, global step 9450: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.054521  0.062767  0.043633
ndcg    0.017056  0.084018  0.114264  0.057402
recall  0.017056  0.182646  0.302699  0.099851





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 9639: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.055448  0.063307  0.044972
ndcg    0.017221  0.084695  0.113494  0.058967
recall  0.017221  0.182315  0.296572  0.102004





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51, global step 9828: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014903  0.053399  0.061760  0.042446
ndcg    0.014903  0.083264  0.113933  0.056776
recall  0.014903  0.182812  0.304521  0.101010





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52, global step 10017: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.054588  0.062477  0.043553
ndcg    0.016228  0.085231  0.114261  0.057999
recall  0.016228  0.187945  0.303361  0.102666





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 10206: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.053834  0.061977  0.042634
ndcg    0.017221  0.082927  0.112771  0.055686
recall  0.017221  0.180328  0.298725  0.095877





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54, global step 10395: 'recall@10' was not in top 1


k              1        10        20         5
map     0.013744  0.055026  0.063155  0.043915
ndcg    0.013744  0.085961  0.115869  0.058841
recall  0.013744  0.188939  0.307832  0.104653





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55, global step 10584: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019374  0.058221  0.066444  0.047516
ndcg    0.019374  0.088572  0.119039  0.062112
recall  0.019374  0.190098  0.311641  0.107137





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 10773: 'recall@10' reached 0.19142 (best 0.19142), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=56-step=10773.ckpt' as top 1


k              1        10        20         5
map     0.019043  0.059331  0.067483  0.048648
ndcg    0.019043  0.089825  0.119791  0.063573
recall  0.019043  0.191422  0.310482  0.109455





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57, global step 10962: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.057202  0.065344  0.046608
ndcg    0.017553  0.087107  0.117173  0.061106
recall  0.017553  0.186786  0.306508  0.105647





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58, global step 11151: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.057300  0.065355  0.046710
ndcg    0.017553  0.087539  0.117083  0.061572
recall  0.017553  0.188276  0.305514  0.107303





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 11340: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.057412  0.065888  0.047144
ndcg    0.018381  0.086715  0.117825  0.061478
recall  0.018381  0.184302  0.307832  0.105481





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60, global step 11529: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.055367  0.063749  0.044585
ndcg    0.015897  0.084983  0.115704  0.058542
recall  0.015897  0.183805  0.305680  0.101341





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61, global step 11718: 'recall@10' reached 0.19225 (best 0.19225), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=61-step=11718.ckpt' as top 1


k              1        10        20         5
map     0.020368  0.059716  0.067775  0.048065
ndcg    0.020368  0.090177  0.119751  0.061545
recall  0.020368  0.192250  0.309654  0.102832





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 11907: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020036  0.058592  0.066610  0.047853
ndcg    0.020036  0.087794  0.117068  0.061487
recall  0.020036  0.185296  0.301209  0.103328





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63, global step 12096: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.057304  0.065813  0.046556
ndcg    0.016559  0.087819  0.118995  0.061253
recall  0.016559  0.189601  0.313297  0.106309





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64, global step 12285: 'recall@10' reached 0.19258 (best 0.19258), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=64-step=12285.ckpt' as top 1


k              1        10        20         5
map     0.016062  0.056793  0.064921  0.045159
ndcg    0.016062  0.088104  0.117958  0.059803
recall  0.016062  0.192582  0.311144  0.104819





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 12474: 'recall@10' reached 0.19639 (best 0.19639), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=65-step=12474.ckpt' as top 1


k              1        10        20         5
map     0.020202  0.061173  0.069347  0.049815
ndcg    0.020202  0.092358  0.122488  0.064585
recall  0.020202  0.196390  0.316278  0.109952





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66, global step 12663: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.056904  0.065218  0.045769
ndcg    0.015897  0.087837  0.118356  0.060758
recall  0.015897  0.190760  0.311972  0.106806





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67, global step 12852: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.057493  0.065642  0.046672
ndcg    0.018712  0.087369  0.117440  0.060856
recall  0.018712  0.187117  0.306839  0.104488





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 13041: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.058986  0.067023  0.047759
ndcg    0.018546  0.089659  0.119183  0.062237
recall  0.018546  0.191919  0.309157  0.106640





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69, global step 13230: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.055224  0.063316  0.044378
ndcg    0.016062  0.085434  0.114937  0.058802
recall  0.016062  0.186289  0.303030  0.103163





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70, global step 13419: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.057426  0.065350  0.046117
ndcg    0.016725  0.088901  0.118084  0.061085
recall  0.016725  0.194072  0.310151  0.107137





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 13608: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.055737  0.064165  0.044569
ndcg    0.01689  0.086396  0.117312  0.059063
recall  0.01689  0.188939  0.311641  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72, global step 13797: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.053764  0.061875  0.042148
ndcg    0.015731  0.084228  0.113811  0.055760
recall  0.015731  0.186455  0.303527  0.097698





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73, global step 13986: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.057706  0.065825  0.046801
ndcg    0.018215  0.088086  0.117842  0.061425
recall  0.018215  0.189435  0.307501  0.106475





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 14175: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01391  0.053191  0.061477  0.041974
ndcg    0.01391  0.083436  0.113733  0.056058
recall  0.01391  0.184468  0.304521  0.099354





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75, global step 14364: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020036  0.058526  0.066724  0.047433
ndcg    0.020036  0.088593  0.118394  0.061460
recall  0.020036  0.189104  0.306839  0.104653





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76, global step 14553: 'recall@10' reached 0.19838 (best 0.19838), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=76-step=14553.ckpt' as top 1


k              1        10        20         5
map     0.016393  0.058749  0.066663  0.046964
ndcg    0.016393  0.090960  0.120020  0.062131
recall  0.016393  0.198377  0.313794  0.108627





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 14742: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.058553  0.066504  0.046777
ndcg    0.01689  0.090448  0.119801  0.061786
recall  0.01689  0.196887  0.313794  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78, global step 14931: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01954  0.060754  0.068455  0.049666
ndcg    0.01954  0.092242  0.120873  0.065020
recall  0.01954  0.197218  0.311641  0.112270





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79, global step 15120: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016062  0.057839  0.066157  0.046694
ndcg    0.016062  0.088868  0.119397  0.061774
recall  0.016062  0.191919  0.313131  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 15309: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.058518  0.066515  0.047262
ndcg    0.018877  0.089015  0.118381  0.061442
recall  0.018877  0.190926  0.307501  0.104984





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81, global step 15498: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.057832  0.065657  0.046567
ndcg    0.016393  0.089653  0.118439  0.062130
recall  0.016393  0.195728  0.310151  0.110118





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82, global step 15687: 'recall@10' reached 0.20152 (best 0.20152), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=82-step=15687.ckpt' as top 1


k             1        10        20         5
map     0.01689  0.058844  0.066714  0.046280
ndcg    0.01689  0.091601  0.120499  0.060642
recall  0.01689  0.201523  0.316278  0.104653





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 15876: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.055820  0.064157  0.044226
ndcg    0.016228  0.086760  0.117449  0.058437
recall  0.016228  0.190263  0.312303  0.102169





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84, global step 16065: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.056773  0.065146  0.046103
ndcg    0.014738  0.087763  0.118476  0.061809
recall  0.014738  0.190594  0.312469  0.110118





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85, global step 16254: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.056851  0.065172  0.046426
ndcg    0.017056  0.087304  0.117972  0.061723
recall  0.017056  0.188773  0.310813  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 16443: 'recall@10' was not in top 1


k            1        10        20         5
map     0.0154  0.056174  0.064064  0.044582
ndcg    0.0154  0.087490  0.116468  0.059140
recall  0.0154  0.192085  0.307170  0.103825





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87, global step 16632: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.057663  0.065807  0.047006
ndcg    0.016228  0.088800  0.118811  0.062827
recall  0.016228  0.192250  0.311641  0.111608





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88, global step 16821: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.057026  0.065339  0.045430
ndcg    0.015731  0.089047  0.119903  0.060786
recall  0.015731  0.195893  0.319093  0.108130





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 17010: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014738  0.055222  0.063457  0.044447
ndcg    0.014738  0.085939  0.116445  0.059596
recall  0.014738  0.188276  0.309985  0.106309





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90, global step 17199: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.055745  0.063889  0.045314
ndcg    0.016228  0.085767  0.115914  0.060178
recall  0.016228  0.185792  0.306011  0.105978





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91, global step 17388: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01275  0.055420  0.063688  0.043995
ndcg    0.01275  0.087574  0.118042  0.059700
recall  0.01275  0.194569  0.315781  0.107965





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 17577: 'recall@10' was not in top 1


k              1        10        20         5
map     0.013744  0.056909  0.064797  0.045452
ndcg    0.013744  0.089048  0.117967  0.060958
recall  0.013744  0.196059  0.310813  0.108462





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93, global step 17766: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.059338  0.067385  0.047792
ndcg    0.018049  0.091522  0.121026  0.063185
recall  0.018049  0.199040  0.316112  0.110614





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94, global step 17955: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014241  0.056536  0.064899  0.045101
ndcg    0.014241  0.088517  0.119234  0.060670
recall  0.014241  0.195065  0.317105  0.108627





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 18144: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01275  0.056052  0.064882  0.044988
ndcg    0.01275  0.087881  0.120126  0.061059
recall  0.01275  0.193410  0.321080  0.110449





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96, global step 18333: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014572  0.057105  0.065156  0.045617
ndcg    0.014572  0.089551  0.119196  0.061175
recall  0.014572  0.197880  0.315781  0.108958





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97, global step 18522: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.057380  0.065672  0.046233
ndcg    0.015731  0.089160  0.119384  0.061815
recall  0.015731  0.195065  0.314622  0.109786





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 18711: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.058714  0.067289  0.047574
ndcg    0.01689  0.089875  0.121386  0.062509
recall  0.01689  0.193741  0.318927  0.108296





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99, global step 18900: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.015234  0.056919  0.065023  0.045038
ndcg    0.015234  0.089387  0.119351  0.060335
recall  0.015234  0.197880  0.317271  0.107468



Now we can get the best model path stored in the checkpoint callback.

In [22]:
best_model_path = checkpoint_callback.best_model_path
best_model_path

'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=82-step=15687.ckpt'

## Inference

To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance.

In [26]:
import replay

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    excluded_features=None,
)

torch.serialization.add_safe_globals([
    replay.nn.lightning.optimizer.OptimizerFactory,
    replay.nn.lightning.scheduler.LRSchedulerFactory,
])

best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)
best_model.eval();

Configure `ParquetModule` for inference

In [27]:
inference_metadata = {"predict": create_meta(shape=MAX_SEQ_LEN)}

parquet_module = ParquetModule(
    predict_path=PREDICT_PATH,
    batch_size=BATCH_SIZE,
    metadata=inference_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.


Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`.

In [28]:
from replay.nn.lightning.callback import PandasTopItemsCallback

csv_logger = CSVLogger(save_dir="sasrec/logs/test", name="SasRec-example")

TOPK = [1, 5, 10, 20]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Predicting: |          | 0/? [00:00<?, ?it/s]



In [29]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,1439,11.456461
0,0,1727,11.241992
0,0,2483,11.100959
0,0,740,10.847163
0,0,354,10.603889
...,...,...,...
6037,6039,1741,5.019725
6037,6039,1656,5.002189
6037,6039,610,4.983106
6037,6039,367,4.927832


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [30]:
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.metrics.torch_metrics_builder import metrics_to_df

In [31]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score",
)(pandas_res, test_gt.explode("item_id"))

In [32]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.017721,0.057485,0.06543,0.047237
Precision,0.017721,0.018649,0.015079,0.02153
Recall,0.017721,0.186486,0.30159,0.107652


Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations

In [33]:
encoder.inverse_transform(pandas_res)

Unnamed: 0,user_id,item_id,score
0,1,1566,11.456461
0,1,1907,11.241992
0,1,2687,11.100959
0,1,783,10.847163
0,1,364,10.603889
...,...,...,...
6037,6040,1921,5.019725
6037,6040,1834,5.002189
6037,6040,628,4.983106
6037,6040,377,4.927832
