# Tutorial: makemore (part 3) — init, activation stats, and batch norm (exercise-based)

This notebook turns the **video transcript** + the **original companion notebook** into a guided, exercise-based tutorial.

**Format per exercise**:  
1) Motivation / concept explanation  
2) Exercise (you do it)  
3) Solution (check yourself)

We’ll start from “nothing” and end with:
- a working character-level MLP language model,
- a deep “PyTorchified” MLP,
- careful diagnosis of **logits**, **tanh saturation**, and **gradient flow**,
- principled initialization (fan-in scaling + gain),
- batch norm (including running stats + inference behavior),
- diagnostic plots (activations, gradients, update-to-data ratios).

> Notes:
> - This is **character-level** language modeling on the popular `names.txt` dataset used in makemore.
> - Training to full convergence can take time on CPU. Default settings are chosen to be reasonable; you can scale them up.

---

In [None]:
# Exercise 0 (Setup): imports, reproducibility, and data download helpers

import os
import math
import random
from dataclasses import dataclass

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# notebook niceties
%matplotlib inline
plt.rcParams["figure.figsize"] = (8, 4)

# reproducibility
GLOBAL_SEED = 2147483647
g = torch.Generator().manual_seed(GLOBAL_SEED)
random.seed(42)

def get_device():
    # For simplicity we default to CPU, but allow GPU if available.
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = get_device()
device

## Exercise 1 — Load words and build a vocabulary

### Motivation / concept
We will model the next character in a name given the previous `block_size` characters.  
To do this we need:
- the list of training words (names),
- a character vocabulary,
- integer encodings (`stoi`, `itos`).

In makemore, the special token `'.'` is used both as:
- a **start/end** marker (we prepend context of dots, and we terminate names with a dot).

### Exercise
1) Load the file `names.txt`. If it doesn’t exist, download it.
2) Print:
   - first 8 words,
   - total number of words,
   - vocabulary size,
   - `stoi`/`itos` for sanity.

### Solution
Run the next cell.


In [None]:
# Solution: load words + vocabulary

# If names.txt is missing, download a copy from the original makemore dataset location.
# (If your environment has no internet, you can place names.txt next to this notebook.)
NAMES_PATH = "names.txt"
if not os.path.exists(NAMES_PATH):
    import urllib.request
    url = "https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
    print(f"Downloading {url} -> {NAMES_PATH}")
    urllib.request.urlretrieve(url, NAMES_PATH)

words = open(NAMES_PATH, "r", encoding="utf-8").read().splitlines()

print("First 8 words:", words[:8])
print("Number of words:", len(words))

# build the vocabulary
chars = sorted(list(set("".join(words))))
stoi = {ch:i+1 for i,ch in enumerate(chars)}  # reserve 0 for '.'
stoi["."] = 0
itos = {i:ch for ch,i in stoi.items()}
vocab_size = len(itos)

print("vocab_size:", vocab_size)
print("itos:", itos)

## Exercise 2 — Build the dataset (X, Y) with a context window

### Motivation / concept
Each training example is:
- `X`: a length-`block_size` context of character indices
- `Y`: the next character index

Example with `block_size=3` for word `"emma"`:

- context `...` → target `e`
- context `..e` → target `m`
- context `.em` → target `m`
- context `emm` → target `a`
- context `mma` → target `.` (end token)

We will create three splits:
- train (80%), dev/val (10%), test (10%)

### Exercise
Implement `build_dataset(words_subset, block_size)`:
- iterate through each word + `'.'`
- keep a sliding window context list
- collect X and Y tensors

### Solution
Run the cell. Also verify the shapes.


In [None]:
# Solution: dataset builder

block_size = 3  # context length

def build_dataset(words_subset, block_size):
    X, Y = [], []
    for w in words_subset:
        context = [0] * block_size
        for ch in w + ".":
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]  # shift left, append
    X = torch.tensor(X, dtype=torch.long)
    Y = torch.tensor(Y, dtype=torch.long)
    return X, Y

# split the words reproducibly
words_shuf = words[:]
random.shuffle(words_shuf)
n1 = int(0.8 * len(words_shuf))
n2 = int(0.9 * len(words_shuf))

Xtr, Ytr = build_dataset(words_shuf[:n1], block_size)
Xdev, Ydev = build_dataset(words_shuf[n1:n2], block_size)
Xte, Yte = build_dataset(words_shuf[n2:], block_size)

print("train:", Xtr.shape, Ytr.shape)
print("dev:  ", Xdev.shape, Ydev.shape)
print("test: ", Xte.shape, Yte.shape)

## Exercise 3 — Implement the baseline MLP forward pass (embedding → tanh → logits)

### Motivation / concept
We’ll implement the exact MLP structure from the original notebook:

1) Embed each character index into a learnable vector (embedding table `C`)
2) Concatenate the `block_size` embeddings into one vector
3) Linear layer → `tanh` nonlinearity
4) Linear layer to produce logits for all characters
5) Cross-entropy loss

### Exercise
Create a function `forward_mlp(Xb, params)` that returns `(logits, loss)` where:
- `Xb` is a minibatch of contexts (shape `(B, block_size)`)
- `Yb` is targets (shape `(B,)`)
- `params` contains `C, W1, b1?, W2, b2`

We’ll start with a **single hidden layer**.

### Solution
Run the next cell.


In [None]:
# Solution: baseline single-hidden-layer MLP forward pass

@dataclass
class MLPParams:
    C: torch.Tensor
    W1: torch.Tensor
    b1: torch.Tensor | None
    W2: torch.Tensor
    b2: torch.Tensor

def init_baseline_params(vocab_size, n_embd=10, n_hidden=200, block_size=3, *,
                         seed=GLOBAL_SEED, device=device):
    g_local = torch.Generator(device="cpu").manual_seed(seed)  # deterministic on CPU generator
    C  = torch.randn((vocab_size, n_embd), generator=g_local, device=device)
    W1 = torch.randn((n_embd*block_size, n_hidden), generator=g_local, device=device)
    b1 = torch.randn((n_hidden,), generator=g_local, device=device)  # (we'll later discuss why this can be problematic)
    W2 = torch.randn((n_hidden, vocab_size), generator=g_local, device=device)
    b2 = torch.randn((vocab_size,), generator=g_local, device=device)
    return MLPParams(C=C, W1=W1, b1=b1, W2=W2, b2=b2)

def forward_mlp(Xb, Yb, p: MLPParams, block_size=3):
    # Xb: (B, block_size) of integer indices
    emb = p.C[Xb]                 # (B, block_size, n_embd)
    embcat = emb.view(emb.shape[0], -1)  # (B, block_size * n_embd)
    hpre = embcat @ p.W1
    if p.b1 is not None:
        hpre = hpre + p.b1
    h = torch.tanh(hpre)          # (B, n_hidden)
    logits = h @ p.W2 + p.b2      # (B, vocab_size)
    loss = F.cross_entropy(logits, Yb)
    return logits, loss

# quick sanity check on a small batch
p0 = init_baseline_params(vocab_size=vocab_size, device=device)
ix = torch.randint(0, Xtr.shape[0], (32,), generator=g, device=device)
Xb, Yb = Xtr[ix].to(device), Ytr[ix].to(device)
logits, loss = forward_mlp(Xb, Yb, p0, block_size=block_size)
logits.shape, loss.item()

## Exercise 4 — What loss should we expect at initialization?

### Motivation / concept
At initialization, we typically want the network to behave like:
- it has **no idea** what the correct next character is,
- so it outputs an approximately **uniform** distribution over the `vocab_size` characters.

If `p = 1/vocab_size`, then the negative log-likelihood loss is `-log(p)`.

For `vocab_size=27`, that’s about `3.296`.

If you see a loss like **27** at iteration 0, it’s a big red flag: the network is **very confident and very wrong**, usually due to extreme logits.

### Exercise
Compute:
- expected init loss = `-log(1/vocab_size)`
- compare it with the actual first-batch loss from the baseline init

### Solution
Run this cell.


In [None]:
# Solution: expected init loss vs actual init loss

expected_loss = -torch.log(torch.tensor(1.0 / vocab_size)).item()
print("vocab_size:", vocab_size)
print("Expected loss at init (uniform):", expected_loss)

# take a fresh init and compute first-batch loss
p_bad = init_baseline_params(vocab_size=vocab_size, device=device)
logits0, loss0 = forward_mlp(Xb, Yb, p_bad, block_size=block_size)
print("Actual first-batch loss:", loss0.item())

# inspect the scale of logits (first row)
print("logits0[0] min/max:", logits0[0].min().item(), logits0[0].max().item())

## Exercise 5 — Toy demo: extreme logits cause “confidently wrong” loss explosions

### Motivation / concept
Softmax converts logits to probabilities:

- If all logits are equal (e.g. all zeros) → uniform probabilities → loss ≈ `log(vocab_size)`.
- If logits are extreme → one class gets probability ≈ 1, others ≈ 0.
  - If that class is incorrect → loss becomes huge.

This is why **logit scale** matters for stable initialization.

### Exercise
For a toy 4-class problem:
1) Sample random logits.
2) Compute loss for different scaling factors: `1x`, `10x`, `50x`.
3) Observe how loss behaves.

### Solution
Run the cell.


In [None]:
# Solution: toy softmax loss explosion with logit scaling

def toy_loss(logits, y):
    # logits: (K,), y integer label
    return F.cross_entropy(logits.view(1,-1), torch.tensor([y], device=logits.device)).item()

g_local = torch.Generator().manual_seed(GLOBAL_SEED + 123)
K = 4
y = 2

base_logits = torch.randn(K, generator=g_local, device=device)
for scale in [1.0, 10.0, 50.0]:
    L = toy_loss(base_logits * scale, y)
    probs = F.softmax(base_logits * scale, dim=0)
    print(f"scale={scale:>4}: loss={L:8.4f}, probs={probs.detach().cpu().numpy()}")

## Exercise 6 — Fix #1: make the output layer less overconfident at init

### Motivation / concept
In the transcript, the first fix is:
- set the output bias `b2` to (near) zero,
- shrink `W2` so logits are near zero.

This drives initial probabilities toward uniform → initial loss near `log(vocab_size)`.

Important subtlety:
- Setting weights exactly to zero can cause **symmetry problems** in many settings.
- Here, the output layer can be *tiny* (e.g. `0.01`) instead of 0.

### Exercise
Modify initialization so:
- `b2 = 0`,
- `W2 *= 0.01` (or similar).

Verify:
- initial loss is near expected,
- logits range is much smaller.

### Solution
Run the cell.


In [None]:
# Solution: output layer fix

def init_fixed_output_params(vocab_size, n_embd=10, n_hidden=200, block_size=3, *,
                             seed=GLOBAL_SEED, W2_scale=0.01, device=device):
    g_local = torch.Generator(device="cpu").manual_seed(seed)
    C  = torch.randn((vocab_size, n_embd), generator=g_local, device=device)
    W1 = torch.randn((n_embd*block_size, n_hidden), generator=g_local, device=device)
    b1 = torch.randn((n_hidden,), generator=g_local, device=device)
    W2 = torch.randn((n_hidden, vocab_size), generator=g_local, device=device) * W2_scale
    b2 = torch.zeros((vocab_size,), device=device)
    return MLPParams(C=C, W1=W1, b1=b1, W2=W2, b2=b2)

p_fix_out = init_fixed_output_params(vocab_size=vocab_size, device=device, W2_scale=0.01)
logits1, loss1 = forward_mlp(Xb, Yb, p_fix_out, block_size=block_size)
print("Expected loss:", expected_loss)
print("New first-batch loss:", loss1.item())
print("logits1[0] min/max:", logits1[0].min().item(), logits1[0].max().item())

## Exercise 7 — Diagnose tanh saturation (hidden activations squashed to ±1)

### Motivation / concept
`tanh` outputs values in `[-1, 1]`.  
If *pre-activations* are too large in magnitude, most `tanh` outputs become very close to `-1` or `+1`.

Why is that bad?

Backprop through `tanh` multiplies upstream gradient by:

\[
\frac{d}{dx}\tanh(x) = 1 - \tanh(x)^2
\]

If `tanh(x) ≈ ±1`, then `1 - tanh(x)^2 ≈ 0`, and gradients **vanish** through that neuron.

A nice diagnostic is:
- histogram of `hpre` and `h`,
- percent of activations with `|h| > 0.99` (near saturation).

### Exercise
Compute:
- histogram of `hpre` and `h`,
- saturation percentage,
for the first minibatch.

### Solution
Run the cell.


In [None]:
# Solution: tanh saturation diagnostics

def diagnose_tanh_saturation(Xb, p: MLPParams):
    with torch.no_grad():
        emb = p.C[Xb]
        embcat = emb.view(emb.shape[0], -1)
        hpre = embcat @ p.W1 + (p.b1 if p.b1 is not None else 0.0)
        h = torch.tanh(hpre)

        sat = (h.abs() > 0.99).float().mean().item() * 100

        print("hpre mean/std:", hpre.mean().item(), hpre.std().item())
        print("h   mean/std:", h.mean().item(), h.std().item())
        print(f"% saturated (|h|>0.99): {sat:.2f}%")

        plt.figure(figsize=(12,4))
        plt.subplot(1,2,1)
        plt.hist(hpre.cpu().view(-1).tolist(), bins=50)
        plt.title("hpre (pre-activation) histogram")
        plt.subplot(1,2,2)
        plt.hist(h.cpu().view(-1).tolist(), bins=50)
        plt.title("tanh(hpre) histogram")
        plt.show()

# Use the 'fixed output' params, so logits are sane, then diagnose hidden activations
diagnose_tanh_saturation(Xb, p_fix_out)

## Exercise 8 — Fix #2: scale the first layer to prevent tanh saturation

### Motivation / concept
If `hpre = embcat @ W1 + b1` has too-large variance, `tanh(hpre)` saturates.

We can fix this by scaling down:
- `W1` (and typically `b1` too)

In the original notebook/video, you saw “magic numbers” like `0.2`.  
Next we’ll replace that with principled scaling, but first let’s see the direct effect.

### Exercise
Try W1 scales: `1.0`, `0.5`, `0.2`, `0.1` and see:
- pre-activation std
- saturation percentage

### Solution
Run the cell.


In [None]:
# Solution: scan W1 scaling factor and see saturation

def clone_with_scaled_W1(p: MLPParams, scale):
    return MLPParams(C=p.C, W1=p.W1 * scale, b1=p.b1, W2=p.W2, b2=p.b2)

for s in [1.0, 0.5, 0.2, 0.1]:
    p_tmp = clone_with_scaled_W1(p_fix_out, s)
    with torch.no_grad():
        emb = p_tmp.C[Xb]
        embcat = emb.view(emb.shape[0], -1)
        hpre = embcat @ p_tmp.W1 + p_tmp.b1
        h = torch.tanh(hpre)
        sat = (h.abs() > 0.99).float().mean().item() * 100
        print(f"W1 scale={s:>4}: hpre std={hpre.std().item():.3f}, tanh sat%={sat:6.2f}")

## Exercise 9 — A principled way: fan-in scaling (variance preservation)

### Motivation / concept
If `x ~ N(0,1)` and `W ~ N(0,1)`, then `y = x @ W` tends to have **larger variance** as the input dimension grows.

A classic approach is to scale weights by `1/sqrt(fan_in)` to keep activations roughly unit variance.

We’ll reproduce the idea:
- draw random inputs `x` with std ≈ 1
- draw weights `W`
- compare std of `y = xW` with/without scaling

### Exercise
Implement a small experiment:
- `x` shape `(1000, fan_in)`
- `W` shape `(fan_in, 200)`
- compare output std for scaling 1.0 vs `1/sqrt(fan_in)`.

### Solution
Run the cell.


In [None]:
# Solution: fan-in scaling experiment

fan_in = 10
n = 1000
m = 200
g_local = torch.Generator().manual_seed(GLOBAL_SEED + 999)

x = torch.randn((n, fan_in), generator=g_local, device=device)
W = torch.randn((fan_in, m), generator=g_local, device=device)

y1 = x @ W
y2 = x @ (W / math.sqrt(fan_in))

print("x std:", x.std().item())
print("W std:", W.std().item())
print("y std (no scaling):      ", y1.std().item())
print("y std (1/sqrt(fan_in)):  ", y2.std().item())

## Exercise 10 — “Kaiming-ish” init for tanh: add a gain factor

### Motivation / concept
`tanh` is **contractive**: it tends to shrink variance (squash tails).
So for tanh networks, people often use a multiplicative **gain**.

In the transcript (and original notebook), a commonly used gain for tanh is:

\[
\text{gain} = 5/3
\]

A simple recipe:
\[
W \sim \mathcal{N}(0, \sigma^2) \quad\text{with}\quad \sigma = \frac{\text{gain}}{\sqrt{fan\_in}}
\]

For our first layer:
- input dimension is `fan_in = n_embd * block_size`

### Exercise
Initialize `W1` with:
- std = `(5/3)/sqrt(n_embd*block_size)`
and compare saturation vs an unscaled W1.

### Solution
Run the next cell.


In [None]:
# Solution: tanh-friendly init using std = gain/sqrt(fan_in)

def init_kaiming_tanh_params(vocab_size, n_embd=10, n_hidden=200, block_size=3, *,
                             seed=GLOBAL_SEED, W2_scale=0.01, gain=5/3, device=device):
    fan_in = n_embd * block_size
    g_local = torch.Generator(device="cpu").manual_seed(seed)

    C  = torch.randn((vocab_size, n_embd), generator=g_local, device=device)
    W1 = torch.randn((fan_in, n_hidden), generator=g_local, device=device) * (gain / math.sqrt(fan_in))
    b1 = torch.randn((n_hidden,), generator=g_local, device=device) * 0.01  # tiny bias (optional)
    W2 = torch.randn((n_hidden, vocab_size), generator=g_local, device=device) * W2_scale
    b2 = torch.zeros((vocab_size,), device=device)
    return MLPParams(C=C, W1=W1, b1=b1, W2=W2, b2=b2)

p_kaiming = init_kaiming_tanh_params(vocab_size=vocab_size, device=device)
print("First-batch loss (kaiming-ish tanh init):", forward_mlp(Xb, Yb, p_kaiming, block_size=block_size)[1].item())
diagnose_tanh_saturation(Xb, p_kaiming)

## Exercise 11 — Implement BatchNorm on hidden pre-activations (single-layer MLP)

### Motivation / concept
BatchNorm (BN) normalizes activations using the current minibatch statistics:

For each hidden neuron dimension (over batch dimension):
- compute mean and variance
- normalize: \(\hat{x}=(x-\mu) / \sqrt{\sigma^2 + \epsilon}\)
- then apply learnable scale/shift: \(y = \gamma \hat{x} + \beta\)

Important: During **training**, BN uses batch statistics and updates running estimates.  
During **inference**, BN uses running estimates (or calibrated estimates).

### Exercise
Add BN to the forward pass:
- between `hpre` and `tanh`
- with parameters: `bngain (gamma)`, `bnbias (beta)`
- keep running mean/std estimates with EMA (exponential moving average)

We will start with a **manual BN** (like the original notebook).

### Solution
Run the next cell.


In [None]:
# Solution: single-hidden-layer model with manual BatchNorm

@dataclass
class BNState:
    gain: torch.Tensor        # (1, n_hidden)
    bias: torch.Tensor        # (1, n_hidden)
    mean_running: torch.Tensor  # (1, n_hidden)
    std_running: torch.Tensor   # (1, n_hidden)
    momentum: float = 0.001
    eps: float = 1e-5

def init_bn_state(n_hidden, device=device, momentum=0.001, eps=1e-5):
    return BNState(
        gain=torch.ones((1, n_hidden), device=device),
        bias=torch.zeros((1, n_hidden), device=device),
        mean_running=torch.zeros((1, n_hidden), device=device),
        std_running=torch.ones((1, n_hidden), device=device),
        momentum=momentum,
        eps=eps
    )

def forward_mlp_with_bn(Xb, Yb, p: MLPParams, bn: BNState, *, training=True):
    emb = p.C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    hpre = embcat @ p.W1  # NOTE: we omit b1 here because BN would subtract it out anyway (we'll discuss this)
    # BN
    if training:
        bnmean = hpre.mean(0, keepdim=True)
        bnstd = hpre.std(0, keepdim=True)
        hpre_bn = bn.gain * (hpre - bnmean) / (bnstd + bn.eps) + bn.bias
        with torch.no_grad():
            bn.mean_running = (1 - bn.momentum) * bn.mean_running + bn.momentum * bnmean
            bn.std_running  = (1 - bn.momentum) * bn.std_running  + bn.momentum * bnstd
    else:
        hpre_bn = bn.gain * (hpre - bn.mean_running) / (bn.std_running + bn.eps) + bn.bias

    h = torch.tanh(hpre_bn)
    logits = h @ p.W2 + p.b2
    loss = F.cross_entropy(logits, Yb)
    return logits, loss

# quick check: forward pass works
bn_state = init_bn_state(n_hidden=p_kaiming.W1.shape[1], device=device)
logits_bn, loss_bn = forward_mlp_with_bn(Xb, Yb, p_kaiming, bn_state, training=True)
loss_bn.item(), bn_state.mean_running.mean().item(), bn_state.std_running.mean().item()

## Exercise 12 — Training loop (single-layer MLP + BatchNorm)

### Motivation / concept
We’ll now train:
- embedding table `C`
- first layer `W1`
- output layer `W2`, `b2`
- BN parameters `gain`, `bias`
Running stats are **not** trained by gradient descent.

We’ll use:
- minibatch training
- SGD with learning rate decay (like the original notebook)

### Exercise
Train for a moderate number of steps and plot the loss curve.

### Solution
Run the cell.


In [None]:
# Solution: training loop for single-layer MLP + BN

def sgd_train_single_layer_bn(
    Xtr, Ytr, Xdev, Ydev, vocab_size,
    *, n_embd=10, n_hidden=200, block_size=3,
    steps=20_000, batch_size=32,
    lr1=0.1, lr2=0.01, lr_decay_step=10_000,
    seed=GLOBAL_SEED, device=device
):
    # init model params (tanh-friendly + sane output layer)
    p = init_kaiming_tanh_params(vocab_size=vocab_size, n_embd=n_embd, n_hidden=n_hidden,
                                 block_size=block_size, seed=seed, device=device, W2_scale=0.01)
    bn = init_bn_state(n_hidden=n_hidden, device=device, momentum=0.001)

    # learnable parameters list
    params = [p.C, p.W1, p.W2, p.b2, bn.gain, bn.bias]
    for t in params:
        t.requires_grad = True

    lossi = []
    for i in range(steps):
        ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
        Xb, Yb = Xtr[ix].to(device), Ytr[ix].to(device)

        logits, loss = forward_mlp_with_bn(Xb, Yb, p, bn, training=True)

        for t in params:
            t.grad = None
        loss.backward()

        lr = lr1 if i < lr_decay_step else lr2
        with torch.no_grad():
            for t in params:
                t -= lr * t.grad

        lossi.append(loss.item())

        if i % (steps // 5) == 0:
            print(f"step {i:>6}/{steps}: loss={loss.item():.4f}")

    # eval helper (BN in inference mode)
    @torch.no_grad()
    def split_loss(X, Y):
        logits, loss = forward_mlp_with_bn(X.to(device), Y.to(device), p, bn, training=False)
        return loss.item()

    train_loss = split_loss(Xtr, Ytr)
    dev_loss = split_loss(Xdev, Ydev)

    return p, bn, lossi, train_loss, dev_loss

p_trained, bn_trained, lossi, trL, devL = sgd_train_single_layer_bn(
    Xtr, Ytr, Xdev, Ydev, vocab_size,
    steps=10_000,  # increase if you want (e.g. 200_000 like the original notebook)
)

plt.plot(lossi)
plt.title("Training loss (single-layer MLP + BN)")
plt.xlabel("step")
plt.ylabel("loss")
plt.show()

print("Final train loss:", trL)
print("Final dev loss:  ", devL)

## Exercise 13 — Why we often remove the bias before BatchNorm

### Motivation / concept
If you compute:
- `hpre = x @ W + b`,
then BatchNorm immediately subtracts the batch mean:
- `hpre - mean(hpre)`.

The constant bias `b` shifts all examples equally, so it becomes part of the mean and gets subtracted away.  
Result: that bias is **redundant** (often learns nothing).

This is why you see `bias=False` in linear/conv layers right before BatchNorm.

### Exercise
Confirm empirically: try including a bias `b1` pre-BN and check whether its gradient becomes (near) zero.

### Solution
Run the cell.


In [None]:
# Solution: show that bias before BatchNorm is (usually) redundant

# Make a tiny run and record gradients
p_test = init_kaiming_tanh_params(vocab_size=vocab_size, device=device)
# add a bias explicitly
p_test = MLPParams(C=p_test.C, W1=p_test.W1, b1=torch.zeros(p_test.W1.shape[1], device=device, requires_grad=True),
                   W2=p_test.W2, b2=p_test.b2)
bn_test = init_bn_state(n_hidden=p_test.W1.shape[1], device=device)

# make W1, W2 require grad too
for t in [p_test.C, p_test.W1, p_test.W2, p_test.b2, bn_test.gain, bn_test.bias]:
    t.requires_grad = True

# forward with "bias before BN"
def forward_with_bias_before_bn(Xb, Yb, p, bn):
    emb = p.C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    hpre = embcat @ p.W1 + p.b1
    bnmean = hpre.mean(0, keepdim=True)
    bnstd = hpre.std(0, keepdim=True)
    hpre_bn = bn.gain * (hpre - bnmean) / (bnstd + bn.eps) + bn.bias
    h = torch.tanh(hpre_bn)
    logits = h @ p.W2 + p.b2
    return logits, F.cross_entropy(logits, Yb)

logits, loss = forward_with_bias_before_bn(Xb, Yb, p_test, bn_test)
for t in [p_test.C, p_test.W1, p_test.W2, p_test.b2, bn_test.gain, bn_test.bias, p_test.b1]:
    if t.grad is not None:
        t.grad = None
loss.backward()
print("||grad b1|| (mean abs):", p_test.b1.grad.abs().mean().item())
print("Explanation: BN subtracts the batch mean, so constant shifts typically cancel out.")

## Exercise 14 — “Calibrate” BatchNorm vs running stats (inference behavior)

### Motivation / concept
BatchNorm needs a strategy at inference time:
- You can compute mean/std over the entire training set (“calibration”),
- or use the running mean/std accumulated during training.

Ideally, running stats approximate calibrated stats.

### Exercise
Compute calibrated `bnmean`, `bnstd` over the full train set and compare to `bn_trained.mean_running/std_running`.

### Solution
Run the cell.


In [None]:
# Solution: compare calibrated vs running stats

@torch.no_grad()
def calibrate_bn_stats_full_train(Xtr, p, bn_like, block_size=3):
    emb = p.C[Xtr.to(device)]
    embcat = emb.view(emb.shape[0], -1)
    hpre = embcat @ p.W1
    mean = hpre.mean(0, keepdim=True)
    std = hpre.std(0, keepdim=True)
    return mean, std

bnmean_cal, bnstd_cal = calibrate_bn_stats_full_train(Xtr, p_trained, bn_trained)

# compare summary statistics
print("Running mean (avg):     ", bn_trained.mean_running.mean().item())
print("Calibrated mean (avg):  ", bnmean_cal.mean().item())
print("Running std (avg):      ", bn_trained.std_running.mean().item())
print("Calibrated std (avg):   ", bnstd_cal.mean().item())

# compare typical elementwise error magnitudes
print("Mean abs diff (mean):   ", (bn_trained.mean_running - bnmean_cal).abs().mean().item())
print("Std abs diff (mean):    ", (bn_trained.std_running - bnstd_cal).abs().mean().item())

# Part B — “PyTorchify”: build layers (Linear, BatchNorm1d, Tanh) and a deeper network

We now switch to the **bonus section** of the transcript + companion notebook:
- implement small “modules” like `nn.Module`,
- stack them as Lego blocks,
- inspect activation and gradient statistics through depth,
- inspect update-to-data ratios.

This is a great set of tools to debug training stability.

---

## Exercise 15 — Implement tiny `Linear`, `BatchNorm1d`, and `Tanh` modules

### Motivation / concept
We want a tiny version of the PyTorch module API:
- `__call__` does forward pass and stores `self.out`
- `parameters()` returns trainable tensors

BatchNorm needs:
- `training` flag
- running mean/var buffers updated with exponential moving average
- (gamma, beta) as trainable parameters

### Exercise
Implement the three tiny module classes and verify:
- forward pass runs,
- `parameters()` returns correct tensors.

### Solution
Run the cell.


In [None]:
# Solution: tiny modules

class Linear:
    def __init__(self, fan_in, fan_out, *, bias=True, generator=None, device=device):
        if generator is None:
            generator = torch.Generator().manual_seed(GLOBAL_SEED)
        # default: variance-preserving init (std ~ 1/sqrt(fan_in))
        self.weight = torch.randn((fan_in, fan_out), generator=generator, device=device) / math.sqrt(fan_in)
        self.bias = torch.zeros((fan_out,), device=device) if bias else None

    def __call__(self, x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out = self.out + self.bias
        return self.out

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1, *, device=device):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # parameters
        self.gamma = torch.ones((dim,), device=device)
        self.beta = torch.zeros((dim,), device=device)
        # buffers
        self.running_mean = torch.zeros((dim,), device=device)
        self.running_var = torch.ones((dim,), device=device)

    def __call__(self, x):
        if self.training:
            xmean = x.mean(0, keepdim=True)
            xvar = x.var(0, keepdim=True, unbiased=False)
        else:
            xmean = self.running_mean.view(1, -1)
            xvar = self.running_var.view(1, -1)

        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        self.out = self.gamma * xhat + self.beta

        if self.training:
            with torch.no_grad():
                # update buffers (EMA)
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean.view(-1)
                self.running_var  = (1 - self.momentum) * self.running_var  + self.momentum * xvar.view(-1)

        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []

# quick test
g_local = torch.Generator().manual_seed(GLOBAL_SEED + 2025)
lin = Linear(5, 3, bias=False, generator=g_local)
bn = BatchNorm1d(3, momentum=0.1)
nonlin = Tanh()

x = torch.randn((4,5), generator=g_local, device=device)
y = nonlin(bn(lin(x)))
print("y shape:", y.shape)
print("num params:", sum(p.numel() for p in (lin.parameters()+bn.parameters()+nonlin.parameters())))

## Exercise 16 — Build a deep MLP and add diagnostic plots

### Motivation / concept
Deep MLPs can suffer from:
- vanishing/exploding activations
- vanishing/exploding gradients
- uneven update magnitudes across layers

We will track:
1) **Activation distributions** at tanh layers:
   - mean, std, and % saturation
2) **Gradient distributions** at tanh layers:
   - mean, std
3) Weight gradient distributions
4) **Update-to-data ratio** (log10), ideally around ~ `-3` as a rough heuristic

### Exercise
Build a deep model:
- Embedding table `C`
- 5× blocks of (Linear → BatchNorm → Tanh)
- Final (Linear → BatchNorm) to vocab logits

Run ~1000 steps and produce the plots.

### Solution
Run the cell.


In [None]:
# Solution: deep net + diagnostics (short run)

# hyperparams (feel free to tweak)
n_embd = 10
n_hidden = 100
gain_tanh = 1.0  # try 1.0, 5/3, 2.0, etc. BatchNorm makes this less critical.

# re-init generator for determinism
g_deep = torch.Generator().manual_seed(GLOBAL_SEED)

C = torch.randn((vocab_size, n_embd), generator=g_deep, device=device)

layers = [
    Linear(n_embd * block_size, n_hidden, bias=False, generator=g_deep), BatchNorm1d(n_hidden, momentum=0.1), Tanh(),
    Linear(n_hidden, n_hidden, bias=False, generator=g_deep),           BatchNorm1d(n_hidden, momentum=0.1), Tanh(),
    Linear(n_hidden, n_hidden, bias=False, generator=g_deep),           BatchNorm1d(n_hidden, momentum=0.1), Tanh(),
    Linear(n_hidden, n_hidden, bias=False, generator=g_deep),           BatchNorm1d(n_hidden, momentum=0.1), Tanh(),
    Linear(n_hidden, n_hidden, bias=False, generator=g_deep),           BatchNorm1d(n_hidden, momentum=0.1), Tanh(),
    Linear(n_hidden, vocab_size, bias=False, generator=g_deep),         BatchNorm1d(vocab_size, momentum=0.1)
]

# optional: scale weights by a "gain"
with torch.no_grad():
    for layer in layers:
        if isinstance(layer, Linear):
            layer.weight *= gain_tanh
    # make last layer less confident at init: shrink gamma (BN scale) for final BN
    layers[-1].gamma *= 0.1

# collect parameters
parameters = [C] + [p for layer in layers for p in layer.parameters()]
for p in parameters:
    p.requires_grad = True

def forward_deep(Xb):
    emb = C[Xb]                   # (B, block_size, n_embd)
    x = emb.view(emb.shape[0], -1) # (B, block_size*n_embd)
    for layer in layers:
        x = layer(x)
    return x  # logits

# training (short debug run)
max_steps = 1000
batch_size = 32
ud = []      # update/data ratio traces
lossi = []

for i in range(max_steps):
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g, device=device)
    Xb, Yb = Xtr[ix].to(device), Ytr[ix].to(device)

    logits = forward_deep(Xb)
    loss = F.cross_entropy(logits, Yb)

    # retain grads on layer outputs for plotting
    for layer in layers:
        if hasattr(layer, "out"):
            layer.out.retain_grad()

    for p in parameters:
        p.grad = None
    loss.backward()

    lr = 0.1
    with torch.no_grad():
        # update-to-data ratio logs
        ud.append([((lr * p.grad).std() / (p.data.std() + 1e-12)).log10().item() for p in parameters])
        for p in parameters:
            p.data -= lr * p.grad

    lossi.append(loss.item())
    if i % 200 == 0:
        print(f"step {i:>4}/{max_steps}: loss={loss.item():.4f}")

# ----- Diagnostics plots -----

# 1) Activation distributions at tanh layers
plt.figure(figsize=(16, 3))
legends = []
for i, layer in enumerate(layers[:-1]):  # ignore final BN
    if isinstance(layer, Tanh):
        t = layer.out.detach()
        sat = (t.abs() > 0.97).float().mean().item() * 100
        print(f"layer {i:2d} (Tanh): mean {t.mean():+.3f}, std {t.std():.3f}, saturated {sat:.2f}%")
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].cpu(), hy.cpu())
        legends.append(f"layer {i}")
plt.legend(legends)
plt.title("Activation distributions at tanh layers")
plt.show()

# 2) Gradient distributions at tanh layers
plt.figure(figsize=(16, 3))
legends = []
for i, layer in enumerate(layers[:-1]):
    if isinstance(layer, Tanh):
        t = layer.out.grad.detach()
        print(f"layer {i:2d} (Tanh grad): mean {t.mean():+.3e}, std {t.std():.3e}")
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].cpu(), hy.cpu())
        legends.append(f"layer {i}")
plt.legend(legends)
plt.title("Gradient distributions at tanh layers")
plt.show()

# 3) Update-to-data ratio traces (weights only are most interesting, but we'll plot all 2D params)
plt.figure(figsize=(16, 3))
legends = []
for pi, p in enumerate(parameters):
    if p.ndim == 2:
        plt.plot([ud_step[pi] for ud_step in ud])
        legends.append(f"param {pi} {tuple(p.shape)}")
plt.plot([0, len(ud)], [-3, -3], "k")  # heuristic line
plt.legend(legends, fontsize=8)
plt.title("log10(update std / weight std) over time (2D params only)")
plt.show()

# loss curve
plt.plot(lossi)
plt.title("Debug training loss (1000 steps)")
plt.xlabel("step")
plt.ylabel("loss")
plt.show()

## Exercise 17 — Sampling: generate names from the trained model

### Motivation / concept
Sampling does:
1) start with context `...` (all dots)
2) repeatedly:
   - forward pass → logits
   - softmax to probabilities
   - sample next character index
   - shift context window
   - stop when `'.'` is sampled

### Exercise
Implement `sample_name(num_samples=20)` for the deep model above.
Make sure to set BatchNorm layers to inference mode (`training=False`).

### Solution
Run the cell.


In [None]:
# Solution: sampling from deep net

@torch.no_grad()
def set_eval_mode(layers):
    for layer in layers:
        if hasattr(layer, "training"):
            layer.training = False

@torch.no_grad()
def sample_names(num_samples=20, *, block_size=3, generator_seed=GLOBAL_SEED+10):
    set_eval_mode(layers)
    g_samp = torch.Generator(device="cpu").manual_seed(generator_seed)
    for _ in range(num_samples):
        out = []
        context = [0] * block_size
        while True:
            X = torch.tensor([context], dtype=torch.long, device=device)
            logits = forward_deep(X)
            probs = F.softmax(logits, dim=1)
            ix = torch.multinomial(probs, num_samples=1, generator=g_samp).item()
            context = context[1:] + [ix]
            out.append(ix)
            if ix == 0:
                break
        print("".join(itos[i] for i in out))

sample_names(20)

# Wrap-up

If you worked through the exercises, you should now understand:

- Why **expected init loss** is about `log(vocab_size)` and why huge loss at step 0 signals overconfident logits
- How to diagnose & fix:
  - **logit scale** (output layer init)
  - **tanh saturation** (hidden layer init)
- The intuition behind **fan-in scaling** and **gain**
- BatchNorm mechanics:
  - per-batch mean/variance
  - learnable `gamma`/`beta`
  - running mean/var and inference behavior
  - why bias before BN is typically redundant
- How to “PyTorchify” a network into modules
- How to use diagnostic plots for:
  - activations, gradients, weight gradients
  - update-to-data ratios over time

Next conceptual step (as the transcript hints): RNNs/GRUs/LSTMs become *very deep* when unrolled in time, so all of these stability ideas matter even more.

---