# SASRec (Self-Attentive Sequential Recommendation)

In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
os.chdir(project_root)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from tecd_retail_recsys.models import SASRec
from tecd_retail_recsys.metrics import calculate_metrics


In [8]:
from tecd_retail_recsys.data.prepare_sasrec_data import DataPreprocessor, save_grouped_data

preprocessor = DataPreprocessor(
    raw_data_path='t_ecd_small_partial/dataset/small',
    processed_data_dir='processed_data',
    day_begin=1082,
    day_end=1308,
    min_user_interactions=1,
    min_item_interactions=20,
    val_days=20,
    test_days=20,
    users_limit=None
)

save_grouped_data(preprocessor, output_dir='processed_data')

Running data preprocessing...
Starting data preprocessing...
Loading events from t_ecd_small_partial/dataset/small/retail/events
Loaded 236,479,226 total events
Loading items data from t_ecd_small_partial/dataset/small/retail/items.pq
Loaded 250,171 items with features: ['item_id', 'item_brand_id', 'item_category', 'item_subcategory', 'item_price', 'item_embedding']
Merged item features. Data shape: (236479226, 12)
Filtered to 3,758,762 events with action_type='added-to-cart'
After filtering (min_user_interactions=1, min_item_interactions=20): 3,249,972 events, 84,944 users, 30,954 items
Created mappings: 84944 users, 30954 items
Temporal split - Train: days < 1269 (902,543 events), Val: days 1269-1288 (228,339 events), Test: days >= 1289 (223,395 events)
Users in each part (train, val, test) - 7425
Grouping data by users...
Saving grouped data to processed_data/...
Saved data for 84944 users and 30954 items
Train: 7425 users
Val: 7425 users
Test: 7425 users


In [10]:
!python3 tecd_retail_recsys/models/sasrec/train.py --exp_name=sasrec_v3 --num_epochs=200 --device=mps --batch_size=128 --max_seq_len=128 --embedding_dim=128

[2026-02-17 15:44:15] [DEBUG]: Loading preprocessed data...
[2026-02-17 15:44:15] [DEBUG]: Loaded data: 7425 users, 30954 items
[2026-02-17 15:44:15] [DEBUG]: Preprocessing data for SASRec...
[2026-02-17 15:44:15] [DEBUG]: Preprocessing data has finished!
[2026-02-17 15:44:16] [DEBUG]: Start training...
[2026-02-17 15:44:16] [DEBUG]: Start epoch 1
[2026-02-17 15:44:25] [DEBUG]: Epoch 1 completed. Average loss: 0.6220
[2026-02-17 15:44:25] [DEBUG]: Start epoch 2
[2026-02-17 15:44:34] [DEBUG]: Epoch 2 completed. Average loss: 0.5477
[2026-02-17 15:44:34] [DEBUG]: Start epoch 3
[2026-02-17 15:44:43] [DEBUG]: Epoch 3 completed. Average loss: 0.5306
[2026-02-17 15:44:43] [DEBUG]: Start epoch 4
[2026-02-17 15:44:52] [DEBUG]: Epoch 4 completed. Average loss: 0.4864
[2026-02-17 15:44:52] [DEBUG]: Start epoch 5
[2026-02-17 15:45:01] [DEBUG]: Epoch 5 completed. Average loss: 0.4307
[2026-02-17 15:45:01] [DEBUG]: Start epoch 6
[2026-02-17 15:45:11] [DEBUG]: Epoch 6 completed. Average loss: 0.3884

In [11]:
sasrec = SASRec(
    checkpoint_path='checkpoints/sasrec_v3_best_state.pth',
    max_seq_len=128,
    device='cpu',
    batch_size=256
)

sasrec.load_model()

In [12]:
sasrec.model

SASRecEncoder(
  (_item_embeddings): Embedding(30955, 128, padding_idx=30954)
  (_position_embeddings): Embedding(128, 128)
  (_layernorm): LayerNorm((128,), eps=1e-09, elementwise_affine=True)
  (_dropout): Dropout(p=0.2, inplace=False)
  (_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-09, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-09, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
        (activation): GELU(approximate='none')
      )
    )
  )
)

In [14]:
train_df, val_df, test_df = preprocessor.preprocess()
joined = preprocessor.get_grouped_data(train_df, val_df, test_df)
joined['train_val_interactions'] = joined['train_interactions'] + joined['val_interactions']
print(joined.shape)

Starting data preprocessing...
Loading events from t_ecd_small_partial/dataset/small/retail/events
Loaded 236,479,226 total events
Loading items data from t_ecd_small_partial/dataset/small/retail/items.pq
Loaded 250,171 items with features: ['item_id', 'item_brand_id', 'item_category', 'item_subcategory', 'item_price', 'item_embedding']
Merged item features. Data shape: (236479226, 12)
Filtered to 3,758,762 events with action_type='added-to-cart'
After filtering (min_user_interactions=1, min_item_interactions=20): 3,249,972 events, 84,944 users, 30,954 items
Created mappings: 84944 users, 30954 items
Temporal split - Train: days < 1269 (902,543 events), Val: days 1269-1288 (228,339 events), Test: days >= 1289 (223,395 events)
Users in each part (train, val, test) - 7425
(7425, 5)


In [15]:
# Генерация рекомендаций (как и для других моделей)
predictions = sasrec.predict(joined, topn=100, return_scores=False)
joined['sasrec_recs'] = pd.Series(predictions)

In [None]:
from tecd_retail_recsys.metrics import calculate_metrics

sasrec_val_metrics = calculate_metrics(
    joined, 
    model_preds='sasrec_recs', 
    gt_col='val_interactions',
    train_col='train_interactions',
    verbose=True
)


At k=10:
  MAP@10       = 0.0689
  NDCG@10      = 0.1966
  Precision@10 = 0.0771
  Recall@10    = 0.0247

At k=100:
  MAP@100       = 0.0265
  NDCG@100      = 0.1303
  Precision@100 = 0.0330
  Recall@100    = 0.1003

Other Metrics:
  MRR                 = 0.1765
  Catalog Coverage    = 0.7044
  Diversity     = 0.9954  [0=same recs for all, 1=unique recs]
  Novelty             = 0.8812
  Serendipity         = 0.0363
