# SASRec (Self-Attentive Sequential Recommendation)

In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
os.chdir(project_root)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from tecd_retail_recsys.models import SASRec
from tecd_retail_recsys.metrics import calculate_metrics

In [None]:
from tecd_retail_recsys.data.prepare_sasrec_data import DataPreprocessor, save_grouped_data

preprocessor = DataPreprocessor(
    raw_data_path='t_ecd_small_partial/dataset/small',
    processed_data_dir='processed_data',
    day_begin=1082,
    day_end=1308,
    min_user_interactions=1,
    min_item_interactions=20,
    val_days=20,
    test_days=20,
    users_limit=None
)

save_grouped_data(preprocessor, output_dir='processed_data')

In [None]:
# эксперимент 1

In [23]:
!python3 tecd_retail_recsys/models/sasrec/train.py --exp_name=make_sasrec_great_again --num_epochs=300 --device=mps --batch_size=128 --max_seq_len=100 --embedding_dim=128 --num_heads=4 --num_layers=4 --learning_rate=5e-4 --dropout=0.3

[2026-02-22 15:18:12] [DEBUG]: Loading preprocessed data...
[2026-02-22 15:18:12] [DEBUG]: Loaded data: 7425 users, 30954 items
[2026-02-22 15:18:12] [DEBUG]: Preprocessing data for SASRec...
[2026-02-22 15:18:12] [DEBUG]: Preprocessing data has finished!
[2026-02-22 15:18:12] [DEBUG]: Start training...
[2026-02-22 15:18:12] [DEBUG]: Start epoch 1
[2026-02-22 15:18:20] [DEBUG]: Epoch 1 completed. Average loss: 0.6512
[2026-02-22 15:18:20] [DEBUG]: Start epoch 2
[2026-02-22 15:18:28] [DEBUG]: Epoch 2 completed. Average loss: 0.5553
[2026-02-22 15:18:28] [DEBUG]: Start epoch 3
[2026-02-22 15:18:35] [DEBUG]: Epoch 3 completed. Average loss: 0.5366
[2026-02-22 15:18:35] [DEBUG]: Start epoch 4
[2026-02-22 15:18:43] [DEBUG]: Epoch 4 completed. Average loss: 0.5262
[2026-02-22 15:18:43] [DEBUG]: Start epoch 5
[2026-02-22 15:18:51] [DEBUG]: Epoch 5 completed. Average loss: 0.5113
[2026-02-22 15:18:51] [DEBUG]: Start epoch 6
[2026-02-22 15:18:58] [DEBUG]: Epoch 6 completed. Average loss: 0.4875

In [None]:
# эксперимент 2

In [4]:
!python3 tecd_retail_recsys/models/sasrec/train.py --exp_name=make_sasrec_great_again_v2 --num_epochs=100 --device=mps --batch_size=64 --max_seq_len=100 --embedding_dim=256 --num_heads=4 --num_layers=6 --learning_rate=3e-4 --dropout=0.4

[2026-02-22 18:45:33] [DEBUG]: Loading preprocessed data...
[2026-02-22 18:45:34] [DEBUG]: Loaded data: 7425 users, 30954 items
[2026-02-22 18:45:34] [DEBUG]: Preprocessing data for SASRec...
[2026-02-22 18:45:34] [DEBUG]: Preprocessing data has finished!
[2026-02-22 18:45:34] [DEBUG]: Start training...
[2026-02-22 18:45:34] [DEBUG]: Start epoch 1
[2026-02-22 18:46:01] [DEBUG]: Epoch 1 completed. Average loss: 0.6175
[2026-02-22 18:46:01] [DEBUG]: Start epoch 2
[2026-02-22 18:46:27] [DEBUG]: Epoch 2 completed. Average loss: 0.5473
[2026-02-22 18:46:27] [DEBUG]: Start epoch 3
[2026-02-22 18:46:54] [DEBUG]: Epoch 3 completed. Average loss: 0.5325
[2026-02-22 18:46:54] [DEBUG]: Start epoch 4
[2026-02-22 18:47:19] [DEBUG]: Epoch 4 completed. Average loss: 0.5056
[2026-02-22 18:47:19] [DEBUG]: Start epoch 5
[2026-02-22 18:47:44] [DEBUG]: Epoch 5 completed. Average loss: 0.4697
[2026-02-22 18:47:44] [DEBUG]: Start epoch 6
[2026-02-22 18:48:09] [DEBUG]: Epoch 6 completed. Average loss: 0.4360

In [None]:
# эксперимент 3

In [6]:
!python3 tecd_retail_recsys/models/sasrec/train.py --exp_name=make_sasrec_great_again_v3 --num_epochs=200 --device=mps --batch_size=64 --max_seq_len=100 --embedding_dim=192 --num_heads=8 --num_layers=5 --learning_rate=2e-4 --dropout=0.2

[2026-02-22 19:43:45] [DEBUG]: Loading preprocessed data...
[2026-02-22 19:43:46] [DEBUG]: Loaded data: 7425 users, 30954 items
[2026-02-22 19:43:46] [DEBUG]: Preprocessing data for SASRec...
[2026-02-22 19:43:46] [DEBUG]: Preprocessing data has finished!
[2026-02-22 19:43:46] [DEBUG]: Start training...
[2026-02-22 19:43:46] [DEBUG]: Start epoch 1
[2026-02-22 19:44:02] [DEBUG]: Epoch 1 completed. Average loss: 0.6496
[2026-02-22 19:44:02] [DEBUG]: Start epoch 2
[2026-02-22 19:44:18] [DEBUG]: Epoch 2 completed. Average loss: 0.5588
[2026-02-22 19:44:18] [DEBUG]: Start epoch 3
[2026-02-22 19:44:34] [DEBUG]: Epoch 3 completed. Average loss: 0.5387
[2026-02-22 19:44:34] [DEBUG]: Start epoch 4
[2026-02-22 19:44:50] [DEBUG]: Epoch 4 completed. Average loss: 0.5288
[2026-02-22 19:44:50] [DEBUG]: Start epoch 5
[2026-02-22 19:45:05] [DEBUG]: Epoch 5 completed. Average loss: 0.5151
[2026-02-22 19:45:05] [DEBUG]: Start epoch 6
[2026-02-22 19:45:21] [DEBUG]: Epoch 6 completed. Average loss: 0.4932

In [None]:
# эксперимент 4

In [3]:
!python3 tecd_retail_recsys/models/sasrec/train.py --exp_name=make_sasrec_great_again_v4 --num_epochs=250 --device=mps --batch_size=32 --max_seq_len=100 --embedding_dim=256 --num_heads=8 --num_layers=5 --learning_rate=2e-4 --dropout=0.25

[2026-02-22 21:51:00] [DEBUG]: Loading preprocessed data...
[2026-02-22 21:51:00] [DEBUG]: Loaded data: 7425 users, 30954 items
[2026-02-22 21:51:00] [DEBUG]: Preprocessing data for SASRec...
[2026-02-22 21:51:00] [DEBUG]: Preprocessing data has finished!
[2026-02-22 21:51:01] [DEBUG]: Start training...
[2026-02-22 21:51:01] [DEBUG]: Start epoch 1
[2026-02-22 21:51:40] [DEBUG]: Epoch 1 completed. Average loss: 0.6132
[2026-02-22 21:51:40] [DEBUG]: Start epoch 2
[2026-02-22 21:52:13] [DEBUG]: Epoch 2 completed. Average loss: 0.5467
[2026-02-22 21:52:13] [DEBUG]: Start epoch 3
[2026-02-22 21:52:44] [DEBUG]: Epoch 3 completed. Average loss: 0.5272
[2026-02-22 21:52:44] [DEBUG]: Start epoch 4
[2026-02-22 21:53:12] [DEBUG]: Epoch 4 completed. Average loss: 0.4919
[2026-02-22 21:53:12] [DEBUG]: Start epoch 5
[2026-02-22 21:53:40] [DEBUG]: Epoch 5 completed. Average loss: 0.4525
[2026-02-22 21:53:40] [DEBUG]: Start epoch 6
[2026-02-22 21:54:05] [DEBUG]: Epoch 6 completed. Average loss: 0.4162

In [25]:
sasrec = SASRec(
    checkpoint_path='checkpoints/make_sasrec_great_again_v4_best_state.pth',
    max_seq_len=100,
    device='cpu',
    batch_size=128
)

sasrec.load_model()

In [26]:
sasrec.model

SASRecEncoder(
  (_item_embeddings): Embedding(30955, 256, padding_idx=30954)
  (_position_embeddings): Embedding(100, 256)
  (_layernorm): LayerNorm((256,), eps=1e-09, elementwise_affine=True)
  (_dropout): Dropout(p=0.25, inplace=False)
  (_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-4): 5 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.25, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-09, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-09, elementwise_affine=True)
        (dropout1): Dropout(p=0.25, inplace=False)
        (dropout2): Dropout(p=0.25, inplace=False)
        (activation): GELU(approximate='none')
      )
    )
  )
)

In [None]:
train_df, val_df, test_df = preprocessor.preprocess()
joined = preprocessor.get_grouped_data(train_df, val_df, test_df)
joined['train_val_interactions'] = joined['train_interactions'] + joined['val_interactions']
print(joined.shape)

(7425, 5)


In [28]:
predictions = sasrec.predict(joined, topn=100, return_scores=False)
joined['sasrec_recs'] = pd.Series(predictions)

In [29]:
sasrec_val_metrics = calculate_metrics(
    joined, 
    model_preds='sasrec_recs', 
    gt_col='val_interactions',
    train_col='train_interactions',
    verbose=True
)

[Metrics debug] resolved gt_col='val_interactions' item_id_index=0
[Metrics debug] ratings_true shape: (228339, 3) ratings_pred shape: (742500, 3)
  ratings_true dtypes: {'user_id': dtype('int64'), 'item_id': dtype('int64')}
  ratings_pred dtypes: {'user_id': dtype('int64'), 'item_id': dtype('int64')}
  user_id=11 gt_count=22 pred_count=100 overlap=5
  user_id=14 gt_count=5 pred_count=100 overlap=0
    [ID spaces] gt sample=[9341, 16732, 17585, 28024, 30789] range=[9341, 30789] | rec sample=[153, 536, 605, 650, 976] range=[153, 30913]
  user_id=21 gt_count=47 pred_count=100 overlap=12

At k=10:
  MAP@10       = 0.0909
  NDCG@10      = 0.2459
  Precision@10 = 0.0952
  Recall@10    = 0.0298

At k=100:
  MAP@100       = 0.0349
  NDCG@100      = 0.1590
  Precision@100 = 0.0397
  Recall@100    = 0.1191

Other Metrics:
  MRR                 = 0.1860
  Catalog Coverage    = 0.8146
  Diversity     = 0.9960  [0=same recs for all, 1=unique recs]
  Novelty             = 0.8814
  Serendipity      

<!DOCTYPE html>
<html>
<head>
    <style>
        table {
            border-collapse: collapse;
            width: 100%;
            font-family: Arial, sans-serif;
            margin: 20px 0;
        }
        th {
            background-color: #4CAF50;
            color: white;
            padding: 12px;
            text-align: left;
            border: 1px solid #ddd;
        }
        td {
            padding: 10px;
            border: 1px solid #ddd;
            text-align: left;
        }
        tr:nth-child(even) {
            background-color: #f2f2f2;
        }
        tr:hover {
            background-color: #ddd;
        }
        .best {
            background-color: #c8e6c9 !important;
            font-weight: bold;
        }
        .worst {
            background-color: #ffcdd2 !important;
        }
    </style>
</head>
<body>
    <h2>SASRec: эксперименты</h2>
    <table>
        <thead>
            <tr>
                <th>Номер эксперимента</th>
                <th>embedding_dim</th>
                <th>num_heads</th>
                <th>num_layers</th>
                <th>dropout</th>
                <th>max_seq_length</th>
                <th>epochs</th>
                <th>learning_rate</th>
                <th>batch_size</th>
                <th>avg_loss</th>
                <th>ndcg@100</th>
            </tr>
        </thead>
        <tbody>
            <tr>
                <td>1</td>
                <td>128</td>
                <td>4</td>
                <td>4</td>
                <td>0.3</td>
                <td>100</td>
                <td>300</td>
                <td>5e-4</td>
                <td>128</td>
                <td>0.0619</td>
                <td>0.1432</td>
            </tr>
            <tr class="worst">
                <td>2</td>
                <td>256</td>
                <td>4</td>
                <td>6</td>
                <td>0.4</td>
                <td>100</td>
                <td>100</td>
                <td>3e-4</td>
                <td>64</td>
                <td>0.0650</td>
                <td>0.1412</td>
            </tr>
            <tr>
                <td>3</td>
                <td>192</td>
                <td>8</td>
                <td>5</td>
                <td>0.2</td>
                <td>100</td>
                <td>200</td>
                <td>2e-4</td>
                <td>64</td>
                <td>0.0320</td>
                <td>0.1542</td>
            </tr>
            <tr class="best">
                <td>4</td>
                <td>256</td>
                <td>8</td>
                <td>5</td>
                <td>0.25</td>
                <td>100</td>
                <td>250</td>
                <td>2e-4</td>
                <td>32</td>
                <td>0.0266</td>
                <td>0.1590</td>
            </tr>
        </tbody>
    </table>
</body>
</html>


`Наилучшая конфигурация смогли добиться NDCG@100 = 0.1590`