# Construct Training/Test Sequences

In [12]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))  # add project root

from src.data_loader import prepare_movielens_sequences
from src.datasets import SequenceDataset, collate_fn

MAX_LEN = 50
BATCH_SIZE = 256

# Call the function to load and preprocess the MovieLens data
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data/raw'))
csv_path = os.path.join(data_dir, 'ml_32m_ratings.csv')
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"CSV file not found: {csv_path}")

bundle = prepare_movielens_sequences(
    ratings_path=csv_path,
    max_len=50,
    min_user_interactions=5,
)

train_sequences = bundle["train_sequences"]
test_sequences = bundle["test_sequences"]
num_items = bundle["num_items"]
mask_id = bundle["mask_id"]
pad_id = bundle["pad_id"]

# Preview the processed data
print(train_sequences[:2])
print(test_sequences[:2])
print(bundle["stats"])

# Build loaders
from torch.utils.data import DataLoader
train_ds = SequenceDataset(train_sequences, max_len=MAX_LEN)
test_ds  = SequenceDataset(test_sequences,  max_len=MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

[[303, 586, 785, 818, 1174, 1195, 80, 601, 1145, 1244, 2178, 36, 110, 1231, 1654, 1937, 17, 30, 1093, 1123, 1183, 1191, 1215, 1656, 1873, 2222, 25, 888, 903, 1155, 1263, 2430, 1019, 1323, 1369, 1631, 304, 1033, 1198, 2223, 299, 636, 951, 1728, 1864, 2239, 2621, 906, 1357, 1932], [293, 376, 152, 340, 345, 581, 229, 315, 588, 335, 184, 251, 453, 34, 223, 39, 360, 450, 352, 234, 280, 185, 373, 346, 496, 353, 535, 580, 48, 579, 590, 31, 192, 274, 546, 235, 578, 504, 217, 377, 220, 547, 206, 215, 302, 462, 587, 516, 609, 358]]
[[586, 785, 818, 1174, 1195, 80, 601, 1145, 1244, 2178, 36, 110, 1231, 1654, 1937, 17, 30, 1093, 1123, 1183, 1191, 1215, 1656, 1873, 2222, 25, 888, 903, 1155, 1263, 2430, 1019, 1323, 1369, 1631, 304, 1033, 1198, 2223, 299, 636, 951, 1728, 1864, 2239, 2621, 906, 1357, 1932, 2036], [376, 152, 340, 345, 581, 229, 315, 588, 335, 184, 251, 453, 34, 223, 39, 360, 450, 352, 234, 280, 185, 373, 346, 496, 353, 535, 580, 48, 579, 590, 31, 192, 274, 546, 235, 578, 504, 217, 377,

# Train BERT4Rec

In [None]:
from src.model import BERT4Rec
from src.trainer import train_one_epoch, evaluate_mlm_loss
import torch

MLM_PROB = 0.15

# Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERT4Rec(num_items=num_items, pad_id=pad_id, mask_id=mask_id, max_len=MAX_LEN).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

# Train a few epochs
num_epoch = 3
for ep in range(num_epoch):
    tr = train_one_epoch(model, train_loader, opt, device, mlm_prob=MLM_PROB)
    ev = evaluate_mlm_loss(model, test_loader, device, mlm_prob=MLM_PROB)
    print(f"epoch {ep} | train_loss {tr['loss']:.4f} | eval_loss {ev['loss']:.4f}")

CKPT_PATH = os.path.abspath(os.path.join(os.getcwd(), '..', 'model/bert4rec_checkpoint.pth'))
torch.save({
    "epoch": num_epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": opt.state_dict(),
    "loss": tr,
}, CKPT_PATH)


train:  34%|███▍      | 268/785 [30:44<57:43,  6.70s/it]  