# Part 5: The Standard Training Loop ðŸ”„

We've written the same 5 lines of training code twice now. It's time to organize it.

In this notebook, we'll build a **robust, reusable training loop**.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

## 1. Setup Data (Train/Val Split)

Real ML requires splitting data into **Training** (to learn) and **Validation** (to test).

Let's use our quadratic data again.

In [None]:
# Generate data
X = torch.linspace(-5, 5, 200).reshape(-1, 1)
y = X ** 2 + 1 + torch.randn(200, 1) * 2

# Split (80% Train, 20% Validation)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert back to tensors (sklearn returns numpy arrays)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)

print(f"Train Shape: {X_train.shape}, Val Shape: {X_val.shape}")

## 2. Define Function for One Training Step

Let's wrap the "Forward -> Loss -> Backward -> Step" logic into a function.

In [None]:
def train_step(model, X_batch, y_batch, loss_fn, optimizer):
    # 1. Set model to training mode (important for Dropout/BatchNorm)
    model.train()
    
    # 2. Forward pass
    predictions = model(X_batch)
    
    # 3. Compute loss
    loss = loss_fn(predictions, y_batch)
    
    # 4. Backward pass
    optimizer.zero_grad()
    loss.backward()
    
    # 5. Optimizer step
    optimizer.step()
    
    return loss.item()

## 3. Define Function for Validation

We also need a function to check performance without training (inference).

In [None]:
def val_step(model, X_batch, y_batch, loss_fn):
    # 1. Set model to eval mode
    model.eval()
    
    # 2. Disable gradient calculation
    with torch.no_grad():
        predictions = model(X_batch)
        loss = loss_fn(predictions, y_batch)
    
    return loss.item()

## 4. The Loop

Now we just loop over epochs and call our functions.

In [None]:
# Model, Loss, Optimizer
model = nn.Sequential(
    nn.Linear(1, 20),
    nn.ReLU(),
    nn.Linear(20, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training
epochs = 500
train_losses = []
val_losses = []

for epoch in range(epochs):
    train_loss = train_step(model, X_train, y_train, loss_fn, optimizer)
    val_loss = val_step(model, X_val, y_val, loss_fn)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    if (epoch+1) % 50 == 0:
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")

## 5. Plot Loss Curves

Crucial for debugging!

In [None]:
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend()
plt.title("Training vs Validation Loss")
plt.show()

## ðŸ§  Summary

1. **`model.train()`** vs **`model.eval()`**: Switches behavior (e.g. Dropout).
2. **`optimizer.zero_grad()`** MUST be called before `.step()`.
3. **`with torch.no_grad()`**: Use this for validation/inference.
4. Always monitor **Train vs Val loss** to spot overfitting.

Next up: **DataLoaders** - Handling messy real-world data!