# Training Workflows & Experiment Structure

You now have data and models. This notebook orchestrates them into robust training and evaluation loops so experiments are reproducible and debuggable.

## Learning Objectives

- Implement clean training and validation loops with logging hooks.
- Monitor gradients, losses, and learning rates.
- Integrate gradient clipping, schedulers, and early stopping.
- Build patterns that extend to large-scale and attention-heavy systems later.

## Anatomy of a Training Step

1. Fetch a batch from the dataloader.
2. Run the forward pass.
3. Compute the loss.
4. Backpropagate gradients.
5. Update parameters and reset gradients.

Instrumentation (logging gradients, losses, learning rate) wraps these steps without cluttering the loop.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(0)

x = torch.linspace(-3, 3, steps=256).unsqueeze(1)
y = torch.sin(x) + 0.1 * torch.randn_like(x)
train_ds = TensorDataset(x, y)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

model = nn.Sequential(nn.Linear(1, 32), nn.ReLU(), nn.Linear(32, 1))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    for xb, yb in loader:
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item() * xb.size(0)
    return total_loss / len(loader.dataset)

epoch_loss = train_one_epoch(model, train_loader, optimizer, criterion)
print(f"Epoch loss: {epoch_loss:.4f}")


### Gradient Monitoring

Exploding or vanishing gradients quietly sabotage experiments. Logging gradient norms helps you react early.

In [None]:
def gradient_norm(model: nn.Module):
    total = 0.0
    for p in model.parameters():
        if p.grad is None:
            continue
        total += p.grad.data.norm(2).item() ** 2
    return total ** 0.5

xb, yb = next(iter(train_loader))
loss = criterion(model(xb), yb)
loss.backward()
print(f"Gradient norm: {gradient_norm(model):.4f}")
optimizer.zero_grad()


## Mini Task – Validation Loop

Implement an evaluation helper that disables gradients, computes the loss and mean absolute error, and returns both metrics.

Attempt the starter cell before expanding the solution.

In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    # TODO: accumulate loss and MAE without gradients
    raise NotImplementedError


In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    with torch.no_grad():
        for xb, yb in loader:
            preds = model(xb)
            loss = criterion(preds, yb)
            total_loss += loss.item() * xb.size(0)
            total_mae += torch.mean(torch.abs(preds - yb)).item() * xb.size(0)
    model.train()
    n = len(loader.dataset)
    return total_loss / n, total_mae / n

val_loss, val_mae = evaluate(model, train_loader, criterion)
print(f"Val loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}")


## Checkpointing & Scheduling

- Save weights **and** optimizer state so you can resume without losing learning rate schedules.
- Pair AdamW with cosine or polynomial schedulers for large models.
- Log learning rates to detect unexpected schedule jumps.

In [None]:
checkpoint = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "epoch": 1,
    "notes": "sine regression warm-up",
}
torch.save(checkpoint, "notebooks/beginner/checkpoint_sine.pt")
print("Checkpoint saved")


## Comprehensive Exercise – Experiment Harness

Create an `Experiment` class that manages training, validation, gradient clipping, optional learning-rate scheduler, and simple logging. Demonstrate usage on the sine regression task.

In [None]:
class Experiment:
    def __init__(self, model, optimizer, train_loader, val_loader, criterion, grad_clip=None, scheduler=None):
        # TODO: store components and initialize history
        raise NotImplementedError

    def train(self, epochs):
        # TODO: implement training with validation metrics and optional early stopping
        raise NotImplementedError


In [None]:
class Experiment:
    def __init__(self, model, optimizer, train_loader, val_loader, criterion, grad_clip=None, scheduler=None):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.grad_clip = grad_clip
        self.scheduler = scheduler
        self.history = {"train": [], "val": []}

    def _train_one_epoch(self):
        self.model.train()
        total = 0.0
        for xb, yb in self.train_loader:
            preds = self.model(xb)
            loss = self.criterion(preds, yb)
            loss.backward()
            if self.grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.optimizer.step()
            self.optimizer.zero_grad()
            total += loss.item() * xb.size(0)
        return total / len(self.train_loader.dataset)

    def _evaluate(self):
        self.model.eval()
        total = 0.0
        with torch.no_grad():
            for xb, yb in self.val_loader:
                total += self.criterion(self.model(xb), yb).item() * xb.size(0)
        self.model.train()
        return total / len(self.val_loader.dataset)

    def train(self, epochs):
        best_val = float("inf")
        patience = 3
        patience_counter = 0
        for epoch in range(epochs):
            train_loss = self._train_one_epoch()
            val_loss = self._evaluate()
            self.history["train"].append(train_loss)
            self.history["val"].append(val_loss)
            if self.scheduler is not None:
                self.scheduler.step()
            print(f"Epoch {epoch+1}: train={train_loss:.4f}, val={val_loss:.4f}")
            if val_loss < best_val - 1e-4:
                best_val = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("Early stopping triggered")
                    break
        return self.history

experiment = Experiment(model, optimizer, train_loader, train_loader, criterion, grad_clip=1.0)
experiment.train(epochs=5)


## Further Reading

- PyTorch Training Loop Recipes: https://pytorch.org/tutorials/recipes/recipes.html
- TorchMetrics, TensorBoard, Weights & Biases for richer experiment tracking
- “A Recipe for Training Neural Networks” by Andrej Karpathy