# Example of the SASRec training and inference stages
Note that all the given examples can be run without using PySpark, using only Pandas

In [1]:
# changing core directory
import os, sys
dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)
os.chdir('..')

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["KAGGLE_USERNAME"] = "recsysaccelerate"
os.environ["KAGGLE_KEY"] = "6363e91b656fea576c39e4f55dcc1d00"

In [3]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.preprocessing.filters import MinCountFilter
from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter, TimeSplitter
from replay.utils import get_spark_session
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback, 
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import (
    SequenceTokenizer,
    SequentialDataset,
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo
)
from replay.models.nn.sequential import SasRec
from replay.models.nn.sequential.sasrec import (
    SasRecPredictionDataset,
    SasRecTrainingDataset,
    SasRecValidationDataset,
    SasRecPredictionBatch,
    SasRecModel
)

import pandas as pd

In [4]:
import random
import numpy as np

SEED = 777

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.backends.cudnn.deterministic=True

## Getting a spark session

In [5]:
spark_session = get_spark_session()

/usr/local/lib/python3.11/site-packages/pyspark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/27 17:55:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/01/27 17:55:23 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
25/01/27 17:55:24 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/01/27 17:55:24 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


## Prepare data
### Load raw movielens-1M interactions, item features and user features.
In the current implementation, the SASRec does not take into account the features of items or users. They are only used to get a complete list of users and items.

In [6]:
!pip install rs-datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [7]:
from rs_datasets import MovieLens

In [8]:
movielens = MovieLens("1m")
interactions = movielens.ratings
user_features = movielens.users
item_features = movielens.items

interactions = interactions[interactions['rating'] >= 4]

In [9]:
interactions.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
3,1,3408,4,978300275
4,1,2355,5,978824291
6,1,1287,5,978302039
7,1,2804,5,978300719


In [10]:
user_features.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [11]:
item_features.head()

Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


In [12]:
interactions.describe()

Unnamed: 0,user_id,item_id,rating,timestamp
count,575281.0,575281.0,575281.0,575281.0
mean,3038.114852,1817.58707,4.39339,972012800.0
std,1733.654125,1083.844338,0.488503,11872690.0
min,1.0,1.0,4.0,956703900.0
25%,1533.0,1022.0,4.0,965263300.0
50%,3080.0,1648.0,4.0,972676000.0
75%,4505.0,2750.0,5.0,975161300.0
max,6040.0,3952.0,5.0,1046455000.0


In [13]:
interactions = MinCountFilter(
    num_entries=5,
    groupby_column='item_id',
).transform(interactions)

interactions = MinCountFilter(
    num_entries=3,
    groupby_column='user_id',
).transform(interactions)

In [14]:
interactions.describe()

Unnamed: 0,user_id,item_id,rating,timestamp
count,574380.0,574380.0,574380.0,574380.0
mean,3038.125194,1817.036676,4.393593,972006800.0
std,1733.79455,1083.628435,0.488547,11863810.0
min,1.0,1.0,4.0,956703900.0
25%,1533.0,1022.0,4.0,965262100.0
50%,3080.0,1645.0,4.0,972675200.0
75%,4505.0,2748.0,5.0,975149400.0
max,6040.0,3952.0,5.0,1046455000.0


Removing duplicates in the timestamp column without changing the original items order where timestamp is the same

In [15]:
# interactions["timestamp"] = interactions["timestamp"].astype("int64")
# interactions = interactions.sort_values(by="timestamp")
# interactions["timestamp"] = interactions.groupby("user_id").cumcount()
# interactions

### Split interactions into the train, validation and test datasets using LastNSplitter

In [16]:
# splitter = LastNSplitter(
#     N=1,
#     divide_column="user_id",
#     query_column="user_id",
#     strategy="interactions",
# )

# raw_test_events, raw_test_gt = splitter.split(interactions)
# raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
# raw_train_events = raw_validation_events


splitter = TimeSplitter(
    time_threshold=0.1,
    drop_cold_users=True,
    drop_cold_items=True,
    item_column='item_id',
    query_column='user_id',
    timestamp_column='timestamp',
)

# train_events, validation_events, validation_gt, test_events, test_gt = (
#     _split_data(splitter, interactions)
# )

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

raw_test_gt = raw_test_gt[raw_test_gt['item_id'].isin(raw_train_events['item_id'])]
raw_test_gt = raw_test_gt[raw_test_gt['user_id'].isin(raw_train_events['user_id'])]

In [17]:
raw_train_events['item_id'].nunique()

3120

In [18]:
def test_splitting(events, gt, name=''):
    if events['timestamp'].max() > gt['timestamp'].min():
        print("Problem with time points in", name)
    if len(set(gt['user_id'].unique().tolist()) - set(events['user_id'].unique().tolist())) > 0:
        print("Problem with cold users in", name)
    if len(set(gt['item_id'].unique().tolist()) - set(events['item_id'].unique().tolist())) > 0:
        print("Problem with cold items in", name)


test_splitting(raw_train_events, raw_test_gt, "train events, test gt")
test_splitting(raw_train_events, raw_validation_gt, "train events, valid gt")

### Prepare FeatureSchema required to create Dataset

In [19]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

### Create Dataset for the training stage

In [20]:
user_features = None
item_features = None

In [21]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the validation stage

In [22]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the testing stage

In [23]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model

In [24]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
    )
)

### Create sequential datasets using SequenceTokenizer
The SequentialDataset internally store data in the form of sequences of items sorted by increasing interaction time (timestamp). A SequenceTokenizer is used to convert to this format. In addition, the SequenceTokenizer encodes all categorical columns from the source dataset and stores mapping inside itself.
SequentialDataset.keep_common_query_ids is used to leave only sequences from the same users

In [25]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True, handle_unknown_rule="drop")
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

In [26]:
test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)

In [27]:
print('SCHEMA CARDINALITY', tensor_schema.item_id_features.item().cardinality)

SCHEMA CARDINALITY 3120


You can get the user and item mapping and inverse mapping as follows

In [28]:
print(tokenizer.query_id_encoder.mapping, tokenizer.query_id_encoder.inverse_mapping)
print(tokenizer.item_id_encoder.mapping, tokenizer.item_id_encoder.inverse_mapping)

print(len(tokenizer.query_id_encoder.mapping["user_id"]), len(tokenizer.item_id_encoder.mapping['item_id']))

{'user_id': {6040: 0, 6039: 1, 6038: 2, 6037: 3, 6036: 4, 6035: 5, 6034: 6, 6033: 7, 6032: 8, 6031: 9, 6030: 10, 6029: 11, 6028: 12, 6027: 13, 6026: 14, 6025: 15, 6024: 16, 6023: 17, 6022: 18, 6021: 19, 6020: 20, 6019: 21, 6018: 22, 6017: 23, 6016: 24, 6015: 25, 6014: 26, 6013: 27, 6012: 28, 6011: 29, 6010: 30, 6009: 31, 6007: 32, 6008: 33, 6006: 34, 6005: 35, 6004: 36, 6003: 37, 6002: 38, 6001: 39, 6000: 40, 5999: 41, 5998: 42, 5997: 43, 5996: 44, 5995: 45, 5994: 46, 5993: 47, 5992: 48, 5991: 49, 5990: 50, 5989: 51, 5988: 52, 5987: 53, 5986: 54, 5984: 55, 5983: 56, 5982: 57, 5981: 58, 5979: 59, 5980: 60, 5977: 61, 5976: 62, 5975: 63, 5974: 64, 5973: 65, 5972: 66, 5971: 67, 5978: 68, 5970: 69, 5969: 70, 5968: 71, 5967: 72, 5966: 73, 5965: 74, 5964: 75, 5963: 76, 5962: 77, 5961: 78, 5960: 79, 5959: 80, 5958: 81, 5957: 82, 5956: 83, 5955: 84, 5954: 85, 5953: 86, 5952: 87, 5951: 88, 5950: 89, 5948: 90, 5947: 91, 5946: 92, 5945: 93, 5944: 94, 5943: 95, 5942: 96, 5941: 97, 5940: 98, 5939: 9

In [29]:
tokenizer.item_id_encoder.mapping['item_id'][1961]

3

## Train model
### Create SASRec model instance and run the training stage using lightning
After each epoch validation metrics are shown. You can change the list of validation metrics in ValidationMetricsCallback
The model is determined to be the best and is saved if the metric updates its maximum during validation (see the ModelCheckpoint)

In [30]:
from replay.models.nn.sequential import SasRec
from replay.models.nn.sequential.sasrec import (
    SasRecPredictionDataset,
    SasRecTrainingDataset,
    SasRecValidationDataset,
    SasRecPredictionBatch,
    SasRecModel
)

%load_ext autoreload
%autoreload 2

In [None]:
MAX_SEQ_LEN = 100
BATCH_SIZE = 512
NUM_WORKERS = 4

model = SasRec(
    tensor_schema,
    block_count=2,
    head_count=2,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=128,
    dropout_rate=0.2,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
    loss_type="CE",
    #loss_sample_count=200,
    n_buckets = int(2 * (96 * BATCH_SIZE) ** 0.5),
    bucket_size_x = int(2 * (96 * BATCH_SIZE) ** 0.5),
    bucket_size_y = 256,
    mix_x = True,
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SASRec_example")

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    # if you use multiple dataloaders, then add the serial number of the dataloader to the suffix of the metric name.
    # For example,"recall@10/dataloader_idx_0"
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)]
)

trainer = L.Trainer(
    max_epochs=20,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
    devices=[0]
)

train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX 6000 Ada Generation') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/usr/local/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /root/RePlay-Accelerated/.checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type             | Params
--------------------------------------------
0 | _model | SasRecModel      | 611 K 
1 | _loss  | CrossEntropyLoss | 0     
--------------------------------------------
611 K     Trainable params
0         Non-trainable

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

k              1        10        20         5
map     0.006696  0.003306  0.002282  0.005112
ndcg    0.006696  0.010118  0.010047  0.010166
recall  0.000105  0.001853  0.005198  0.000840



/usr/local/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (11) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 11: 'recall@10' reached 0.04044 (best 0.04044), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=0-step=11-v5.ckpt' as top 1


k              1        10        20         5
map     0.095982  0.036956  0.033033  0.047096
ndcg    0.095982  0.078761  0.086332  0.076952
recall  0.004679  0.040440  0.077076  0.018271



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 22: 'recall@10' reached 0.04322 (best 0.04322), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=1-step=22.ckpt' as top 1


k              1        10        20         5
map     0.089286  0.050531  0.045823  0.060119
ndcg    0.089286  0.093248  0.103265  0.093608
recall  0.004684  0.043218  0.086430  0.029833



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 33: 'recall@10' was not in top 1


k              1        10        20         5
map     0.109375  0.055342  0.045655  0.068616
ndcg    0.109375  0.101260  0.101850  0.105051
recall  0.004743  0.039543  0.074690  0.018832



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 44: 'recall@10' reached 0.04350 (best 0.04350), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=3-step=44.ckpt' as top 1


k              1        10        20         5
map     0.116071  0.057034  0.049169  0.070110
ndcg    0.116071  0.103348  0.107832  0.105550
recall  0.006214  0.043501  0.082212  0.021498



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 55: 'recall@10' was not in top 1


k              1        10        20         5
map     0.109375  0.057341  0.049678  0.069531
ndcg    0.109375  0.103229  0.106813  0.106499
recall  0.006921  0.043453  0.078440  0.030104



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 66: 'recall@10' reached 0.04638 (best 0.04638), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=5-step=66.ckpt' as top 1


k              1        10        20         5
map     0.111607  0.058468  0.051022  0.070521
ndcg    0.111607  0.105421  0.109582  0.104608
recall  0.009249  0.046380  0.087136  0.022285



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 77: 'recall@10' reached 0.05753 (best 0.05753), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=6-step=77-v2.ckpt' as top 1


k              1        10        20         5
map     0.109375  0.057464  0.049181  0.066877
ndcg    0.109375  0.103134  0.104689  0.098627
recall  0.007547  0.057527  0.090697  0.027195



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 88: 'recall@10' was not in top 1


k              1        10        20         5
map     0.098214  0.055525  0.048982  0.066236
ndcg    0.098214  0.099453  0.102820  0.100479
recall  0.005201  0.057475  0.092279  0.039061



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 99: 'recall@10' reached 0.06490 (best 0.06490), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=8-step=99.ckpt' as top 1


k             1        10        20         5
map     0.09375  0.056943  0.050902  0.063475
ndcg    0.09375  0.101555  0.105304  0.096446
recall  0.00908  0.064902  0.102572  0.037663



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 110: 'recall@10' reached 0.06850 (best 0.06850), saving model to '/root/RePlay-Accelerated/.checkpoints/epoch=9-step=110-v1.ckpt' as top 1


k              1        10        20         5
map     0.091518  0.056056  0.051681  0.063724
ndcg    0.091518  0.098843  0.102281  0.096318
recall  0.011978  0.068503  0.093092  0.042098



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 121: 'recall@10' was not in top 1


k              1        10        20         5
map     0.095982  0.054569  0.050457  0.061923
ndcg    0.095982  0.096253  0.101312  0.093349
recall  0.011280  0.066975  0.093046  0.040207



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 132: 'recall@10' was not in top 1


k              1        10        20         5
map     0.087054  0.049129  0.045612  0.055997
ndcg    0.087054  0.089698  0.094461  0.086609
recall  0.009969  0.061349  0.090519  0.038085



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 143: 'recall@10' was not in top 1


k              1        10        20         5
map     0.095982  0.047367  0.044209  0.054926
ndcg    0.095982  0.088443  0.095508  0.086795
recall  0.008243  0.056562  0.098036  0.037426



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 154: 'recall@10' was not in top 1


k              1        10        20         5
map     0.084821  0.046866  0.042789  0.052917
ndcg    0.084821  0.087881  0.093568  0.082287
recall  0.007410  0.060278  0.097294  0.032833



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 165: 'recall@10' was not in top 1


k              1        10        20         5
map     0.075893  0.047987  0.044806  0.054156
ndcg    0.075893  0.086810  0.093272  0.080617
recall  0.011244  0.062316  0.095711  0.034930



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 176: 'recall@10' was not in top 1


k              1        10        20         5
map     0.069196  0.045706  0.042408  0.052232
ndcg    0.069196  0.084517  0.090825  0.079590
recall  0.008019  0.063526  0.094869  0.033026



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 187: 'recall@10' was not in top 1


k              1        10        20         5
map     0.064732  0.045957  0.041965  0.051324
ndcg    0.064732  0.084143  0.089757  0.078330
recall  0.006904  0.060641  0.091488  0.033610



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 198: 'recall@10' was not in top 1


k              1        10        20         5
map     0.064732  0.043721  0.041299  0.049087
ndcg    0.064732  0.080056  0.088916  0.075495
recall  0.007592  0.056233  0.093627  0.031715



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 209: 'recall@10' was not in top 1


k              1        10        20         5
map     0.066964  0.045481  0.042078  0.049824
ndcg    0.066964  0.081529  0.090055  0.073251
recall  0.009820  0.058397  0.096188  0.026484



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 220: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=20` reached.


k              1        10        20         5
map     0.060268  0.043790  0.040391  0.047448
ndcg    0.060268  0.079434  0.085903  0.072065
recall  0.009205  0.055097  0.086607  0.029633



The path to the best model is saved inside checkpoint_callback

In [32]:
int(2 * (96 * BATCH_SIZE) ** 0.5)

443

In [33]:
best_model = SasRec.load_from_checkpoint(checkpoint_callback.best_model_path)

## Inference stage
### Prepare Dataloader and logger

In [34]:
prediction_dataloader = DataLoader(
    dataset=SasRecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

csv_logger = CSVLogger(save_dir=".logs/test", name="SASRec_example")

### Run inference
You can get the recommendations in three formats: PySpark DataFrame, Pandas DataFrame, PyTorch tensors. Each of the types corresponds a callback
You can filter the results using postprocessors strategy. For example the RemoveSeenItems postprocessor is filtering out the items that already have been seen in test dataset
You don't need to use all three callbacks. This is shown only for example

Also, you can get user embeddings, that were used to perform predictions, using `get_query_embedding` method inside SasRecModel or `QueryEmbeddingsPredictionCallback` for lightning module.

In [35]:
TOPK = [1, 10, 20, 100]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

spark_prediction_callback = SparkPredictionCallback(
    spark_session=spark_session,
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

torch_prediction_callback = TorchPredictionCallback(
    top_k=max(TOPK),
    postprocessors=postprocessors,
)

query_embeddings_callback = QueryEmbeddingsPredictionCallback()

trainer = L.Trainer(
    callbacks=[
        spark_prediction_callback,
        pandas_prediction_callback,
        torch_prediction_callback,
        query_embeddings_callback
    ], 
    logger=csv_logger, 
    inference_mode=True
)
trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

spark_res = spark_prediction_callback.get_result()
pandas_res = pandas_prediction_callback.get_result()
torch_user_ids, torch_item_ids, torch_scores = torch_prediction_callback.get_result()
user_embeddings = query_embeddings_callback.get_result()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

25/01/27 17:55:52 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
25/01/27 17:55:52 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
25/01/27 17:55:52 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.


In [36]:
spark_res.show()

[Stage 0:>                                                          (0 + 1) / 1]

+-------+-------+------------------+
|user_id|item_id|             score|
+-------+-------+------------------+
|   5122|   1003|3.1973793506622314|
|   5122|    705|2.9980478286743164|
|   5122|    610| 2.967081308364868|
|   5122|   1699|2.9271583557128906|
|   5122|   1515|2.9029107093811035|
|   5122|    729|2.8825559616088867|
|   5122|    632|2.8579301834106445|
|   5122|    756|2.8310935497283936|
|   5122|    929|2.8010332584381104|
|   5122|   1569| 2.781919240951538|
|   5122|    442|2.7577714920043945|
|   5122|    982|2.7440521717071533|
|   5122|    928| 2.733532667160034|
|   5122|    749|2.7306714057922363|
|   5122|    663|2.7242941856384277|
|   5122|   1149|2.7208995819091797|
|   5122|   1931|2.7192840576171875|
|   5122|    680|2.6880593299865723|
|   5122|   1098|2.6808016300201416|
|   5122|   1615| 2.665447473526001|
+-------+-------+------------------+
only showing top 20 rows



                                                                                

In [37]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,5122,1003,3.197379
0,5122,705,2.998048
0,5122,610,2.967081
0,5122,1699,2.927158
0,5122,1515,2.902911
...,...,...,...
998,89,277,1.891246
998,89,581,1.890986
998,89,598,1.890561
998,89,812,1.876812


In [38]:
print(torch_user_ids[0], torch_item_ids[0], torch_scores[0])

tensor(5122) tensor([1003,  705,  610, 1699, 1515,  729,  632,  756,  929, 1569,  442,  982,
         928,  749,  663, 1149, 1931,  680, 1098, 1615,  976,  606, 1505, 1511,
         959,  391,  629,  995,  662, 2109, 1933,  938,  738,  931, 1512,  753,
        1620,  380,  910, 1738, 1102,  731, 1700, 1233,  634,  837, 2143, 1084,
         679, 2117, 1131, 1426, 1127, 1650,  752,  728,  107,  665, 1139, 1513,
        1015, 1455,  671,  437, 2454, 1202,  310, 1940, 1835,  733,  776,  681,
        1603, 1539,  927, 1160,  102,  704, 1002, 1101,  863, 1018,  128, 1247,
        1601,  631,  707, 1736,  205,  524,  992, 1377, 2359, 1027, 1510, 1148,
        1568,  727,  254, 1096]) tensor([3.1974, 2.9980, 2.9671, 2.9272, 2.9029, 2.8826, 2.8579, 2.8311, 2.8010,
        2.7819, 2.7578, 2.7441, 2.7335, 2.7307, 2.7243, 2.7209, 2.7193, 2.6881,
        2.6808, 2.6654, 2.6430, 2.6403, 2.6007, 2.5821, 2.5531, 2.5447, 2.5431,
        2.5404, 2.5298, 2.5248, 2.5243, 2.5214, 2.5206, 2.5173, 2.5110, 2.

Suppose we want to get the recomendations in PySpark format. 
Let's get the inverse representation of labels using inverse_transform method.

Note that the reverse representation can only be obtained for PySpark and Pandas formats. When working with PyTorch tensors, the reverse representation must be done manually

In [39]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(pandas_res)

In [40]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(spark_res)

In [41]:
recommendations.show()



+------------------+-------+-------+
|             score|user_id|item_id|
+------------------+-------+-------+
|3.2674994468688965|   5872|   3300|
| 3.212432384490967|   5872|   1198|
| 3.212956666946411|   5872|   1617|
|  3.44307279586792|   5872|    593|
| 3.339540481567383|   5872|   3081|
| 4.250976085662842|   5872|   3555|
| 3.331231117248535|   5872|   3948|
| 3.638554573059082|   5872|   2628|
| 4.071298599243164|   5872|   3755|
|  4.79286003112793|   5872|   3753|
|  3.43864107131958|   5872|   2959|
| 3.222958564758301|   5872|   3176|
|3.2520573139190674|   5872|   3160|
|3.2733778953552246|   5872|    527|
|4.3250226974487305|   5872|   3624|
| 3.603006601333618|   5872|    110|
|3.6212527751922607|   5872|   3408|
| 4.582570552825928|   5872|   3623|
| 3.708123207092285|   5872|   3510|
| 3.668764352798462|   5872|    480|
+------------------+-------+-------+
only showing top 20 rows



                                                                                

### Calculating metrics

In [42]:
init_args = {"query_column": "user_id", "rating_column": "score"}

In [43]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK), NDCG(TOPK), MRR(TOPK), HitRate(TOPK)], **init_args
)(recommendations.toPandas(), raw_test_gt)

                                                                                

In [44]:
metrics_to_df(result_metrics)

k,1,10,100,20
HitRate,0.159159,0.513514,0.82983,0.627628
MAP,0.159159,0.068691,0.050072,0.055355
MRR,0.159159,0.255435,0.268785,0.263495
NDCG,0.159159,0.133537,0.16491,0.131021
Precision,0.159159,0.11982,0.073894,0.106907
Recall,0.008108,0.047615,0.229246,0.079231


### User embeddings

Got 6040 x 300 user embeddings, because among all 12 batches: 

11 batches contains 512 samples

1 batch contains 408 left samples

11 * 512 + 408 == 6040

In [45]:
user_embeddings

tensor([[-1.3588,  0.6115, -0.5583,  ...,  1.2862, -1.5047,  0.0346],
        [-3.8591,  1.4096,  1.4237,  ..., -0.6157, -2.2586,  0.0914],
        [-0.5016, -0.3925,  0.6858,  ..., -0.4272, -0.7225, -0.6103],
        ...,
        [-0.6518,  1.9461, -0.6579,  ...,  2.1331, -2.5580,  0.1644],
        [-3.2403,  1.4374,  1.9984,  ..., -1.9835, -0.4550,  0.1752],
        [ 1.2360,  0.9704, -0.7481,  ...,  1.6565, -1.7223, -0.3767]],
       device='cuda:0')

In [46]:
user_embeddings.shape

torch.Size([999, 128])

You can access user embeddings directly with `SasRecModel` class

In [47]:
device = "cuda" if torch.cuda.is_available() else "cpu"

core_model = SasRecModel(
    tensor_schema,
    num_blocks=2,
    num_heads=2,
    max_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout=0.5
)
core_model.eval()
core_model = core_model.to(device)

# Get first batch of data 
data = next(iter(prediction_dataloader))
tensor_map, padding_mask = data.features, data.padding_mask

# Ensure everything is on the same device
padding_mask = padding_mask.to(device)
tensor_map["item_id_seq"] = tensor_map["item_id_seq"].to(device)

# Get user embeddings
user_embeddings_batch = core_model.get_query_embeddings(tensor_map, padding_mask)
user_embeddings_batch

tensor([[-0.0135, -1.8645, -0.5616,  ..., -1.4365, -0.4063,  0.5429],
        [-0.5126,  1.0381,  0.4256,  ..., -0.9326, -0.6378,  0.1428],
        [ 1.5835,  0.1055,  0.1863,  ..., -2.0653, -0.3098, -0.5548],
        ...,
        [-0.9978, -1.2346,  0.3958,  ..., -1.7343,  1.3727, -0.2931],
        [ 1.0758, -2.0191, -0.0468,  ...,  0.2012, -0.7288, -0.3621],
        [-1.1641, -1.2004,  0.2267,  ..., -0.1463,  0.2319, -0.9344]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [48]:
user_embeddings_batch.shape

torch.Size([512, 300])

### Item embeddings

`get_all_embeddings()` method in transformers can be used to get copies of all embeddings that are presented in model as a dict.

In [49]:
all_embeddings = best_model.get_all_embeddings()
all_embeddings

{'item_embedding': tensor([[-0.0345, -0.0091,  0.0657,  ...,  0.0364, -0.0691, -0.0106],
         [-0.0494, -0.0726,  0.0483,  ..., -0.0268, -0.0485, -0.0410],
         [ 0.0141, -0.0136, -0.0091,  ..., -0.0315,  0.0065, -0.0430],
         ...,
         [ 0.0829, -0.0413, -0.0633,  ...,  0.0596,  0.0755,  0.0625],
         [-0.0071, -0.0068,  0.0063,  ...,  0.0582,  0.0791,  0.0534],
         [ 0.0530, -0.0474, -0.0015,  ...,  0.0267,  0.0295,  0.0558]]),
 'positional_embedding': tensor([[ 0.0048, -0.1108,  0.0972,  ..., -0.1670,  0.0168,  0.0838],
         [-0.0271,  0.0556,  0.0412,  ...,  0.0855, -0.0119, -0.0321],
         [-0.0476,  0.0519,  0.0134,  ..., -0.0149, -0.0687, -0.0949],
         ...,
         [-0.1944,  0.0821, -0.1507,  ...,  0.0389,  0.0946,  0.1570],
         [-0.0951,  0.0570,  0.1439,  ...,  0.0934,  0.0340,  0.0887],
         [-0.0529,  0.0834,  0.0725,  ..., -0.0190,  0.0494, -0.0178]])}

You can access item embeddings from this dictionary

In [50]:
item_embeddings = all_embeddings["item_embedding"]
item_embeddings

tensor([[-0.0345, -0.0091,  0.0657,  ...,  0.0364, -0.0691, -0.0106],
        [-0.0494, -0.0726,  0.0483,  ..., -0.0268, -0.0485, -0.0410],
        [ 0.0141, -0.0136, -0.0091,  ..., -0.0315,  0.0065, -0.0430],
        ...,
        [ 0.0829, -0.0413, -0.0633,  ...,  0.0596,  0.0755,  0.0625],
        [-0.0071, -0.0068,  0.0063,  ...,  0.0582,  0.0791,  0.0534],
        [ 0.0530, -0.0474, -0.0015,  ...,  0.0267,  0.0295,  0.0558]])

Item embeddings shape is (N_ITEMS, HIDDEN_SIZE)

In [51]:
item_embeddings.shape

torch.Size([3120, 128])

Ensure we got correct dimension and ensure we got the copy of tensor

In [52]:
assert item_embeddings.shape[0] == len(tokenizer.item_id_encoder.mapping["item_id"])
assert id(item_embeddings) != id(best_model._model.item_embedder.item_emb.weight.data)

For example we observe one new item id in our training data. We can easily expand our item embedder by one element

In order to expand item embeddings by new size `set_item_embeddings_by_size` method is applied

In [53]:
best_model.set_item_embeddings_by_size(item_embeddings.shape[0] + 1)

Now our new item embeddings have one extra embedding

In [54]:
new_size = best_model.get_all_embeddings()["item_embedding"].shape[0]
old_size = item_embeddings.shape[0]

assert new_size == old_size + 1

Alternatively, we can pass our item embeddings that replace the existing ones by calling `set_item_embeddings_by_tensor`.

If tensor contains new items, they will be added to item embedder.

In [55]:
new_embeddings_weights = torch.rand((new_size + 1, 300))    # randint used for example only

best_model.set_item_embeddings_by_tensor(new_embeddings_weights)

ValueError: Input tensor second dimension doesn't match model hidden size

At the moment we expanded our item embeddings by one more item and replace weights by passing `new_embeddings_weights`

In [None]:
old_size = new_size
new_size = best_model.get_all_embeddings()["item_embedding"].shape[0]

assert new_size == old_size + 1

Similarly, we can append tensor for only new items with no replace for existing by calling `append_item_embeddings`

In [None]:
new_item_weights = torch.rand((1, 300))    # randint used for example only

best_model.append_item_embeddings(new_item_weights)

We passed one new example and its weights to item embeddings, thus expanded our vocabulary by one item again

In [None]:
old_size = new_size
new_size = best_model.get_all_embeddings()["item_embedding"].shape[0]

assert new_size == old_size + 1

## Example of launching an inference for a single user without using a trainer (in order to speed up)
An example for the production of an online script

Let's assume that the user's sequence consisted of a sequence of items [1, 2, 3, 4, 5]. 
Сreate a padding mask corresponding to the sequence of items.

It is important to take only the latest MAX_SEQ_LEN or less items.

In [None]:
item_sequence = torch.arange(1, 5).unsqueeze(0)[:, -MAX_SEQ_LEN:]
padding_mask = torch.ones_like(item_sequence, dtype=torch.bool)
sequence_item_count = item_sequence.shape[1]

### Wrapping created tensors in the SasRecPredictionBatch entity

In [None]:
batch = SasRecPredictionBatch(
    query_id=torch.arange(0, item_sequence.shape[0], 1).long(),
    padding_mask=padding_mask.bool(),
    features={ITEM_FEATURE_NAME: item_sequence.long()}
)

### Run predict step of the SasRec and get scores from the model

In [None]:
with torch.no_grad():
    scores = best_model.predict_step(batch, 0)
scores

tensor([[ 8.7962,  5.1799, -2.5661,  ...,  7.6662, -0.4550, -6.3377]])

### Getting three items with the highest score

In [None]:
torch.topk(scores, k=3).indices

tensor([[3249, 1365,  561]])