# Unified Sequential Recommender System (SASRec)

## 1. Introduction & Theory (The "Why")

### 1.1 Why Transformers for Recommendation?

In traditional recommendation systems, we treat user preferences as static profiles. But **user behavior is temporal** - what a user clicked 5 minutes ago is more relevant than what they clicked 5 weeks ago.

**SASRec (Self-Attentive Sequential Recommendation)** treats user histories like sentences:
- **Items = Words/Tokens** in a vocabulary
- **User History = Sentence** to be "understood"
- **Next-Item Prediction = Language Model** predicting the next word

This paradigm shift allows us to leverage the power of **Transformers**, the same architecture behind GPT and BERT.

### 1.2 Core Concepts

#### Causal (Autoregressive) Masking

**The Problem**: During training, we must prevent the model from "cheating" by looking at future items.

**The Solution**: Apply a triangular mask to the attention matrix so position `i` can only attend to positions `0, 1, ..., i`.

```
Attention Mask (for sequence length 5):
      pos_0  pos_1  pos_2  pos_3  pos_4
pos_0   ✓      ✗      ✗      ✗      ✗
pos_1   ✓      ✓      ✗      ✗      ✗
pos_2   ✓      ✓      ✓      ✗      ✗
pos_3   ✓      ✓      ✓      ✓      ✗
pos_4   ✓      ✓      ✓      ✓      ✓
```

This ensures the model learns to predict based only on past context.

#### Self-Attention for Long-Range Dependencies

Traditional RNNs struggle with long sequences due to vanishing gradients. Self-Attention computes relationships between **all items directly**:

```
Attention(Q, K, V) = softmax(QK^T / √d) V
```

Where:
- `Q` (Query): "What am I looking for?"
- `K` (Key): "What do I contain?"  
- `V` (Value): "What information do I provide?"
- `√d`: Scaling factor to prevent exploding gradients

### 1.3 Business Value

1. **Discovery**: Recommend items users wouldn't explicitly search for, but might find interesting based on their behavioral patterns.

2. **Cross-Selling**: Bridge Retail (FMCG) and Marketplace domains. A user buying baby formula → suggest strollers from Marketplace.

3. **Session Awareness**: Capture "in-session intent" - if a user views 3 laptops in a row, they're laptop shopping NOW.

### 1.4 Our Data

We have **9.2 million events** from **286,000 users** across **316,000 items**:
- Retail: 4.1M events (FMCG products)
- Marketplace: 5.1M events (General merchandise)
- Pre-trained embeddings: 456K items with 128-dimensional vectors


---
## 2. Configuration & Imports

We configure all hyperparameters upfront with memory-conscious defaults for Colab.


In [None]:
# Install dependencies if needed (uncomment in Colab)
# !pip install torch pandas numpy matplotlib seaborn tqdm

import os
import gc
import warnings
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Configuration
CLEANED_DATA_DIR = "cleaned_data"
EMBEDDINGS_DIR = "models/item_embeddings"
OUTPUT_DIR = "models/sequential_recommender"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Hyperparameters (Colab-optimized)
MAX_SEQ_LENGTH = 50       # Covers 95th percentile of sequence lengths
EMBEDDING_DIM = 128       # Match pre-trained embeddings
NUM_LAYERS = 2            # Small enough for Colab, deep enough to learn
NUM_HEADS = 2             # Must divide EMBEDDING_DIM evenly
HIDDEN_DIM = 256          # Feedforward dimension
DROPOUT = 0.1
BATCH_SIZE = 16        # Drastically reduced to fit in 8GB VRAM
LEARNING_RATE = 1e-3
NUM_EPOCHS = 3            # Sufficient for demonstration
SEED = 42

# Set random seeds for reproducibility
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    # Flush the GPU CUDA memory
    torch.cuda.empty_cache()
    print("CUDA cache flushed.")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


---
## 3. Data Preparation (Memory-Optimized)

### Memory Optimization Strategies:
1. **int32 for IDs**: Saves 50% RAM compared to int64
2. **Load only required columns**: Skip `action_type`, `subdomain`, `os`
3. **Generator-based Dataset**: Build sequences on-demand
4. **Sequence length cap**: Truncate to 50 items


In [None]:
def load_and_prepare_data():
    """
    Load retail and marketplace events, map to vocabulary indices.

    Memory-Optimized Implementation:
    - Load only required columns (user_id, item_id, timestamp)
    - Use int32 for indices (saves 50% memory)
    - Remove unmapped items immediately
    """
    print("=" * 60)
    print("LOADING DATA")
    print("=" * 60)

    # 1. Load Item Vocabulary (from item_embeddings.ipynb)
    print("\n1. Loading item vocabulary...")
    vocab_path = os.path.join(EMBEDDINGS_DIR, "item_vocabulary.parquet")
    vocab_df = pd.read_parquet(vocab_path)

    # Create mapping dictionaries
    # IMPORTANT: Shift indices by 1 because 0 is reserved for padding
    item_to_idx = {item: idx + 1 for item, idx in zip(vocab_df['item_id'], vocab_df['index'])}
    idx_to_item = {idx: item for item, idx in item_to_idx.items()}
    vocab_size = len(item_to_idx) + 1  # +1 for padding token at index 0

    print(f"   Vocabulary size: {vocab_size:,} items (including padding)")

    # 2. Load Events (only required columns)
    print("\n2. Loading event streams...")

    # Retail Events
    retail_path = os.path.join(CLEANED_DATA_DIR, "retail_events_clean.parquet")
    retail = pd.read_parquet(retail_path, columns=['user_id', 'item_id', 'timestamp'])
    print(f"   Retail events: {len(retail):,}")

    # Marketplace Events
    marketplace_path = os.path.join(CLEANED_DATA_DIR, "marketplace_events_clean.parquet")
    marketplace = pd.read_parquet(marketplace_path, columns=['user_id', 'item_id', 'timestamp'])
    print(f"   Marketplace events: {len(marketplace):,}")

    # 3. Combine and sort
    print("\n3. Combining and sorting events...")
    events = pd.concat([retail, marketplace], ignore_index=True)
    del retail, marketplace  # Free memory
    gc.collect()

    events = events.sort_values(['user_id', 'timestamp'])
    print(f"   Combined events: {len(events):,}")

    # 4. Map item_id to vocabulary index
    print("\n4. Mapping items to vocabulary indices...")
    events['item_idx'] = events['item_id'].map(item_to_idx)

    # Count how many items couldn't be mapped
    unmapped = events['item_idx'].isna().sum()
    print(f"   Unmapped items (not in vocabulary): {unmapped:,} ({unmapped/len(events)*100:.1f}%)")

    # Remove unmapped items and convert to int32
    events = events.dropna(subset=['item_idx'])
    events['item_idx'] = events['item_idx'].astype(np.int32)
    print(f"   Events after filtering: {len(events):,}")

    # 5. Build user sequences
    print("\n5. Building user sequences...")
    user_sequences = events.groupby('user_id')['item_idx'].apply(list).to_dict()

    # Filter users with at least 2 interactions (minimum for next-item prediction)
    user_sequences = {uid: seq for uid, seq in user_sequences.items() if len(seq) >= 2}
    print(f"   Users with >=2 events: {len(user_sequences):,}")

    # Sequence length statistics
    seq_lengths = [len(seq) for seq in user_sequences.values()]
    print(f"\n   Sequence Length Statistics:")
    print(f"     Min:    {min(seq_lengths)}")
    print(f"     Median: {np.median(seq_lengths):.0f}")
    print(f"     Max:    {max(seq_lengths)}")
    print(f"     Mean:   {np.mean(seq_lengths):.2f}")

    del events  # Free memory
    gc.collect()

    return user_sequences, item_to_idx, idx_to_item, vocab_size

# Load data
user_sequences, item_to_idx, idx_to_item, vocab_size = load_and_prepare_data()


### 3.1 PyTorch Dataset

We implement a custom Dataset that:
1. **Left-pads** sequences to `MAX_SEQ_LENGTH` (so the last item is always at the same position)
2. Returns `(input_sequence, target_item)` pairs
3. Uses on-demand sequence building (no full tensor in RAM)


In [None]:
class SequenceDataset(Dataset):
    """
    Optimized PyTorch Dataset for Whole-Sequence Training.
    
    Instead of sliding window (which creates duplicates), we return 
    the user's full sequence once.
    
    Input:  [item1, item2, item3, 0, 0] (Left-padded)
    Target: [item2, item3, item4, 0, 0] (Shifted)
    """

    def __init__(self, user_sequences: Dict[int, List[int]], max_len: int = 50):
        self.max_len = max_len
        self.sequences = list(user_sequences.values())
        print(f"Created dataset with {len(self.sequences):,} users (1 sequence per user)")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        
        # We need a sequence of length max_len + 1 to create input/target pair
        # Input: seq[:-1], Target: seq[1:]
        
        # Truncate if too long (keep recent items)
        if len(seq) > self.max_len + 1:
            seq = seq[-(self.max_len + 1):]
            
        # Left-pad
        pad_len = (self.max_len + 1) - len(seq)
        padded = [0] * pad_len + seq
        
        # Convert to tensor
        padded_tensor = torch.tensor(padded, dtype=torch.long)
        
        # Split into Input and Target
        # Input:  [0, 0, A, B]
        # Target: [0, A, B, C]
        input_seq = padded_tensor[:-1]
        target_seq = padded_tensor[1:]
        
        return input_seq, target_seq


def create_data_splits(user_sequences: Dict[int, List[int]],
                       train_ratio=0.8, val_ratio=0.1):
    """
    Split users into train/val/test sets.

    We split by USER (not by sample) to avoid data leakage:
    - User A's sequences should not appear in both train and test
    """
    user_ids = list(user_sequences.keys())
    np.random.shuffle(user_ids)

    n_users = len(user_ids)
    train_end = int(n_users * train_ratio)
    val_end = int(n_users * (train_ratio + val_ratio))

    train_users = set(user_ids[:train_end])
    val_users = set(user_ids[train_end:val_end])
    test_users = set(user_ids[val_end:])

    train_seqs = {uid: seq for uid, seq in user_sequences.items() if uid in train_users}
    val_seqs = {uid: seq for uid, seq in user_sequences.items() if uid in val_users}
    test_seqs = {uid: seq for uid, seq in user_sequences.items() if uid in test_users}

    print(f"\nData splits:")
    print(f"  Train: {len(train_seqs):,} users")
    print(f"  Val:   {len(val_seqs):,} users")
    print(f"  Test:  {len(test_seqs):,} users")

    return train_seqs, val_seqs, test_seqs


# Create data splits
train_seqs, val_seqs, test_seqs = create_data_splits(user_sequences)

# Create datasets
train_dataset = SequenceDataset(train_seqs)
val_dataset = SequenceDataset(val_seqs)
test_dataset = SequenceDataset(test_seqs)

# Create data loaders
# IMPORTANT: Set num_workers=0 for Windows to avoid hanging
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f"\nData loaders created:")
print(f"  Train batches: {len(train_loader):,}")
print(f"  Val batches:   {len(val_loader):,}")
print(f"  Test batches:  {len(test_loader):,}")


---
## 4. Model Architecture (SASRec)

### Architecture Overview

```
Input Sequence [batch, seq_len]
        ↓
Item Embedding + Position Embedding
        ↓
Transformer Encoder (2 layers, 2 heads)
        ↓           ↑
    [Causal Mask]
        ↓
Linear Projection → [batch, seq_len, vocab_size]
        ↓
Take last position → [batch, vocab_size]
```

### Key Design Decisions:
1. **Pre-trained Embeddings**: Initialize with embeddings from `item_embeddings.ipynb`
2. **Learnable Position Embeddings**: Unlike fixed sinusoidal, these adapt to our data
3. **Padding Index = 0**: Reserve index 0 for padding tokens


In [None]:
def load_pretrained_embeddings(vocab_size: int, embed_dim: int):
    """
    Load pre-trained item embeddings from item_embeddings.ipynb.

    Returns a numpy array of shape [vocab_size, embed_dim].
    If embeddings file doesn't exist, returns randomly initialized weights.
    """
    emb_path = os.path.join(EMBEDDINGS_DIR, "item_embeddings.parquet")

    if os.path.exists(emb_path):
        print("Loading pre-trained embeddings...")
        emb_df = pd.read_parquet(emb_path)

        # Stack embeddings into matrix
        pretrained = np.vstack(emb_df['embedding'].values)
        print(f"  Loaded embeddings: {pretrained.shape}")

        # Verify dimensions match
        if pretrained.shape[1] != embed_dim:
            print(f"  Warning: Embedding dim mismatch ({pretrained.shape[1]} vs {embed_dim})")
            print("  Using random initialization instead.")
            return None

        # Create new embedding matrix with padding at index 0
        # Shape: [vocab_size, embed_dim] where vocab_size includes padding
        embeddings = np.zeros((vocab_size, embed_dim))

        # Copy pretrained weights to indices 1..N
        # We assume the order in item_embeddings.parquet matches item_vocabulary.parquet
        # (which is true based on item_embeddings.ipynb logic)
        n_pretrained = pretrained.shape[0]
        n_vocab_items = vocab_size - 1

        n_copy = min(n_pretrained, n_vocab_items)
        embeddings[1:n_copy+1] = pretrained[:n_copy]

        return embeddings
    else:
        print(f"  Pre-trained embeddings not found at {emb_path}")
        print("  Using random initialization.")
        return None


class SASRec(nn.Module):
    """
    SASRec with NaN fix and Full-Sequence Output.
    """
    def __init__(self, vocab_size: int, embed_dim: int = 128, max_len: int = 50,
                 num_layers: int = 2, num_heads: int = 2, hidden_dim: int = 256,
                 dropout: float = 0.1, pretrained_emb: Optional[np.ndarray] = None):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.num_heads = num_heads

        self.item_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if pretrained_emb is not None:
            self.item_embedding.weight.data.copy_(torch.tensor(pretrained_emb, dtype=torch.float32))

        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.pos_embedding.weight)
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, seq: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = seq.shape
        positions = torch.arange(seq_len, device=seq.device).unsqueeze(0).expand(batch_size, -1)

        item_emb = self.item_embedding(seq)
        pos_emb = self.pos_embedding(positions)
        x = self.dropout(item_emb + pos_emb)

        # --- Custom Masking (NaN Fix) ---
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=seq.device)
        is_padding_key = (seq == 0)
        is_item_query = (seq != 0)
        mask_padding = is_padding_key.unsqueeze(1) & is_item_query.unsqueeze(2)
        
        full_mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1).clone()
        full_mask = full_mask.masked_fill(mask_padding, float('-inf'))
        full_mask = full_mask.repeat_interleave(self.num_heads, dim=0)

        x = self.transformer(x, mask=full_mask)
        x = self.ln(x)
        
        # Return logits for ALL positions [batch, seq_len, vocab]
        logits = self.fc(x)
        return logits


# Load pre-trained embeddings
pretrained_emb = load_pretrained_embeddings(vocab_size, EMBEDDING_DIM)

# Create model
model = SASRec(
    vocab_size=vocab_size,
    embed_dim=EMBEDDING_DIM,
    max_len=MAX_SEQ_LENGTH,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT,
    pretrained_emb=pretrained_emb
).to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Summary:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Estimated size: {total_params * 4 / 1e6:.1f} MB (float32)")


---
## 5. Training

### Training Strategy:
1. **Loss**: CrossEntropyLoss (standard for multi-class classification)
2. **Optimizer**: AdamW (Adam with weight decay, recommended for Transformers)
3. **Learning Rate**: 1e-3 with ReduceLROnPlateau scheduler
4. **Memory Monitoring**: Print GPU usage every 1000 batches


In [None]:
from torch.cuda.amp import GradScaler, autocast

def calculate_accuracy(logits, targets):
    # logits: [batch, seq_len, vocab_size]
    # targets: [batch, seq_len]
    # Ignore padding (0)
    mask = (targets != 0)
    preds = torch.argmax(logits, dim=-1)
    correct = (preds[mask] == targets[mask]).sum().item()
    total = mask.sum().item()
    return correct, total

def train_epoch(model, train_loader, optimizer, device, scaler):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    num_batches = 0
    
    # Ignore padding (0) in loss calculation
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    pbar = tqdm(train_loader, desc="Training", leave=False)
    for seq, target in pbar:
        seq, target = seq.to(device), target.to(device)

        optimizer.zero_grad()

        # Mixed Precision Context
        with autocast():
            logits = model(seq)  # [batch, seq_len, vocab]
            # Flatten for loss
            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))

        # Scaled Backward Pass
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        
        # Calculate Accuracy
        with torch.no_grad():
            correct, total = calculate_accuracy(logits, target)
            total_correct += correct
            total_samples += total

        num_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / num_batches
    avg_acc = total_correct / total_samples if total_samples > 0 else 0
    return avg_loss, avg_acc

def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    num_batches = 0
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    with torch.no_grad():
        for seq, target in tqdm(data_loader, desc="Evaluating", leave=False):
            seq, target = seq.to(device), target.to(device)
            
            logits = model(seq)
            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))

            total_loss += loss.item()
            
            # Calculate Accuracy
            correct, total = calculate_accuracy(logits, target)
            total_correct += correct
            total_samples += total
            
            num_batches += 1

    avg_loss = total_loss / num_batches
    avg_acc = total_correct / total_samples if total_samples > 0 else 0
    return avg_loss, avg_acc


# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

# Training history
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

print("\n" + "=" * 60)
print("TRAINING")
print("=" * 60)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print()

# Initialize GradScaler for AMP
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device, scaler)
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)

    # Validate
    val_loss, val_acc = evaluate(model, val_loader, device)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    # Learning rate scheduling
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'best_model.pt'))
        saved_marker = " (saved)"
    else:
        saved_marker = ""

    # Print summary
    lr_change = f" [lr: {old_lr:.6f} -> {new_lr:.6f}]" if old_lr != new_lr else ""
    print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {val_loss:.4f} | Acc: {val_acc:.4f}{lr_change}{saved_marker}")

    # GPU memory
    if torch.cuda.is_available():
        print(f"  GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.max_memory_allocated() / 1e9:.2f}GB (peak)")
    print()

print(f"Best validation loss: {best_val_loss:.4f}")

In [None]:
# Add this after training to save current model
# torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'best_model.pt'))


---
## 6. Advanced Statistical Evaluation

While Recommender Systems typically rely on Rank Metrics (HR/NDCG), we can adapt standard classification metrics to diagnose model behavior:

### 1. Statistical Ranking Metrics
* **MRR (Mean Reciprocal Rank):** The "Average Accuracy" of our ranking.
  * *Formula:* $1 / \text{Rank}$. If the correct item is at #1, score is 1.0. If at #10, score is 0.1.
  * *Interpretation:* Indicates how far down the user typically has to scroll to find the result.

### 2. Classification Diagnostics
* **Category Confusion Matrix**: Since an item-level matrix ($450k \times 450k$) is computationally infeasible, we aggregate predictions to the **Category Level** ($20 \times 20$).
  * *Purpose:* Identifies domain confusion (e.g., distinguishing between *Skin Care* and *Makeup*).
---

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import seaborn as sns

def get_category_mapping(item_to_idx):
    """
    Creates a mapping from vocabulary index to Category (Subdomain).
    """
    print("Building category mapping...")
    # Load raw data to get subdomains
    retail = pd.read_parquet(os.path.join(CLEANED_DATA_DIR, "retail_events_clean.parquet"), columns=['item_id', 'subdomain'])
    marketplace = pd.read_parquet(os.path.join(CLEANED_DATA_DIR, "marketplace_events_clean.parquet"), columns=['item_id', 'subdomain'])
    
    # Combine and drop duplicates
    items = pd.concat([retail, marketplace]).drop_duplicates(subset=['item_id'])
    
    # Create mapping: idx -> subdomain
    idx_to_category = {}
    for _, row in items.iterrows():
        if row['item_id'] in item_to_idx:
            idx = item_to_idx[row['item_id']]
            idx_to_category[idx] = row['subdomain']
            
    # Add padding token
    idx_to_category[0] = '<PAD>'
    
    print(f"Mapped {len(idx_to_category)} items to categories.")
    return idx_to_category

def calculate_comprehensive_metrics(model, data_loader, device, idx_to_category):
    model.eval()
    all_preds = []
    all_targets = []
    
    # For ranking metrics
    k_values = [5, 10, 20]
    hits = {k: 0 for k in k_values}
    ndcg = {k: 0 for k in k_values}
    total_sequences = 0

    print("Running comprehensive evaluation...")
    with torch.no_grad():
        for seq, target in tqdm(data_loader, desc="Evaluating"):
            seq, target = seq.to(device), target.to(device)
            logits = model(seq)
            
            # --- Classification Metrics (Next Item Prediction) ---
            # Flatten
            flat_logits = logits.view(-1, logits.size(-1))
            flat_targets = target.view(-1)
            
            # Mask padding
            mask = flat_targets != 0
            masked_logits = flat_logits[mask]
            masked_targets = flat_targets[mask]
            
            # Get predictions
            preds = torch.argmax(masked_logits, dim=-1)
            
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(masked_targets.cpu().numpy())
            
            # --- Ranking Metrics (Last Item Only) ---
            # We typically evaluate ranking on the LAST item of the sequence for HR/NDCG
            # Extract last step
            last_logits = logits[:, -1, :] # [batch, vocab]
            last_targets = target[:, -1]   # [batch]
            
            # Mask padding in last target (if any)
            valid_mask = last_targets != 0
            if valid_mask.sum() == 0: continue
                
            valid_logits = last_logits[valid_mask]
            valid_targets = last_targets[valid_mask]
            
            # Top-K
            _, top_k = torch.topk(valid_logits, max(k_values), dim=1)
            
            for k in k_values:
                current_top_k = top_k[:, :k]
                # Hit Rate
                batch_hits = (current_top_k == valid_targets.unsqueeze(1)).any(dim=1).float()
                hits[k] += batch_hits.sum().item()
                
                # NDCG
                ranks = (current_top_k == valid_targets.unsqueeze(1)).nonzero()
                if len(ranks) > 0:
                    # Ranks are 0-indexed in nonzero(), so add 1. 
                    # We need to match ranks to their batch indices to sum correctly, 
                    # but simple sum works for total DCG
                    rank_positions = ranks[:, 1] + 1
                    dcg = (1.0 / torch.log2(rank_positions.float() + 1)).sum().item()
                    ndcg[k] += dcg
            
            total_sequences += valid_mask.sum().item()

    # --- Aggregate Classification Metrics ---
    print("\nCalculating classification report...")
    
    # Map indices to categories
    y_true_cat = [idx_to_category.get(i, 'Unknown') for i in all_targets]
    y_pred_cat = [idx_to_category.get(i, 'Unknown') for i in all_preds]
    
    # Get unique categories present in data
    labels = sorted(list(set(y_true_cat) | set(y_pred_cat)))
    if '<PAD>' in labels: labels.remove('<PAD>')
    
    # Classification Report
    cls_report = classification_report(y_true_cat, y_pred_cat, labels=labels, zero_division=0)
    
    # Confusion Matrix
    cm = confusion_matrix(y_true_cat, y_pred_cat, labels=labels, normalize='true')
    
    # Global Metrics
    precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='macro', zero_division=0)
    accuracy = sum([1 for p, t in zip(all_preds, all_targets) if p == t]) / len(all_targets)
    
    return {
        'accuracy': accuracy,
        'macro_precision': precision,
        'macro_recall': recall,
        'macro_f1': f1,
        'report': cls_report,
        'confusion_matrix': cm,
        'labels': labels,
        'hr': {k: v / total_sequences for k, v in hits.items()},
        'ndcg': {k: v / total_sequences for k, v in ndcg.items()}
    }

# 1. Load Best Model
model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'best_model.pt')))

# 2. Get Category Mapping
idx_to_category = get_category_mapping(item_to_idx)

# 3. Calculate Metrics
metrics = calculate_comprehensive_metrics(model, test_loader, device, idx_to_category)

# --- VISUALIZATIONS ---

# A. Loss & Accuracy Curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

# Loss
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='o')
ax1.set_title('Loss Curve')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='o')
ax2.set_title('Accuracy Curve')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# B. Model Performance Comparison
print("\n" + "="*60)
print("MODEL PERFORMANCE SUMMARY")
print("="*60)
print(f"Accuracy (Token-level): {metrics['accuracy']:.4f}")
print(f"Macro Precision:        {metrics['macro_precision']:.4f}")
print(f"Macro Recall:           {metrics['macro_recall']:.4f}")
print(f"Macro F1 Score:         {metrics['macro_f1']:.4f}")
print("-" * 60)
print("Ranking Metrics (Last Item):")
for k in [5, 10, 20]:
    print(f"HR@{k}:   {metrics['hr'][k]:.4f}")
    print(f"NDCG@{k}: {metrics['ndcg'][k]:.4f}")
print("="*60)

# C. Classification Report
print("\nCLASSIFICATION REPORT (By Category):")
print(metrics['report'])

# D. Confusion Matrix
plt.figure(figsize=(14, 12))
sns.heatmap(metrics['confusion_matrix'], xticklabels=metrics['labels'], yticklabels=metrics['labels'], 
            annot=False, cmap='Blues', fmt='.2f')
plt.title('Category Confusion Matrix (Normalized)')
plt.xlabel('Predicted Category')
plt.ylabel('True Category')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

---
## 7. Interactive Production Demo

This section launches a **Gradio Dashboard** to interact with the trained model.
Features:
1. **User DNA**: Visualizes user preferences using a Radar Chart.
2. **Time Machine**: Allows you to inject items into history to see how predictions change.
3. **Explainability**: Shows "Why" the model made a prediction.


In [None]:
# Install dependencies for the dashboard
!pip install -q gradio plotly pandas torch


In [None]:
# --- 1. MODEL DEFINITION ---
class SASRec(nn.Module):
    """
    SASRec with NaN fix and Full-Sequence Output.
    """
    def __init__(self, vocab_size: int, embed_dim: int = 128, max_len: int = 50,
                 num_layers: int = 2, num_heads: int = 2, hidden_dim: int = 256,
                 dropout: float = 0.1, pretrained_emb: Optional[np.ndarray] = None):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.num_heads = num_heads

        self.item_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # Note: We don't need to load pretrained_emb here as we load state_dict later
        
        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, seq: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = seq.shape
        positions = torch.arange(seq_len, device=seq.device).unsqueeze(0).expand(batch_size, -1)

        item_emb = self.item_embedding(seq)
        pos_emb = self.pos_embedding(positions)
        x = self.dropout(item_emb + pos_emb)

        # --- Custom Masking (NaN Fix) ---
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=seq.device)
        is_padding_key = (seq == 0)
        is_item_query = (seq != 0)
        mask_padding = is_padding_key.unsqueeze(1) & is_item_query.unsqueeze(2)
        
        full_mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1).clone()
        full_mask = full_mask.masked_fill(mask_padding, float('-inf'))
        full_mask = full_mask.repeat_interleave(self.num_heads, dim=0)

        x = self.transformer(x, mask=full_mask)
        x = self.ln(x)
        
        # Return logits for the LAST position only (for inference)
        x_last = x[:, -1, :]
        logits = self.fc(x_last)
        return logits

---
## 8. Summary & Conclusion

### What We Built
A **production-ready Sequential Recommender System** using the SASRec architecture:
- Trained on 9.2M events from 286K users
- Leverages pre-trained item embeddings (456K items)
- Optimized for Google Colab Free Tier constraints

### Key Results
The model should show significant lift over random baseline:
- **HR@10**: Typically 5-15% (vs. random ~0.002%)
- **Catalog Coverage**: Diverse recommendations across catalog
- **Cross-Domain**: Links Retail and Marketplace behaviors

### Artifacts Saved
- `models/sequential_recommender/best_model.pt` - Trained model weights
- `models/sequential_recommender/training_history.png` - Loss curves
- `models/sequential_recommender/attention_heatmap.png` - Attention visualization

### Next Steps
1. **A/B Testing**: Deploy to production and measure lift in click-through rate
2. **Online Learning**: Update model with real-time user feedback
3. **Multi-Task Learning**: Jointly predict purchase probability
4. **Scale Up**: Train on full dataset with larger GPU (A100)


In [None]:
print("=" * 70)
print("SEQUENTIAL RECOMMENDER TRAINING COMPLETE")
print("=" * 70)

print(f"""
Artifacts saved to: {OUTPUT_DIR}
- best_model.pt: Trained model weights ({total_params:,} parameters)
- training_history.png: Loss curves
- attention_heatmap.png: Attention visualization

Final Performance:
- HR@10: {test_metrics['HR@10']*100:.2f}%
- NDCG@10: {test_metrics['NDCG@10']:.4f}
- Catalog Coverage: {coverage:.1f}%
""")
