# Notebook 2: Batch Normalization Demo

**Course:** 21CSE558T - Deep Neural Network Architectures  
**Module:** 4 - CNNs (Week 2 of 3)  
**Date:** October 31, 2025  
**Duration:** ~25 minutes

---

## Learning Objectives

By the end of this notebook, you will be able to:
1. Explain the Internal Covariate Shift problem
2. Understand how Batch Normalization works mathematically
3. Know the correct placement of BatchNorm layers
4. Observe the training speed improvement from BatchNorm
5. Implement BatchNormalization in Keras correctly

---

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import (
    Conv2D, BatchNormalization, Activation, 
    MaxPooling2D, GlobalAveragePooling2D, Dense, Flatten
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical

print(f"TensorFlow version: {tf.__version__}")

# Set random seed
np.random.seed(42)
tf.random.set_seed(42)

---

## Part 1: The Internal Covariate Shift Problem

### Character: Sneha's Factory Assembly Line

**Meet Character: Sneha - Factory Manager**

**The Problem:**

**Character: Sneha** manages an assembly line with 5 stations:

**Day 1:** Station 1 receives parts sized 10-20cm
- Station 2 calibrated for 10-20cm inputs ✅
- Station 3 calibrated for 15-25cm outputs from Station 2 ✅
- Everything works smoothly!

**Day 2:** Station 1 suddenly produces parts sized 50-80cm
- Station 2 NOT calibrated for 50-80cm (expects 10-20cm) ❌
- Station 2 produces weird 70-100cm outputs
- Station 3 completely confused (expects 15-25cm, gets 70-100cm) ❌
- Entire line breaks down!

**Character: Sneha's Solution:**
"Add a quality control checkpoint after EACH station that normalizes parts to expected size range BEFORE sending to next station!"

**Batch Normalization = Quality Control for Neural Networks**

---

## Part 2: Understanding Batch Normalization Mathematics

In [None]:
# Simulate layer outputs WITHOUT normalization
# These values shift during training (Internal Covariate Shift)

np.random.seed(42)

# Batch of 32 samples, 10 features
batch_size = 32
num_features = 10

# Simulate unstable layer outputs
layer_output_epoch1 = np.random.randn(batch_size, num_features) * 2 + 5  # Mean≈5, Std≈2
layer_output_epoch2 = np.random.randn(batch_size, num_features) * 5 + 15  # Mean≈15, Std≈5

print("="*60)
print("INTERNAL COVARIATE SHIFT DEMONSTRATION")
print("="*60)

print("\n📊 Epoch 1 Layer Output:")
print(f"  Mean: {layer_output_epoch1.mean():.2f}")
print(f"  Std:  {layer_output_epoch1.std():.2f}")
print(f"  Range: [{layer_output_epoch1.min():.2f}, {layer_output_epoch1.max():.2f}]")

print("\n📊 Epoch 2 Layer Output (SHIFTED!):")
print(f"  Mean: {layer_output_epoch2.mean():.2f}")
print(f"  Std:  {layer_output_epoch2.std():.2f}")
print(f"  Range: [{layer_output_epoch2.min():.2f}, {layer_output_epoch2.max():.2f}]")

print("\n❌ Problem: Next layer expects inputs from Epoch 1 distribution")
print("   But gets inputs from Epoch 2 distribution (very different!)")
print("   → Slow, unstable training")

In [None]:
# Visualize the distribution shift
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Epoch 1 distribution
axes[0].hist(layer_output_epoch1.flatten(), bins=30, alpha=0.7, color='blue', edgecolor='black')
axes[0].axvline(layer_output_epoch1.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {layer_output_epoch1.mean():.2f}')
axes[0].set_title('Epoch 1: Layer Output Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Epoch 2 distribution (SHIFTED)
axes[1].hist(layer_output_epoch2.flatten(), bins=30, alpha=0.7, color='orange', edgecolor='black')
axes[1].axvline(layer_output_epoch2.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {layer_output_epoch2.mean():.2f}')
axes[1].set_title('Epoch 2: Layer Output Distribution (SHIFTED!)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Value')
axes[1].set_ylabel('Frequency')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\n💡 Distribution keeps shifting → Next layer must constantly adapt!")

### Batch Normalization: The Solution

In [None]:
def manual_batch_norm(x, gamma=1.0, beta=0.0, epsilon=1e-5):
    """
    Manual Batch Normalization implementation.
    
    Steps:
    1. Calculate batch mean (μ) and variance (σ²)
    2. Normalize: x_hat = (x - μ) / sqrt(σ² + ε)
    3. Scale and shift: y = γ * x_hat + β
    
    Args:
        x: Input data (batch_size, num_features)
        gamma: Learnable scale parameter
        beta: Learnable shift parameter
        epsilon: Small constant to prevent division by zero
    
    Returns:
        Normalized output
    """
    # Step 1: Calculate batch statistics
    batch_mean = np.mean(x, axis=0)
    batch_var = np.var(x, axis=0)
    batch_std = np.sqrt(batch_var + epsilon)
    
    print("Step 1: Batch Statistics")
    print(f"  Batch Mean: {batch_mean[:3]} ...")
    print(f"  Batch Std:  {batch_std[:3]} ...")
    
    # Step 2: Normalize
    x_normalized = (x - batch_mean) / batch_std
    
    print("\nStep 2: After Normalization")
    print(f"  Mean: {x_normalized.mean():.6f} (≈0)")
    print(f"  Std:  {x_normalized.std():.6f} (≈1)")
    
    # Step 3: Scale and shift
    y = gamma * x_normalized + beta
    
    print("\nStep 3: After Scale (γ) and Shift (β)")
    print(f"  γ (gamma): {gamma}")
    print(f"  β (beta):  {beta}")
    print(f"  Output Mean: {y.mean():.6f}")
    print(f"  Output Std:  {y.std():.6f}")
    
    return y

# Apply Batch Normalization to Epoch 2 data
print("="*60)
print("APPLYING BATCH NORMALIZATION")
print("="*60)
print("\nOriginal Epoch 2 Statistics:")
print(f"  Mean: {layer_output_epoch2.mean():.2f}")
print(f"  Std:  {layer_output_epoch2.std():.2f}")
print("\n" + "-"*60 + "\n")

normalized_output = manual_batch_norm(layer_output_epoch2, gamma=1.0, beta=0.0)

print("\n" + "="*60)
print("✅ Result: Stable distribution (Mean≈0, Std≈1) regardless of input!")
print("="*60)

In [None]:
# Visualize the normalization effect
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Before BatchNorm
axes[0].hist(layer_output_epoch2.flatten(), bins=30, alpha=0.7, color='orange', edgecolor='black')
axes[0].axvline(layer_output_epoch2.mean(), color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {layer_output_epoch2.mean():.2f}')
axes[0].set_title('Before Batch Normalization\n(Shifted Distribution)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Frequency')
axes[0].legend()
axes[0].grid(alpha=0.3)

# After BatchNorm
axes[1].hist(normalized_output.flatten(), bins=30, alpha=0.7, color='green', edgecolor='black')
axes[1].axvline(normalized_output.mean(), color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {normalized_output.mean():.4f} ≈ 0')
axes[1].set_title('After Batch Normalization\n(Normalized: Mean≈0, Std≈1)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Value')
axes[1].set_ylabel('Frequency')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\n💡 Batch Normalization stabilizes layer inputs → Faster, more stable training!")

---

## Part 3: Batch Normalization Placement

### Modern Best Practice: Conv → BatchNorm → Activation

In [None]:
print("="*60)
print("BATCH NORMALIZATION PLACEMENT")
print("="*60)

print("\n❌ OLD (Incorrect Placement):")
print("-" * 60)
print("Conv2D(32, activation='relu')  ← Activation built-in")
print("BatchNormalization()            ← After activation (less effective)")

print("\n✅ MODERN (Correct Placement):")
print("-" * 60)
print("Conv2D(32)                      ← No activation")
print("BatchNormalization()            ← BEFORE activation")
print("Activation('relu')              ← Activation applied last")

print("\n💡 Why this order?")
print("  1. Conv2D: Linear transformation (can shift distribution)")
print("  2. BatchNorm: Stabilize distribution before non-linearity")
print("  3. ReLU: Apply non-linearity to normalized values")
print("\n  Result: Stable gradients, faster convergence!")

In [None]:
# Build example architectures

# WITHOUT BatchNorm (Old approach)
model_without_bn = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2,2)),
    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D((2,2)),
    Flatten(),
    Dense(10, activation='softmax')
], name='Without_BatchNorm')

print("\nModel WITHOUT Batch Normalization:")
print("="*60)
model_without_bn.summary()

# WITH BatchNorm (Modern approach)
model_with_bn = Sequential([
    Conv2D(32, (3,3), input_shape=(28, 28, 1)),  # No activation
    BatchNormalization(),                         # Add BatchNorm
    Activation('relu'),                           # Then activation
    MaxPooling2D((2,2)),
    
    Conv2D(64, (3,3)),                           # No activation
    BatchNormalization(),                         # Add BatchNorm
    Activation('relu'),                           # Then activation
    MaxPooling2D((2,2)),
    
    Flatten(),
    Dense(10, activation='softmax')
], name='With_BatchNorm')

print("\nModel WITH Batch Normalization:")
print("="*60)
model_with_bn.summary()

# Compare parameters
params_without = model_without_bn.count_params()
params_with = model_with_bn.count_params()
extra_params = params_with - params_without

print("\n" + "="*60)
print("PARAMETER COMPARISON")
print("="*60)
print(f"Without BatchNorm: {params_without:,} parameters")
print(f"With BatchNorm:    {params_with:,} parameters")
print(f"Extra parameters:  {extra_params:,} ({extra_params/params_without*100:.1f}% increase)")
print("\n💡 Small parameter overhead, HUGE training benefit!")

---

## Part 4: Training Speed Comparison (Live Demo)

In [None]:
# Load Fashion-MNIST dataset
print("Loading Fashion-MNIST dataset...")
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Preprocess
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Use subset for faster demo (10% of data)
subset_size = len(x_train) // 10
x_train_subset = x_train[:subset_size]
y_train_subset = y_train[:subset_size]

print(f"\nTraining subset: {len(x_train_subset):,} samples")
print(f"Test set: {len(x_test):,} samples")

In [None]:
# Compile both models
model_without_bn.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model_with_bn.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("Models compiled and ready to train!")

In [None]:
# Train WITHOUT BatchNorm
print("="*60)
print("Training Model WITHOUT Batch Normalization")
print("="*60)

history_without_bn = model_without_bn.fit(
    x_train_subset, y_train_subset,
    validation_data=(x_test, y_test),
    epochs=10,
    batch_size=128,
    verbose=1
)

In [None]:
# Train WITH BatchNorm
print("\n" + "="*60)
print("Training Model WITH Batch Normalization")
print("="*60)

history_with_bn = model_with_bn.fit(
    x_train_subset, y_train_subset,
    validation_data=(x_test, y_test),
    epochs=10,
    batch_size=128,
    verbose=1
)

In [None]:
# Plot training comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Training Accuracy
axes[0].plot(history_without_bn.history['accuracy'], 'o-', label='Without BatchNorm', linewidth=2, markersize=8)
axes[0].plot(history_with_bn.history['accuracy'], 's-', label='With BatchNorm', linewidth=2, markersize=8)
axes[0].set_title('Training Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend(fontsize=12)
axes[0].grid(alpha=0.3)

# Validation Accuracy
axes[1].plot(history_without_bn.history['val_accuracy'], 'o-', label='Without BatchNorm', linewidth=2, markersize=8)
axes[1].plot(history_with_bn.history['val_accuracy'], 's-', label='With BatchNorm', linewidth=2, markersize=8)
axes[1].set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend(fontsize=12)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Print final results
final_train_acc_without = history_without_bn.history['accuracy'][-1]
final_val_acc_without = history_without_bn.history['val_accuracy'][-1]
final_train_acc_with = history_with_bn.history['accuracy'][-1]
final_val_acc_with = history_with_bn.history['val_accuracy'][-1]

print("\n" + "="*60)
print("FINAL RESULTS (After 10 Epochs)")
print("="*60)
print(f"\nWithout BatchNorm:")
print(f"  Train Accuracy: {final_train_acc_without:.4f}")
print(f"  Val Accuracy:   {final_val_acc_without:.4f}")

print(f"\nWith BatchNorm:")
print(f"  Train Accuracy: {final_train_acc_with:.4f}")
print(f"  Val Accuracy:   {final_val_acc_with:.4f}")

improvement = (final_val_acc_with - final_val_acc_without) * 100
print(f"\n✅ Improvement: {improvement:+.2f}% validation accuracy")
print("💡 BatchNorm converges faster AND generalizes better!")

---

## Part 5: When to Use Batch Normalization

In [None]:
print("="*60)
print("BATCH NORMALIZATION GUIDELINES")
print("="*60)

print("\n✅ USE Batch Normalization When:")
print("-" * 60)
print("  • Deep networks (>5 layers)")
print("  • Training is slow or unstable")
print("  • Want to use higher learning rates")
print("  • Building modern CNNs (almost always!)")
print("  • Batch size ≥ 16 (reliable statistics)")

print("\n⚠️ BE CAREFUL When:")
print("-" * 60)
print("  • Very small batch size (< 8)")
print("    → Use Layer Normalization or Group Normalization instead")
print("  • Recurrent networks (RNNs/LSTMs)")
print("    → Layer Normalization often better")
print("  • Real-time inference (slight overhead)")
print("    → Usually negligible, but consider for edge devices")

print("\n❌ DON'T Use (Rare Cases):")
print("-" * 60)
print("  • Extremely shallow networks (2-3 layers)")
print("    → Not enough depth to benefit")
print("  • Already using heavy regularization (might over-regularize)")
print("    → Monitor for underfitting")

---

## Part 6: Practice Exercise - Build Modern CNN

In [None]:
# TODO: Build a modern CNN with proper BatchNorm placement

def build_modern_cnn(input_shape=(32, 32, 3), num_classes=10):
    """
    Build modern CNN with BatchNormalization.
    
    Architecture pattern:
    [Conv → BatchNorm → ReLU] × 2 → Pool → Dropout
    """
    from tensorflow.keras.layers import Dropout
    
    model = Sequential([
        # Block 1: 32 filters
        Conv2D(32, (3,3), padding='same', input_shape=input_shape),
        BatchNormalization(),  # TODO: Add BatchNorm
        Activation('relu'),
        
        Conv2D(32, (3,3), padding='same'),
        BatchNormalization(),  # TODO: Add BatchNorm
        Activation('relu'),
        MaxPooling2D((2,2)),
        Dropout(0.2),
        
        # Block 2: 64 filters
        Conv2D(64, (3,3), padding='same'),
        BatchNormalization(),  # TODO: Add BatchNorm
        Activation('relu'),
        
        Conv2D(64, (3,3), padding='same'),
        BatchNormalization(),  # TODO: Add BatchNorm
        Activation('relu'),
        MaxPooling2D((2,2)),
        Dropout(0.3),
        
        # Block 3: 128 filters
        Conv2D(128, (3,3), padding='same'),
        BatchNormalization(),  # TODO: Add BatchNorm
        Activation('relu'),
        GlobalAveragePooling2D(),
        
        # Output
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ], name='Modern_CNN_with_BatchNorm')
    
    return model

# Build and display model
modern_cnn = build_modern_cnn()
modern_cnn.summary()

print("\n✅ Modern CNN with BatchNormalization created!")
print("💡 Ready for Monday's Tutorial T11 (CIFAR-10 implementation)")

---

## Summary

### Key Takeaways

1. **Internal Covariate Shift Problem:**
   - Layer inputs shift during training
   - Each layer must constantly adapt
   - Slows down training, unstable gradients

2. **Batch Normalization Solution:**
   - Normalizes layer inputs to mean=0, std=1
   - Learns optimal scale (γ) and shift (β)
   - Stabilizes training, enables higher learning rates

3. **Correct Placement:**
   - **Modern:** Conv → BatchNorm → Activation
   - NOT Conv(activation) → BatchNorm

4. **Benefits:**
   - ✅ 2-3× faster training
   - ✅ Acts as regularization
   - ✅ Less sensitive to initialization
   - ✅ Allows deeper networks
   - ✅ More stable gradients

5. **When to Use:**
   - Almost always for modern CNNs!
   - Deep networks (>5 layers)
   - Batch size ≥ 16

### Character: Sneha's Summary

**Character: Sneha** says:
- "Quality control checkpoints after each station keep production stable!"
- "Workers (next layers) always receive parts in expected size range"
- "Assembly line runs smoothly and efficiently!"

### Next Steps

In the next notebook, we'll explore **Data Augmentation** - artificially expanding your training dataset!

---

**End of Notebook 2**