# Lesson 17: Training and Evaluating Your First Neural Network

In the last two lessons, we learned the crucial pieces:
1.  **Lesson 15:** We learned the 5-step training loop (Forward -> Loss -> Zero Grads -> Backward -> Step).
2.  **Lesson 16:** We learned how to load and prepare data with `Dataset` and `DataLoader`.

Today, we combine all these pieces. We will define our first true **Neural Network** using `nn.Module` and then train it from start to finish to recognize handwritten digits. This is the complete process from start to finish.

## Recommended Videos

These concepts are the heart of deep learning. Watching visual explanations can be extremely helpful.

* **What is a Neural Network? (3Blue1Brown):** A foundational and beautiful explanation of what a neural network *is*. 
    * [https://www.youtube.com/watch?v=aircAruvnKk](https://www.youtube.com/watch?v=aircAruvnKk)

* **What is Backpropagation? (3Blue1Brown):** The best explanation of how gradients flow backward through the network. This is the "how it learns" part.
    * [https://www.youtube.com/watch?v=Ilg3gGewQ5U](https://www.youtube.com/watch?v=Ilg3gGewQ5U)

* **Cross-Entropy Loss (StatQuest):** A fantastic, simple explanation of *why* we use Cross-Entropy for classification.
    * [https://www.youtube.com/watch?v=7q7E91pA9aQ](https://www.youtube.com/watch?v=7q7E91pA9aQ)

* **Activation Functions (ReLU):** A quick overview of different activation functions and why ReLU is so popular.
    * [https://www.youtube.com/watch?v=68BZlGvl1W4](https://www.youtube.com/watch?v=68BZlGvl1W4)

## 1. Core Concepts: The Model and The Loss

Before we write code, let's solidify the key concepts we're using.

### What is a Perceptron (or a `nn.Linear` Layer)?

A perceptron (or a "Linear Layer" in PyTorch) is the most basic building block of a neural network. It simply performs the operation: **`output = weights * input + bias`**.

* **Analogy:** Think of it as a panel of "smart knobs".
* The **`weights`** are the knobs. They decide how much *importance* to give to each input feature. If a weight is high, that input feature strongly affects the output. If it's zero, that input is ignored.
* The **`bias`** is a "starting offset" knob. It helps shift the entire output up or down, regardless of the input. 

The model "learns" by finding the perfect settings for all these knobs (`weights` and `biases`) to get the correct answer.

### What is an Activation Function (like `nn.ReLU`)?

If you just stack a bunch of Linear layers, you can only learn simple lines. To learn complex patterns (like the shape of a "3"), we need to introduce **non-linearity**.

An activation function is a simple "switch" applied after each layer. The most popular one is **ReLU (Rectified Linear Unit)**. 

Its logic is "f(x) = max(0, x)".
* If the input is positive (e.g., 5.0), it passes it through (output is 5.0).
* If the input is negative (e.g., -3.0), it clips it to **zero**.

This simple "bend" at zero is enough to let our network learn curves, corners, and all sorts of complex patterns.

### What is One-Hot Encoding?

Our labels are numbers like `[5, 0, 4, 1, 9, ...]`. We have a problem: does the model think that the digit `9` is "better" or "more" than the digit `1`? This is a false relationship that could confuse the model.

We need to treat each digit as a separate, independent category. We do this with **One-Hot Encoding**.

We create a vector of all possible classes (10 in our case, for digits 0-9) and set the correct class to `1` and all others to `0`.

* `3` becomes `[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]` (index 3 is "on")
* `9` becomes `[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]` (index 9 is "on")
* `0` becomes `[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]` (index 0 is "on")

The model's final output (a vector of 10 numbers) will be its attempt to recreate this one-hot vector.

### What is Cross-Entropy Loss?

This is the loss function we use for classification. It's the best way to measure "how wrong" our model is.

**Analogy: The "Confidence Penalty"**

Imagine the correct answer is `3`. The one-hot vector is `[0, 0, 0, 1, 0, ...]`. 
Our model outputs **logits** (raw scores), which we can think of as *confidence scores*.

* **Good Prediction:** `[-1.2, 0.5, 0.9, 8.5, 1.1, ...]`
    * The model is *very confident* (8.5) that the answer is `3`. It is correct and confident. **LOW LOSS**.

* **Unsure Prediction:** `[0.1, 0.2, 0.3, 0.5, 0.1, ...]`
    * The model's highest score is for `3`, so it's *correct*, but it's not *confident* (0.5 is not much higher than 0.3). **MEDIUM LOSS**.

* **Bad Prediction:** `[9.0, -1.5, 0.1, 0.2, -0.5, ...]`
    * The model is *very confident* (9.0) that the answer is `0`. It is confident and **WRONG**. **HIGH LOSS**.

**Cross-Entropy Loss** perfectly captures this. It *punishes* the model heavily for being confident and wrong. This forces the model to become *both* accurate and confident in its correct predictions.

**PyTorch Nuance:** `nn.CrossEntropyLoss` is a 2-in-1 function. It automatically applies a "Softmax" (which turns logits into probabilities) and then calculates the loss. It's very efficient. It also means we pass in the *raw class index* (e.g., `3`) as the target, not the one-hot vector.

## 2. Setup: Imports and Data Loading

Let's import everything we need and use our `data_loader.py` file to get the data.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import sys
import os

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
data_loader_path = os.path.join(parent_dir, '16 - Datasets')

if data_loader_path not in sys.path:
    sys.path.append(data_loader_path)

from data_loader import get_mnist_loaders

%matplotlib inline

BATCH_SIZE = 64
train_loader, val_loader, test_loader = get_mnist_loaders(batch_size=BATCH_SIZE)

## 3. Defining the Neural Network (`nn.Module`)

This is where we build our model's architecture. We create a class that inherits from `nn.Module`. This is the standard, most flexible way to build models in PyTorch.

It has two main parts:
1.  **`__init__(self)`**: The constructor. This is where we *define and initialize* all the layers our model will use (e.g., linear layers, flatten, etc.).
2.  **`forward(self, x)`**: This function defines the *data flow*. It describes how an input `x` moves through the layers we defined in `__init__` to produce an output.

In [None]:
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
        super(SimpleNN, self).__init__()
        
        # This layer flattens the [1, 28, 28] image into a [784] vector
        self.flatten = nn.Flatten()
        
        # We define our network as a sequence of layers
        self.layers = nn.Sequential(
            # 1st Layer: 784 inputs -> 128 outputs
            nn.Linear(input_size, hidden_size1),
            nn.ReLU(), # Activation Function
            
            # 2nd Layer: 128 inputs -> 64 outputs
            nn.Linear(hidden_size1, hidden_size2),
            nn.ReLU(), # Activation Function
            
            # Output Layer: 64 inputs -> 10 outputs (logits)
            nn.Linear(hidden_size2, num_classes)
        )

    def forward(self, x):
        # The forward pass defines how data flows through the layers
        x = self.flatten(x)
        x = self.layers(x)
        return x

# --- Define our constants ---
INPUT_SIZE = 784  # 28*28 pixels
HIDDEN_SIZE_1 = 128
HIDDEN_SIZE_2 = 64
NUM_CLASSES = 10   # Digits 0-9

# --- Instantiate the model ---
model = SimpleNN(INPUT_SIZE, HIDDEN_SIZE_1, HIDDEN_SIZE_2, NUM_CLASSES)
print("Our Model Architecture:")
print(model)

## 4. The Full Training Loop (Run 1: Normal Training)

This is the main event! We will now write the full 5-step training loop and add a validation loop inside it to check our progress. An **epoch** is one full pass through the *entire* training dataset.

In [None]:
# --- Hyperparameters ---
LEARNING_RATE = 0.01
NUM_EPOCHS = 10

# --- Define Loss and Optimizer ---
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

# --- We will log our progress here ---
train_losses = []
val_losses = []
val_accuracies = []

print("--- Starting Training (10 Epochs) ---")

for epoch in range(NUM_EPOCHS):
    
    # --- TRAINING LOOP ---
    model.train() # Set the model to training mode (e.g., activates dropout)
    current_train_loss = 0.0
    for images, labels in train_loader: 
        # 1. Forward Pass: Get model's predictions
        outputs = model(images)
        
        # 2. Calculate Loss: Measure how wrong the predictions are
        loss = loss_fn(outputs, labels)
        
        # 3. Zero Gradients: Reset gradients from previous loop
        optimizer.zero_grad()
        
        # 4. Backward Pass: Calculate gradients for all parameters
        loss.backward()
        
        # 5. Update Parameters: Take a step using the optimizer
        optimizer.step()
        
        current_train_loss += loss.item() * images.size(0)

    # Calculate average training loss for the epoch
    epoch_train_loss = current_train_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    # --- VALIDATION LOOP ---
    model.eval() # Set the model to evaluation mode (e.g., deactivates dropout)
    current_val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad(): # Disable gradient calculation for efficiency
        for images, labels in val_loader:
            outputs = model(images)
            
            # Calculate loss
            loss = loss_fn(outputs, labels)
            current_val_loss += loss.item() * images.size(0)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1) # Get the index of the max score
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy for the epoch
    epoch_val_loss = current_val_loss / len(val_loader.dataset)
    val_losses.append(epoch_val_loss)
    
    epoch_val_acc = 100 * correct / total
    val_accuracies.append(epoch_val_acc)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.2f}%")

print("--- Training Finished ---")

## 5. Plotting and Final Evaluation

Now we'll use Matplotlib to visualize the data we logged. This is the *most important* way to see if our model is learning correctly.

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))

# Plot 1: Training & Validation Loss
ax1.plot(train_losses, label='Training Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training vs. Validation Loss (10 Epochs)')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Plot 2: Validation Accuracy
ax2.plot(val_accuracies, label='Validation Accuracy', color='green')
ax2.set_title('Validation Accuracy (10 Epochs)')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

### Final Evaluation (The "Final Exam")

Our model looks good! The validation loss is decreasing and the accuracy is increasing. Now it's time for the "final exam" - running our trained model on the **Test Set** (the data it has *never* seen).

In [None]:
model.eval() # Set model to evaluation mode
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

final_accuracy = 100 * correct / total
print(f"\n--- FINAL TEST ACCURACY ---")
print(f"Accuracy on 10,000 test images: {final_accuracy:.2f} % ")

## 6. Overfitting and Early Stopping

What happens if we let the model train for much longer? Let's re-initialize our model and train it for 30 epochs to see.

**Overfitting** is when the model stops learning the *general patterns* (e.g., what a "7" looks like) and starts *memorizing* the specific quirks of the training data (e.g., "this specific pixel on this specific image is always off").

When this happens, its performance on new, unseen data (the validation set) gets *worse*. The training loss will continue to go down, but the validation loss will start to go **up**.

In [None]:
# --- Re-initialize model and optimizer ---
model_long_run = SimpleNN(INPUT_SIZE, HIDDEN_SIZE_1, HIDDEN_SIZE_2, NUM_CLASSES)
optimizer_long_run = optim.SGD(model_long_run.parameters(), lr=LEARNING_RATE)
loss_fn_long_run = nn.CrossEntropyLoss()

# --- New lists for logging ---
long_train_losses = []
long_val_losses = []
long_val_accuracies = []

NUM_EPOCHS_LONG = 30

print("--- Starting LONG Training (30 Epochs) to demonstrate overfitting ---")

for epoch in range(NUM_EPOCHS_LONG):
    model_long_run.train()
    current_train_loss = 0.0
    for images, labels in train_loader: 
        outputs = model_long_run(images)
        loss = loss_fn_long_run(outputs, labels)
        optimizer_long_run.zero_grad()
        loss.backward()
        optimizer_long_run.step()
        current_train_loss += loss.item() * images.size(0)
    
    epoch_train_loss = current_train_loss / len(train_loader.dataset)
    long_train_losses.append(epoch_train_loss)

    model_long_run.eval()
    current_val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model_long_run(images)
            loss = loss_fn_long_run(outputs, labels)
            current_val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_val_loss = current_val_loss / len(val_loader.dataset)
    long_val_losses.append(epoch_val_loss)
    epoch_val_acc = 100 * correct / total
    long_val_accuracies.append(epoch_val_acc)
    
    # We log everything, even if the results don't change much
    print(f"Epoch {epoch+1}/{NUM_EPOCHS_LONG} | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.2f}%")

print("--- Long Training Finished ---")

### Plotting the Overfitting Example

Now let's plot the 30-epoch run. Pay close attention to the loss graph.

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))

# Plot 1: Training & Validation Loss (30 Epochs)
ax1.plot(long_train_losses, label='Training Loss')
ax1.plot(long_val_losses, label='Validation Loss')
ax1.set_title('Training vs. Validation Loss (30 Epochs)')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Plot 2: Validation Accuracy (30 Epochs)
ax2.plot(long_val_accuracies, label='Validation Accuracy', color='green')
ax2.set_title('Validation Accuracy (30 Epochs)')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

### Analysis of the Plots

Look at the 'Training vs. Validation Loss' plot. You can clearly see:

1.  **Training Loss** (blue line) consistently goes down. The model is getting better and better at matching the training data.
2.  **Validation Loss** (orange line) goes down for the first ~10-15 epochs, but then it **starts to rise again**. 

**This inflection point is where overfitting begins.** The model is now learning the *noise* of the training set, not the *pattern* of digits. Its performance on unseen validation data gets worse.

**Early Stopping** is the simple technique of monitoring the validation loss and stopping the training (and saving the model) at the epoch where the validation loss was at its **minimum** (around epoch 10-15 in this graph), not at the end.