# BERT4Rec Model Training and Evaluation

Bidirectional Encoder Representations from Transformers for Sequential Recommendation

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
os.chdir(project_root)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
import torch

# Fix for MPS (Apple Silicon) - disable MPS to avoid float64 issues
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
if torch.backends.mps.is_available():
    torch.set_default_device('cpu')

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

def seed_everything(seed):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

from rectools import Columns
from rectools.dataset import Dataset
from rectools.models import BERT4RecModel
from rectools.models.nn.item_net import IdEmbeddingsItemNet, CatFeaturesItemNet

from tecd_retail_recsys.data import DataPreprocessor
from tecd_retail_recsys.metrics import calculate_metrics

print(f"System version: {sys.version}")
print(f"Pandas version: {pd.__version__}")
print(f"Numpy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

System version: 3.11.14 (main, Oct 28 2025, 12:11:54) [Clang 20.1.4 ]
Pandas version: 2.3.3
Numpy version: 1.26.4
PyTorch version: 2.10.0


## Load Configuration

In [115]:
with open('configs/bert4rec.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Model parameters
N_BLOCKS = config['model']['n_blocks']
N_HEADS = config['model']['n_heads']
N_FACTORS = config['model']['n_factors']
DROPOUT_RATE = config['model']['dropout_rate']
MASK_PROB = config['model']['mask_prob']
SESSION_MAX_LEN = config['model']['session_max_len']
TRAIN_MIN_USER_INTERACTIONS = config['model']['train_min_user_interactions']
USE_POS_EMB = config['model']['use_pos_emb']
USE_KEY_PADDING_MASK = config['model']['use_key_padding_mask']

# Training parameters
BATCH_SIZE = config['train']['batch_size']
EPOCHS = config['train']['epochs']
LEARNING_RATE = config['train']['learning_rate']
LOSS = config['train']['loss']
N_NEGATIVES = config['train']['n_negatives']
GBCE_T = config['train']['gbce_t']
DETERMINISTIC = config['train']['deterministic']
VERBOSE = config['train']['verbose']
DATALOADER_NUM_WORKERS = config['train']['dataloader_num_workers']
TOP_K = config['train']['top_k']

# Info parameters
MODEL_DIR = config['info']['MODEL_DIR']
METRICS = config['info']['metrics']
SAVE_MODEL = config['info']['save_model']

SEED = 42

# Enable deterministic behaviour
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
try:
    torch.use_deterministic_algorithms(True, warn_only=True)
except TypeError:
    torch.use_deterministic_algorithms(True)

try:
    seed_everything(SEED, workers=True)
except TypeError:
    seed_everything(SEED)

os.makedirs(MODEL_DIR, exist_ok=True)

print(f"Model: BERT4Rec")
print(f"Number of transformer blocks: {N_BLOCKS}")
print(f"Number of attention heads: {N_HEADS}")
print(f"Latent factors (n_factors): {N_FACTORS}")
print(f"Dropout rate: {DROPOUT_RATE}")
print(f"Mask probability: {MASK_PROB}")
print(f"Session max length: {SESSION_MAX_LEN}")
print(f"Loss: {LOSS}")
print(f"Number of negatives: {N_NEGATIVES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")

Model: BERT4Rec
Number of transformer blocks: 2
Number of attention heads: 4
Latent factors (n_factors): 256
Dropout rate: 0.2
Mask probability: 0.15
Session max length: 150
Loss: softmax
Number of negatives: 1
Batch size: 256
Epochs: 100
Learning rate: 0.001


## Data Preparation

In [4]:
dp = DataPreprocessor(
    day_begin=1082, 
    day_end=1308, 
    val_days=20, 
    test_days=20, 
    min_user_interactions=1, 
    min_item_interactions=20
)
train_df, val_df, test_df = dp.preprocess()

train_orig = train_df.copy()
val_orig = val_df.copy()

print(f"Train shape: {train_df.shape}")
print(f"Val shape: {val_df.shape}")
print(f"Test shape: {test_df.shape}")
print(f"Number of users: {train_df['user_id'].nunique()}")
print(f"Number of items: {train_df['item_id'].nunique()}")

Starting data preprocessing...
Loading events from t_ecd_small_partial/dataset/small/retail/events
Loaded 236,479,226 total events
Loading items data from t_ecd_small_partial/dataset/small/retail/items.pq
Loaded 250,171 items with features: ['item_id', 'item_brand_id', 'item_category', 'item_subcategory', 'item_price', 'item_embedding']
Merged item features. Data shape: (236479226, 12)
Filtered to 3,758,762 events with action_type='added-to-cart'
After filtering (min_user_interactions=1, min_item_interactions=20): 3,249,972 events, 84,944 users, 30,954 items
Created mappings: 84944 users, 30954 items
Temporal split - Train: days < 1269 (902,543 events), Val: days 1269-1288 (228,339 events), Test: days >= 1289 (223,395 events)
Users in each part (train, val, test) - 7425
Train shape: (902543, 12)
Val shape: (228339, 12)
Test shape: (223395, 12)
Number of users: 7425
Number of items: 30751


## Prepare RecTools Dataset

RecTools requires data in specific format with interactions and optional features

In [5]:
# Prepare interactions for training
interactions_train = train_df[['user_id', 'item_id', 'timestamp']].copy()
interactions_train.columns = [Columns.User, Columns.Item, Columns.Datetime]
interactions_train[Columns.Weight] = 1

print(f"Train interactions shape: {interactions_train.shape}")
interactions_train.head()

Train interactions shape: (902543, 4)


Unnamed: 0,user_id,item_id,datetime,weight
1252,79038,20358,93485160,1
1336,44584,23489,93485187,1
1453,12869,2908,93485221,1
2144,42145,18904,93485421,1
2189,15304,14462,93485437,1


In [6]:
# Prepare validation interactions
interactions_val = val_df[['user_id', 'item_id', 'timestamp']].copy()
interactions_val.columns = [Columns.User, Columns.Item, Columns.Datetime]
interactions_val[Columns.Weight] = 1

print(f"Val interactions shape: {interactions_val.shape}")
interactions_val.head()

Val interactions shape: (228339, 4)


Unnamed: 0,user_id,item_id,datetime,weight
173756285,40764,15800,109641615,1
173756680,52328,26142,109641663,1
173757269,21228,17027,109641733,1
173757315,29325,24210,109641741,1
173757516,22801,20587,109641767,1


In [74]:
from tecd_retail_recsys.data.bert4rec_dataset import BERT4RecDatasetBuilder

builder = BERT4RecDatasetBuilder(train_df)
dataset, item_net_config = builder.build_dataset(
    use_item_embeddings=False,
    use_price_features=True,
    use_temporal_features=False,
    n_factors=N_FACTORS
)


üèóÔ∏è  BERT4Rec Dataset Builder
‚úÖ Interactions: 902543 —Å—Ç—Ä–æ–∫
üì¶ –î–æ–±–∞–≤–ª–µ–Ω–∏–µ –±–∞–∑–æ–≤—ã—Ö item features...
  ‚úÖ Brand: 30751 items
  ‚úÖ Category: 30199 items
  ‚úÖ Subcategory: 30199 items
üí∞ –î–æ–±–∞–≤–ª–µ–Ω–∏–µ price features...
  ‚úÖ Price buckets: 30751 items, 10 categories
  ‚úÖ Price tier: 30751 items
  ‚úÖ Price in category: 29213 items

üì¶ –ò—Ç–æ–≥–æ item features: 181864 —Å—Ç—Ä–æ–∫
   –§–∏—á–∏: ['brand', 'category', 'subcategory', 'price_bucket', 'price_tier', 'price_in_category']
   –£–Ω–∏–∫–∞–ª—å–Ω—ã—Ö —Ç–æ–≤–∞—Ä–æ–≤: 30751

üî® –°–æ–∑–¥–∞–Ω–∏–µ RecTools Dataset...
‚úÖ Dataset: 7425 users, 30751 items

‚úÖ ItemNet: ID + Categorical
‚úÖ Dataset –≥–æ—Ç–æ–≤ –∫ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—é!



## Model Training

In [55]:
# Create custom trainer function to force CPU usage (avoid MPS float64 issues)
def get_cpu_trainer(**kwargs):
    # Force CPU
    from pytorch_lightning import Trainer
    trainer_kwargs = {
        'accelerator': 'cpu',
        'devices': 1,
        'max_epochs': EPOCHS,
    }
    trainer_kwargs.update(kwargs)
    return Trainer(**trainer_kwargs)

In [126]:
model=BERT4RecModel(
    n_blocks=N_BLOCKS,
    n_heads=N_HEADS,
    n_factors=N_FACTORS,
    dropout_rate=DROPOUT_RATE,
    mask_prob=MASK_PROB,
    session_max_len=90,
    train_min_user_interactions=TRAIN_MIN_USER_INTERACTIONS,
    loss=LOSS,
    n_negatives=N_NEGATIVES,
    gbce_t=GBCE_T,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    epochs=10,
    deterministic=DETERMINISTIC,
    verbose=VERBOSE,
    dataloader_num_workers=DATALOADER_NUM_WORKERS,
    use_pos_emb=USE_POS_EMB,
    use_key_padding_mask=USE_KEY_PADDING_MASK,
    item_net_block_types=item_net_config['item_net_block_types'],  # With all features!
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [127]:
%%time
# Train the model
model.fit(dataset)

print("Training completed!")


  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 9.4 M  | train
-----------------------------------------------------------------
9.4 M     Trainable params
0         Non-trainable params
9.4 M     Total params
37.422    Total estimated model params size (MB)
40        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Training completed!
CPU times: user 11.4 s, sys: 3min 34s, total: 3min 46s
Wall time: 9min 15s


## Generate Recommendations

In [128]:
%%time
# Get all users from validation set
val_users = val_df['user_id'].unique()

recommendations = model.recommend(
    users=val_users,
    dataset=dataset,
    k=TOP_K,
    filter_viewed=False,
    on_unsupported_targets="ignore"
)

recs_grouped = recommendations.groupby('user_id', as_index=False)['item_id'].agg(list)
recs_grouped.columns = ['user_id', 'bert4rec_recs']
recs_grouped.head()

CPU times: user 11.3 s, sys: 979 ms, total: 12.3 s
Wall time: 6.75 s


Unnamed: 0,user_id,bert4rec_recs
0,11,"[4873, 16561, 27185, 27272, 27714, 20587, 8897..."
1,14,"[17934, 7982, 26537, 29228, 25997, 5631, 1245,..."
2,21,"[17934, 23993, 29377, 29228, 7982, 5631, 3155,..."
3,29,"[29980, 1809, 18227, 3025, 18836, 1678, 5631, ..."
4,39,"[432, 27446, 23932, 16749, 14597, 19586, 25723..."


## Evaluation

In [129]:
joined = dp.get_grouped_data(train_orig, val_orig, test_df)
joined['train_val_interactions'] = joined['train_interactions'] + joined['val_interactions']

evaluation_df = joined.merge(
    recs_grouped, 
    on='user_id', 
    how='left'
)

# Fill users without recommendations with empty lists
evaluation_df['bert4rec_recs'] = evaluation_df['bert4rec_recs'].apply(
    lambda x: x if isinstance(x, list) else []
)

print(f"Evaluation dataframe shape: {evaluation_df.shape}")
print(f"Users with recommendations: {(evaluation_df['bert4rec_recs'].str.len() > 0).sum()}")

Evaluation dataframe shape: (7425, 6)
Users with recommendations: 7425


In [130]:
# Calculate metrics
metrics_result = calculate_metrics(
    evaluation_df,
    train_col='train_interactions',
    gt_col='val_interactions',
    model_preds='bert4rec_recs',
    verbose=True
)

[Metrics debug] resolved gt_col='val_interactions' item_id_index=0
[Metrics debug] ratings_true shape: (228339, 3) ratings_pred shape: (742500, 3)
  ratings_true dtypes: {'user_id': dtype('int64'), 'item_id': dtype('int64')}
  ratings_pred dtypes: {'user_id': dtype('int64'), 'item_id': dtype('int64')}
  user_id=11 gt_count=22 pred_count=100 overlap=2
  user_id=14 gt_count=5 pred_count=100 overlap=0
    [ID spaces] gt sample=[9341, 16732, 17585, 28024, 30789] range=[9341, 30789] | rec sample=[21, 83, 394, 415, 567] range=[21, 30642]
  user_id=21 gt_count=47 pred_count=100 overlap=8

At k=10:
  MAP@10       = 0.1022
  NDCG@10      = 0.2649
  Precision@10 = 0.1239
  Recall@10    = 0.0380

At k=100:
  MAP@100       = 0.0392
  NDCG@100      = 0.1714
  Precision@100 = 0.0447
  Recall@100    = 0.1360

Other Metrics:
  MRR                 = 0.2452
  Catalog Coverage    = 0.2692
  Diversity     = 0.9879  [0=same recs for all, 1=unique recs]
  Novelty             = 0.8026
  Serendipity         =

## Save Model and Recommendations

In [70]:
if SAVE_MODEL:
    # Save model
    model_path = os.path.join(MODEL_DIR, "bert4rec_model")
    model.save(model_path)
    print(f"Model saved to {model_path}")
    
    # Save recommendations
    recs_path = os.path.join(MODEL_DIR, "recommendations.parquet")
    recs_grouped.to_parquet(recs_path, index=False)
    print(f"Recommendations saved to {recs_path}")
    
    # Save full recommendations with scores
    recs_full_path = os.path.join(MODEL_DIR, "recommendations_full.parquet")
    recommendations.to_parquet(recs_full_path, index=False)
    print(f"Full recommendations saved to {recs_full_path}")

Model saved to ./models/bert4rec/bert4rec_model
Recommendations saved to ./models/bert4rec/recommendations.parquet
Full recommendations saved to ./models/bert4rec/recommendations_full.parquet


## Best Model Inference

In [None]:
model_path = os.path.join(MODEL_DIR, "bert4rec_model_exp8.pkl")
model.save(model_path)

In [99]:
loaded_model = BERT4RecModel.load(model_path)

val_users = val_df['user_id'].unique()

recommendations = loaded_model.recommend(
    users=val_users,
    dataset=dataset,
    k=TOP_K,
    filter_viewed=False,
    on_unsupported_targets="ignore"
)

recs_grouped = recommendations.groupby('user_id', as_index=False)['item_id'].agg(list)
recs_grouped.columns = ['user_id', 'bert4rec_recs']

joined = dp.get_grouped_data(train_orig, val_orig, test_df)
joined['train_val_interactions'] = joined['train_interactions'] + joined['val_interactions']

evaluation_df = joined.merge(
    recs_grouped, 
    on='user_id', 
    how='left'
)

# Fill users without recommendations with empty lists
evaluation_df['bert4rec_recs'] = evaluation_df['bert4rec_recs'].apply(
    lambda x: x if isinstance(x, list) else []
)


# Calculate metrics
metrics_result = calculate_metrics(
    evaluation_df,
    train_col='train_interactions',
    gt_col='val_interactions',
    model_preds='bert4rec_recs',
    verbose=True
)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at /var/folders/5t/b05_gxx17hnftz_n3c82pt4h0000gn/T/tmp8i6kfqb7

  | Name        | Type                     | Params | Mode 
-----------------------------------------------------------------
0 | torch_model | TransformerTorchBackbone | 9.4 M  | train
-----------------------------------------------------------------
9.4 M     Trainable params
0         Non-trainable params
9.4 M     Total params
37.799    Total estimated model params size (MB)
40        Modules in train mode
0         Modules in eval mode
Restored all states from the checkpoint at /var/folders/5t/b05_gxx17hnftz_n3c82pt4h0000gn/T/tmp8i6kfqb7
`Trainer.fit` stopped: No training batches.


[Metrics debug] resolved gt_col='val_interactions' item_id_index=0
[Metrics debug] ratings_true shape: (228339, 3) ratings_pred shape: (742500, 3)
  ratings_true dtypes: {'user_id': dtype('int64'), 'item_id': dtype('int64')}
  ratings_pred dtypes: {'user_id': dtype('int64'), 'item_id': dtype('int64')}
  user_id=11 gt_count=22 pred_count=100 overlap=7
  user_id=14 gt_count=5 pred_count=100 overlap=0
    [ID spaces] gt sample=[9341, 16732, 17585, 28024, 30789] range=[9341, 30789] | rec sample=[153, 232, 1447, 1683, 1698] range=[153, 30913]
  user_id=21 gt_count=47 pred_count=100 overlap=13

At k=10:
  MAP@10       = 0.2782
  NDCG@10      = 0.5719
  Precision@10 = 0.2155
  Recall@10    = 0.0673

At k=100:
  MAP@100       = 0.0905
  NDCG@100      = 0.2774
  Precision@100 = 0.0599
  Recall@100    = 0.1773

Other Metrics:
  MRR                 = 0.3171
  Catalog Coverage    = 0.8947
  Diversity     = 0.9964  [0=same recs for all, 1=unique recs]
  Novelty             = 0.8951
  Serendipity   

<!DOCTYPE html>
<html>
<head>
    <style>
        table {
            border-collapse: collapse;
            width: 100%;
            font-family: Arial, sans-serif;
            margin: 20px 0;
        }
        th {
            background-color: #2196F3;
            color: white;
            padding: 12px;
            text-align: left;
            border: 1px solid #ddd;
            font-size: 13px;
        }
        td {
            padding: 10px;
            border: 1px solid #ddd;
            text-align: left;
            font-size: 12px;
        }
        tr:nth-child(even) {
            background-color: #f2f2f2;
        }
        tr:hover {
            background-color: #ddd;
        }
        .best {
            background-color: #c8e6c9 !important;
            font-weight: bold;
        }
        .worst {
            background-color: #ffcdd2 !important;
        }
        .good {
            background-color: #e8f5e9 !important;
        }
    </style>
</head>
<body>
    <h2>BERT4Rec: —ç–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—ã</h2>
    <table>
        <thead>
            <tr>
                <th>‚Ññ</th>
                <th>n_blocks</th>
                <th>n_heads</th>
                <th>n_factors</th>
                <th>dropout_rate</th>
                <th>mask_prob</th>
                <th>session_max_len</th>
                <th>batch_size</th>
                <th>learning_rate</th>
                <th>loss</th>
                <th>n_negatives</th>
                <th>epochs</th>
                <th>NDCG@100</th>
            </tr>
        </thead>
        <tbody>
            <tr>
                <td>1</td>
                <td>2</td>
                <td>4</td>
                <td>256</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>50</td>
                <td>128</td>
                <td>0.001</td>
                <td>softmax</td>
                <td>1</td>
                <td>100</td>
                <td>0.2581</td>
            </tr>
            <tr class="worst">
                <td>2</td>
                <td>2</td>
                <td>4</td>
                <td>256</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>50</td>
                <td>128</td>
                <td>0.001</td>
                <td>gBCE</td>
                <td>50</td>
                <td>100</td>
                <td>0.2190</td>
            </tr>
            <tr class="good">
                <td>3</td>
                <td>2</td>
                <td>4</td>
                <td>256</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>100</td>
                <td>128</td>
                <td>0.001</td>
                <td>softmax</td>
                <td>1</td>
                <td>100</td>
                <td><strong>0.2726</strong></td>
            </tr>
            <tr>
                <td>4</td>
                <td>3</td>
                <td>8</td>
                <td>512</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>100</td>
                <td>128</td>
                <td>0.001</td>
                <td>softmax</td>
                <td>1</td>
                <td>100</td>
                <td>0.2661</td>
            </tr>
            <tr>
                <td>5</td>
                <td>2</td>
                <td>4</td>
                <td>256</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>100</td>
                <td>256</td>
                <td>0.001</td>
                <td>sampled_softmax</td>
                <td>100</td>
                <td>100</td>
                <td>0.2263</td>
            </tr>
            <tr>
                <td>6</td>
                <td>2</td>
                <td>4</td>
                <td>512</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>100</td>
                <td>256</td>
                <td>0.001</td>
                <td>softmax</td>
                <td>1</td>
                <td>100</td>
                <td>0.2331</td>
            </tr>
            <tr>
                <td>7</td>
                <td>2</td>
                <td>4</td>
                <td>256</td>
                <td>0.1</td>
                <td>0.2</td>
                <td>100</td>
                <td>256</td>
                <td>0.0005</td>
                <td>softmax</td>
                <td>1</td>
                <td>150</td>
                <td>0.2480</td>
            </tr>
            <tr class="best">
                <td>8</td>
                <td>2</td>
                <td>4</td>
                <td>256</td>
                <td>0.2</td>
                <td>0.15</td>
                <td>130</td>
                <td>256</td>
                <td>0.001</td>
                <td>softmax</td>
                <td>1</td>
                <td>100</td>
                <td>0.2774</td>
            </tr>
        </tbody>
    </table>
    
</body>
</html>

–ù–∞–∏–ª—É—á—à–∞—è –∫–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏—è —Å–º–æ–≥–ª–∞ –¥–æ—Å—Ç–∏—á—å NDCG@100=0.2774
