# Unified Sequential Recommender System (SASRec)

## 1. Introduction & Theory (The "Why")

### 1.1 Academic Context: Why Transformers for Recommendation?

In traditional recommendation systems, we treat user preferences as static profiles. But **user behavior is temporal** - what a user clicked 5 minutes ago is more relevant than what they clicked 5 weeks ago.

**SASRec (Self-Attentive Sequential Recommendation)** treats user histories like sentences:
- **Items = Words/Tokens** in a vocabulary
- **User History = Sentence** to be "understood"
- **Next-Item Prediction = Language Model** predicting the next word

This paradigm shift allows us to leverage the power of **Transformers**, the same architecture behind GPT and BERT.

### 1.2 Core Concepts

#### Causal (Autoregressive) Masking

**The Problem**: During training, we must prevent the model from "cheating" by looking at future items.

**The Solution**: Apply a triangular mask to the attention matrix so position `i` can only attend to positions `0, 1, ..., i`.

```
Attention Mask (for sequence length 5):
      pos_0  pos_1  pos_2  pos_3  pos_4
pos_0   âœ“      âœ—      âœ—      âœ—      âœ—
pos_1   âœ“      âœ“      âœ—      âœ—      âœ—
pos_2   âœ“      âœ“      âœ“      âœ—      âœ—
pos_3   âœ“      âœ“      âœ“      âœ“      âœ—
pos_4   âœ“      âœ“      âœ“      âœ“      âœ“
```

This ensures the model learns to predict based only on past context.

#### Self-Attention for Long-Range Dependencies

Traditional RNNs struggle with long sequences due to vanishing gradients. Self-Attention computes relationships between **all items directly**:

```
Attention(Q, K, V) = softmax(QK^T / âˆšd) V
```

Where:
- `Q` (Query): "What am I looking for?"
- `K` (Key): "What do I contain?"  
- `V` (Value): "What information do I provide?"
- `âˆšd`: Scaling factor to prevent exploding gradients

### 1.3 Business Value

1. **Discovery**: Recommend items users wouldn't explicitly search for, but might find interesting based on their behavioral patterns.

2. **Cross-Selling**: Bridge Retail (FMCG) and Marketplace domains. A user buying baby formula â†’ suggest strollers from Marketplace.

3. **Session Awareness**: Capture "in-session intent" - if a user views 3 laptops in a row, they're laptop shopping NOW.

### 1.4 Our Data

We have **9.2 million events** from **286,000 users** across **316,000 items**:
- Retail: 4.1M events (FMCG products)
- Marketplace: 5.1M events (General merchandise)
- Pre-trained embeddings: 456K items with 128-dimensional vectors

**Constraint**: Google Colab Free Tier (12GB RAM, 15GB GPU)


---
## 2. Configuration & Imports

We configure all hyperparameters upfront with memory-conscious defaults for Colab.


In [None]:
# Install dependencies if needed (uncomment in Colab)
# !pip install torch pandas numpy matplotlib seaborn tqdm

import os
import gc
import warnings
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Configuration
CLEANED_DATA_DIR = "cleaned_data"
EMBEDDINGS_DIR = "models/item_embeddings"
OUTPUT_DIR = "models/sequential_recommender"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Hyperparameters (Colab-optimized)
MAX_SEQ_LENGTH = 50       # Covers 95th percentile of sequence lengths
EMBEDDING_DIM = 128       # Match pre-trained embeddings
NUM_LAYERS = 2            # Small enough for Colab, deep enough to learn
NUM_HEADS = 2             # Must divide EMBEDDING_DIM evenly
HIDDEN_DIM = 256          # Feedforward dimension
DROPOUT = 0.1
BATCH_SIZE = 128          # Memory-friendly
LEARNING_RATE = 1e-3
NUM_EPOCHS = 3            # Sufficient for demonstration
SEED = 42

# Set random seeds for reproducibility
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


---
## 3. Data Preparation (Memory-Optimized)

### Memory Optimization Strategies:
1. **int32 for IDs**: Saves 50% RAM compared to int64
2. **Load only required columns**: Skip `action_type`, `subdomain`, `os`
3. **Generator-based Dataset**: Build sequences on-demand
4. **Sequence length cap**: Truncate to 50 items


In [None]:
def load_and_prepare_data():
    """
    Load retail and marketplace events, map to vocabulary indices.
    
    Memory-Optimized Implementation:
    - Load only required columns (user_id, item_id, timestamp)
    - Use int32 for indices (saves 50% memory)
    - Remove unmapped items immediately
    """
    print("=" * 60)
    print("LOADING DATA")
    print("=" * 60)
    
    # 1. Load Item Vocabulary (from item_embeddings.ipynb)
    print("\n1. Loading item vocabulary...")
    vocab_path = os.path.join(EMBEDDINGS_DIR, "item_vocabulary.parquet")
    vocab_df = pd.read_parquet(vocab_path)
    
    # Create mapping dictionaries
    # IMPORTANT: Shift indices by 1 because 0 is reserved for padding
    item_to_idx = {item: idx + 1 for item, idx in zip(vocab_df['item_id'], vocab_df['index'])}
    idx_to_item = {idx: item for item, idx in item_to_idx.items()}
    vocab_size = len(item_to_idx) + 1  # +1 for padding token at index 0
    
    print(f"   Vocabulary size: {vocab_size:,} items (including padding)")
    
    # 2. Load Events (only required columns)
    print("\n2. Loading event streams...")
    
    # Retail Events
    retail_path = os.path.join(CLEANED_DATA_DIR, "retail_events_clean.parquet")
    retail = pd.read_parquet(retail_path, columns=['user_id', 'item_id', 'timestamp'])
    print(f"   Retail events: {len(retail):,}")
    
    # Marketplace Events  
    marketplace_path = os.path.join(CLEANED_DATA_DIR, "marketplace_events_clean.parquet")
    marketplace = pd.read_parquet(marketplace_path, columns=['user_id', 'item_id', 'timestamp'])
    print(f"   Marketplace events: {len(marketplace):,}")
    
    # 3. Combine and sort
    print("\n3. Combining and sorting events...")
    events = pd.concat([retail, marketplace], ignore_index=True)
    del retail, marketplace  # Free memory
    gc.collect()
    
    events = events.sort_values(['user_id', 'timestamp'])
    print(f"   Combined events: {len(events):,}")
    
    # 4. Map item_id to vocabulary index
    print("\n4. Mapping items to vocabulary indices...")
    events['item_idx'] = events['item_id'].map(item_to_idx)
    
    # Count how many items couldn't be mapped
    unmapped = events['item_idx'].isna().sum()
    print(f"   Unmapped items (not in vocabulary): {unmapped:,} ({unmapped/len(events)*100:.1f}%)")
    
    # Remove unmapped items and convert to int32
    events = events.dropna(subset=['item_idx'])
    events['item_idx'] = events['item_idx'].astype(np.int32)
    print(f"   Events after filtering: {len(events):,}")
    
    # 5. Build user sequences
    print("\n5. Building user sequences...")
    user_sequences = events.groupby('user_id')['item_idx'].apply(list).to_dict()
    
    # Filter users with at least 2 interactions (minimum for next-item prediction)
    user_sequences = {uid: seq for uid, seq in user_sequences.items() if len(seq) >= 2}
    print(f"   Users with >=2 events: {len(user_sequences):,}")
    
    # Sequence length statistics
    seq_lengths = [len(seq) for seq in user_sequences.values()]
    print(f"\n   Sequence Length Statistics:")
    print(f"     Min:    {min(seq_lengths)}")
    print(f"     Median: {np.median(seq_lengths):.0f}")
    print(f"     Max:    {max(seq_lengths)}")
    print(f"     Mean:   {np.mean(seq_lengths):.2f}")
    
    del events  # Free memory
    gc.collect()
    
    return user_sequences, item_to_idx, idx_to_item, vocab_size

# Load data
user_sequences, item_to_idx, idx_to_item, vocab_size = load_and_prepare_data()


### 3.1 PyTorch Dataset

We implement a custom Dataset that:
1. **Left-pads** sequences to `MAX_SEQ_LENGTH` (so the last item is always at the same position)
2. Returns `(input_sequence, target_item)` pairs
3. Uses on-demand sequence building (no full tensor in RAM)


In [None]:
class SequenceDataset(Dataset):
    """
    PyTorch Dataset for sequential recommendation.
    
    For each user, we create training samples using sliding window:
    - Input: items[0:i] for i in range(2, len(items)+1)
    - Target: items[i] (next item to predict)
    
    Optimization: We left-pad sequences so that the prediction target
    is always at the last position of the sequence.
    """
    
    def __init__(self, user_sequences: Dict[int, List[int]], max_len: int = MAX_SEQ_LENGTH):
        self.max_len = max_len
        self.samples = []
        
        # Generate all training samples
        for user_id, items in user_sequences.items():
            # For each position in the sequence (starting from position 1)
            for i in range(1, len(items)):
                # Input: all items before position i (capped at max_len)
                input_seq = items[max(0, i - max_len):i]
                # Target: the item at position i
                target = items[i]
                self.samples.append((input_seq, target))
        
        print(f"Created {len(self.samples):,} training samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        input_seq, target = self.samples[idx]
        
        # Left-pad sequence to max_len (pad with 0, which we'll use as padding idx)
        # This ensures the most recent item is always at the last position
        padded = [0] * (self.max_len - len(input_seq)) + input_seq
        
        return torch.tensor(padded, dtype=torch.long), torch.tensor(target, dtype=torch.long)


def create_data_splits(user_sequences: Dict[int, List[int]], 
                       train_ratio=0.8, val_ratio=0.1):
    """
    Split users into train/val/test sets.
    
    We split by USER (not by sample) to avoid data leakage:
    - User A's sequences should not appear in both train and test
    """
    user_ids = list(user_sequences.keys())
    np.random.shuffle(user_ids)
    
    n_users = len(user_ids)
    train_end = int(n_users * train_ratio)
    val_end = int(n_users * (train_ratio + val_ratio))
    
    train_users = set(user_ids[:train_end])
    val_users = set(user_ids[train_end:val_end])
    test_users = set(user_ids[val_end:])
    
    train_seqs = {uid: seq for uid, seq in user_sequences.items() if uid in train_users}
    val_seqs = {uid: seq for uid, seq in user_sequences.items() if uid in val_users}
    test_seqs = {uid: seq for uid, seq in user_sequences.items() if uid in test_users}
    
    print(f"\nData splits:")
    print(f"  Train: {len(train_seqs):,} users")
    print(f"  Val:   {len(val_seqs):,} users")
    print(f"  Test:  {len(test_seqs):,} users")
    
    return train_seqs, val_seqs, test_seqs


# Create data splits
train_seqs, val_seqs, test_seqs = create_data_splits(user_sequences)

# Create datasets
train_dataset = SequenceDataset(train_seqs)
val_dataset = SequenceDataset(val_seqs)
test_dataset = SequenceDataset(test_seqs)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"\nData loaders created:")
print(f"  Train batches: {len(train_loader):,}")
print(f"  Val batches:   {len(val_loader):,}")
print(f"  Test batches:  {len(test_loader):,}")


---
## 4. Model Architecture (SASRec)

### Architecture Overview

```
Input Sequence [batch, seq_len]
        â†“
Item Embedding + Position Embedding
        â†“
Transformer Encoder (2 layers, 2 heads)
        â†“           â†‘
    [Causal Mask]
        â†“
Linear Projection â†’ [batch, seq_len, vocab_size]
        â†“
Take last position â†’ [batch, vocab_size]
```

### Key Design Decisions:
1. **Pre-trained Embeddings**: Initialize with embeddings from `item_embeddings.ipynb`
2. **Learnable Position Embeddings**: Unlike fixed sinusoidal, these adapt to our data
3. **Padding Index = 0**: Reserve index 0 for padding tokens


In [None]:
def load_pretrained_embeddings(vocab_size: int, embed_dim: int):
    """
    Load pre-trained item embeddings from item_embeddings.ipynb.
    
    Returns a numpy array of shape [vocab_size, embed_dim].
    If embeddings file doesn't exist, returns randomly initialized weights.
    """
    emb_path = os.path.join(EMBEDDINGS_DIR, "item_embeddings.parquet")
    
    if os.path.exists(emb_path):
        print("Loading pre-trained embeddings...")
        emb_df = pd.read_parquet(emb_path)
        
        # Stack embeddings into matrix
        pretrained = np.vstack(emb_df['embedding'].values)
        print(f"  Loaded embeddings: {pretrained.shape}")
        
        # Verify dimensions match
        if pretrained.shape[1] != embed_dim:
            print(f"  Warning: Embedding dim mismatch ({pretrained.shape[1]} vs {embed_dim})")
            print("  Using random initialization instead.")
            return None
        
        # Create new embedding matrix with padding at index 0
        # Shape: [vocab_size, embed_dim] where vocab_size includes padding
        embeddings = np.zeros((vocab_size, embed_dim))
        
        # Copy pretrained weights to indices 1..N
        # We assume the order in item_embeddings.parquet matches item_vocabulary.parquet
        # (which is true based on item_embeddings.ipynb logic)
        n_pretrained = pretrained.shape[0]
        n_vocab_items = vocab_size - 1
        
        n_copy = min(n_pretrained, n_vocab_items)
        embeddings[1:n_copy+1] = pretrained[:n_copy]
        
        return embeddings
    else:
        print(f"  Pre-trained embeddings not found at {emb_path}")
        print("  Using random initialization.")
        return None


class SASRec(nn.Module):
    """
    Self-Attentive Sequential Recommendation (SASRec) Model.
    
    Paper: "Self-Attentive Sequential Recommendation" (Kang & McAuley, 2018)
    
    Architecture:
    - Item Embedding (optionally pre-trained)
    - Learnable Position Embedding
    - Transformer Encoder with Causal Masking
    - Linear Prediction Head
    
    Args:
        vocab_size: Number of items in vocabulary
        embed_dim: Embedding dimension
        max_len: Maximum sequence length
        num_layers: Number of transformer layers
        num_heads: Number of attention heads
        hidden_dim: Feedforward network dimension
        dropout: Dropout rate
        pretrained_emb: Optional pre-trained embedding weights
    """
    
    def __init__(self, vocab_size: int, embed_dim: int = 128, max_len: int = 50,
                 num_layers: int = 2, num_heads: int = 2, hidden_dim: int = 256,
                 dropout: float = 0.1, pretrained_emb: Optional[np.ndarray] = None):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.max_len = max_len
        
        # Item Embedding (with padding_idx=0)
        self.item_embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Initialize with pre-trained weights if available
        if pretrained_emb is not None:
            self.item_embedding.weight.data.copy_(torch.tensor(pretrained_emb, dtype=torch.float32))
            print(f"  Initialized item embeddings with pre-trained weights")
        
        # Position Embedding (learnable)
        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True,  # [batch, seq, features]
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Layer Normalization
        self.ln = nn.LayerNorm(embed_dim)
        
        # Prediction Head
        self.fc = nn.Linear(embed_dim, vocab_size)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize position embeddings and FC layer."""
        nn.init.xavier_uniform_(self.pos_embedding.weight)
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)
    
    def forward(self, seq: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            seq: [batch, seq_len] - Item indices (0 = padding)
            
        Returns:
            logits: [batch, seq_len, vocab_size] - Prediction logits for each position
        """
        batch_size, seq_len = seq.shape
        
        # Create position indices [0, 1, 2, ..., seq_len-1]
        positions = torch.arange(seq_len, device=seq.device).unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        item_emb = self.item_embedding(seq)  # [batch, seq_len, embed_dim]
        pos_emb = self.pos_embedding(positions)  # [batch, seq_len, embed_dim]
        
        # Combine and apply dropout
        x = self.dropout(item_emb + pos_emb)
        
        # Create causal mask (upper triangular = masked)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(
            seq_len, device=seq.device
        )
        
        # Create padding mask (True = masked/padded position)
        padding_mask = (seq == 0)
        
        # Apply transformer
        x = self.transformer(x, mask=causal_mask, src_key_padding_mask=padding_mask)
        
        # Layer norm
        x = self.ln(x)
        
        # OPTIMIZATION: Only compute logits for the last position
        # We only need to predict the next item after the sequence end
        # This reduces output size from [batch, seq_len, vocab] to [batch, vocab]
        # saving massive amount of memory (11GB -> 200MB)
        x_last = x[:, -1, :]  # [batch, embed_dim]
        
        # Project to vocabulary
        logits = self.fc(x_last)  # [batch, vocab_size]
        
        return logits
    
    def predict(self, seq: torch.Tensor, k: int = 10) -> torch.Tensor:
        """
        Predict top-k next items.
        
        Args:
            seq: [batch, seq_len] - Item indices
            k: Number of top items to return
            
        Returns:
            top_k_items: [batch, k] - Top-k predicted item indices
        """
        self.eval()
        with torch.no_grad():
            logits = self.forward(seq)  # [batch, vocab_size]
            # Get top-k
            _, top_k = torch.topk(logits, k, dim=1)
        return top_k


# Load pre-trained embeddings
pretrained_emb = load_pretrained_embeddings(vocab_size, EMBEDDING_DIM)

# Create model
model = SASRec(
    vocab_size=vocab_size,
    embed_dim=EMBEDDING_DIM,
    max_len=MAX_SEQ_LENGTH,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT,
    pretrained_emb=pretrained_emb
).to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Summary:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Estimated size: {total_params * 4 / 1e6:.1f} MB (float32)")


---
## 5. Training

### Training Strategy:
1. **Loss**: CrossEntropyLoss (standard for multi-class classification)
2. **Optimizer**: AdamW (Adam with weight decay, recommended for Transformers)
3. **Learning Rate**: 1e-3 with ReduceLROnPlateau scheduler
4. **Memory Monitoring**: Print GPU usage every 1000 batches


In [None]:
def train_epoch(model, train_loader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for batch_idx, (seq, target) in enumerate(pbar):
        seq = seq.to(device)
        target = target.to(device)
        
        # Forward pass
        logits = model(seq)  # [batch, vocab_size]
        
        # Compute loss
        loss = F.cross_entropy(logits, target)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Print GPU memory every 1000 batches
        if batch_idx > 0 and batch_idx % 1000 == 0 and torch.cuda.is_available():
            gpu_mem = torch.cuda.memory_allocated() / 1e9
            print(f"  Batch {batch_idx}: Loss={loss.item():.4f}, GPU Memory={gpu_mem:.2f}GB")
    
    return total_loss / num_batches


def evaluate(model, data_loader, device):
    """Evaluate model on validation/test set."""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for seq, target in tqdm(data_loader, desc="Evaluating", leave=False):
            seq = seq.to(device)
            target = target.to(device)
            
            logits = model(seq)
            loss = F.cross_entropy(logits, target)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches


# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

# Training history
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')

print("\n" + "=" * 60)
print("TRAINING")
print("=" * 60)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print()

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, device)
    history['train_loss'].append(train_loss)
    
    # Validate
    val_loss = evaluate(model, val_loader, device)
    history['val_loss'].append(val_loss)
    
    # Learning rate scheduling
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'best_model.pt'))
        saved_marker = " (saved)"
    else:
        saved_marker = ""
    
    # Print summary
    lr_change = f" [lr: {old_lr:.6f} â†’ {new_lr:.6f}]" if old_lr != new_lr else ""
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}{lr_change}{saved_marker}")
    
    # GPU memory
    if torch.cuda.is_available():
        print(f"  GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB / {torch.cuda.max_memory_allocated() / 1e9:.2f}GB (peak)")
    print()

print(f"Best validation loss: {best_val_loss:.4f}")


---
## 6. Comprehensive Evaluation

### Technical Metrics:
1. **Hit Rate@K (HR@K)**: Was the ground truth item in the top-K predictions?
2. **NDCG@K**: Normalized Discounted Cumulative Gain - accounts for rank position

### Business Metrics:
1. **Catalog Coverage**: What % of items did we recommend at least once?
2. **Novelty Score**: Are we recommending popular or niche items?

### Visualizations:
1. **Loss Curve**: Training vs Validation loss
2. **Attention Heatmap**: Which past items influenced the prediction?


In [None]:
def calculate_metrics(model, data_loader, device, k_values=[5, 10, 20]):
    """
    Calculate HR@K and NDCG@K for multiple K values.
    
    Also tracks:
    - All recommended items (for coverage)
    - Correctly predicted items (for analysis)
    """
    model.eval()
    
    metrics = {k: {'hits': 0, 'ndcg': 0} for k in k_values}
    total_samples = 0
    recommended_items = set()
    max_k = max(k_values)
    
    with torch.no_grad():
        for seq, target in tqdm(data_loader, desc="Computing metrics"):
            seq = seq.to(device)
            target = target.to(device)
            
            # Get predictions
            logits = model(seq)  # [batch, vocab_size]
            
            # Get top-k predictions
            _, top_k_items = torch.topk(logits, max_k, dim=1)  # [batch, max_k]
            
            # Track recommended items
            recommended_items.update(top_k_items.cpu().numpy().flatten().tolist())
            
            # Calculate metrics for each K
            for k in k_values:
                top_k = top_k_items[:, :k]  # [batch, k]
                
                # Hit Rate: Is target in top-k?
                hits = (top_k == target.unsqueeze(1)).any(dim=1).float()
                metrics[k]['hits'] += hits.sum().item()
                
                # NDCG: Account for rank position
                # DCG = 1 / log2(rank + 1) if hit, else 0
                ranks = (top_k == target.unsqueeze(1)).nonzero()[:, 1] + 1  # 1-indexed ranks
                if len(ranks) > 0:
                    dcg = (1.0 / torch.log2(ranks.float() + 1)).sum().item()
                    metrics[k]['ndcg'] += dcg
            
            total_samples += len(target)
    
    # Calculate final metrics
    results = {}
    for k in k_values:
        results[f'HR@{k}'] = metrics[k]['hits'] / total_samples
        results[f'NDCG@{k}'] = metrics[k]['ndcg'] / total_samples
    
    results['num_items_recommended'] = len(recommended_items)
    results['total_samples'] = total_samples
    
    return results


# Load best model for evaluation
model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'best_model.pt')))

# Calculate metrics on test set
print("\n" + "=" * 60)
print("EVALUATION ON TEST SET")
print("=" * 60)

test_metrics = calculate_metrics(model, test_loader, device)

print(f"\nTechnical Metrics:")
print(f"  HR@5:   {test_metrics['HR@5']:.4f} ({test_metrics['HR@5']*100:.2f}%)")
print(f"  HR@10:  {test_metrics['HR@10']:.4f} ({test_metrics['HR@10']*100:.2f}%)")
print(f"  HR@20:  {test_metrics['HR@20']:.4f} ({test_metrics['HR@20']*100:.2f}%)")
print(f"  NDCG@10: {test_metrics['NDCG@10']:.4f}")

# Catalog Coverage
coverage = test_metrics['num_items_recommended'] / vocab_size * 100
print(f"\nBusiness Metrics:")
print(f"  Catalog Coverage: {test_metrics['num_items_recommended']:,} / {vocab_size:,} items ({coverage:.1f}%)")

# Random baseline comparison
random_hr10 = 10 / vocab_size * 100
print(f"\nComparison to Random Baseline:")
print(f"  Random HR@10: {random_hr10:.4f}%")
print(f"  Model HR@10:  {test_metrics['HR@10']*100:.2f}%")
print(f"  Lift: {test_metrics['HR@10']*100 / random_hr10:.1f}x better than random")


In [None]:
# Visualization: Training Loss Curve
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(history['train_loss']) + 1), history['train_loss'], 'b-o', label='Train Loss')
plt.plot(range(1, len(history['val_loss']) + 1), history['val_loss'], 'r-o', label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (CrossEntropy)')
plt.title('SASRec Training History')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_history.png'), dpi=150)
plt.show()
print(f"Saved: {OUTPUT_DIR}/training_history.png")


In [None]:
def visualize_attention(model, sample_seq, idx_to_item, layer_idx=0, head_idx=0):
    """
    Visualize attention weights for a sample sequence.
    
    Shows which past items the model "attends to" when making predictions.
    """
    model.eval()
    
    # Get the actual items (remove padding)
    actual_items = [idx for idx in sample_seq if idx != 0]
    if len(actual_items) < 3:
        print("Sequence too short for visualization")
        return
    
    # Prepare input
    seq_tensor = torch.tensor([sample_seq], dtype=torch.long).to(device)
    
    # Register hooks to capture attention weights
    attention_weights = []
    
    def hook_fn(module, input, output):
        # MultiheadAttention returns (output, attention_weights)
        if isinstance(output, tuple) and len(output) == 2:
            attn = output[1]
            if attn is not None:
                attention_weights.append(attn.detach().cpu())
    
    # Register hooks on attention layers
    hooks = []
    for layer in model.transformer.layers:
        hook = layer.self_attn.register_forward_hook(hook_fn)
        hooks.append(hook)
    
    # Forward pass
    with torch.no_grad():
        _ = model(seq_tensor)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    # Check if we captured attention
    if not attention_weights:
        print("Could not capture attention weights (may need to modify model)")
        return
    
    # Get attention from specified layer
    attn = attention_weights[layer_idx][0]  # [num_heads, seq_len, seq_len]
    attn = attn[head_idx].numpy()  # [seq_len, seq_len]
    
    # Focus on actual items (not padding)
    start_idx = MAX_SEQ_LENGTH - len(actual_items)
    attn = attn[start_idx:, start_idx:]
    
    # Get item names (truncated)
    item_labels = [idx_to_item.get(idx, f"[{idx}]")[:15] for idx in actual_items]
    
    # Plot
    plt.figure(figsize=(12, 10))
    sns.heatmap(attn, xticklabels=item_labels, yticklabels=item_labels,
                cmap='Blues', annot=False, fmt='.2f')
    plt.title(f'Attention Heatmap (Layer {layer_idx}, Head {head_idx})')
    plt.xlabel('Key (Past Items)')
    plt.ylabel('Query (Current Position)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'attention_heatmap.png'), dpi=150)
    plt.show()
    print(f"Saved: {OUTPUT_DIR}/attention_heatmap.png")


# Get a sample sequence for visualization
sample_batch = next(iter(test_loader))
sample_seq = sample_batch[0][0].tolist()  # First sequence from batch

print("Visualizing attention weights...")
print("(Note: This visualization works best with sequences that have clear patterns)")
visualize_attention(model, sample_seq, idx_to_item)


---
## 7. Interactive Production Demo

This section demonstrates the model in a production-like setting:
1. **Real Data Only**: Uses actual users from the test set
2. **Full Pipeline**: From user history to predictions
3. **Model Reasoning**: Explains *why* a recommendation makes sense


In [None]:
def get_item_domain(item_id: str) -> str:
    """Determine the domain of an item based on its ID prefix."""
    if item_id.startswith('fmcg_'):
        return 'Retail'
    elif item_id.startswith('nfmcg_'):
        return 'Marketplace'
    elif item_id.startswith('offer_'):
        return 'Offers'
    else:
        return 'Unknown'


def visualize_user_journey(user_id: int, user_sequences: Dict[int, List[int]], 
                           model: nn.Module, idx_to_item: Dict[int, str], 
                           k: int = 5, history_len: int = 10):
    """
    Interactive demo showing user journey and predictions.
    
    Outputs:
    1. History: Last N items the user actually clicked
    2. Prediction: Top K items the AI predicts
    3. Ground Truth: What the user actually clicked next
    4. Model Reasoning: Why the prediction makes sense
    """
    if user_id not in user_sequences:
        print(f"User {user_id} not found in dataset")
        return
    
    sequence = user_sequences[user_id]
    if len(sequence) < 3:
        print(f"User {user_id} has insufficient history ({len(sequence)} items)")
        return
    
    # Split into input and ground truth
    input_seq = sequence[:-1]
    ground_truth_idx = sequence[-1]
    
    # Get last N items for display
    display_history = input_seq[-history_len:]
    
    # Prepare model input
    padded_input = [0] * (MAX_SEQ_LENGTH - len(input_seq[-MAX_SEQ_LENGTH:])) + input_seq[-MAX_SEQ_LENGTH:]
    seq_tensor = torch.tensor([padded_input], dtype=torch.long).to(device)
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        logits = model(seq_tensor)  # [1, vocab_size]
        probs = F.softmax(logits[0], dim=0)
        top_probs, top_indices = torch.topk(probs, k)
    
    # Decode items
    history_items = [idx_to_item.get(idx, f"[Unknown:{idx}]") for idx in display_history]
    predicted_items = [idx_to_item.get(idx.item(), f"[Unknown:{idx.item()}]") for idx in top_indices]
    predicted_probs = [p.item() for p in top_probs]
    ground_truth_item = idx_to_item.get(ground_truth_idx, f"[Unknown:{ground_truth_idx}]")
    
    # Check if ground truth is in predictions
    hit = ground_truth_idx in [idx.item() for idx in top_indices]
    
    # Print formatted output
    print("=" * 70)
    print(f"USER JOURNEY ANALYSIS: User {user_id}")
    print("=" * 70)
    
    print(f"\nðŸ“œ HISTORY (Last {len(display_history)} Items):")
    for i, item in enumerate(history_items, 1):
        domain = get_item_domain(item)
        print(f"   {i:2}. [{domain:12}] {item}")
    
    print(f"\nðŸ¤– AI PREDICTION (Top {k} Next Items):")
    for i, (item, prob) in enumerate(zip(predicted_items, predicted_probs), 1):
        domain = get_item_domain(item)
        print(f"   {i}. [{domain:12}] {item} (confidence: {prob:.2%})")
    
    print(f"\nâœ… GROUND TRUTH:")
    domain = get_item_domain(ground_truth_item)
    hit_marker = " âœ“ HIT!" if hit else ""
    print(f"   User actually clicked: [{domain}] {ground_truth_item}{hit_marker}")
    
    # Model Reasoning
    print(f"\nðŸ’¡ MODEL REASONING:")
    history_domains = [get_item_domain(item) for item in history_items]
    domain_counts = pd.Series(history_domains).value_counts()
    
    if len(domain_counts) == 1:
        main_domain = domain_counts.index[0]
        print(f"   User has been browsing exclusively in {main_domain}.")
        print(f"   â†’ Model predicts more items from {main_domain} domain.")
    else:
        print(f"   User has been cross-browsing: {dict(domain_counts)}")
        pred_domains = [get_item_domain(item) for item in predicted_items]
        print(f"   â†’ Model predicts diverse items across domains.")
    
    return hit


# Demo with sample users from test set
print("\n" + "=" * 70)
print("INTERACTIVE PRODUCTION DEMO")
print("=" * 70)
print("\nDemonstrating predictions on real users from test set...\n")

# Get sample users
test_user_ids = list(test_seqs.keys())[:5]

total_hits = 0
for user_id in test_user_ids:
    hit = visualize_user_journey(user_id, test_seqs, model, idx_to_item)
    if hit:
        total_hits += 1
    print()

print(f"\nDemo Summary: {total_hits}/{len(test_user_ids)} predictions were hits (in top-5)")


---
## 8. Summary & Conclusion

### What We Built
A **production-ready Sequential Recommender System** using the SASRec architecture:
- Trained on 9.2M events from 286K users
- Leverages pre-trained item embeddings (456K items)
- Optimized for Google Colab Free Tier constraints

### Key Results
The model should show significant lift over random baseline:
- **HR@10**: Typically 5-15% (vs. random ~0.002%)
- **Catalog Coverage**: Diverse recommendations across catalog
- **Cross-Domain**: Links Retail and Marketplace behaviors

### Artifacts Saved
- `models/sequential_recommender/best_model.pt` - Trained model weights
- `models/sequential_recommender/training_history.png` - Loss curves
- `models/sequential_recommender/attention_heatmap.png` - Attention visualization

### Next Steps
1. **A/B Testing**: Deploy to production and measure lift in click-through rate
2. **Online Learning**: Update model with real-time user feedback
3. **Multi-Task Learning**: Jointly predict purchase probability
4. **Scale Up**: Train on full dataset with larger GPU (A100)


In [None]:
print("=" * 70)
print("SEQUENTIAL RECOMMENDER TRAINING COMPLETE")
print("=" * 70)

print(f"""
Artifacts saved to: {OUTPUT_DIR}
- best_model.pt: Trained model weights ({total_params:,} parameters)
- training_history.png: Loss curves
- attention_heatmap.png: Attention visualization

Final Performance:
- HR@10: {test_metrics['HR@10']*100:.2f}%
- NDCG@10: {test_metrics['NDCG@10']:.4f}
- Catalog Coverage: {coverage:.1f}%
""")
