# Lesson 18: Model Improvement - Custom Classes and Optimizers

In the previous lesson, we trained a model using standard settings (SGD, Learning Rate 0.01). It worked, but can we do better?

Today we dive deeper into the "brain" of the training process:
1.  **Advanced `nn.Module`**: Building flexible, reusable model classes.
2.  **Optimizers**: Why `Adam` is often better than `SGD`.
3.  **Learning Rate**: The most important hyperparameter and how to tune it.

## Recommended Videos

Understanding optimizers is much easier with visuals. Check these out:

* **Optimizers Explained (Vizualization):** A great visual comparison of how SGD, Momentum, and Adam navigate the "loss landscape".
    * [https://www.youtube.com/watch?v=mdKjMPmcWjY](https://www.youtube.com/watch?v=mdKjMPmcWjY)

* **The Learning Rate (Andrew Ng):** A classic explanation of how step size affects convergence.
    * [https://www.youtube.com/watch?v=4qJaSmvhxi8](https://www.youtube.com/watch?v=4qJaSmvhxi8)

## 1. Setup and Data Loading

We use the same setup as before, importing our reusable `data_loader`.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import sys
import os
import time

# --- Fix import path for data_loader ---
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
data_loader_path = os.path.join(parent_dir, '16 - Datasets')
if data_loader_path not in sys.path:
    sys.path.append(data_loader_path)

from data_loader import get_mnist_loaders

%matplotlib inline

# --- Constants ---
BATCH_SIZE = 64
INPUT_SIZE = 784
NUM_CLASSES = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load Data
train_loader, val_loader, test_loader = get_mnist_loaders(batch_size=BATCH_SIZE)

## 2. Improving the Model Architecture (`nn.Module`)

We can make our model class more flexible. Instead of hardcoding layers, we can pass configuration parameters to `__init__`.

We will also introduce a new layer type: **Dropout**.

### What is Dropout?
Dropout is a regularization technique to prevent **Overfitting**.
* During training, it randomly "turns off" (zeros out) a percentage of neurons (e.g., 20%).
* This forces the network not to rely too heavily on any single neuron. It makes the network more robust, like a team where every member can handle tasks if someone is sick.

In [None]:
class FlexibleNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.0):
        """
        Args:
            input_size (int): Size of input vector (784)
            hidden_sizes (list): A list of integers, e.g., [128, 64]
            output_size (int): Number of classes (10)
            dropout_rate (float): Probability of an element to be zeroed.
        """
        super(FlexibleNN, self).__init__()
        
        self.flatten = nn.Flatten()
        
        # We can build layers dynamically using a list
        layers = []
        
        # Input Layer -> First Hidden Layer
        prev_size = input_size
        
        for size in hidden_sizes:
            layers.append(nn.Linear(prev_size, size))
            layers.append(nn.ReLU())
            if dropout_rate > 0:
                layers.append(nn.Dropout(p=dropout_rate))
            prev_size = size
        
        # Last Hidden Layer -> Output Layer
        layers.append(nn.Linear(prev_size, output_size))
        
        # Wrap the list in nn.Sequential
        self.model_stack = nn.Sequential(*layers)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.model_stack(x)
        return logits

# Example: Creating a deeper model easily
model = FlexibleNN(784, [256, 128, 64], 10, dropout_rate=0.2).to(device)
print(model)

## 3. Optimizers: SGD vs. Adam

The optimizer is the algorithm that updates the weights based on the gradients.

### SGD (Stochastic Gradient Descent)
* **Analogy:** You are walking down a misty mountain. You look at your feet, see which way is down, and take a fixed-size step. 
* **Pros:** Simple, well-understood.
* **Cons:** Can get stuck in local valleys. If the slope is steep, it might overshoot. If the slope is flat, it moves very slowly.

### Adam (Adaptive Moment Estimation)
* **Analogy:** You are a heavy ball rolling down the mountain. 
    1.  **Momentum:** If you are moving fast in one direction, you keep going (even if the slope changes slightly).
    2.  **Adaptivity:** If you haven't moved much in a certain direction recently, you take bigger steps. If you are moving a lot, you take smaller, more careful steps.
* **Pros:** Generally converges *much* faster and requires less tuning of the learning rate.
* **Cons:** More complex math (but PyTorch handles it).

## 4. The Learning Rate (LR)

This is the single most important setting.

* **Low LR (e.g., 0.0001):** The model learns very slowly. It might take 100 epochs to get anywhere.
* **High LR (e.g., 0.1 or 1.0):** The model makes huge changes. It might jump *over* the optimal solution and never settle down. The loss might even explode to Infinity.
* **Good LR:** Fast convergence initially, settling into the minimum.

## 5. Experimentation Time

Let's define a reusable training function so we can easily run multiple experiments and compare them.

In [None]:
def run_experiment(name, model, optimizer, num_epochs=5):
    print(f"\n--- Starting Experiment: {name} ---")
    loss_fn = nn.CrossEntropyLoss()
    val_accuracies = []
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        acc = 100 * correct / total
        val_accuracies.append(acc)
        print(f"Epoch {epoch+1}: Val Acc = {acc:.2f}%")
    
    print(f"Finished in {time.time() - start_time:.2f} seconds")
    return val_accuracies

### Experiment 1: SGD (The Baseline)
Standard parameters.

In [None]:
model_sgd = FlexibleNN(784, [128, 64], 10).to(device)
opt_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)

results_sgd = run_experiment("SGD (lr=0.01)", model_sgd, opt_sgd)

### Experiment 2: Adam (The Modern Standard)
Same Learning Rate, just changing the optimizer.

In [None]:
model_adam = FlexibleNN(784, [128, 64], 10).to(device)
opt_adam = optim.Adam(model_adam.parameters(), lr=0.01)

results_adam = run_experiment("Adam (lr=0.01)", model_adam, opt_adam)

### Experiment 3: Adam with Lower Learning Rate
Adam is more aggressive. Sometimes 0.01 is too high. Let's try 0.001 (the default for Adam in many libraries).

In [None]:
model_adam_low = FlexibleNN(784, [128, 64], 10).to(device)
opt_adam_low = optim.Adam(model_adam_low.parameters(), lr=0.001)

results_adam_low = run_experiment("Adam (lr=0.001)", model_adam_low, opt_adam_low)

## 6. Final Comparison

Let's plot the results side-by-side to see the difference.

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(results_sgd, label='SGD (lr=0.01)', marker='o')
plt.plot(results_adam, label='Adam (lr=0.01)', marker='o')
plt.plot(results_adam_low, label='Adam (lr=0.001)', marker='o')

plt.title('Optimizer Comparison: Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.show()

## Conclusion

You should likely see that:
1.  **Adam** converges faster than **SGD**.
2.  **Adam with high LR (0.01)** might be unstable or erratic.
3.  **Adam with low LR (0.001)** usually gives the smoothest and best results for this problem.

This is why **Hyperparameter Tuning** is so important. Just changing one line of code can boost accuracy significantly.