# Recommender Systems with Transformer Models - Model Development
Transformer-based model for sequential recommendation from the paper [Self-Attentive Sequential Recommendation (2018)](https://arxiv.org/abs/1808.09781).

* **Data:** build per-user, time-sorted history sequences (left-padded to max length in batch). For training, use history -> next item triples + one negative per triples.
* **Model:** Item embeddings + positional embeddings -> Transformer encoder with causal mask. Take final position representation as the user's state and score candidate items by dot-product.
* **Eval:** Identical to the baseline using LOO setting.

In [1]:
import os
import random
import numpy as np
import pandas as pd
import time
import gc
import matplotlib.pyplot as plt
from collections import defaultdict

#
os.environ["HF_HOME"] = "D:/Python Projects/recommendation_system"
os.environ["HF_DATASETS_CACHE"] = "D:/Python Projects/recommendation_system/recsys/data"
os.environ["TRANSFORMERS_CACHE"] = "D:/Python Projects/recommendation_system/recsys/models"

# os.environ["HF_HOME"] = "E:/Python Scripts/recsys"
# os.environ['HF_DATASETS_CACHE'] = "E:/Python Scripts/recsys/data"
# os.environ['TRANSFORMERS_CACHE'] = "E:/Python Scripts/recsys/models"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from datasets import load_dataset, Features, Value
from tqdm import tqdm
from tensorboardX import SummaryWriter

In [2]:
SEED = 42
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

DEVICE: cuda


## Load Amazon Reviews 2023 dataset from Hugging Face

In [120]:
USE_HF = True
HF_DATASET = "McAuley-Lab/Amazon-Reviews-2023"
HF_DOMAIN = "Books"   # choose your domain

def load_amazon_reviews_fast(domain: str, max_items: int | None = None, seed: int = SEED) -> pd.DataFrame:
    ds = load_dataset(
        "McAuley-Lab/Amazon-Reviews-2023",
        f"raw_review_{domain}",
        split="full",
        trust_remote_code=True,
    )

    # Keep only needed columns early
    ds = ds.select_columns(["user_id", "parent_asin", "rating", "verified_purchase", "timestamp"])

    # Rename + cast in Arrow (vectorized; no Python loop)
    ds = ds.rename_columns({"user_id": "user", "parent_asin": "item"})
    ds = ds.cast(Features({
        "user": Value("string"),
        "item": Value("string"),
        "rating": Value("float32"),
        "verified_purchase": Value("bool"),
        "timestamp": Value("int64"),   # seconds or ms as stored
    }))

    # Fast random subset without building a huge pandas frame first
    if max_items is not None:
        k = min(max_items, ds.num_rows)
        ds = ds.shuffle(seed=seed).select(range(k))

    # Convert to pandas (Arrow zero-copy where possible)
    df = ds.to_pandas()
    df.insert(3, "domain", domain)
    return df

df = load_amazon_reviews_fast(HF_DOMAIN, max_items=6_000_000, seed=SEED)
print(f"Loaded {len(df)} rows from {HF_DOMAIN} domain.")

Loading dataset shards:   0%|          | 0/33 [00:00<?, ?it/s]

Loaded 6000000 rows from Books domain.


## Data preprocessing

In [316]:
# Make implicit dataset, filter users/items with less than k interactions
def preprocess_dataset(df, min_user_interactions=5, min_item_interactions=5, only_verified=False):
    if only_verified:
        df = df[df["verified_purchase"] == True].copy()
        print("After verified filtering:", len(df), "rows,", df["user"].nunique(), "users,", df["item"].nunique(), "items")

    # Make it implicit
    df["label"] = 1.0

    # Filter users and items with less than 5 interactions
    user_counts = df["user"].value_counts()
    item_counts = df["item"].value_counts()
    valid_users = user_counts[user_counts >= min_user_interactions].index
    valid_items = item_counts[item_counts >= min_item_interactions].index
    df = df[df["user"].isin(valid_users) & df["item"].isin(valid_items)].copy()

    print("After interactions filtering:", len(df), "rows,", df["user"].nunique(), "users,", df["item"].nunique(), "items")

    return df

def label_encoder(df, shift_item=False):
    df_enc = df.copy()
    user_enc = LabelEncoder()
    item_enc = LabelEncoder()
    df_enc["user_id"] = user_enc.fit_transform(df_enc["user"])
    df_enc["item_id"] = item_enc.fit_transform(df_enc["item"])
    if shift_item:
        df_enc["item_id"] = df_enc["item_id"] + 1
    return {"encoded_df": df_enc,
            "user_encoder": user_enc,
            "item_encoder": item_enc}

filtered_df = preprocess_dataset(df, min_user_interactions=20, min_item_interactions=20, only_verified=False)
le = label_encoder(filtered_df, shift_item=False)
df_encoded = le["encoded_df"]
user_enc = le["user_encoder"]
item_enc = le["item_encoder"]

After interactions filtering: 86885 rows, 10446 users, 22498 items


In [317]:
def calculate_data_sparsity(df):
    num_users = df["user"].nunique()
    num_items = df["item"].nunique()
    num_interactions = len(df)
    density = num_interactions / (num_users * num_items)
    sparsity = 1 - density

    print(f"Number of users: {num_users}")
    print(f"Number of items: {num_items}")
    print(f"Number of interactions: {num_interactions}")
    print(f"-" * 30)
    print(f"Interaction Matrix Density: {density:.6f}")
    print(f"Interaction Matrix Sparsity: {sparsity:.6f}")

calculate_data_sparsity(df_encoded)

Number of users: 10446
Number of items: 22498
Number of interactions: 86885
------------------------------
Interaction Matrix Density: 0.000370
Interaction Matrix Sparsity: 0.999630


In [318]:
def loo_split(df):
    df = df.sort_values(["user_id", "timestamp"])
    train_rows, val_rows, test_rows = [], [], []
    for uid, group in df.groupby("user_id", sort=False):
        g = group.sort_values("timestamp")
        if len(g) < 5:
            train_rows.append(g)
            continue
        g = g.sort_values("timestamp")
        test_rows.append(g.iloc[[-1]])  # Last interaction as test
        val_rows.append(g.iloc[[-2]])   # Second last interaction as validation
        train_rows.append(g.iloc[:-2])    # All but last two as training

    train_df = pd.concat(train_rows, ignore_index=True)
    val_df = pd.concat(val_rows, ignore_index=True)
    test_df = pd.concat(test_rows, ignore_index=True)

    return train_df, val_df, test_df

train_df, val_df, test_df = loo_split(df_encoded)
print(f"Train/Validation/Test split: {len(train_df)}, {len(val_df)}, {len(test_df)}")

Train/Validation/Test split: 74391, 6247, 6247


In [319]:
NUM_USERS = df_encoded["user_id"].max() + 1
NUM_ITEMS = df_encoded["item_id"].max() + 1

In [302]:
# # Rank the single positive against k negative samples for train and evaluation
# def build_pos_items_by_user(train_df):
#     pos_items_by_user = defaultdict(set)
#     for u, i in zip(train_df["user_id"].values, train_df["item_id"].values):
#         pos_items_by_user[u].add(int(i))
#     return pos_items_by_user
#
# pos_items_by_user = build_pos_items_by_user(train_df)

In [341]:
def create_user_sequences(df):
    df_sorted = df.sort_values(["user_id", "timestamp"])
    user_sequences = {}
    for uid, group in df_sorted.groupby("user_id"):
        items = group["item_id"].tolist()
        if len(items) >= 3:  # Need at least 3 for train/val/test split
            user_sequences[uid] = items

    print(f"Users with >=3 interactions: {len(user_sequences)}")
    return user_sequences

# train_users = set(train_df["user_id"].unique())
# val_users = set(val_df["user_id"].unique())
# test_users = set(test_df["user_id"].unique())

# Filter sequences to match splits
# full_seq = create_user_sequences(df_encoded)
# train_sequences = {u: seq for u, seq in full_seq.items() if u in train_users}
# val_sequences = {u: seq for u, seq in full_seq.items() if u in val_users}
# test_sequences = {u: seq for u, seq in full_seq.items() if u in test_users}

# print(f"\nDataset splits:")
# print(f"Train users: {len(train_sequences)}")
# print(f"Val users: {len(val_sequences)}")
# print(f"Test users: {len(test_sequences)}")

In [342]:
# def create_user_sequences(df):
#     return df.groupby("user_id")["item_id"].apply(list).to_dict()

full_seq = create_user_sequences(df_encoded.sort_values(["user_id", "timestamp"]))

Users with >=3 interactions: 8232


In [343]:
len_seq = []
for i in full_seq:
    len_seq.append(len(full_seq[i]))

len_seq.sort(reverse=False)

In [351]:
full_seq

{0: [15564, 7003, 7025, 7048, 6196],
 1: [17394,
  17461,
  15844,
  12963,
  18987,
  18983,
  19650,
  18498,
  19785,
  16295,
  16126,
  19490,
  20532,
  12820,
  16569,
  19902,
  20672,
  21147,
  17486,
  20796,
  16251,
  18199,
  17636,
  18249,
  19816,
  17391,
  17341,
  19815,
  19276,
  21463,
  19544,
  19023,
  21251,
  17473,
  18364,
  19317,
  16995,
  21818,
  21534,
  21527,
  21963,
  21250,
  21722,
  13327,
  21269,
  22029,
  21877,
  21871,
  15842,
  21705,
  22095],
 2: [7353, 13975, 5855, 6499, 8108, 7969],
 3: [9714, 608, 15023, 10347, 14593],
 5: [13660, 14533, 8914, 12385, 15431, 9449, 7256, 4053],
 6: [3413, 5724, 11894, 17832],
 8: [16477, 16686, 17609, 18512, 18442, 18589, 18465, 18699, 18663, 18313],
 9: [7962,
  14758,
  19196,
  8416,
  17837,
  625,
  12879,
  8439,
  7779,
  11466,
  10675,
  1726],
 10: [11953, 4512, 9556, 1974, 16241, 16555, 1553, 12490],
 11: [17179,
  3441,
  9132,
  14687,
  6488,
  10320,
  4077,
  3443,
  12648,
  2898,
 

## Prepare dataloader for SASRec

In [344]:
# class SASRecDataset(Dataset):
#     def __init__(self, user_sequences, num_items, max_seq_len=50, mode="train", neg_samples=1):
#         self.user_sequences = user_sequences
#         self.num_items = num_items
#         self.max_seq_len = max_seq_len
#         self.mode = mode
#         self.neg_samples = neg_samples
#         self.users = list(user_sequences.keys())
#
#         # Set of all items each user has interacted with
#         self.user_item_set = {u: set(items) for u, items in user_sequences.items()}
#
#     def __len__(self):
#         return len(self.users)
#
#     def __getitem__(self, idx):
#         user = self.users[idx]
#         seq = self.user_sequences[user]
#
#         if self.mode == "train":
#             # input_seq = seq[:-2] if len(seq) > 2 else seq[:-1]
#             # target = seq[-2] if len(seq) > 2 else seq[-1]
#             t = random.randint(1, len(seq) - 1)
#             input_seq = seq[:t]
#             target = seq[t]
#         elif self.mode == "val":
#             # For validation: use all but last item to predict second-to-last
#             input_seq = seq[:-2]
#             target = seq[-2]
#         else:  # test
#             # For test: use all but last to predict last
#             input_seq = seq[:-1]
#             target = seq[-1]
#
#         input_seq = self._pad(input_seq)
#         neg_items = self._sample_negatives(user)
#
#         return {
#             "user": user,
#             "input_seq": torch.tensor(input_seq, dtype=torch.long),
#             "target": torch.tensor(target, dtype=torch.long), # target is a single item
#             "neg_items": torch.tensor(neg_items, dtype=torch.long)
#         }
#
#     def _pad(self, seq):
#         # Pad or truncate sequence to max_seq_len
#         if len(seq) > self.max_seq_len:
#             seq = seq[-self.max_seq_len:]
#         elif len(seq) < self.max_seq_len:
#             seq = [0] * (self.max_seq_len - len(seq)) + seq
#         return seq
#
#     def _sample_negatives(self, user):
#         # Negative sampling: sample items not in user's history
#         neg_items = []
#         user_items = self.user_item_set[user]
#         while len(neg_items) < self.neg_samples:
#             neg = random.randint(1, self.num_items - 1)
#             if neg not in user_items:
#                 neg_items.append(neg)
#         return neg_items

In [345]:
training_instances = []
for u, seq in full_seq.items():
    train_seq = seq[:-2] if len(seq) > 2 else seq[:-1]
    for i in range(1, len(train_seq)):
        training_instances.append((u, train_seq[:i], train_seq[i]))

In [346]:
training_instances

[(0, [15564], 7003),
 (0, [15564, 7003], 7025),
 (1, [17394], 17461),
 (1, [17394, 17461], 15844),
 (1, [17394, 17461, 15844], 12963),
 (1, [17394, 17461, 15844, 12963], 18987),
 (1, [17394, 17461, 15844, 12963, 18987], 18983),
 (1, [17394, 17461, 15844, 12963, 18987, 18983], 19650),
 (1, [17394, 17461, 15844, 12963, 18987, 18983, 19650], 18498),
 (1, [17394, 17461, 15844, 12963, 18987, 18983, 19650, 18498], 19785),
 (1, [17394, 17461, 15844, 12963, 18987, 18983, 19650, 18498, 19785], 16295),
 (1,
  [17394, 17461, 15844, 12963, 18987, 18983, 19650, 18498, 19785, 16295],
  16126),
 (1,
  [17394,
   17461,
   15844,
   12963,
   18987,
   18983,
   19650,
   18498,
   19785,
   16295,
   16126],
  19490),
 (1,
  [17394,
   17461,
   15844,
   12963,
   18987,
   18983,
   19650,
   18498,
   19785,
   16295,
   16126,
   19490],
  20532),
 (1,
  [17394,
   17461,
   15844,
   12963,
   18987,
   18983,
   19650,
   18498,
   19785,
   16295,
   16126,
   19490,
   20532],
  12820),
 (1,


In [352]:
class SASRecDataset(Dataset):
    def __init__(self, user_sequences, num_items, max_seq_len=50, mode="train", neg_samples=1):
        self.user_sequences = user_sequences
        self.num_items = num_items
        self.max_seq_len = max_seq_len
        self.mode = mode
        self.neg_samples = neg_samples
        self.users = list(user_sequences.keys())

        # Set of all items each user has interacted with
        self.user_item_set = {u: set(items) for u, items in user_sequences.items()}

        # For training, generate all possible subsequence-target pairs
        if mode == "train":
            self.training_instances = []
            for u, seq in self.user_sequences.items():
                train_seq = seq[:-2] if len(seq) > 2 else seq[:-1]
                for i in range(1, len(train_seq)):
                    self.training_instances.append((u, train_seq[:i], train_seq[i]))

    def __len__(self):
        if self.mode == "train":
            return len(self.training_instances)
        else:
            return len(self.users)

    def __getitem__(self, idx):
        if self.mode == "train":
            user, input_seq, target = self.training_instances[idx]
        else:
            user = self.users[idx]
            seq = self.user_sequences[user]

            if self.mode == "val":
                input_seq = seq[:-2]
                target = seq[-2]
            else: # test
                input_seq = seq[:-1]
                target = seq[-1]

        # Pad sequence and create attention mask
        padded_seq, attn_mask = self._pad_with_mask(input_seq)
        neg_items = self._sample_negatives(user)

        return {
            "user": user,
            "input_seq": torch.tensor(padded_seq, dtype=torch.long),
            "attn_mask": torch.tensor(attn_mask, dtype=torch.bool),
            "target": torch.tensor(target, dtype=torch.long),
            "neg_items": torch.tensor(neg_items, dtype=torch.long)
        }

    def _pad_with_mask(self, seq):
        """Pad sequence and create attention mask (True for real items, False for padding)"""
        seq_len = len(seq)
        if seq_len > self.max_seq_len:
            # Truncate from the beginning (keep most recent items)
            seq = seq[-self.max_seq_len:]
            mask = [True] * self.max_seq_len
        elif seq_len < self.max_seq_len:
            # Left padding with zeros
            padding_len = self.max_seq_len - seq_len
            seq = [0] * padding_len + seq
            mask = [False] * padding_len + [True] * seq_len
        else:
            mask = [True] * self.max_seq_len
        return seq, mask

    def _sample_negatives(self, user):
        """Sample negative items not in user's interaction history"""
        neg_items = []
        user_items = self.user_item_set[user]

        while len(neg_items) < self.neg_samples:
            neg = random.randint(1, self.num_items - 1)
            if neg not in user_items:
                neg_items.append(neg)

        return neg_items

In [353]:
# train_dataset_sasrec = SASRecDataset(train_sequences, NUM_ITEMS, max_seq_len=50, mode="train", neg_samples=1)
# val_dataset_sasrec = SASRecDataset(val_sequences, NUM_ITEMS, max_seq_len=50, mode="val", neg_samples=99)
# test_dataset_sasrec = SASRecDataset(test_sequences, NUM_ITEMS, max_seq_len=50, mode="test",  neg_samples=99)
#
# train_loader_sasrec = DataLoader(train_dataset_sasrec, batch_size=4096, shuffle=True)
# val_loader_sasrec = DataLoader(val_dataset_sasrec, batch_size=4096, shuffle=False)
# test_loader_sasrec = DataLoader(test_dataset_sasrec, batch_size=4096, shuffle=False)

In [354]:
train_dataset_sasrec = SASRecDataset(full_seq, NUM_ITEMS, max_seq_len=50, mode="train", neg_samples=1)
val_dataset_sasrec = SASRecDataset(full_seq, NUM_ITEMS, max_seq_len=50, mode="val", neg_samples=99)
test_dataset_sasrec = SASRecDataset(full_seq, NUM_ITEMS, max_seq_len=50, mode="test", neg_samples=99)

train_loader_sasrec = DataLoader(train_dataset_sasrec, batch_size=2048, shuffle=True)
val_loader_sasrec = DataLoader(val_dataset_sasrec, batch_size=2048, shuffle=False)
test_loader_sasrec = DataLoader(test_dataset_sasrec, batch_size=2048, shuffle=False)

In [355]:
first = next(iter(train_loader_sasrec))
print("Shapes:",
      first["input_seq"].shape,   # expect [B, L]
      first["target"].shape,      # [B]
      first["neg_items"].shape,   # [B, N] (N==neg_samples or padded)
      first["user"].shape)        # [B]

Shapes: torch.Size([2048, 50]) torch.Size([2048]) torch.Size([2048, 1]) torch.Size([2048])


## Building the SASRec model

In [356]:
# # Building SASRec model
# class PointWiseFeedForward(nn.Module):
#     def __init__(self, hidden_dim, dropout=0.2):
#         super().__init__()
#         self.w1 = nn.Linear(hidden_dim, hidden_dim)
#         self.w2 = nn.Linear(hidden_dim, hidden_dim)
#         self.relu = nn.ReLU()
#         self.dropout = nn.Dropout(dropout)
#
#     def forward(self, x):
#         return self.w2(self.dropout(self.relu(self.w1(x))))
#
# class AttentionBlock(nn.Module):
#     def __init__(self, hidden_dim, num_heads, dropout=0.2):
#         super().__init__()
#
#         # Multi-head attention
#         self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
#
#         # Layer norms
#         self.ln1 = nn.LayerNorm(hidden_dim)
#         self.ln2 = nn.LayerNorm(hidden_dim)
#
#         # Feed-forward network
#         self.ffn = PointWiseFeedForward(hidden_dim, dropout)
#         self.dropout = nn.Dropout(dropout)
#
#     def forward(self, x, attn_mask=None):
#         # Self-attention with residual connection
#         attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask)
#         x = self.ln1(x + self.dropout(attn_out))
#
#         # Feed-forward network with residual connection
#         ffn_out = self.ffn(x)
#         x = self.ln2(x + self.dropout(ffn_out))
#         return x

# class SASRec(nn.Module):
#     def __init__(self,
#                  num_items,
#                  hidden_dim=64,
#                  max_seq_len=50,
#                  num_blocks=2,
#                  num_heads=2,
#                  dropout=0.2):
#         super().__init__()
#
#         self.num_items = num_items
#         self.hidden_dim = hidden_dim
#         self.max_seq_len = max_seq_len
#
#         # Embedding layers
#         self.item_embed = nn.Embedding(num_items, hidden_dim, padding_idx=0)
#         self.positional_embed = nn.Embedding(max_seq_len, hidden_dim)
#         self.dropout = nn.Dropout(dropout)
#
#         # Stack of SASRec blocks
#         self.blocks = nn.ModuleList([
#             AttentionBlock(hidden_dim, num_heads, dropout) for _ in range(num_blocks)
#         ])
#
#         # Final layer norm
#         self.ln = nn.LayerNorm(hidden_dim)
#
#         # Initialize weights
#         self._reset_parameters()
#
#     def _reset_parameters(self):
#         nn.init.xavier_normal_(self.item_embed.weight[1:])  # Skip padding idx
#         nn.init.xavier_normal_(self.positional_embed.weight)
#
#     def forward(self, input_seq, candidate_items=None):
#         batch_size, seq_len = input_seq.shape
#         item_embeds = self.item_embed(input_seq)  # [B, L, D]
#
#         # Add positional embeddings
#         pos_ids = torch.arange(seq_len, device=input_seq.device).unsqueeze(0).expand(batch_size, seq_len)
#         pos_embeds = self.positional_embed(pos_ids)  # [B, L, D]
#         x = self.dropout(item_embeds + pos_embeds)
#
#         # Create causal attention mask
#         attn_mask = self._create_causal_mask(seq_len, input_seq.device)
#
#         # Pass through transformer blocks
#         for block in self.blocks:
#             x = block(x, attn_mask=attn_mask)
#
#         # Final layer norm
#         x = self.ln(x)  # [B, L, D]
#
#         # If candidate_items provided, score them
#         if candidate_items is not None:
#             # Get embeddings for candidate items
#             cand_emb = self.item_embed(candidate_items) # [B, N, D]
#             # Use last position's representation for scoring
#             last_hidden = x[:, -1, :].unsqueeze(1)  # [B, 1, D]
#             # Compute scores via dot product
#             scores = torch.matmul(last_hidden, cand_emb.transpose(1, 2)).squeeze(1) # [B, N]
#             return scores
#         return x
#
#     def _create_causal_mask(self, seq_len, device):
#         mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
#         mask = mask.masked_fill(mask == 1, float("-inf"))
#         return mask
#
#     def predict_next(self, input_seq):
#         # Get sequence representations
#         seq_repr = self.forward(input_seq)  # [B, L, D]
#         # Use last position for prediction
#         last_hidden = seq_repr[:, -1, :]  # [B, D]
#         # Score against all item embeddings
#         all_item_embeds = self.item_embed.weight  # [num_items, D]
#         scores = torch.matmul(last_hidden, all_item_embeds.T)  # [B, num_items]
#         return scores

In [383]:
class SASRecBlock(nn.Module):
    """Single SASRec attention block with proper layer normalization"""
    def __init__(self, hidden_dim, num_heads, dropout=0.2):
        super().__init__()

        # Multi-head self-attention
        self.attention = nn.MultiheadAttention(
            hidden_dim,
            num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Point-wise feed-forward network
        self.feed_forward = PointWiseFeedForward(hidden_dim, dropout)

        # Layer normalization
        self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-8)
        self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-8)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # Self-attention with residual connection
        attn_out, _ = self.attention(
            query=self.norm1(x),
            key=self.norm1(x),
            value=self.norm1(x),
            attn_mask=attn_mask
        )
        x = x + self.dropout(attn_out)

        # Feed-forward with residual connection
        ff_out = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_out)

        return x


class PointWiseFeedForward(nn.Module):
    """Point-wise feed-forward network"""
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()

        # According to the paper, use 1-D convolutional layers
        self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: [B, L, D]
        x = x.transpose(1, 2)  # [B, D, L]
        x = self.conv2(self.dropout(self.relu(self.conv1(x))))
        x = x.transpose(1, 2)  # [B, L, D]
        return x

class SASRec(nn.Module):
    """
    Corrected SASRec model with proper attention masking and initialization
    """
    def __init__(self,
                 num_items,
                 hidden_dim=50,
                 max_seq_len=50,
                 num_blocks=2,
                 num_heads=1,
                 dropout=0.2):
        super().__init__()

        self.num_items = num_items
        self.hidden_dim = hidden_dim
        self.max_seq_len = max_seq_len

        # Embedding layers
        self.item_embed = nn.Embedding(num_items, hidden_dim, padding_idx=0)
        self.positional_embed = nn.Embedding(max_seq_len, hidden_dim)

        # Dropout and layer norm
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-8)

        # Self-attention blocks
        self.attention_blocks = nn.ModuleList([
            SASRecBlock(hidden_dim, num_heads, dropout)
            for _ in range(num_blocks)
        ])

        # Initialize weights according to the paper
        self._init_weights()

    def _init_weights(self):
        # Xavier uniform initialization as mentioned in the paper
        nn.init.xavier_uniform_(self.item_embed.weight[1:])  # Skip padding embedding
        nn.init.xavier_uniform_(self.positional_embed.weight)

    def forward(self, input_seq, attn_mask=None, candidate_items=None):
        batch_size, seq_len = input_seq.shape

        # Item embeddings
        item_embeds = self.item_embed(input_seq)  # [B, L, D]

        # Add positional embeddings
        pos_ids = torch.arange(seq_len, device=input_seq.device)
        pos_ids = pos_ids.unsqueeze(0).expand(batch_size, -1)
        pos_embeds = self.positional_embed(pos_ids)

        # Combine embeddings
        seq_embeds = item_embeds + pos_embeds
        seq_embeds = self.layer_norm(seq_embeds)
        seq_embeds = self.dropout(seq_embeds)

        # Create combined mask (causal + padding)
        timeline_mask = self._create_timeline_mask(seq_len, input_seq.device)

        # Apply padding mask if provided
        if attn_mask is not None:
            # Expand padding mask to match attention shape
            padding_mask = ~attn_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, L]
            padding_mask = padding_mask.expand(-1, seq_len, seq_len, -1)

            # Combine with causal mask
            timeline_mask = timeline_mask.unsqueeze(0) | padding_mask

        # Pass through self-attention blocks
        for block in self.attention_blocks:
            seq_embeds = block(seq_embeds, timeline_mask)

        # Output layer norm
        seq_output = self.layer_norm(seq_embeds)  # [B, L, D]

        # If candidate items provided, compute scores
        if candidate_items is not None:
            return self.compute_scores(seq_output, candidate_items)

        return seq_output

    def _create_timeline_mask(self, seq_len, device):
        """Create causal attention mask (True = mask out)"""
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        return mask.bool()

    def compute_scores(self, seq_output, candidate_items):
        """Compute scores for candidate items using the last hidden state"""
        last_hidden = seq_output[:, -1, :]  # [B, D]

        if len(candidate_items.shape) == 1:
            # Single item per user
            item_embeds = self.item_embed(candidate_items)  # [B, D]
            scores = (last_hidden * item_embeds).sum(dim=1)  # [B]
        else:
            # Multiple items per user
            item_embeds = self.item_embed(candidate_items)  # [B, N, D]
            scores = torch.bmm(item_embeds, last_hidden.unsqueeze(-1)).squeeze(-1)  # [B, N]

        return scores

    def predict_next(self, input_seq, attn_mask=None, k=10):
        """Predict top-k next items"""
        seq_output = self.forward(input_seq, attn_mask)
        last_hidden = seq_output[:, -1, :]  # [B, D]

        # Score all items
        item_embeds = self.item_embed.weight  # [num_items, D]
        scores = torch.matmul(last_hidden, item_embeds.T)  # [B, num_items]

        # Mask out padding token
        scores[:, 0] = -float('inf')

        # Get top-k items
        topk_scores, topk_items = torch.topk(scores, k, dim=1)

        return topk_items, topk_scores

## Training and evaluation functions

In [384]:
# def train_sasrec_epoch(model, train_loader, loss_fn, optimizer, device="cpu"):
#     model.train()
#     total_loss = 0.0
#     n_batches = 0
#
#     for batch in tqdm(train_loader, desc="Training"):
#         input_seq = batch["input_seq"].to(device)
#         pos_items = batch["target"].to(device)
#         neg_items = batch["neg_items"].to(device)
#
#         optimizer.zero_grad()
#
#         # Get predictions for last position
#         seq_output = model(input_seq)  # [B, L, D]
#         last_hidden = seq_output[:, -1, :]  # [B, D]
#
#         # Get embeddings for positive and negative items
#         pos_embeds = model.item_embed(pos_items)
#         neg_embeds = model.item_embed(neg_items)
#
#         # Compute logits
#         pos_logits = (last_hidden * pos_embeds).sum(dim=1)
#         neg_logits = torch.bmm(neg_embeds, last_hidden.unsqueeze(-1)).squeeze(-1)
#
#         # Binary cross-entropy loss with logits
#         pos_labels = torch.ones_like(pos_logits)
#         neg_labels = torch.zeros_like(neg_logits)
#
#         # Concatenate logits and labels
#         all_logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
#         all_labels = torch.cat([pos_labels.unsqueeze(1), neg_labels], dim=1)
#
#         loss = loss_fn(all_logits, all_labels)
#
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#         n_batches += 1
#
#     return total_loss / n_batches
#
# # Validation loss and ranking metrics
# @torch.no_grad()
# def evaluate_sasrec(model, eval_loader, loss_fn, k=10, device="cpu"):
#     model.eval()
#     total = 0
#     sum_hr = 0.0
#     sum_ndcg = 0.0
#     sum_prec = 0.0
#     sum_ap = 0.0
#
#     sum_val_loss = 0.0
#     n_loss_batches = 0
#
#     for batch in tqdm(eval_loader, desc="Evaluating"):
#         input_seq = batch["input_seq"].to(device)
#         target = batch["target"].to(device)
#         neg_items = batch["neg_items"].to(device)
#
#         batch_size = input_seq.size(0)
#
#         # Create candidate set: 1 positive + negatives
#         seq_output = model(input_seq)  # [B, L, D]
#         last_hidden = seq_output[:, -1, :]  # [B, D]
#         candidates = torch.cat([target.unsqueeze(1),neg_items], dim=1)  # [B, 1 + neg_samples]
#
#         # Get embeddings for all candidates
#         cand_emb = model.item_embed(candidates)  # [B, 1+neg_samples, D]
#
#         # Compute scores via dot product
#         scores = model(input_seq, candidate_items=candidates)
#         # scores = torch.bmm(cand_emb, last_hidden.unsqueeze(-1)).squeeze(-1)  # [B, 1+neg_samples]
#
#         # Loss calculation
#         pos_scores = scores[:, 0]
#         neg_scores = scores[:, 1:]
#         pos_labels = torch.ones_like(scores[:, 0])
#         neg_labels = torch.zeros_like(scores[:, 1:])
#         all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
#         all_labels = torch.cat([pos_labels.unsqueeze(1), neg_labels], dim=1)
#
#         batch_loss = loss_fn(all_scores, all_labels)
#         sum_val_loss += batch_loss.item()
#         n_loss_batches += 1
#
#         # Calculate metrics
#         _, full_idx = torch.sort(scores, dim=1, descending=True)
#         rank = (full_idx == 0).nonzero(as_tuple=True)[1] + 1  # Rank of the positive item (1-based)
#
#         hit = (rank <= k).float()
#         ndcg = torch.where(rank <= k, 1.0 / torch.log2(rank.float() + 1), torch.zeros_like(hit))
#         precision = hit / float(k)
#         ap = torch.where(rank <= k, 1.0 / rank.float(), torch.zeros_like(hit))
#
#         sum_hr += hit.sum().item()
#         sum_ndcg += ndcg.sum().item()
#         sum_prec += precision.sum().item()
#         sum_ap += ap.sum().item()
#         total += batch_size
#
#     metrics = {
#         "HR@K": sum_hr / total if total else 0.0,
#         "NDCG@K": sum_ndcg / total if total else 0.0,
#         "Precision@K": sum_prec / total if total else 0.0,
#         "MAP@K": sum_ap / total if total else 0.0,
#         "Val loss": sum_val_loss / max(n_loss_batches, 1)
#     }
#
#     return metrics

In [385]:
def train_sasrec_epoch(model, train_loader, loss_fn, optimizer, device="cpu", use_bpr=True):
    """Training with BPR loss as in the original paper"""
    model.train()
    total_loss = 0.0
    n_batches = 0

    for batch in tqdm(train_loader, desc="Training"):
        input_seq = batch["input_seq"].to(device)
        attn_mask = batch["attn_mask"].to(device)
        pos_items = batch["target"].to(device)
        neg_items = batch["neg_items"].to(device)

        # Forward pass
        seq_output = model(input_seq, attn_mask)
        last_hidden = seq_output[:, -1, :]

        # Compute scores
        pos_embeds = model.item_embed(pos_items)
        neg_embeds = model.item_embed(neg_items.squeeze(1) if neg_items.dim() > 1 else neg_items)

        pos_scores = (last_hidden * pos_embeds).sum(dim=1)
        neg_scores = (last_hidden * neg_embeds).sum(dim=1)

        # Alternative: Binary cross-entropy
        pos_labels = torch.ones_like(pos_scores)
        neg_labels = torch.zeros_like(neg_scores)

        all_scores = torch.cat([pos_scores, neg_scores])
        all_labels = torch.cat([pos_labels, neg_labels])

        loss = loss_fn(all_scores, all_labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)  # Gradient clipping
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches

In [386]:
def sasrec_trainer(
        model,
        train_loader,
        eval_loader,
        epochs,
        loss_fn,
        optimizer,
        k=10,             # kept to log hyperparams
        device="cpu",
        save_dir="model"
    ):
    os.makedirs(save_dir, exist_ok=True)
    model.to(device)
    writer = SummaryWriter()

    train_losses, val_losses, val_metrics_log = [], [], []
    best_ndcg, best_epoch = 0.0, 0

    for epoch in range(epochs):
        t0 = time.time()

        # Train (batched)
        train_loss = train_sasrec_epoch(model, train_loader, loss_fn, optimizer, device=device)
        train_losses.append(train_loss)

        # Eval (batched)
        m = evaluate_sasrec(model, eval_loader, loss_fn, k=k, device=device)
        val_losses.append(m["Val loss"])
        val_metrics_log.append({k_: m[k_] for k_ in ["HR@K", "NDCG@K", "Precision@K", "MAP@K"]})

        # Checkpointing by NDCG
        if m["NDCG@K"] > best_ndcg:
            best_ndcg = m["NDCG@K"]
            best_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
        torch.save(model.state_dict(), os.path.join(save_dir, "last_model.pth"))

        # TB logs
        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Loss/Validation", m["Val loss"], epoch)
        writer.add_scalar(f"Metrics/Val_HR@{k}", m["HR@K"], epoch)
        writer.add_scalar(f"Metrics/Val_NDCG@{k}", m["NDCG@K"], epoch)
        writer.add_scalar(f"Metrics/Val_Precision@{k}", m["Precision@K"], epoch)
        writer.add_scalar(f"Metrics/Val_MAP@{k}", m["MAP@K"], epoch)

        print(
            f"Epoch {epoch+1}/{epochs}  "
            f"Train loss {train_loss:.4f}  "
            f"Val loss {m['Val loss']:.4f}  "
            f"HR@{k} {m['HR@K']:.4f}  "
            f"NDCG@{k} {m['NDCG@K']:.4f}  "
            f"Precision@{k} {m['Precision@K']:.4f}  "
            f"MAP@{k} {m['MAP@K']:.4f}  "
            f"{'(new best)' if m['NDCG@K'] == best_ndcg and best_epoch==epoch+1 else ''}  "
            f"Time {time.time()-t0:.2f}s"
        )

    # print("\n" + "="*50)
    print("\nTraining Complete.")
    print(f"Best epoch: {best_epoch} with NDCG@{k}: {best_ndcg:.4f}\n")
    # print("="*50)
    # print("\n")

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    writer.close()
    return train_losses, val_losses, val_metrics_log, best_ndcg

@torch.no_grad()
def evaluate_sasrec(model, eval_loader, k=10, device="cpu"):
    """Evaluation with ranking metrics"""
    model.eval()

    metrics = defaultdict(float)
    total = 0

    for batch in tqdm(eval_loader, desc="Evaluating"):
        input_seq = batch["input_seq"].to(device)
        attn_mask = batch["attn_mask"].to(device)
        target = batch["target"].to(device)
        neg_items = batch["neg_items"].to(device)

        batch_size = input_seq.size(0)

        # Create candidate set: 1 positive + negatives
        candidates = torch.cat([target.unsqueeze(1), neg_items], dim=1)

        # Get scores for all candidates
        scores = model(input_seq, attn_mask, candidate_items=candidates)

        # Rank items
        _, indices = torch.sort(scores, dim=1, descending=True)

        # Find rank of positive item (first item in candidates)
        ranks = (indices == 0).nonzero(as_tuple=True)[1] + 1

        # Calculate metrics
        for rank in ranks:
            rank = rank.item()

            # Hit Rate
            if rank <= k:
                metrics[f'HR@{k}'] += 1.0

                # NDCG
                metrics[f'NDCG@{k}'] += 1.0 / np.log2(rank + 1)

                # MRR
                metrics[f'MRR@{k}'] += 1.0 / rank

        total += batch_size

    # Average metrics
    for key in metrics:
        metrics[key] /= total

    return dict(metrics)

## Train model

In [388]:
# Hyperparameters from the original paper, except higher hidden_dim
sasrec = SASRec(
    num_items=NUM_ITEMS,
    hidden_dim=64,
    max_seq_len=50,
    num_blocks=2,
    num_heads=1,
    dropout=0.5
)
sasrec.to(DEVICE)

loss_fn_sasrec = nn.BCEWithLogitsLoss()
optimizer_sasrec = torch.optim.Adam(sasrec.parameters(), lr=1e-3, weight_decay=1e-6)

best_ndcg = 0
for epoch in range(10):
    # Train
    train_loss = train_sasrec_epoch(
        sasrec, train_loader_sasrec, loss_fn_sasrec, optimizer_sasrec, device=DEVICE, use_bpr=True
    )

    # Evaluate
    val_metrics = evaluate_sasrec(sasrec, val_loader_sasrec, k=10, device=DEVICE)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, "
          f"HR@10={val_metrics.get('HR@10', 0):.4f}, "
          f"NDCG@10={val_metrics.get('NDCG@10', 0):.4f}, "
          f"MRR@10={val_metrics.get('MRR@10', 0):.4f}")

    # Save best model
    if val_metrics.get('NDCG@10', 0) > best_ndcg:
        best_ndcg = val_metrics.get('NDCG@10', 0)
        # torch.save(model.state_dict(), 'best_sasrec.pth')

Training: 100%|██████████| 29/29 [00:04<00:00,  6.04it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.39it/s]


Epoch 1: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:03<00:00,  7.47it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.55it/s]


Epoch 2: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:04<00:00,  7.02it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.45it/s]


Epoch 3: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.47it/s]


Epoch 4: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:03<00:00,  7.40it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.19it/s]


Epoch 5: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.46it/s]


Epoch 6: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.31it/s]


Epoch 7: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:03<00:00,  7.28it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.32it/s]


Epoch 8: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:04<00:00,  6.65it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.44it/s]


Epoch 9: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000


Training: 100%|██████████| 29/29 [00:04<00:00,  6.96it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.29it/s]

Epoch 10: Train Loss=nan, HR@10=1.0000, NDCG@10=1.0000, MRR@10=1.0000



