# BERT4Rec Training Demo



## 1. Import Required Libraries

In [1]:
import torch
from options import args
from models import model_factory
from dataloaders import dataloader_factory
from trainers import trainer_factory
from utils import *

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
2025-05-22 10:29:47.633031: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-22 10:29:47.649458: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8463] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-22 10:29:47.654540: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## 2. Configure Training Parameters

We'll set up the same configuration as used in the BERT4Rec template.

In [2]:
# Set basic configuration
args.mode = 'train'

# Dataset selection
dataset_choice = "1" # input('Input 1 for ml-1m, 20 for ml-20m: ')
# args.dataset_code = f'ml-{dataset_choice}m'
# args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4
args.dataset_code = 'kion'
args.min_rating = 0 
args.min_uc = 5
args.min_sc = 0
args.split = 'leave_one_out'

# Dataloader configuration
args.dataloader_code = 'bert'
batch = 128
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch

# Negative sampling configuration
args.train_negative_sampler_code = 'random'
args.train_negative_sample_size = 0
args.train_negative_sampling_seed = 0
args.test_negative_sampler_code = 'random'
args.test_negative_sample_size = 100
args.test_negative_sampling_seed = 98765

# Training configuration
args.trainer_code = 'bert'
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.num_gpu = 1
args.device_idx = '0'
args.optimizer = 'Adam'
args.lr = 0.001
args.enable_lr_schedule = True
args.decay_step = 25
args.gamma = 1.0
args.num_epochs = 6
args.metric_ks = [1, 5, 10, 20, 50, 100]
args.best_metric = 'NDCG@10'

# Model configuration
args.model_code = 'sasrec'
args.model_init_seed = 0
args.dropout_rate = 0.1
args.hidden_units = 256
args.mask_prob = 0.15
args.maxlen = 100
args.num_blocks = 2
args.num_heads = 4
args.bert_max_len   = args.maxlen
args.bert_mask_prob = args.mask_prob

## 3. Setup Training Environment

In [3]:
# Setup training directory and logging
export_root = setup_train(args)
print(f"Training logs and model checkpoints will be saved to: {export_root}")

Folder created: /home/user/MovieRecommender/experiments/test_2025-05-22_32
{'anneal_cap': 0.2,
 'bert_mask_prob': 0.15,
 'bert_max_len': 100,
 'best_metric': 'NDCG@10',
 'dae_dropout': 0.5,
 'dae_hidden_dim': 600,
 'dae_latent_dim': 200,
 'dae_num_hidden': 0,
 'dataloader_code': 'bert',
 'dataloader_random_seed': 0.0,
 'dataset_code': 'kion',
 'dataset_split_seed': 98765,
 'decay_step': 25,
 'device': 'cuda',
 'device_idx': '0',
 'dropout_rate': 0.1,
 'enable_lr_schedule': True,
 'eval_set_size': 500,
 'experiment_description': 'test',
 'experiment_dir': 'experiments',
 'find_best_beta': True,
 'gamma': 1.0,
 'hidden_units': 256,
 'log_period_as_iter': 12800,
 'lr': 0.001,
 'mask_prob': 0.15,
 'maxlen': 100,
 'metric_ks': [1,
               5,
               10,
               20,
               50,
               100],
 'min_rating': 0,
 'min_sc': 0,
 'min_uc': 5,
 'mode': 'train',
 'model_code': 'sasrec',
 'model_init_seed': 0,
 'num_blocks': 2,
 'num_epochs': 6,
 'num_gpu': 1,
 'num

## 4. Prepare Data

In [4]:
# Create dataloaders
train_loader, val_loader, test_loader = dataloader_factory(args)
print(f"Dataset: {args.dataset_code}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Already preprocessed. Skip preprocessing
Negatives samples exist. Loading.
Negatives samples exist. Loading.
Dataset: kion
Training batches: 2364
Validation batches: 2364
Test batches: 2364


## 5. Initialize Model

In [5]:
# Create model
model = model_factory(args)
print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

Model initialized with 9549693 parameters


## 6. Initialize Trainer

In [6]:
# Create trainer
trainer = trainer_factory(args, model, train_loader, val_loader, test_loader, export_root)
print("Trainer initialized with the following configuration:")
print(f"- Optimizer: {args.optimizer}")
print(f"- Learning rate: {args.lr}")
print(f"- Number of epochs: {args.num_epochs}")
print(f"- Best metric: {args.best_metric}")

Trainer initialized with the following configuration:
- Optimizer: Adam
- Learning rate: 0.001
- Number of epochs: 6
- Best metric: NDCG@10


## 7. Train Model

In [7]:
# Start training
trainer.train()

Val: N@1 0.011, N@5 0.029, N@10 0.048, R@1 0.011, R@5 0.048, R@10 0.108, M@1 0.011, M@5 0.023, M@10 0.031, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:42<00:00, 22.99it/s]


Update Best NDCG@10 Model at 1


Epoch 1, loss 7.111 : 100%|██████████| 2364/2364 [01:43<00:00, 22.87it/s]  
Val: N@1 0.466, N@5 0.624, N@10 0.661, R@1 0.466, R@5 0.763, R@10 0.878, M@1 0.466, M@5 0.577, M@10 0.593, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:43<00:00, 22.76it/s]


Update Best NDCG@10 Model at 1


Epoch 2, loss 6.498 : 100%|██████████| 2364/2364 [01:43<00:00, 22.84it/s]  
Val: N@1 0.493, N@5 0.652, N@10 0.687, R@1 0.493, R@5 0.793, R@10 0.900, M@1 0.493, M@5 0.606, M@10 0.620, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:42<00:00, 23.08it/s]


Update Best NDCG@10 Model at 2


Epoch 3, loss 6.229 : 100%|██████████| 2364/2364 [01:43<00:00, 22.83it/s]  
Val: N@1 0.513, N@5 0.673, N@10 0.705, R@1 0.513, R@5 0.812, R@10 0.910, M@1 0.513, M@5 0.626, M@10 0.639, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:43<00:00, 22.74it/s]


Update Best NDCG@10 Model at 3


Epoch 4, loss 6.121 : 100%|██████████| 2364/2364 [01:43<00:00, 22.81it/s]  
Val: N@1 0.521, N@5 0.682, N@10 0.713, R@1 0.521, R@5 0.821, R@10 0.916, M@1 0.521, M@5 0.635, M@10 0.648, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:42<00:00, 22.98it/s]


Update Best NDCG@10 Model at 4


Epoch 5, loss 6.048 : 100%|██████████| 2364/2364 [01:43<00:00, 22.81it/s]  
Val: N@1 0.531, N@5 0.691, N@10 0.720, R@1 0.531, R@5 0.829, R@10 0.920, M@1 0.531, M@5 0.645, M@10 0.657, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:43<00:00, 22.91it/s]


Update Best NDCG@10 Model at 5


Epoch 6, loss 5.998 : 100%|██████████| 2364/2364 [01:43<00:00, 22.84it/s]  
Val: N@1 0.534, N@5 0.694, N@10 0.724, R@1 0.534, R@5 0.833, R@10 0.924, M@1 0.534, M@5 0.648, M@10 0.660, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:43<00:00, 22.77it/s]


Update Best NDCG@10 Model at 6


In [8]:
import torch
from pathlib import Path

export_root = Path(export_root)

pl_module = trainer.model if hasattr(trainer, "model") else trainer.lightning_module
weights_path = export_root / "bert4rec_weights.pth"
torch.save(pl_module.state_dict(), weights_path)
print(f"Weights saved to {weights_path}")

Weights saved to experiments/test_2025-05-22_32/bert4rec_weights.pth


## 8. Test Model (Optional)

In [9]:
# Ask user if they want to run test set evaluation
test_model = (input('Test model with test dataset? y/[n]: ') == 'y')
if test_model:
    trainer.test()

Test best model with test set!


Val: N@1 0.519, N@5 0.680, N@10 0.712, R@1 0.519, R@5 0.821, R@10 0.916, M@1 0.519, M@5 0.634, M@10 0.647, V@1 0.000, V@5 0.000, V@10 0.000: 100%|██████████| 2364/2364 [01:43<00:00, 22.93it/s]

{'Recall@100': 0.9999471235194586, 'NDCG@100': 0.731544487922123, 'MAP@100': 0.6516750656862549, 'Recall@50': 0.9963151702622673, 'NDCG@50': 0.7309481085304883, 'MAP@50': 0.6516203433023088, 'Recall@20': 0.973799703891709, 'NDCG@20': 0.7262994846464615, 'MAP@20': 0.6507979661175846, 'Recall@10': 0.9162839130701752, 'NDCG@10': 0.7115605161375604, 'MAP@10': 0.6466492197386504, 'Recall@5': 0.820784182264115, 'NDCG@5': 0.6804214571484455, 'MAP@5': 0.6336485505255346, 'Recall@1': 0.5192641636920863, 'NDCG@1': 0.5192641636920863, 'MAP@1': 0.5192641636920863}



