# SASRec (Self-Attentive Sequential Recommendation)

In [1]:
import pandas as pd
import numpy as np
from tecd_retail_recsys.models import SASRec
from tecd_retail_recsys.metrics import calculate_metrics

import matplotlib.pyplot as plt

In [15]:
!python3 run_sasrec.py prepare

STEP 1: Preparing data for SASRec
Running data preprocessing...
Starting data preprocessing...
Loading events from t_ecd_small_partial/dataset/small/retail/events
Loaded 236,479,226 total events
Loading items data from t_ecd_small_partial/dataset/small/retail/items.pq
Loaded 250,171 items with features: ['item_id', 'item_brand_id', 'item_category', 'item_subcategory', 'item_price', 'item_embedding']
Merged item features. Data shape: (236479226, 12)
Filtered to 3,758,762 events with action_type='added-to-cart'
After filtering (min_user_interactions=1, min_item_interactions=20): 3,249,972 events, 84,944 users, 30,954 items
Created mappings: 84944 users, 30954 items
Temporal split - Train: days < 1269 (902,543 events), Val: days 1269-1288 (228,339 events), Test: days >= 1289 (223,395 events)
Users in each part (train, val, test) - 7425
Grouping data by users...
Saving grouped data to processed_data/...
Saved data for 84944 users and 30954 items
Train: 7425 users
Val: 7425 users
Test: 7425

In [21]:
!python3 tecd_retail_recsys/models/sasrec/train.py --exp_name=sasrec_v2 --num_epochs=100 --device=mps --batch_size=128 --max_seq_len=64 --embedding_dim=128

[2026-02-17 15:03:32] [DEBUG]: Loading preprocessed data...
[2026-02-17 15:03:32] [DEBUG]: Loaded data: 7425 users, 30954 items
[2026-02-17 15:03:32] [DEBUG]: Preprocessing data for SASRec...
[2026-02-17 15:03:32] [DEBUG]: Preprocessing data has finished!
[2026-02-17 15:03:32] [DEBUG]: Start training...
[2026-02-17 15:03:32] [DEBUG]: Start epoch 1
[2026-02-17 15:03:36] [DEBUG]: Epoch 1 completed. Average loss: 0.6270
[2026-02-17 15:03:36] [DEBUG]: Start epoch 2
[2026-02-17 15:03:40] [DEBUG]: Epoch 2 completed. Average loss: 0.5474
[2026-02-17 15:03:40] [DEBUG]: Start epoch 3
[2026-02-17 15:03:44] [DEBUG]: Epoch 3 completed. Average loss: 0.5256
[2026-02-17 15:03:44] [DEBUG]: Start epoch 4
[2026-02-17 15:03:48] [DEBUG]: Epoch 4 completed. Average loss: 0.4886
[2026-02-17 15:03:48] [DEBUG]: Start epoch 5
[2026-02-17 15:03:52] [DEBUG]: Epoch 5 completed. Average loss: 0.4396
[2026-02-17 15:03:52] [DEBUG]: Start epoch 6
[2026-02-17 15:03:56] [DEBUG]: Epoch 6 completed. Average loss: 0.3900

In [None]:
sasrec = SASRec(
    checkpoint_path='checkpoints/sasrec_v2_best_state.pth',
    max_seq_len=64,
    device='cpu',
    batch_size=128
)

sasrec.load_model()

In [23]:
sasrec.model

SASRecEncoder(
  (_item_embeddings): Embedding(30955, 128, padding_idx=30954)
  (_position_embeddings): Embedding(64, 128)
  (_layernorm): LayerNorm((128,), eps=1e-09, elementwise_affine=True)
  (_dropout): Dropout(p=0.2, inplace=False)
  (_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-09, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-09, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
        (activation): GELU(approximate='none')
      )
    )
  )
)

In [24]:
from tecd_retail_recsys.data import DataPreprocessor
dp = DataPreprocessor(day_begin=1082, day_end=1308, val_days=20, test_days=20, min_user_interactions=1, min_item_interactions=20)
train_df, val_df, test_df = dp.preprocess()
joined = dp.get_grouped_data(train_df, val_df, test_df)
joined['train_val_interactions'] = joined['train_interactions'] + joined['val_interactions']
print(joined.shape)

Starting data preprocessing...
Loading events from t_ecd_small_partial/dataset/small/retail/events
Loaded 236,479,226 total events
Loading items data from t_ecd_small_partial/dataset/small/retail/items.pq
Loaded 250,171 items with features: ['item_id', 'item_brand_id', 'item_category', 'item_subcategory', 'item_price', 'item_embedding']
Merged item features. Data shape: (236479226, 12)
Filtered to 3,758,762 events with action_type='added-to-cart'
After filtering (min_user_interactions=1, min_item_interactions=20): 3,249,972 events, 84,944 users, 30,954 items
Created mappings: 84944 users, 30954 items
Temporal split - Train: days < 1269 (902,543 events), Val: days 1269-1288 (228,339 events), Test: days >= 1289 (223,395 events)
Users in each part (train, val, test) - 7425
(7425, 5)


In [25]:
# Генерация рекомендаций (как и для других моделей)
predictions = sasrec.predict(joined, topn=100, return_scores=False)
joined['sasrec_recs'] = pd.Series(predictions)

In [26]:
from tecd_retail_recsys.metrics import calculate_metrics

metrics = calculate_metrics(
    joined, 
    model_preds='sasrec_recs', 
    gt_col='val_interactions',
    train_col='train_interactions',
    verbose=True
)


At k=10:
  MAP@10       = 0.0587
  NDCG@10      = 0.1720
  Precision@10 = 0.0696
  Recall@10    = 0.0226

At k=100:
  MAP@100       = 0.0229
  NDCG@100      = 0.1162
  Precision@100 = 0.0292
  Recall@100    = 0.0895

Other Metrics:
  MRR                 = 0.1758
  Catalog Coverage    = 0.6993
  Diversity     = 0.9953  [0=same recs for all, 1=unique recs]
  Novelty             = 0.8733
  Serendipity         = 0.0362
