# ðŸ“Š Batch Normalization: Stabilizing Deep Network Training

This notebook demonstrates the impact of **Batch Normalization** on neural network training. We'll train two identical networks on MNISTâ€”one with BatchNorm and one withoutâ€”and compare their performance.

## What You'll Learn

1. How Batch Normalization affects training dynamics
2. Implementing BatchNorm in PyTorch
3. Comparing convergence speed with and without BatchNorm
4. When and why to use Batch Normalization

---

## 1. Setup and Imports

We import the necessary libraries for building and training neural networks with PyTorch.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from sklearn.metrics import accuracy_score
from matplotlib import pyplot as plt

## 2. Device Configuration

We check for GPU availability to accelerate training. CUDA-enabled GPUs significantly speed up neural network computations.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Loading the MNIST Dataset

**MNIST** is a classic dataset of handwritten digits (0-9), containing:
- 60,000 training images
- 10,000 test images
- 28Ã—28 grayscale images

We normalize the pixel values to have mean=0.5 and std=0.5, which helps with training stability.

For faster experimentation, we use a subset of the data (5,000 training, 1,000 test samples).

In [None]:
# Normalize images to [-1, 1] range
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))
])

# Load full datasets
full_train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=transform)
full_test_dataset = datasets.MNIST(root="data", train=False, download=True, transform=transform)

# Create smaller subsets for faster training
train_subset = Subset(full_train_dataset, torch.arange(5000))
test_subset = Subset(full_test_dataset, torch.arange(1000))

# Create DataLoaders
train_loader = DataLoader(train_subset, batch_size=60, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=60, shuffle=False)

print(f"Training samples: {len(train_subset)}")
print(f"Test samples: {len(test_subset)}")

---

## 4. Model Without Batch Normalization

First, let's define a simple feedforward neural network **without** Batch Normalization.

### Architecture

```
Input (28Ã—28 = 784) â†’ Hidden (128) â†’ ReLU â†’ Hidden (64) â†’ ReLU â†’ Output (10)
```

This is a standard MLP with two hidden layers. Without BatchNorm, the network must learn to handle varying input distributions at each layer.

In [None]:
class MNISTClassifier(nn.Module):
    """Simple MLP without Batch Normalization."""
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),  
            nn.ReLU(),
            nn.Linear(128, 64), 
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.network(x)

## 5. Training Function

We define a reusable training function that:
1. Trains the model for a specified number of epochs
2. Evaluates on the test set after each epoch
3. Returns validation accuracies for comparison

### Training Loop Steps

1. **Forward pass**: Compute predictions
2. **Compute loss**: Compare predictions to true labels
3. **Backward pass**: Compute gradients
4. **Update weights**: Apply gradients via optimizer

In [None]:
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs=5):
    """
    Train a model and track validation accuracy.
    
    Args:
        model: Neural network to train
        train_loader: DataLoader for training data
        test_loader: DataLoader for test/validation data
        optimizer: Optimization algorithm
        criterion: Loss function
        epochs: Number of training epochs
    
    Returns:
        List of validation accuracies per epoch
    """
    val_accuracies = []
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        for batch, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()              # Clear gradients
            outputs = model(images)            # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()                    # Backpropagate
            optimizer.step()                   # Update weights
            
            if batch % 100 == 0:
                print(f"Epoch {epoch+1}, Batch: {batch}, Train Loss: {loss.item():.4f}")
        
        # Evaluation phase
        model.eval()
        y_pred, y_true = [], []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                y_pred.extend(predicted.cpu().numpy())
                y_true.extend(labels.cpu().numpy())
        
        val_accuracy = accuracy_score(y_true, y_pred)
        val_accuracies.append(val_accuracy)
        print(f"Epoch {epoch+1} - Validation Accuracy: {val_accuracy:.4f}")
        print("-" * 50)
    
    return val_accuracies

## 6. Train Model Without BatchNorm

Let's train our baseline model without Batch Normalization for 10 epochs.

In [None]:
# Instantiate model, loss function, and optimizer
model = MNISTClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Training WITHOUT Batch Normalization")
print("=" * 50)
val_accuracies_no_bn = train_model(model, train_loader, test_loader, optimizer, criterion, epochs=10)

In [None]:
print("\nValidation Accuracies (No BatchNorm):")
print(val_accuracies_no_bn)

---

## 7. Model With Batch Normalization

Now let's create the same architecture but **with Batch Normalization** layers.

### What BatchNorm Does

For each mini-batch, BatchNorm:
1. Computes the mean and variance of activations
2. Normalizes to zero mean and unit variance
3. Applies learnable scale (Î³) and shift (Î²) parameters

```
BatchNorm(x) = Î³ Ã— (x - Î¼) / âˆš(ÏƒÂ² + Îµ) + Î²
```

### Architecture with BatchNorm

```
Input â†’ Linear(128) â†’ BatchNorm â†’ ReLU â†’ Linear(64) â†’ BatchNorm â†’ ReLU â†’ Output(10)
```

We place `BatchNorm1d` after each linear layer, before the activation function.

In [None]:
class MNISTClassifierWithBN(nn.Module):
    """MLP with Batch Normalization after each hidden layer."""
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.BatchNorm1d(128),    # BatchNorm after first linear layer
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),     # BatchNorm after second linear layer
            nn.ReLU(),
            nn.Linear(64, 10)       # No BatchNorm before output
        )

    def forward(self, x):
        return self.network(x)

## 8. Train Model With BatchNorm

Now let's train the BatchNorm model with the same hyperparameters.

In [None]:
# Instantiate model with BatchNorm
model_with_bn = MNISTClassifierWithBN().to(device)
criterion_with_bn = nn.CrossEntropyLoss()
optimizer_with_bn = optim.Adam(model_with_bn.parameters(), lr=0.001)

print("Training WITH Batch Normalization")
print("=" * 50)
val_accuracies_with_bn = train_model(
    model_with_bn, train_loader, test_loader, 
    optimizer_with_bn, criterion_with_bn, epochs=10
)

In [None]:
print("\nValidation Accuracies (With BatchNorm):")
print(val_accuracies_with_bn)

---

## 9. Comparing Results

Let's visualize the training progress of both models to see the impact of Batch Normalization.

In [None]:
plt.figure(figsize=(10, 6))
epochs = range(1, len(val_accuracies_no_bn) + 1)

plt.plot(epochs, val_accuracies_no_bn, 'b-o', label="Without BatchNorm", linewidth=2)
plt.plot(epochs, val_accuracies_with_bn, 'r-o', label="With BatchNorm", linewidth=2)

plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Accuracy', fontsize=12)
plt.title('Batch Normalization: Impact on Training', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.ylim(0.8, 1.0)
plt.show()

## 10. Analysis

### Expected Observations

| Aspect | Without BatchNorm | With BatchNorm |
|--------|-------------------|----------------|
| **Early epochs** | Slower improvement | Faster improvement |
| **Final accuracy** | Good | Often slightly better |
| **Training stability** | May fluctuate | More stable |

### Why BatchNorm Helps

1. **Reduces internal covariate shift**: Each layer receives inputs with consistent statistics
2. **Allows higher learning rates**: Normalized activations are more stable
3. **Acts as regularization**: Batch statistics add noise, reducing overfitting
4. **Reduces sensitivity to initialization**: Less dependent on careful weight initialization

### When to Use BatchNorm

- Deep networks (many layers)
- When training is unstable
- When you want to use higher learning rates
- CNNs (use `BatchNorm2d`)

### When BatchNorm May Not Help

- Very small batch sizes (noisy statistics)
- RNNs/Transformers (use LayerNorm instead)
- When batch statistics don't make sense (e.g., online learning)

---

## 11. Key Takeaways

1. **Batch Normalization normalizes layer inputs** using batch statistics
2. **Faster convergence**: Models with BatchNorm often train faster
3. **More stable training**: Reduces sensitivity to hyperparameters
4. **Implicit regularization**: Adds noise through batch statistics
5. **Remember `model.eval()`**: BatchNorm behaves differently at inference time

### Next Steps

- Try different learning rates with and without BatchNorm
- Experiment with `BatchNorm2d` for CNNs
- Compare with LayerNorm for sequence models