# Week 3 Day 13: Training Loop Details (Complete)

## Overview
In this notebook, we'll implement and explore key components of an efficient and stable training loop for language models, focusing on:
- Mixed precision training (AMP)
- Gradient clipping and accumulation
- AdamW optimizer with learning rate scheduling
- Monitoring, logging, and checkpointing

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
import math
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Dict, Tuple, Optional

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Model and Data Definition

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

class SimpleLanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        src = self.pos_encoder(src)
        if src_mask is None:
            src_mask = generate_square_subsequent_mask(src.size(1)).to(src.device)
        output = self.transformer_decoder(src, src, tgt_mask=src_mask) # Using src as memory for decoder-only
        return self.output(output)

class SyntheticDataset(Dataset):
    def __init__(self, vocab_size=1000, seq_len=64, size=10000):
        self.data = [torch.randint(0, vocab_size, (seq_len + 1,)) for _ in range(size)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        return seq[:-1], seq[1:]

# Create dataset and dataloaders
vocab_size = 1000
seq_len = 64
dataset = SyntheticDataset(vocab_size, seq_len, 5000)
train_dataset, val_dataset = random_split(dataset, [int(0.9 * len(dataset)), len(dataset) - int(0.9 * len(dataset))])
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)

## 2. Learning Rate Scheduling

In [None]:
def get_cosine_warmup_lr_scheduler(optimizer, warmup_steps, total_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

## 3. Complete Training Loop

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, criterion, use_amp=True):
    model.eval()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        with autocast(enabled=use_amp):
            output = model(x)
            loss = criterion(output.view(-1, vocab_size), y.view(-1))
        total_loss += loss.item()
    return total_loss / len(dataloader)

def complete_training_loop(model, train_dataloader, val_dataloader, epochs=5, lr=0.001, warmup_ratio=0.1, max_grad_norm=1.0, accumulation_steps=4, use_amp=True):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scaler = GradScaler(enabled=use_amp)
    criterion = nn.CrossEntropyLoss()
    
    total_steps = (len(train_dataloader) // accumulation_steps) * epochs
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = get_cosine_warmup_lr_scheduler(optimizer, warmup_steps, total_steps)
    
    history = {'train_loss': [], 'val_loss': [], 'lr': []}
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        optimizer.zero_grad()
        for batch_idx, (x, y) in enumerate(train_dataloader):
            x, y = x.to(device), y.to(device)
            with autocast(enabled=use_amp):
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1)) / accumulation_steps
            
            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                history['lr'].append(scheduler.get_last_lr()[0])
            
            epoch_loss += loss.item() * accumulation_steps
        
        avg_train_loss = epoch_loss / len(train_dataloader)
        val_loss = evaluate(model, val_dataloader, criterion, use_amp)
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(val_loss)
        print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
    return model, history

# Train the model
model = SimpleLanguageModel(vocab_size)
trained_model, history = complete_training_loop(model, train_dataloader, val_dataloader, epochs=5)

# Plot results
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss Curves')
plt.subplot(1, 2, 2)
plt.plot(history['lr'], label='Learning Rate')
plt.title('Learning Rate Schedule')
plt.legend()
plt.tight_layout()
plt.show()