# Week 3 Day 13: Training Loop Details - Part 1

## Overview
In this notebook, we'll implement and explore key components of an efficient and stable training loop for language models, focusing on:
- Mixed precision training (AMP)
- Gradient clipping and accumulation
- AdamW optimizer configuration
- Learning rate scheduling

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
import math
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset, random_split
from typing import List, Dict, Tuple, Optional

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Simple Language Model

Let's start by defining a simple language model to use for our training loop examples.

In [None]:
class SimpleLanguageModel(nn.Module):
    """A simple transformer-based language model for demonstration."""
    
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # Transformer decoder layer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        # Transformer decoder
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        # Output projection
        self.output = nn.Linear(d_model, vocab_size)
        
        # Initialize weights
        self.init_weights()
    
    def init_weights(self):
        """Initialize weights with Xavier uniform."""
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output.bias.data.zero_()
        self.output.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, src, src_mask=None):
        # Embed tokens
        src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        
        # Add positional encoding
        src = self.pos_encoder(src)
        
        # Create causal mask if not provided
        if src_mask is None:
            src_mask = generate_square_subsequent_mask(src.size(1)).to(src.device)
        
        # Pass through transformer decoder
        # Using None for memory since we're using decoder-only architecture
        output = self.transformer_decoder(src, None, tgt_mask=src_mask)
        
        # Project to vocabulary
        output = self.output(output)
        
        return output

class PositionalEncoding(nn.Module):
    """Positional encoding for sequences."""
    
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices, cosine to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer (not a parameter)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # Add positional encoding to input
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

def generate_square_subsequent_mask(sz):
    """Generate a square mask for the sequence."""
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

## 2. Synthetic Dataset

Let's create a synthetic dataset for our training examples.

In [None]:
class SyntheticDataset(Dataset):
    """Synthetic dataset for language modeling."""
    
    def __init__(self, vocab_size=1000, seq_len=64, size=10000):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.size = size
        
        # Generate synthetic data
        # We'll create sequences with some patterns to make them learnable
        self.data = []
        for _ in range(size):
            # Create a sequence with some patterns
            seq = torch.randint(0, vocab_size, (seq_len + 1,))
            
            # Add some patterns (e.g., repeated tokens, sequential tokens)
            pattern_start = torch.randint(0, seq_len - 10, (1,)).item()
            pattern_type = torch.randint(0, 3, (1,)).item()
            
            if pattern_type == 0:  # Repeated token
                token = torch.randint(0, vocab_size, (1,)).item()
                seq[pattern_start:pattern_start+5] = token
            elif pattern_type == 1:  # Sequential tokens
                start_token = torch.randint(0, vocab_size-5, (1,)).item()
                for i in range(5):
                    seq[pattern_start+i] = start_token + i
            else:  # Alternating tokens
                token1 = torch.randint(0, vocab_size, (1,)).item()
                token2 = torch.randint(0, vocab_size, (1,)).item()
                for i in range(6):
                    seq[pattern_start+i] = token1 if i % 2 == 0 else token2
            
            self.data.append(seq)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        seq = self.data[idx]
        
        # Input: all tokens except the last one
        # Target: all tokens except the first one
        x = seq[:-1]
        y = seq[1:]
        
        return x, y

# Create dataset and split into train/val
vocab_size = 1000
seq_len = 64
dataset_size = 5000

dataset = SyntheticDataset(vocab_size, seq_len, dataset_size)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Create dataloaders
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

# Check a sample batch
x_sample, y_sample = next(iter(train_dataloader))
print(f"Input shape: {x_sample.shape}")
print(f"Target shape: {y_sample.shape}")

## 3. Mixed Precision Training (AMP)

Let's implement mixed precision training using PyTorch's Automatic Mixed Precision (AMP).

In [None]:
def train_with_amp(model, train_dataloader, val_dataloader, epochs=5, lr=0.001):
    """Train a model using mixed precision."""
    model = model.to(device)
    
    # Create optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Create gradient scaler for AMP
    scaler = GradScaler()
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_time = time.time()
        
        for batch_idx, (x, y) in enumerate(train_dataloader):
            x, y = x.to(device), y.to(device)
            
            # Forward pass with autocast (mixed precision)
            with autocast():
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
            
            # Backward pass with gradient scaling
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            
            # Update weights with gradient scaling
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_dataloader)}, "
                      f"Loss: {loss.item():.4f}")
        
        # Validation
        val_loss = evaluate(model, val_dataloader, criterion)
        
        # Print epoch stats
        avg_loss = total_loss / len(train_dataloader)
        elapsed = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Time: {elapsed:.2f}s")
    
    return model

def evaluate(model, dataloader, criterion):
    """Evaluate model on dataloader."""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            
            # Forward pass with autocast
            with autocast():
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

## 4. Gradient Clipping and Accumulation

Now let's implement gradient clipping and accumulation.

In [None]:
def train_with_grad_clip_accum(model, train_dataloader, val_dataloader, epochs=5, lr=0.001, 
                              max_grad_norm=1.0, accumulation_steps=4):
    """Train with gradient clipping and accumulation."""
    model = model.to(device)
    
    # Create optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Create gradient scaler for AMP
    scaler = GradScaler()
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Effective batch size
    effective_batch_size = batch_size * accumulation_steps
    print(f"Effective batch size: {effective_batch_size}")
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_time = time.time()
        optimizer.zero_grad()
        
        for batch_idx, (x, y) in enumerate(train_dataloader):
            x, y = x.to(device), y.to(device)
            
            # Forward pass with autocast
            with autocast():
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
                # Normalize loss for gradient accumulation
                loss = loss / accumulation_steps
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            
            # Update weights after accumulation steps
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_dataloader):
                # Clip gradients
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                
                # Step optimizer and update scaler
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            total_loss += loss.item() * accumulation_steps
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_dataloader)}, "
                      f"Loss: {loss.item() * accumulation_steps:.4f}")
        
        # Validation
        val_loss = evaluate(model, val_dataloader, criterion)
        
        # Print epoch stats
        avg_loss = total_loss / len(train_dataloader)
        elapsed = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Time: {elapsed:.2f}s")
    
    return model