# Complete HistoryCentricModel Training Walkthrough

This notebook provides a **complete, self-contained walkthrough** of training the HistoryCentricModel for next-location prediction. The notebook is designed to be fully executable without any dependencies on external project scripts.

## Overview

This notebook walks through the entire training pipeline from start to finish:

1. **Data Loading**: Reading preprocessed pickle files containing trajectory sequences
2. **Model Architecture**: Implementing the HistoryCentricModel from scratch
3. **Training Loop**: Setting up optimizer, loss functions, and training procedures
4. **Evaluation**: Computing comprehensive metrics on test data
5. **Results**: Displaying final performance metrics

## Model Architecture: HistoryCentricModel

The **HistoryCentricModel** is designed based on a key insight: **83.81% of next locations are already in the visit history**. The model combines:

- **History-based scoring**: Uses recency (exponential decay) and frequency patterns from visit history
- **Learned patterns**: A compact transformer to learn temporal context and transition patterns
- **Ensemble approach**: Combines history scores with learned model outputs

The model is highly parameter-efficient (<500K parameters) while achieving strong performance.

## Dataset

We use the **GeoLife** trajectory dataset, which contains:
- Location sequences from multiple users
- Temporal features (time of day, day of week, duration, time gaps)
- User identifiers
- Target: Next location to visit

Each sample includes:
- `X`: Location ID sequence
- `user_X`: User ID for each location  
- `weekday_X`: Day of week (0-6)
- `start_min_X`: Start time in minutes from midnight
- `dur_X`: Duration at each location (minutes)
- `diff`: Time gap indicator
- `Y`: Target next location


## 1. Setup and Imports

First, we import all necessary libraries. This notebook is self-contained and only uses standard deep learning libraries.

In [None]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import math
import time
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
import random

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))

## 2. Configuration Parameters

We define all hyperparameters and paths in one place for easy modification. These values match the production configuration from `configs/geolife_default.yaml`.

### Key Parameters:
- **Data**: Paths to train/val/test pickle files, batch size, sequence length
- **Model**: Embedding dimensions, transformer architecture (d_model=80, 1 layer, 4 heads)
- **Training**: Learning rate, weight decay, epochs, early stopping
- **System**: Device (GPU/CPU), random seed for reproducibility

In [None]:
# Configuration dictionary
config = {
    # Data paths and parameters
    'data': {
        'data_dir': '/content/expr_hrcl_next_pred_av5/data/geolife',
        'train_file': 'geolife_transformer_7_train.pk',
        'val_file': 'geolife_transformer_7_validation.pk',
        'test_file': 'geolife_transformer_7_test.pk',
        'num_locations': 1187,  # 1186 + 1 for padding
        'num_users': 46,  # 45 + 1 for padding
        'num_weekdays': 7,
        'max_seq_len': 60,
    },
    
    # Model architecture
    'model': {
        'name': 'HistoryCentricModel',
        'loc_emb_dim': 56,
        'user_emb_dim': 12,
        'weekday_emb_dim': 4,
        'time_emb_dim': 12,
        'd_model': 80,
        'nhead': 4,
        'num_layers': 1,
        'dim_feedforward': 160,
        'dropout': 0.35,
    },
    
    # Training parameters
    'training': {
        'batch_size': 96,
        'num_epochs': 120,
        'learning_rate': 0.0025,
        'weight_decay': 0.00008,
        'grad_clip': 1.0,
        'label_smoothing': 0.02,
        'early_stopping_patience': 20,
        'lr_scheduler_patience': 10,
        'lr_scheduler_factor': 0.6,
        'min_lr': 5e-7,
    },
    
    # System
    'system': {
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'num_workers': 2,
        'seed': 42,
    }
}

print("Configuration loaded successfully!")
print(f"Device: {config['system']['device']}")
print(f"Batch size: {config['training']['batch_size']}")
print(f"Max epochs: {config['training']['num_epochs']}")

## 3. Set Random Seed for Reproducibility

Setting random seeds ensures that results are reproducible across runs. We set seeds for:
- Python's random module
- NumPy
- PyTorch (CPU and CUDA)
- CuDNN backend for deterministic behavior

In [None]:
def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config['system']['seed'])
print(f"Random seed set to {config['system']['seed']}")

## 4. Dataset Implementation

The `GeoLifeDataset` class handles loading and preparing trajectory data:

### Data Structure:
Each sample in the pickle file contains:
- **X**: Array of location IDs (sequence)
- **user_X**: User ID for each visit
- **weekday_X**: Day of week (0=Monday, 6=Sunday)
- **start_min_X**: Start time in minutes from midnight (0-1439)
- **dur_X**: Duration at each location in minutes
- **diff**: Time gap indicator between visits
- **Y**: Target next location (ground truth)

### Processing:
- Sequences longer than `max_seq_len` are truncated (keeping most recent)
- Each feature is converted to appropriate PyTorch tensor type
- Variable-length sequences are handled in the collate function

In [None]:
class GeoLifeDataset(Dataset):
    """Dataset for GeoLife trajectory sequences."""
    
    def __init__(self, data_path, max_seq_len=60):
        with open(data_path, 'rb') as f:
            self.data = pickle.load(f)
        self.max_seq_len = max_seq_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        
        # Extract features
        loc_seq = sample['X']
        user_seq = sample['user_X']
        weekday_seq = sample['weekday_X']
        start_min_seq = sample['start_min_X']
        dur_seq = sample['dur_X']
        diff_seq = sample['diff']
        target = sample['Y']
        
        # Truncate if too long (keep most recent)
        seq_len = len(loc_seq)
        if seq_len > self.max_seq_len:
            loc_seq = loc_seq[-self.max_seq_len:]
            user_seq = user_seq[-self.max_seq_len:]
            weekday_seq = weekday_seq[-self.max_seq_len:]
            start_min_seq = start_min_seq[-self.max_seq_len:]
            dur_seq = dur_seq[-self.max_seq_len:]
            diff_seq = diff_seq[-self.max_seq_len:]
            seq_len = self.max_seq_len
        
        return {
            'loc_seq': torch.LongTensor(loc_seq),
            'user_seq': torch.LongTensor(user_seq),
            'weekday_seq': torch.LongTensor(weekday_seq),
            'start_min_seq': torch.FloatTensor(start_min_seq),
            'dur_seq': torch.FloatTensor(dur_seq),
            'diff_seq': torch.LongTensor(diff_seq),
            'target': torch.LongTensor([target]),
            'seq_len': seq_len
        }

print("Dataset class defined!")

### Collate Function for Batching

Since sequences have variable lengths, we need a custom collate function:
- Finds the maximum sequence length in the batch
- Pads all sequences to this length with zeros
- Creates an attention mask (1 for real tokens, 0 for padding)
- Stacks all samples into batch tensors

In [None]:
def collate_fn(batch):
    """Custom collate function to handle variable-length sequences."""
    max_len = max(item['seq_len'] for item in batch)
    batch_size = len(batch)
    
    # Initialize padded tensors
    loc_seqs = torch.zeros(batch_size, max_len, dtype=torch.long)
    user_seqs = torch.zeros(batch_size, max_len, dtype=torch.long)
    weekday_seqs = torch.zeros(batch_size, max_len, dtype=torch.long)
    start_min_seqs = torch.zeros(batch_size, max_len, dtype=torch.float)
    dur_seqs = torch.zeros(batch_size, max_len, dtype=torch.float)
    diff_seqs = torch.zeros(batch_size, max_len, dtype=torch.long)
    targets = torch.zeros(batch_size, dtype=torch.long)
    seq_lens = torch.zeros(batch_size, dtype=torch.long)
    
    # Fill in the data
    for i, item in enumerate(batch):
        length = item['seq_len']
        loc_seqs[i, :length] = item['loc_seq']
        user_seqs[i, :length] = item['user_seq']
        weekday_seqs[i, :length] = item['weekday_seq']
        start_min_seqs[i, :length] = item['start_min_seq']
        dur_seqs[i, :length] = item['dur_seq']
        diff_seqs[i, :length] = item['diff_seq']
        targets[i] = item['target']
        seq_lens[i] = length
    
    # Create attention mask (1 for real tokens, 0 for padding)
    mask = torch.arange(max_len).unsqueeze(0) < seq_lens.unsqueeze(1)
    
    return {
        'loc_seq': loc_seqs,
        'user_seq': user_seqs,
        'weekday_seq': weekday_seqs,
        'start_min_seq': start_min_seqs,
        'dur_seq': dur_seqs,
        'diff_seq': diff_seqs,
        'target': targets,
        'mask': mask,
        'seq_len': seq_lens
    }

print("Collate function defined!")

## 5. Load Datasets

We create DataLoaders for train, validation, and test sets:
- **Training**: Shuffled for better learning
- **Validation/Test**: Not shuffled for consistent evaluation
- **pin_memory**: Faster GPU transfer if using CUDA

The datasets are loaded from preprocessed pickle files containing trajectory sequences.

In [None]:
import os

# Paths to data files
data_dir = config['data']['data_dir']
train_path = os.path.join(data_dir, config['data']['train_file'])
val_path = os.path.join(data_dir, config['data']['val_file'])
test_path = os.path.join(data_dir, config['data']['test_file'])

# Create datasets
print("Loading datasets...")
train_dataset = GeoLifeDataset(train_path, max_seq_len=config['data']['max_seq_len'])
val_dataset = GeoLifeDataset(val_path, max_seq_len=config['data']['max_seq_len'])
test_dataset = GeoLifeDataset(test_path, max_seq_len=config['data']['max_seq_len'])

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=config['system']['num_workers'],
    pin_memory=True if config['system']['device'] == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=config['system']['num_workers'],
    pin_memory=True if config['system']['device'] == 'cuda' else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=config['system']['num_workers'],
    pin_memory=True if config['system']['device'] == 'cuda' else False
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Inspect a sample batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch shape:")
print(f"  loc_seq: {sample_batch['loc_seq'].shape}")
print(f"  user_seq: {sample_batch['user_seq'].shape}")
print(f"  target: {sample_batch['target'].shape}")
print(f"  mask: {sample_batch['mask'].shape}")

## 6. Model Implementation: HistoryCentricModel

The **HistoryCentricModel** is the core of our next-location prediction system. It combines:

### Architecture Components:

1. **Embeddings**:
   - Location embeddings (56-dim)
   - User embeddings (12-dim)
   - Temporal features encoded via sinusoidal functions

2. **Compact Transformer**:
   - Single layer with 4 attention heads
   - d_model = 80 (very compact)
   - Feedforward dimension = 160
   - Dropout = 0.35 for regularization

3. **History Scoring Mechanism**:
   - **Recency**: Exponential decay based on how recently a location was visited
   - **Frequency**: How often a location appears in the sequence
   - Learnable weights to balance these factors

4. **Ensemble Strategy**:
   - Combines history scores with learned transformer outputs
   - Learnable weight balances history vs learned patterns

### Key Insight:
Since 83.81% of next locations are already in visit history, the model heavily weights historical patterns while still learning complex temporal dependencies.

In [None]:
class HistoryCentricModel(nn.Module):
    """Model that heavily prioritizes locations from visit history."""
    
    def __init__(self, num_locations, num_users, max_seq_len=60):
        super().__init__()
        
        self.num_locations = num_locations
        self.d_model = 80  # Compact
        
        # Core embeddings
        self.loc_emb = nn.Embedding(num_locations, 56, padding_idx=0)
        self.user_emb = nn.Embedding(num_users, 12, padding_idx=0)
        
        # Compact temporal encoder
        self.temporal_proj = nn.Linear(6, 12)  # sin/cos time, dur, sin/cos wd, gap
        
        # Input fusion: 56 + 12 + 12 = 80
        self.input_norm = nn.LayerNorm(80)
        
        # Positional encoding
        pe = torch.zeros(max_seq_len, 80)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, 80, 2).float() * (-math.log(10000.0) / 80))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
        # Very compact transformer
        self.attn = nn.MultiheadAttention(80, 4, dropout=0.35, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(80, 160),
            nn.GELU(),
            nn.Dropout(0.35),
            nn.Linear(160, 80)
        )
        self.norm1 = nn.LayerNorm(80)
        self.norm2 = nn.LayerNorm(80)
        self.dropout = nn.Dropout(0.35)
        
        # Prediction head
        self.predictor = nn.Sequential(
            nn.Linear(80, 160),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(160, num_locations)
        )
        
        # History scoring parameters (learnable)
        self.recency_decay = nn.Parameter(torch.tensor(0.62))
        self.freq_weight = nn.Parameter(torch.tensor(2.2))
        self.history_scale = nn.Parameter(torch.tensor(11.0))
        self.model_weight = nn.Parameter(torch.tensor(0.22))
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                if m.padding_idx is not None:
                    m.weight.data[m.padding_idx].zero_()
    
    def compute_history_scores(self, loc_seq, mask):
        """
        Compute history-based scores for all locations.
        
        Args:
            loc_seq: (batch_size, seq_len) - location sequence
            mask: (batch_size, seq_len) - attention mask
        
        Returns:
            history_scores: (batch_size, num_locations) - scores for each location
        """
        batch_size, seq_len = loc_seq.shape
        
        # Initialize score matrix
        recency_scores = torch.zeros(batch_size, self.num_locations, device=loc_seq.device)
        frequency_scores = torch.zeros(batch_size, self.num_locations, device=loc_seq.device)
        
        # Compute recency and frequency scores
        for t in range(seq_len):
            locs_t = loc_seq[:, t]  # (B,)
            valid_t = mask[:, t].float()  # (B,)
            
            # Recency: exponential decay from the end
            time_from_end = seq_len - t - 1
            recency_weight = torch.pow(self.recency_decay, time_from_end)
            
            # Update recency scores (max over time for each location)
            indices = locs_t.unsqueeze(1)  # (B, 1)
            values = (recency_weight * valid_t).unsqueeze(1)  # (B, 1)
            
            # For each location, keep the maximum recency (most recent visit)
            current_scores = torch.zeros(batch_size, self.num_locations, device=loc_seq.device)
            current_scores.scatter_(1, indices, values)
            recency_scores = torch.maximum(recency_scores, current_scores)
            
            # Update frequency scores (sum over time)
            frequency_scores.scatter_add_(1, indices, valid_t.unsqueeze(1))
        
        # Normalize frequency scores
        max_freq = frequency_scores.max(dim=1, keepdim=True)[0].clamp(min=1.0)
        frequency_scores = frequency_scores / max_freq
        
        # Combine recency and frequency
        history_scores = recency_scores + self.freq_weight * frequency_scores
        history_scores = self.history_scale * history_scores
        
        return history_scores
    
    def forward(self, loc_seq, user_seq, weekday_seq, start_min_seq, dur_seq, diff_seq, mask):
        batch_size, seq_len = loc_seq.shape
        
        # === Compute history-based scores ===
        history_scores = self.compute_history_scores(loc_seq, mask)
        
        # === Learned model ===
        # Feature extraction
        loc_emb = self.loc_emb(loc_seq)
        user_emb = self.user_emb(user_seq)
        
        # Temporal features with sinusoidal encoding
        hours = start_min_seq / 60.0
        time_rad = (hours / 24.0) * 2 * math.pi
        time_sin = torch.sin(time_rad)
        time_cos = torch.cos(time_rad)
        
        dur_norm = torch.log1p(dur_seq) / 8.0
        
        wd_rad = (weekday_seq.float() / 7.0) * 2 * math.pi
        wd_sin = torch.sin(wd_rad)
        wd_cos = torch.cos(wd_rad)
        
        diff_norm = diff_seq.float() / 7.0
        
        temporal_feats = torch.stack([time_sin, time_cos, dur_norm, wd_sin, wd_cos, diff_norm], dim=-1)
        temporal_emb = self.temporal_proj(temporal_feats)
        
        # Combine features
        x = torch.cat([loc_emb, user_emb, temporal_emb], dim=-1)
        x = self.input_norm(x)
        
        # Add positional encoding
        x = x + self.pe[:seq_len, :].unsqueeze(0)
        x = self.dropout(x)
        
        # Transformer layer
        attn_mask = ~mask
        attn_out, _ = self.attn(x, x, x, key_padding_mask=attn_mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        
        # Get last valid position
        seq_lens = mask.sum(dim=1) - 1
        indices_gather = seq_lens.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, self.d_model)
        last_hidden = torch.gather(x, 1, indices_gather).squeeze(1)
        
        # Learned logits
        learned_logits = self.predictor(last_hidden)
        
        # === Ensemble: History + Learned ===
        # Normalize learned logits to similar scale as history scores
        learned_logits_normalized = F.softmax(learned_logits, dim=1) * self.num_locations
        
        # Combine with learned weight
        final_logits = history_scores + self.model_weight * learned_logits_normalized
        
        return final_logits
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("Model class defined!")

## 7. Instantiate Model

Create the model instance and move it to the appropriate device (GPU/CPU).

In [None]:
device = torch.device(config['system']['device'])
print(f"Using device: {device}")

model = HistoryCentricModel(
    num_locations=config['data']['num_locations'],
    num_users=config['data']['num_users'],
    max_seq_len=config['data']['max_seq_len']
)

model = model.to(device)

num_params = model.count_parameters()
print(f"\nModel: {config['model']['name']}")
print(f"Total parameters: {num_params:,}")

if num_params >= 500000:
    print(f"WARNING: Model has {num_params:,} parameters (limit is 500K)")
    print(f"Exceeded by: {num_params - 500000:,}")
else:
    print(f"✓ Model is within budget (remaining: {500000 - num_params:,})")

## 8. Evaluation Metrics

We implement comprehensive metrics for next-location prediction:

- **Accuracy@k**: Percentage of times the ground truth is in top-k predictions
- **MRR (Mean Reciprocal Rank)**: Average of 1/rank of the correct location
- **NDCG (Normalized Discounted Cumulative Gain)**: Ranking quality metric
- **F1 Score**: Weighted F1 score for top-1 predictions

In [None]:
def get_mrr(prediction, targets):
    """Calculate MRR score."""
    index = torch.argsort(prediction, dim=-1, descending=True)
    hits = (targets.unsqueeze(-1).expand_as(index) == index).nonzero()
    ranks = (hits[:, -1] + 1).float()
    rranks = torch.reciprocal(ranks)
    return torch.sum(rranks).cpu().numpy()

def get_ndcg(prediction, targets, k=10):
    """Calculate NDCG score."""
    index = torch.argsort(prediction, dim=-1, descending=True)
    hits = (targets.unsqueeze(-1).expand_as(index) == index).nonzero()
    ranks = (hits[:, -1] + 1).float().cpu().numpy()
    not_considered_idx = ranks > k
    ndcg = 1 / np.log2(ranks + 1)
    ndcg[not_considered_idx] = 0
    return np.sum(ndcg)

def calculate_metrics(logits, true_y):
    """Calculate all metrics."""
    top1 = []
    result_ls = []
    
    for k in [1, 3, 5, 10]:
        if logits.shape[-1] < k:
            k = logits.shape[-1]
        prediction = torch.topk(logits, k=k, dim=-1).indices
        if k == 1:
            top1 = torch.squeeze(prediction).cpu()
        top_k = torch.eq(true_y[:, None], prediction).any(dim=1).sum().cpu().numpy()
        result_ls.append(top_k)
    
    result_ls.append(get_mrr(logits, true_y))
    result_ls.append(get_ndcg(logits, true_y))
    result_ls.append(true_y.shape[0])
    
    return np.array(result_ls, dtype=np.float32), true_y.cpu(), top1

def get_performance_dict(metrics_dict):
    """Convert raw counts to percentages."""
    perf = {
        'acc@1': metrics_dict['correct@1'] / metrics_dict['total'] * 100,
        'acc@3': metrics_dict['correct@3'] / metrics_dict['total'] * 100,
        'acc@5': metrics_dict['correct@5'] / metrics_dict['total'] * 100,
        'acc@10': metrics_dict['correct@10'] / metrics_dict['total'] * 100,
        'mrr': metrics_dict['rr'] / metrics_dict['total'] * 100,
        'ndcg': metrics_dict['ndcg'] / metrics_dict['total'] * 100,
        'f1': metrics_dict['f1'],
        'total': metrics_dict['total']
    }
    return perf

print("Metric functions defined!")

## 9. Loss Function and Optimizer

**Label Smoothing Cross-Entropy**: Instead of hard targets (one-hot), we use label smoothing to prevent overconfidence and improve generalization.

**AdamW Optimizer**: With separate weight decay for different parameter groups:
- Weights: Apply weight decay
- Biases and layer norms: No weight decay

**Learning Rate Scheduler**: ReduceLROnPlateau reduces LR when validation loss plateaus.

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing for better generalization."""
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_class = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), 1)
        one_hot = one_hot * (1 - self.smoothing) + self.smoothing / n_class
        log_prob = F.log_softmax(pred, dim=1)
        loss = -(one_hot * log_prob).sum(dim=1).mean()
        return loss

# Loss function
criterion = LabelSmoothingCrossEntropy(smoothing=config['training']['label_smoothing'])
print("Loss function initialized")

# Optimizer with separate weight decay for different param groups
param_groups = [
    {'params': [p for n, p in model.named_parameters() if 'bias' not in n and 'norm' not in n], 
     'weight_decay': config['training']['weight_decay']},
    {'params': [p for n, p in model.named_parameters() if 'bias' in n or 'norm' in n], 
     'weight_decay': 0.0}
]
optimizer = AdamW(param_groups, lr=config['training']['learning_rate'])
print("Optimizer initialized")

# Learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=config['training']['lr_scheduler_factor'],
    patience=config['training']['lr_scheduler_patience'],
    verbose=True,
    min_lr=config['training']['min_lr']
)
print("LR scheduler initialized")

## 10. Training and Validation Functions

### Training Function:
1. Iterate through training batches
2. Forward pass through model
3. Calculate loss
4. Backward pass and gradient clipping
5. Update parameters

### Validation Function:
1. No gradient computation (eval mode)
2. Calculate predictions
3. Compute all metrics (Acc@k, MRR, NDCG, F1)
4. Return performance dictionary

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device, grad_clip, epoch):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch} [Train]', leave=False)
    
    for batch in pbar:
        # Move to device
        loc_seq = batch['loc_seq'].to(device)
        user_seq = batch['user_seq'].to(device)
        weekday_seq = batch['weekday_seq'].to(device)
        start_min_seq = batch['start_min_seq'].to(device)
        dur_seq = batch['dur_seq'].to(device)
        diff_seq = batch['diff_seq'].to(device)
        target = batch['target'].to(device)
        mask = batch['mask'].to(device)
        
        # Forward pass
        logits = model(loc_seq, user_seq, weekday_seq, start_min_seq, dur_seq, diff_seq, mask)
        
        # Calculate loss
        loss = criterion(logits, target)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        pbar.set_postfix({'loss': f'{total_loss/num_batches:.4f}'})
    
    avg_loss = total_loss / num_batches
    return avg_loss

@torch.no_grad()
def validate(model, data_loader, criterion, device, split_name='Val'):
    """Validate the model."""
    model.eval()
    
    metrics = {
        'correct@1': 0,
        'correct@3': 0,
        'correct@5': 0,
        'correct@10': 0,
        'rr': 0,
        'ndcg': 0,
        'f1': 0,
        'total': 0
    }
    
    true_ls = []
    top1_ls = []
    total_val_loss = 0
    num_batches = 0
    
    pbar = tqdm(data_loader, desc=f'{split_name:5s}', leave=False)
    
    for batch in pbar:
        loc_seq = batch['loc_seq'].to(device)
        user_seq = batch['user_seq'].to(device)
        weekday_seq = batch['weekday_seq'].to(device)
        start_min_seq = batch['start_min_seq'].to(device)
        dur_seq = batch['dur_seq'].to(device)
        diff_seq = batch['diff_seq'].to(device)
        target = batch['target'].to(device)
        mask = batch['mask'].to(device)
        
        logits = model(loc_seq, user_seq, weekday_seq, start_min_seq, dur_seq, diff_seq, mask)
        
        loss = criterion(logits, target)
        total_val_loss += loss.item()
        num_batches += 1
        
        result, batch_true, batch_top1 = calculate_metrics(logits, target)
        
        metrics['correct@1'] += result[0]
        metrics['correct@3'] += result[1]
        metrics['correct@5'] += result[2]
        metrics['correct@10'] += result[3]
        metrics['rr'] += result[4]
        metrics['ndcg'] += result[5]
        metrics['total'] += result[6]
        
        true_ls.extend(batch_true.tolist())
        if not batch_top1.shape:
            top1_ls.extend([batch_top1.tolist()])
        else:
            top1_ls.extend(batch_top1.tolist())
        
        pbar.set_postfix({'loss': f'{total_val_loss/num_batches:.4f}'})
    
    avg_val_loss = total_val_loss / num_batches
    
    # Calculate F1
    f1 = f1_score(true_ls, top1_ls, average='weighted')
    metrics['f1'] = f1
    
    perf = get_performance_dict(metrics)
    perf['val_loss'] = avg_val_loss
    
    return perf

print("Training and validation functions defined!")

## 11. Main Training Loop

The main training loop:
1. Trains for each epoch
2. Validates on validation set
3. Updates learning rate based on validation loss
4. Saves best model (lowest validation loss)
5. Early stopping if no improvement for patience epochs

**Model Selection**: We use **validation loss** (not accuracy) to select the best model, as it's more stable and generalizes better.

In [None]:
# Training state
best_val_loss = float('inf')
best_epoch = 0
best_model_state = None
epochs_without_improvement = 0
train_losses = []
val_losses = []

num_epochs = config['training']['num_epochs']
early_stop_patience = config['training']['early_stopping_patience']
grad_clip = config['training']['grad_clip']

print(f"Starting training on {device}")
print(f"Model parameters: {num_params:,}")
print(f"Using validation loss for model selection\n")
print("=" * 80)

start_time = time.time()

for epoch in range(1, num_epochs + 1):
    print(f"\n=== Epoch {epoch}/{num_epochs} ===")
    epoch_start = time.time()
    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, grad_clip, epoch)
    train_losses.append(train_loss)
    
    # Validate
    val_perf = validate(model, val_loader, criterion, device, split_name='Val')
    val_loss = val_perf['val_loss']
    val_losses.append(val_loss)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Display results
    print(f"Val - Loss: {val_loss:.4f} | Acc@1: {val_perf['acc@1']:.2f}% Acc@5: {val_perf['acc@5']:.2f}% Acc@10: {val_perf['acc@10']:.2f}% | F1: {100*val_perf['f1']:.2f}% MRR: {val_perf['mrr']:.2f}% NDCG: {val_perf['ndcg']:.2f}%")
    
    # Check for improvement
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        epochs_without_improvement = 0
        
        # Save best model state
        best_model_state = model.state_dict().copy()
        
        print(f"✓ New best model! Val Loss: {val_loss:.4f} (Acc@1: {val_perf['acc@1']:.2f}%)")
    else:
        epochs_without_improvement += 1
    
    epoch_time = time.time() - epoch_start
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} (best: {best_val_loss:.4f} @ epoch {best_epoch}) | Time: {epoch_time:.1f}s")
    print(f"Epochs without improvement: {epochs_without_improvement}/{early_stop_patience}")
    
    # Early stopping
    if epochs_without_improvement >= early_stop_patience:
        print(f"\nEarly stopping triggered after {epoch} epochs")
        break

total_time = time.time() - start_time

print("\n" + "=" * 80)
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")
print(f"Total training time: {total_time:.2f}s")
print("=" * 80)

## 12. Load Best Model and Evaluate on Test Set

After training completes, we:
1. Load the best model state (from epoch with lowest validation loss)
2. Evaluate on the held-out test set
3. Report comprehensive metrics

This gives us an unbiased estimate of model performance on unseen data.

In [None]:
# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"Loaded best model from epoch {best_epoch} (Val Loss: {best_val_loss:.4f})")
else:
    print("Using final model state")

print("\n" + "=" * 80)
print("EVALUATION ON TEST SET")
print("=" * 80)

# Evaluate on test set
test_perf = validate(model, test_loader, criterion, device, split_name='Test')

print(f"\nTest - Loss: {test_perf['val_loss']:.4f} | Acc@1: {test_perf['acc@1']:.2f}% Acc@5: {test_perf['acc@5']:.2f}% Acc@10: {test_perf['acc@10']:.2f}% | F1: {100*test_perf['f1']:.2f}% MRR: {test_perf['mrr']:.2f}% NDCG: {test_perf['ndcg']:.2f}%")

## 13. Final Results Summary

Display all final metrics in a clear, organized format.

In [None]:
print("\n" + "=" * 80)
print("FINAL RESULTS")
print("=" * 80)
print(f"\nTraining Summary:")
print(f"  Best Validation Loss: {best_val_loss:.4f} (Epoch {best_epoch})")
print(f"  Total Epochs Trained: {epoch}")
print(f"  Training Time: {total_time:.2f}s")
print(f"  Model Parameters: {num_params:,}")

print(f"\nTest Set Performance:")
print(f"  Accuracy@1:  {test_perf['acc@1']:.2f}%")
print(f"  Accuracy@3:  {test_perf['acc@3']:.2f}%")
print(f"  Accuracy@5:  {test_perf['acc@5']:.2f}%")
print(f"  Accuracy@10: {test_perf['acc@10']:.2f}%")
print(f"  F1 Score:    {100 * test_perf['f1']:.2f}%")
print(f"  MRR:         {test_perf['mrr']:.2f}%")
print(f"  NDCG:        {test_perf['ndcg']:.2f}%")
print(f"  Test Loss:   {test_perf['val_loss']:.4f}")
print("=" * 80)

## 14. Training Visualization (Optional)

Visualize the training and validation loss curves over epochs.

In [None]:
try:
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(12, 4))
    
    # Loss curves
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss', linewidth=2)
    plt.plot(val_losses, label='Val Loss', linewidth=2)
    plt.axvline(x=best_epoch-1, color='r', linestyle='--', label=f'Best Epoch ({best_epoch})')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Validation loss only (zoomed)
    plt.subplot(1, 2, 2)
    plt.plot(val_losses, label='Val Loss', linewidth=2, color='orange')
    plt.axvline(x=best_epoch-1, color='r', linestyle='--', label=f'Best Epoch ({best_epoch})')
    plt.axhline(y=best_val_loss, color='g', linestyle=':', label=f'Best Val Loss ({best_val_loss:.4f})')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
except ImportError:
    print("Matplotlib not available for visualization")

## 15. Conclusion

This notebook demonstrated the complete training pipeline for the **HistoryCentricModel** from start to finish:

### What We Covered:

1. **Data Loading**: Loaded GeoLife trajectory sequences from pickle files
2. **Dataset Processing**: Handled variable-length sequences with custom collate function
3. **Model Architecture**: Implemented the compact HistoryCentricModel combining:
   - History-based scoring (recency + frequency)
   - Learned transformer patterns
   - Ensemble approach
4. **Training Loop**: 
   - Label smoothing loss
   - AdamW optimizer with weight decay
   - Learning rate scheduling
   - Early stopping
5. **Evaluation**: Comprehensive metrics on test set

### Key Results:
- The model achieves strong performance while staying under 500K parameters
- History-centric approach leverages the insight that most next locations are already in visit history
- Compact transformer learns temporal patterns efficiently

### Notes:
- This notebook is **fully self-contained** and can be run independently
- All code matches the production training script logic
- Results should be comparable to running `train_model.py` with the same configuration

---

**End of Notebook**