# Day 12: Batch Normalization — Taming Internal Covariate Shift

**Building LLMs from Scratch** — Following Andrej Karpathy's makemore lectures.

---

## 1. Introduction

**Batch Normalization** (Ioffe & Szegedy, 2015) addresses **internal covariate shift**: as parameters update during training, the distribution of activations at each layer shifts, making it harder for deeper layers to learn. BatchNorm normalizes activations to zero mean and unit variance (per feature dimension), then applies learned scale (γ) and shift (β). This:

- **Stabilizes training**: Gradients flow more smoothly; higher learning rates become feasible
- **Reduces sensitivity to initialization**: Networks converge faster and more reliably
- **Acts as mild regularization**: Batch statistics add noise during training

We implement BatchNorm1d from scratch and compare MLPs with and without it.

## 2. BatchNorm1d from Scratch

Implement a `BatchNorm1d` class:
- `gamma` (scale), `beta` (shift): learnable parameters
- `running_mean`, `running_var`: exponential moving averages for inference
- **Training**: use batch mean/var, update running stats
- **Eval**: use running stats (no batch dependency)

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.dim = dim
        self.eps = eps
        self.momentum = momentum
        # Learnable scale and shift
        self.gamma = torch.ones(dim, requires_grad=True)
        self.beta = torch.zeros(dim, requires_grad=True)
        # Running statistics (not learned, updated during training)
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self, x, training=True):
        # x: (batch, dim)
        if training:
            batch_mean = x.mean(dim=0)
            batch_var = x.var(dim=0, unbiased=False)
            # Update running stats: EMA
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
            mean, var = batch_mean, batch_var
        else:
            mean = self.running_mean
            var = self.running_var
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

    def parameters(self):
        return [self.gamma, self.beta]


# Sanity check
bn = BatchNorm1d(10)
x = torch.randn(32, 10)
out = bn(x, training=True)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Output mean (per dim): {out.mean(0).tolist()[:5]}...")
print(f"Output std (per dim): {out.std(0).tolist()[:5]}...")
print(f"Parameters: {[p.shape for p in bn.parameters()]}")

## 3. Dataset & Model

Same names dataset as before. Build an MLP with BatchNorm after each hidden layer (before tanh). We'll also define a version without BatchNorm for comparison.

In [None]:
torch.manual_seed(42)

words = ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'mia', 'charlotte', 'amelia', 'harper', 'evelyn',
         'abigail', 'emily', 'ella', 'elizabeth', 'camila', 'luna', 'sofia', 'avery', 'mila', 'aria']

chars = sorted(list(set(''.join(words))))
stoi = {'.': 0, **{c: i + 1 for i, c in enumerate(chars)}}
itos = {i: c for c, i in stoi.items()}

block_size = 3
X, Y = [], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for i in range(len(chs) - block_size):
        context = chs[i:i + block_size]
        target = chs[i + block_size]
        X.append([stoi[c] for c in context])
        Y.append(stoi[target])

X = torch.tensor(X)
Y = torch.tensor(Y)

n, block_size = X.shape[0], X.shape[1]
vocab_size = len(stoi)
emb_dim = 10
hidden = 200

print(f"Dataset: {len(words)} names, {n} examples")
print(f"block_size={block_size}, vocab_size={vocab_size}, hidden={hidden}")

In [None]:
def init_params(use_bn=False):
    """Initialize MLP params. use_bn=True adds BatchNorm layers."""
    C = torch.randn(vocab_size, emb_dim, requires_grad=True)
    W1 = torch.randn(block_size * emb_dim, hidden, requires_grad=True)
    b1 = torch.randn(hidden, requires_grad=True)
    W2 = torch.randn(hidden, vocab_size, requires_grad=True)
    b2 = torch.randn(vocab_size, requires_grad=True)
    params = [C, W1, b1, W2, b2]
    bn1 = BatchNorm1d(hidden) if use_bn else None
    if use_bn:
        params.extend(bn1.parameters())
    return C, W1, b1, W2, b2, bn1, params


def forward_no_bn(X, Y, C, W1, b1, W2, b2, return_h=False):
    emb = C[X].view(-1, block_size * emb_dim)
    h = torch.tanh(emb @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    return (loss, h) if return_h else loss


def forward_bn(X, Y, C, W1, b1, bn1, W2, b2, training=True, return_h=False):
    emb = C[X].view(-1, block_size * emb_dim)
    preact = emb @ W1 + b1
    h = torch.tanh(bn1(preact, training=training))
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    return (loss, h, preact) if return_h else loss


# Quick test
C, W1, b1, W2, b2, bn1, params_bn = init_params(use_bn=True)
loss = forward_bn(X, Y, C, W1, b1, bn1, W2, b2, training=True)
print(f"Forward (with BN): loss = {loss.item():.4f}")

## 4. Before vs After BatchNorm

Train two models: one without BN, one with BN. Compare activation distributions and training loss curves.

In [None]:
def train_model(use_bn, steps=5000, batch_size=32, lr=0.1):
    torch.manual_seed(42)
    C, W1, b1, W2, b2, bn1, params = init_params(use_bn=use_bn)
    losses = []
    for step in range(steps):
        ix = torch.randint(0, n, (batch_size,))
        Xb, Yb = X[ix], Y[ix]
        if use_bn:
            loss = forward_bn(Xb, Yb, C, W1, b1, bn1, W2, b2, training=True)
        else:
            loss = forward_no_bn(Xb, Yb, C, W1, b1, W2, b2)
        losses.append(loss.item())
        for p in params:
            p.grad = None
        loss.backward()
        for p in params:
            p.data -= lr * p.grad
    return losses, (C, W1, b1, bn1, W2, b2) if use_bn else (C, W1, b1, None, W2, b2)


print("Training without BatchNorm...")
losses_no_bn, model_no_bn = train_model(use_bn=False)
print("Training with BatchNorm...")
losses_bn, model_bn = train_model(use_bn=True)
print("Done.")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].plot(losses_no_bn, alpha=0.8, label='No BatchNorm')
axes[0].plot(losses_bn, alpha=0.8, label='With BatchNorm')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(losses_no_bn[-500:], alpha=0.8, label='No BatchNorm')
axes[1].plot(losses_bn[-500:], alpha=0.8, label='With BatchNorm')
axes[1].set_xlabel('Step (last 500)')
axes[1].set_ylabel('Loss')
axes[1].set_title('Loss (Zoomed)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Activation distributions: collect activations from a fresh model at init and after a few steps
def get_activation_histograms(use_bn, num_steps=100, num_batches=20):
    torch.manual_seed(42)
    C, W1, b1, W2, b2, bn1, params = init_params(use_bn=use_bn)
    acts_init, acts_after = [], []
    for _ in range(num_batches):
        ix = torch.randint(0, n, (32,))
        Xb, Yb = X[ix], Y[ix]
        if use_bn:
            _, h, preact = forward_bn(Xb, Yb, C, W1, b1, bn1, W2, b2, training=True, return_h=True)
        else:
            _, h = forward_no_bn(Xb, Yb, C, W1, b1, W2, b2, return_h=True)
        acts_init.append(h.detach().flatten())
    acts_init = torch.cat(acts_init)

    for step in range(num_steps):
        ix = torch.randint(0, n, (32,))
        Xb, Yb = X[ix], Y[ix]
        if use_bn:
            loss = forward_bn(Xb, Yb, C, W1, b1, bn1, W2, b2, training=True)
            _, h, _ = forward_bn(Xb, Yb, C, W1, b1, bn1, W2, b2, training=True, return_h=True)
        else:
            loss = forward_no_bn(Xb, Yb, C, W1, b1, W2, b2)
            _, h = forward_no_bn(Xb, Yb, C, W1, b1, W2, b2, return_h=True)
        for p in params:
            p.grad = None
        loss.backward()
        for p in params:
            p.data -= 0.1 * p.grad
    for _ in range(num_batches):
        ix = torch.randint(0, n, (32,))
        Xb, Yb = X[ix], Y[ix]
        if use_bn:
            _, h, _ = forward_bn(Xb, Yb, C, W1, b1, bn1, W2, b2, training=True, return_h=True)
        else:
            _, h = forward_no_bn(Xb, Yb, C, W1, b1, W2, b2, return_h=True)
        acts_after.append(h.detach().flatten())
    acts_after = torch.cat(acts_after)
    return acts_init, acts_after


acts_no_bn_init, acts_no_bn_after = get_activation_histograms(use_bn=False)
acts_bn_init, acts_bn_after = get_activation_histograms(use_bn=True)

fig, axes = plt.subplots(2, 2, figsize=(10, 8))

axes[0, 0].hist(acts_no_bn_init.numpy(), bins=50, alpha=0.7, color='steelblue', edgecolor='black', density=True)
axes[0, 0].set_title('No BN: Activations at Init')
axes[0, 0].set_xlabel('Activation value')

axes[0, 1].hist(acts_no_bn_after.numpy(), bins=50, alpha=0.7, color='steelblue', edgecolor='black', density=True)
axes[0, 1].set_title('No BN: Activations After 100 Steps')
axes[0, 1].set_xlabel('Activation value')

axes[1, 0].hist(acts_bn_init.numpy(), bins=50, alpha=0.7, color='coral', edgecolor='black', density=True)
axes[1, 0].set_title('With BN: Activations at Init')
axes[1, 0].set_xlabel('Activation value')

axes[1, 1].hist(acts_bn_after.numpy(), bins=50, alpha=0.7, color='coral', edgecolor='black', density=True)
axes[1, 1].set_title('With BN: Activations After 100 Steps')
axes[1, 1].set_xlabel('Activation value')

plt.suptitle('Activation Distributions at Hidden Layer')
plt.tight_layout()
plt.show()

## 5. Train vs Eval Mode

During **training**: BatchNorm uses batch mean and variance. During **eval**: it uses `running_mean` and `running_var`. Switching matters — if you forget to set eval mode, inference can behave differently (especially with small batches).

In [None]:
# Use the trained model with BN
C, W1, b1, bn1, W2, b2 = model_bn

# Single example (batch_size=1) — batch stats would be degenerate!
x_single = X[:1]
y_single = Y[:1]

with torch.no_grad():
    loss_train_mode = forward_bn(x_single, y_single, C, W1, b1, bn1, W2, b2, training=True)
    loss_eval_mode = forward_bn(x_single, y_single, C, W1, b1, bn1, W2, b2, training=False)

print(f"Batch size 1:")
print(f"  training=True  (batch stats): loss = {loss_train_mode.item():.4f}")
print(f"  training=False (running stats): loss = {loss_eval_mode.item():.4f}")
print(f"  Difference: {abs(loss_train_mode.item() - loss_eval_mode.item()):.4f}")

# With full batch, they should be closer
with torch.no_grad():
    loss_train_full = forward_bn(X, Y, C, W1, b1, bn1, W2, b2, training=True)
    loss_eval_full = forward_bn(X, Y, C, W1, b1, bn1, W2, b2, training=False)

print(f"\nFull batch:")
print(f"  training=True:  loss = {loss_train_full.item():.4f}")
print(f"  training=False: loss = {loss_eval_full.item():.4f}")
print(f"  Difference: {abs(loss_train_full.item() - loss_eval_full.item()):.4f}")
print("\n→ Always use training=False at inference!")

## 6. The Running Mean Update

Running stats are updated via exponential moving average:
$$\text{running\_mean} = (1 - \text{momentum}) \cdot \text{running\_mean} + \text{momentum} \cdot \text{batch\_mean}$$

Visualize how `running_mean` converges during training.

In [None]:
torch.manual_seed(42)
C, W1, b1, W2, b2, bn1, params = init_params(use_bn=True)

momentum = 0.1
running_means = []  # Store first 5 dims of running_mean over time
batch_means_history = []

for step in range(500):
    ix = torch.randint(0, n, (32,))
    Xb, Yb = X[ix], Y[ix]
    emb = C[Xb].view(-1, block_size * emb_dim)
    preact = emb @ W1 + b1
    batch_mean = preact.mean(dim=0)
    _ = bn1(preact, training=True)
    running_means.append(bn1.running_mean[:5].clone().detach())
    batch_means_history.append(batch_mean[:5].clone().detach())
    loss = forward_bn(Xb, Yb, C, W1, b1, bn1, W2, b2, training=True)
    for p in params:
        p.grad = None
    loss.backward()
    for p in params:
        p.data -= 0.1 * p.grad

running_means = torch.stack(running_means)  # (500, 5)
batch_means_history = torch.stack(batch_means_history)

fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

for d in range(5):
    axes[0].plot(running_means[:, d].numpy(), label=f'dim {d}', alpha=0.8)
axes[0].set_ylabel('running_mean')
axes[0].set_title('Running Mean Convergence (first 5 dims)')
axes[0].legend(loc='upper right', fontsize=8)
axes[0].grid(True, alpha=0.3)

for d in range(5):
    axes[1].plot(batch_means_history[:, d].numpy(), label=f'dim {d}', alpha=0.8)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('batch_mean')
axes[1].set_title('Batch Mean (noisy per step)')
axes[1].legend(loc='upper right', fontsize=8)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print("Running mean smooths out batch statistics → stable inference.")

---

**Building LLMs from Scratch** — [Day 12: Batch Normalization](https://omkarray.com/llm-day12.html) | [← Prev](llm_day11.ipynb) | [Next →](llm_day13.ipynb)