# Lab 1.5.4 Solution: Normalization Comparison

This notebook contains solutions to the exercises from Lab 1.5.4.

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
%matplotlib inline

## Exercise 1 Solution: Group Normalization Implementation

Group Normalization divides channels into groups and normalizes within each group.
It's a middle ground between LayerNorm (1 group) and InstanceNorm (C groups).

In [None]:
class GroupNorm:
    """
    Group Normalization implementation.
    
    Divides channels into groups and normalizes within each group.
    Works well with small batch sizes where BatchNorm struggles.
    
    Args:
        num_groups: Number of groups to divide channels into
        num_channels: Number of channels (must be divisible by num_groups)
        eps: Small constant for numerical stability
    """
    
    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
        assert num_channels % num_groups == 0, "num_channels must be divisible by num_groups"
        
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.channels_per_group = num_channels // num_groups
        
        # Learnable parameters (per channel)
        self.gamma = np.ones(num_channels)
        self.beta = np.zeros(num_channels)
        
        # Gradient accumulators
        self.dgamma = np.zeros_like(self.gamma)
        self.dbeta = np.zeros_like(self.beta)
        
        # Cache for backward pass
        self.cache = {}
    
    def forward(self, X: np.ndarray) -> np.ndarray:
        """
        Forward pass.
        
        Args:
            X: Input of shape (batch_size, num_channels) or (batch_size, num_channels, height, width)
        
        Returns:
            Normalized output of same shape
        """
        self.input_shape = X.shape
        batch_size = X.shape[0]
        
        # Handle both 2D (batch, channels) and 4D (batch, channels, height, width)
        if X.ndim == 2:
            # Reshape to (batch, groups, channels_per_group)
            X_grouped = X.reshape(batch_size, self.num_groups, self.channels_per_group)
        else:
            # 4D input: (batch, channels, height, width)
            _, C, H, W = X.shape
            X_grouped = X.reshape(batch_size, self.num_groups, self.channels_per_group, H, W)
        
        # Compute mean and variance per group
        axis = tuple(range(2, X_grouped.ndim))  # All axes except batch and group
        mean = X_grouped.mean(axis=axis, keepdims=True)
        var = X_grouped.var(axis=axis, keepdims=True)
        
        # Normalize
        X_norm = (X_grouped - mean) / np.sqrt(var + self.eps)
        
        # Reshape back
        if X.ndim == 2:
            X_norm = X_norm.reshape(batch_size, self.num_channels)
            out = self.gamma * X_norm + self.beta
        else:
            X_norm = X_norm.reshape(batch_size, C, H, W)
            out = self.gamma.reshape(1, C, 1, 1) * X_norm + self.beta.reshape(1, C, 1, 1)
        
        # Cache for backward
        self.cache = {
            'X_grouped': X_grouped,
            'X_norm': X_norm if X.ndim == 2 else X_norm.reshape(batch_size, self.num_channels, -1),
            'mean': mean,
            'var': var,
        }
        
        return out
    
    def backward(self, dout: np.ndarray) -> np.ndarray:
        """
        Backward pass.
        
        Args:
            dout: Gradient of loss w.r.t. output
        
        Returns:
            Gradient of loss w.r.t. input
        """
        X_grouped = self.cache['X_grouped']
        X_norm = self.cache['X_norm']
        mean = self.cache['mean']
        var = self.cache['var']
        
        batch_size = dout.shape[0]
        
        if dout.ndim == 2:
            # Gradients for gamma and beta
            self.dgamma = np.sum(dout * X_norm, axis=0)
            self.dbeta = np.sum(dout, axis=0)
            
            # Gradient w.r.t. normalized input
            dX_norm = dout * self.gamma
            dX_norm_grouped = dX_norm.reshape(batch_size, self.num_groups, self.channels_per_group)
        else:
            _, C, H, W = dout.shape
            # Gradients for gamma and beta
            X_norm_flat = X_norm.reshape(batch_size, C, -1)
            dout_flat = dout.reshape(batch_size, C, -1)
            self.dgamma = np.sum(np.sum(dout_flat * X_norm_flat, axis=0), axis=1)
            self.dbeta = np.sum(np.sum(dout_flat, axis=0), axis=1)
            
            # Gradient w.r.t. normalized input
            dX_norm = dout * self.gamma.reshape(1, C, 1, 1)
            dX_norm_grouped = dX_norm.reshape(batch_size, self.num_groups, self.channels_per_group, H, W)
        
        # Standard normalization backward pass
        std = np.sqrt(var + self.eps)
        axis = tuple(range(2, X_grouped.ndim))
        N = np.prod([X_grouped.shape[i] for i in axis])
        
        dvar = np.sum(dX_norm_grouped * (X_grouped - mean) * -0.5 * (var + self.eps)**(-1.5), axis=axis, keepdims=True)
        dmean = np.sum(dX_norm_grouped * -1/std, axis=axis, keepdims=True) + dvar * np.mean(-2 * (X_grouped - mean), axis=axis, keepdims=True)
        
        dX_grouped = dX_norm_grouped / std + dvar * 2 * (X_grouped - mean) / N + dmean / N
        
        # Reshape back to original shape
        dX = dX_grouped.reshape(self.input_shape)
        
        return dX
    
    def parameters(self):
        return [(self.gamma, self.dgamma), (self.beta, self.dbeta)]

In [None]:
# Test GroupNorm implementation
print("Testing Group Normalization")
print("=" * 50)

# Test with 2D input (batch, channels)
batch_size = 4
num_channels = 8
num_groups = 2  # 4 channels per group

gn = GroupNorm(num_groups=num_groups, num_channels=num_channels)
X = np.random.randn(batch_size, num_channels) * 5 + 3  # Non-zero mean, large variance

out = gn.forward(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {out.shape}")
print(f"\nInput stats (per sample):")
print(f"  Mean: {X.mean(axis=1)}")
print(f"  Std:  {X.std(axis=1)}")
print(f"\nOutput stats (per sample, should be ~0 mean, ~1 std):")
print(f"  Mean: {out.mean(axis=1)}")
print(f"  Std:  {out.std(axis=1)}")

# Verify gradients with numerical check
print("\n" + "=" * 50)
print("Gradient check...")

dout = np.random.randn(*out.shape)
dX = gn.backward(dout)

# Numerical gradient
eps = 1e-5
numerical_grad = np.zeros_like(X)
for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        X_plus = X.copy()
        X_plus[i, j] += eps
        out_plus = gn.forward(X_plus)
        
        X_minus = X.copy()
        X_minus[i, j] -= eps
        out_minus = gn.forward(X_minus)
        
        numerical_grad[i, j] = np.sum((out_plus - out_minus) * dout) / (2 * eps)

# Reset and compute analytical gradient
gn.forward(X)
dX = gn.backward(dout)

error = np.max(np.abs(dX - numerical_grad))
print(f"Max gradient error: {error:.2e}")
print(f"Gradient check: {'PASSED' if error < 1e-4 else 'FAILED'}")

## Exercise 2 Solution: Normalization Method Comparison

Compare all normalization methods on the same task.

In [None]:
# Import or define all normalization layers
class BatchNorm:
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        self.cache = {}
    
    def forward(self, X):
        if self.training:
            mean = X.mean(axis=0)
            var = X.var(axis=0)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        X_norm = (X - mean) / np.sqrt(var + self.eps)
        self.cache = {'X': X, 'X_norm': X_norm, 'mean': mean, 'var': var}
        return self.gamma * X_norm + self.beta
    
    def backward(self, dout):
        X, X_norm, mean, var = self.cache['X'], self.cache['X_norm'], self.cache['mean'], self.cache['var']
        N = X.shape[0]
        self.dgamma = np.sum(dout * X_norm, axis=0)
        self.dbeta = np.sum(dout, axis=0)
        dX_norm = dout * self.gamma
        std = np.sqrt(var + self.eps)
        dX = (1/N) * (1/std) * (N * dX_norm - np.sum(dX_norm, axis=0) - X_norm * np.sum(dX_norm * X_norm, axis=0))
        return dX


class LayerNorm:
    def __init__(self, normalized_shape, eps=1e-5):
        self.eps = eps
        self.gamma = np.ones(normalized_shape)
        self.beta = np.zeros(normalized_shape)
        self.cache = {}
    
    def forward(self, X):
        mean = X.mean(axis=-1, keepdims=True)
        var = X.var(axis=-1, keepdims=True)
        X_norm = (X - mean) / np.sqrt(var + self.eps)
        self.cache = {'X': X, 'X_norm': X_norm, 'mean': mean, 'var': var}
        return self.gamma * X_norm + self.beta
    
    def backward(self, dout):
        X, X_norm, var = self.cache['X'], self.cache['X_norm'], self.cache['var']
        N = X.shape[-1]
        self.dgamma = np.sum(dout * X_norm, axis=0)
        self.dbeta = np.sum(dout, axis=0)
        dX_norm = dout * self.gamma
        std = np.sqrt(var + self.eps)
        dX = (1/N) * (1/std) * (N * dX_norm - np.sum(dX_norm, axis=-1, keepdims=True) - X_norm * np.sum(dX_norm * X_norm, axis=-1, keepdims=True))
        return dX


class RMSNorm:
    def __init__(self, normalized_shape, eps=1e-5):
        self.eps = eps
        self.gamma = np.ones(normalized_shape)
        self.cache = {}
    
    def forward(self, X):
        rms = np.sqrt(np.mean(X**2, axis=-1, keepdims=True) + self.eps)
        X_norm = X / rms
        self.cache = {'X': X, 'X_norm': X_norm, 'rms': rms}
        return self.gamma * X_norm
    
    def backward(self, dout):
        X, X_norm, rms = self.cache['X'], self.cache['X_norm'], self.cache['rms']
        N = X.shape[-1]
        self.dgamma = np.sum(dout * X_norm, axis=0)
        dX_norm = dout * self.gamma
        dX = (1/rms) * (dX_norm - X_norm * np.mean(dX_norm * X_norm, axis=-1, keepdims=True))
        return dX

In [None]:
# Load MNIST data
import gzip, os, urllib.request

def load_mnist(path='../data'):
    os.makedirs(path, exist_ok=True)
    base_url = 'http://yann.lecun.com/exdb/mnist/'
    files = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
             't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
    for f in files:
        fp = os.path.join(path, f)
        if not os.path.exists(fp): urllib.request.urlretrieve(base_url + f, fp)
    def load_img(fp): 
        with gzip.open(fp) as f: f.read(16); return np.frombuffer(f.read(), np.uint8).reshape(-1,784).astype(np.float32)/255
    def load_lbl(fp): 
        with gzip.open(fp) as f: f.read(8); return np.frombuffer(f.read(), np.uint8)
    return (load_img(os.path.join(path, files[0])), load_lbl(os.path.join(path, files[1])),
            load_img(os.path.join(path, files[2])), load_lbl(os.path.join(path, files[3])))

X_train, y_train, X_test, y_test = load_mnist()
X_train, y_train = X_train[:5000], y_train[:5000]  # Use subset for speed

In [None]:
class NormalizedMLP:
    """MLP with configurable normalization."""
    
    def __init__(self, layer_sizes, norm_type='none'):
        self.norm_type = norm_type
        self.layers = []
        self.norms = []
        
        for i in range(len(layer_sizes) - 1):
            W = np.random.randn(layer_sizes[i], layer_sizes[i+1]) * np.sqrt(2.0 / layer_sizes[i])
            b = np.zeros(layer_sizes[i+1])
            self.layers.append({'W': W, 'b': b, 'cache': {}})
            
            # Add normalization layer (except for output layer)
            if i < len(layer_sizes) - 2:
                if norm_type == 'batch':
                    self.norms.append(BatchNorm(layer_sizes[i+1]))
                elif norm_type == 'layer':
                    self.norms.append(LayerNorm(layer_sizes[i+1]))
                elif norm_type == 'rms':
                    self.norms.append(RMSNorm(layer_sizes[i+1]))
                elif norm_type == 'group':
                    num_groups = min(4, layer_sizes[i+1])  # Use 4 groups or fewer
                    while layer_sizes[i+1] % num_groups != 0:
                        num_groups -= 1
                    self.norms.append(GroupNorm(num_groups, layer_sizes[i+1]))
                else:
                    self.norms.append(None)
    
    def forward(self, X, training=True):
        out = X
        
        for i, layer in enumerate(self.layers[:-1]):
            layer['cache']['X'] = out
            out = out @ layer['W'] + layer['b']
            layer['cache']['Z'] = out
            
            # Apply normalization
            if self.norms[i] is not None:
                if hasattr(self.norms[i], 'training'):
                    self.norms[i].training = training
                out = self.norms[i].forward(out)
            
            out = np.maximum(0, out)  # ReLU
            layer['cache']['A'] = out
        
        # Output layer
        self.layers[-1]['cache']['X'] = out
        out = out @ self.layers[-1]['W'] + self.layers[-1]['b']
        
        # Softmax
        out_shifted = out - np.max(out, axis=1, keepdims=True)
        exp_out = np.exp(out_shifted)
        self.probs = exp_out / np.sum(exp_out, axis=1, keepdims=True)
        
        return self.probs
    
    def backward(self, targets, lr):
        batch_size = len(targets)
        grad = self.probs.copy()
        grad[np.arange(batch_size), targets] -= 1
        
        # Output layer
        layer = self.layers[-1]
        dW = layer['cache']['X'].T @ grad / batch_size
        layer['W'] -= lr * dW
        layer['b'] -= lr * np.mean(grad, axis=0)
        grad = grad @ layer['W'].T
        
        # Hidden layers
        for i in range(len(self.layers) - 2, -1, -1):
            layer = self.layers[i]
            
            # ReLU backward
            grad = grad * (layer['cache']['A'] > 0)
            
            # Normalization backward
            if self.norms[i] is not None:
                grad = self.norms[i].backward(grad)
                # Update norm parameters
                self.norms[i].gamma -= lr * self.norms[i].dgamma
                if hasattr(self.norms[i], 'beta'):
                    self.norms[i].beta -= lr * self.norms[i].dbeta
            
            X = layer['cache']['X']
            dW = X.T @ grad / batch_size
            layer['W'] -= lr * dW
            layer['b'] -= lr * np.mean(grad, axis=0)
            grad = grad @ layer['W'].T
    
    def predict(self, X):
        return np.argmax(self.forward(X, training=False), axis=1)

In [None]:
# Compare normalization methods
print("Comparing Normalization Methods")
print("=" * 70)

norm_types = ['none', 'batch', 'layer', 'rms', 'group']
all_histories = {}

for norm_type in norm_types:
    print(f"\nTraining with {norm_type.upper()} normalization...")
    np.random.seed(42)
    
    model = NormalizedMLP([784, 256, 128, 10], norm_type=norm_type)
    history = {'train_acc': [], 'test_acc': []}
    
    for epoch in range(20):
        # Training
        indices = np.random.permutation(len(X_train))
        for start in range(0, len(X_train), 64):
            batch_idx = indices[start:start+64]
            model.forward(X_train[batch_idx], training=True)
            model.backward(y_train[batch_idx], 0.1)
        
        train_acc = np.mean(model.predict(X_train) == y_train)
        test_acc = np.mean(model.predict(X_test[:2000]) == y_test[:2000])
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        
        if (epoch + 1) % 5 == 0:
            print(f"  Epoch {epoch+1:2d} | Train: {train_acc:.2%} | Test: {test_acc:.2%}")
    
    all_histories[norm_type] = history
    print(f"  Final test accuracy: {test_acc:.2%}")

In [None]:
# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = {'none': 'gray', 'batch': 'blue', 'layer': 'green', 'rms': 'red', 'group': 'purple'}

for norm_type, history in all_histories.items():
    epochs = range(1, len(history['train_acc']) + 1)
    axes[0].plot(epochs, history['train_acc'], color=colors[norm_type], 
                 label=f'{norm_type.upper()}', linewidth=2)
    axes[1].plot(epochs, history['test_acc'], color=colors[norm_type], 
                 label=f'{norm_type.upper()}', linewidth=2)

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Training Accuracy by Normalization Method')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Test Accuracy by Normalization Method')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary table
print("\n" + "=" * 50)
print("Final Results Summary")
print("=" * 50)
print(f"{'Method':<12} {'Final Train':<15} {'Final Test':<15}")
print("-" * 50)
for norm_type, history in all_histories.items():
    print(f"{norm_type.upper():<12} {history['train_acc'][-1]:<15.2%} {history['test_acc'][-1]:<15.2%}")

## Exercise 3 Solution: Small Batch Training Comparison

BatchNorm struggles with small batches because batch statistics are noisy.
Let's compare normalization methods with batch size = 4.

In [None]:
# Small batch training experiment
print("Small Batch Training (batch_size=4)")
print("=" * 70)

small_batch_histories = {}
batch_size = 4  # Very small batch size

for norm_type in ['none', 'batch', 'layer', 'group']:
    print(f"\nTraining with {norm_type.upper()} normalization (batch_size={batch_size})...")
    np.random.seed(42)
    
    model = NormalizedMLP([784, 128, 64, 10], norm_type=norm_type)
    history = {'train_acc': [], 'test_acc': []}
    
    for epoch in range(15):
        indices = np.random.permutation(len(X_train))
        for start in range(0, len(X_train), batch_size):
            batch_idx = indices[start:start+batch_size]
            if len(batch_idx) < batch_size:
                continue  # Skip incomplete batches
            model.forward(X_train[batch_idx], training=True)
            model.backward(y_train[batch_idx], 0.05)  # Lower LR for small batches
        
        train_acc = np.mean(model.predict(X_train) == y_train)
        test_acc = np.mean(model.predict(X_test[:2000]) == y_test[:2000])
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
    
    small_batch_histories[norm_type] = history
    print(f"  Final test accuracy: {test_acc:.2%}")

# Plot small batch comparison
plt.figure(figsize=(10, 5))
for norm_type, history in small_batch_histories.items():
    plt.plot(history['test_acc'], label=norm_type.upper(), linewidth=2)

plt.xlabel('Epoch')
plt.ylabel('Test Accuracy')
plt.title('Small Batch Training: BatchNorm vs LayerNorm vs GroupNorm')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\n" + "=" * 50)
print("Key Insight: BatchNorm struggles with small batches!")
print("LayerNorm and GroupNorm are more stable because they")
print("normalize within each sample, not across the batch.")
print("=" * 50)

---

## Key Takeaways

1. **BatchNorm** - Best for large batches, CNN training
2. **LayerNorm** - Best for transformers, RNNs, small batches
3. **GroupNorm** - Good middle ground, works well in vision with small batches
4. **RMSNorm** - Simpler, faster, used in modern LLMs (Llama, etc.)

### When to use what:
- **Large batch CNN training**: BatchNorm
- **Transformers/NLP**: LayerNorm or RMSNorm
- **Small batch or instance-level processing**: LayerNorm or GroupNorm
- **Maximum efficiency in LLMs**: RMSNorm