# Week 3 Day 13: Training Loop Details - Part 2

## Overview
In this notebook, we'll continue exploring key components of an efficient training loop, focusing on:
- Learning rate scheduling
- Putting it all together in a complete training loop
- Monitoring and visualization

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
import math
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, OneCycleLR
from typing import List, Dict, Tuple, Optional

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Learning Rate Scheduling

Let's implement different learning rate schedules commonly used for training language models.

In [None]:
def get_linear_warmup_lr_scheduler(optimizer, warmup_steps, total_steps):
    """Linear warmup followed by linear decay."""
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(
            0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps))
        )
    
    return LambdaLR(optimizer, lr_lambda)

def get_cosine_warmup_lr_scheduler(optimizer, warmup_steps, total_steps):
    """Linear warmup followed by cosine decay."""
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return LambdaLR(optimizer, lr_lambda)

def plot_lr_schedule(scheduler, steps):
    """Plot learning rate schedule."""
    lrs = []
    for i in range(steps):
        scheduler.step()
        lrs.append(scheduler.get_last_lr()[0])
    
    plt.figure(figsize=(10, 5))
    plt.plot(lrs)
    plt.title('Learning Rate Schedule')
    plt.xlabel('Steps')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.show()

In [None]:
# Create a dummy model and optimizer
dummy_model = nn.Linear(10, 10)
optimizer = torch.optim.AdamW(dummy_model.parameters(), lr=0.001)

# Example parameters
total_steps = 1000
warmup_steps = 100  # 10% warmup

# Create schedulers
linear_scheduler = get_linear_warmup_lr_scheduler(optimizer, warmup_steps, total_steps)
plot_lr_schedule(linear_scheduler, total_steps)

# Reset optimizer for next scheduler
optimizer = torch.optim.AdamW(dummy_model.parameters(), lr=0.001)
cosine_scheduler = get_cosine_warmup_lr_scheduler(optimizer, warmup_steps, total_steps)
plot_lr_schedule(cosine_scheduler, total_steps)

# One-cycle policy
optimizer = torch.optim.AdamW(dummy_model.parameters(), lr=0.001)
onecycle_scheduler = OneCycleLR(
    optimizer, max_lr=0.01, total_steps=total_steps,
    pct_start=0.3, anneal_strategy='cos'
)
plot_lr_schedule(onecycle_scheduler, total_steps)

## 2. Complete Training Loop

Now let's put everything together into a complete training loop with all the components we've discussed.

In [None]:
# First, let's define a simple model and dataset if running this notebook separately
# If you're running this after Part 1, you can skip these definitions

try:
    # Check if model and dataset are defined from Part 1
    SimpleLanguageModel
    train_dataloader
    val_dataloader
    vocab_size
except NameError:
    print("Defining model and dataset for standalone execution...")
    
    # Define a simple model
    class SimpleLanguageModel(nn.Module):
        def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, d_model)
            encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
            self.output = nn.Linear(d_model, vocab_size)
        
        def forward(self, x):
            # Create causal mask
            mask = torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1)
            mask = mask.to(x.device)
            
            # Forward pass
            x = self.embedding(x)
            x = self.transformer(x, mask=mask)
            x = self.output(x)
            return x
    
    # Create synthetic dataset
    vocab_size = 1000
    seq_len = 64
    dataset_size = 1000
    batch_size = 32
    
    # Simple dataset that returns random data
    class SyntheticDataset(Dataset):
        def __init__(self, size):
            self.size = size
        
        def __len__(self):
            return self.size
        
        def __getitem__(self, idx):
            x = torch.randint(0, vocab_size, (seq_len,))
            y = torch.randint(0, vocab_size, (seq_len,))
            return x, y
    
    # Create dataloaders
    dataset = SyntheticDataset(dataset_size)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [None]:
class TrainingLogger:
    """Logger for tracking and visualizing training progress."""
    
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.learning_rates = []
        self.grad_norms = []
        self.times = []
    
    def log_step(self, train_loss, val_loss=None, lr=None, grad_norm=None, time_taken=None):
        self.train_losses.append(train_loss)
        if val_loss is not None:
            self.val_losses.append(val_loss)
        if lr is not None:
            self.learning_rates.append(lr)
        if grad_norm is not None:
            self.grad_norms.append(grad_norm)
        if time_taken is not None:
            self.times.append(time_taken)
    
    def plot_losses(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Train Loss')
        if self.val_losses:
            plt.plot(self.val_losses, label='Val Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Steps' if len(self.val_losses) == 0 else 'Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.show()
    
    def plot_learning_rates(self):
        if not self.learning_rates:
            return
        
        plt.figure(figsize=(10, 5))
        plt.plot(self.learning_rates)
        plt.title('Learning Rate Schedule')
        plt.xlabel('Steps')
        plt.ylabel('Learning Rate')
        plt.grid(True)
        plt.show()
    
    def plot_grad_norms(self):
        if not self.grad_norms:
            return
        
        plt.figure(figsize=(10, 5))
        plt.plot(self.grad_norms)
        plt.title('Gradient Norm')
        plt.xlabel('Steps')
        plt.ylabel('Norm')
        plt.grid(True)
        plt.show()

In [None]:
def complete_training_loop(model, train_dataloader, val_dataloader, epochs=5, 
                           lr=0.001, warmup_ratio=0.1, max_grad_norm=1.0, 
                           accumulation_steps=4, use_amp=True):
    """Complete training loop with all components."""
    model = model.to(device)
    
    # Create optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    
    # Create gradient scaler for AMP
    scaler = GradScaler() if use_amp else None
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Calculate steps
    steps_per_epoch = len(train_dataloader) // accumulation_steps
    total_steps = steps_per_epoch * epochs
    warmup_steps = int(total_steps * warmup_ratio)
    
    # Create scheduler
    scheduler = get_cosine_warmup_lr_scheduler(optimizer, warmup_steps, total_steps)
    
    # Create logger
    logger = TrainingLogger()
    
    # Training loop
    global_step = 0
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_time = time.time()
        optimizer.zero_grad()
        
        for batch_idx, (x, y) in enumerate(train_dataloader):
            x, y = x.to(device), y.to(device)
            
            # Forward pass with autocast if using AMP
            if use_amp:
                with autocast():
                    output = model(x)
                    loss = criterion(output.view(-1, vocab_size), y.view(-1))
                    # Normalize loss for gradient accumulation
                    loss = loss / accumulation_steps
            else:
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
                # Normalize loss for gradient accumulation
                loss = loss / accumulation_steps
            
            # Backward pass with gradient scaling if using AMP
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            
            # Update weights after accumulation steps
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_dataloader):
                # Clip gradients
                if use_amp:
                    scaler.unscale_(optimizer)
                
                # Calculate gradient norm for logging
                grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) 
                                                  for p in model.parameters() if p.grad is not None]), 2)
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                
                # Step optimizer and update scaler if using AMP
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                
                # Step scheduler
                scheduler.step()
                
                # Log learning rate and gradient norm
                current_lr = scheduler.get_last_lr()[0]
                logger.log_step(loss.item() * accumulation_steps, lr=current_lr, grad_norm=grad_norm.item())
                
                # Reset gradients
                optimizer.zero_grad()
                
                # Increment global step
                global_step += 1
            
            total_loss += loss.item() * accumulation_steps
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_dataloader)}, "
                      f"Loss: {loss.item() * accumulation_steps:.4f}, "
                      f"LR: {scheduler.get_last_lr()[0]:.6f}")
        
        # Validation
        val_loss = evaluate(model, val_dataloader, criterion, use_amp)
        
        # Print epoch stats
        avg_loss = total_loss / len(train_dataloader)
        elapsed = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Time: {elapsed:.2f}s")
        
        # Log epoch results
        logger.log_step(avg_loss, val_loss=val_loss, time_taken=elapsed)
    
    # Plot training progress
    logger.plot_losses()
    logger.plot_learning_rates()
    logger.plot_grad_norms()
    
    return model, logger

def evaluate(model, dataloader, criterion, use_amp=True):
    """Evaluate model on dataloader."""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            
            # Forward pass with autocast if using AMP
            if use_amp:
                with autocast():
                    output = model(x)
                    loss = criterion(output.view(-1, vocab_size), y.view(-1))
            else:
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

## 3. Training the Model

Let's train our model using the complete training loop.

In [None]:
# Create model
model = SimpleLanguageModel(vocab_size)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model (uncomment to run - may take some time)
# trained_model, logger = complete_training_loop(
#     model, train_dataloader, val_dataloader,
#     epochs=3,
#     lr=0.0005,
#     warmup_ratio=0.1,
#     max_grad_norm=1.0,
#     accumulation_steps=2,
#     use_amp=True
# )

## 4. Checkpointing and Resuming Training

Let's implement checkpointing and resuming training, which is essential for long training runs.

In [None]:
def save_checkpoint(model, optimizer, scheduler, scaler, epoch, global_step, loss, path):
    """Save training checkpoint."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'scaler_state_dict': scaler.state_dict() if scaler else None,
        'epoch': epoch,
        'global_step': global_step,
        'loss': loss
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved to {path}")

def load_checkpoint(model, optimizer, scheduler, scaler, path):
    """Load training checkpoint."""
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    if scaler and checkpoint['scaler_state_dict']:
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
    epoch = checkpoint['epoch']
    global_step = checkpoint['global_step']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded from {path}, resuming from epoch {epoch+1}, step {global_step}")
    return epoch, global_step, loss

## 5. Summary and Key Insights

In this notebook, we've explored and implemented key components of an efficient and stable training loop for language models:

1. **Learning Rate Scheduling**:
   - Linear warmup + decay
   - Cosine warmup + decay
   - One-cycle policy

2. **Complete Training Loop**:
   - Mixed precision training
   - Gradient clipping and accumulation
   - AdamW optimizer with weight decay
   - Learning rate scheduling
   - Monitoring and logging

3. **Checkpointing**:
   - Saving model, optimizer, scheduler, and scaler states
   - Resuming training from checkpoints

These techniques are essential for training large language models efficiently and stably, especially when working with limited computational resources or when training needs to be resumed after interruptions.