# Part 4: Becoming a Backpropagation Ninja

Here we take the 2-layer MLP (with `BatchNorm`) from the previous part and backpropagate through it manually without using PyTorch autograd's `loss.backward()` through the cross entropy loss, 2nd linear layer, tanh, batchnorm, 1st linear layer, and the embedding table.

## 1. Starter Codes from Previous Notebooks

In [1]:
import torch
import matplotlib.pyplot as plt 
import torch.nn.functional as F
%matplotlib inline

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build the dataset
import random
block_size = 3  # context length: how many characters do we take to predict the next one?


def build_dataset(words):
  X, Y = [], []

  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix]  # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y


random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


## 2. Setting Up for Manual Backpropagation

### Utility Function for Gradient Comparison
This helper function `cmp` will be our sanity check. It compares a manually calculated gradient (dt) with the gradient calculated automatically by PyTorch (t.grad).

It checks for:
- `exact`: Whether the tensors are bit-for-bit identical.
- `approximate`: Whether they are very close in value (useful for floating-point comparisons).
- `maxdiff`: The maximum absolute difference between any two corresponding elements in the tensors.

We'll use this function after each step of our manual backpropagation to verify that our calculations are correct.

In [5]:
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()# exact
  app = torch.allclose(dt, t.grad)# approximate
  maxdiff = (dt - t.grad).abs().max().item()# max difference
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

### Initializing Parameters
We initialize all the parameters for our 2-layer MLP including the weights and biases for each layer and the gain/bias for Batch Normalization.

In [6]:
n_embd = 10  # the dimensionality of the character embedding vectors
n_hidden = 64  # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)  # for reproducibility
C = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * \
    (5/3)/((n_embd * block_size)**0.5)
# using b1 just for fun, it's useless because of BN
b1 = torch.randn(n_hidden,                        generator=g) * 0.1
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))  # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


### Constructing a Minibatch
We create a single minibatch of 32 examples (`Xb`, `Yb`) from our training set. We will perform one full forward and backward pass on this single batch to analyze the gradient calculations at each step.

In [7]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]  # batch X,Y

### The Updated Forward Pass and PyTorch Backward Pass

Here we set up exercise performs a complete forward pass but instead of writing it in a few compact lines where every single intermediate calculation is stored in its own variable.
- **Why break it down?** By storing every result (e.g. `hprebn`, `bndiff`, `bnraw`, `logits`, `probs`) we can analyze the gradient of the loss with respect to each of these intermediate variables. This allows us to manually backpropagate the gradients one step at a time.

- **Numerical Stability:**

    - `norm_logits = logits - logit_maxes`: Subtracting the maximum value from the logits before exponentiating is a standard trick to prevent numerical instability. `exp()` of large positive numbers can result in infinity (`inf`) but this trick keeps the inputs to `exp()` at or below zero.

    - `counts_sum_inv = counts_sum**-1`: Using `x**-1` instead of `1.0 / x` can sometimes lead to more numerically precise gradients in PyTorch which is important when we want to check for exact equality.
    
- `t.retain_grad()`: This command tells PyTorch to save the gradient for intermediate non-leaf variables (like `h`, `logits`, `probs`). Normally PyTorch only saves gradients for the leaf nodes (our parameters). We need these intermediate gradients to check our work.

- `loss.backward()`: This is PyTorch's automatic differentiation engine in action. It calculates the gradients for all variables in one go. We run it here to get the "correct answers" that we will compare our manual calculations against.

In [8]:
emb = C[Xb]  # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1)  # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1  # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)# batch mean
bndiff = hprebn - bnmeani# batch difference from mean
bndiff2 = bndiff**2# batch squared difference from mean
# note: Bessel's correction (dividing by n-1, not n)
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-0.5# inverse standard deviation
bnraw = bndiff * bnvar_inv# normalized batch
hpreact = bngain * bnraw + bnbias# affine transform
# Non-linearity
h = torch.tanh(hpreact)  # hidden layer
# Linear layer 2
logits = h @ W2 + b2  # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes  # subtract max for numerical stability
counts = norm_logits.exp()# unnormalized probabilities
counts_sum = counts.sum(1, keepdims=True)# normalization constant
counts_sum_inv = counts_sum**-1# inverse of normalization constant
probs = counts * counts_sum_inv# probabilities for each class
logprobs = probs.log()# log-probabilities for each class
loss = -logprobs[range(n), Yb].mean()# average negative log-likelihood loss

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,  # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
          embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3461, grad_fn=<NegBackward0>)