# Building `makemore` Part 4: Becoming a Backprop Ninja


Lecture: [YouTube](https://youtu.be/q8SA3rM6ckI)

We are going to manually reimplement what `loss.backward()` does. In this way, we can better understand how gradients flow in the backward pass and get intuition that will prevent us from committing silly mistakes when building a network.

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

%matplotlib inline
plt.style.use("seaborn-v0_8-whitegrid")

In [2]:
words = open("./names.txt", "r").read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
len(words)

32033

In [4]:
# build the vacabulary of characters and mapping to/from integers
chars = sorted(list(set("".join(words))))
s2i = {s: i + 1 for i, s in enumerate(chars)}
s2i["."] = 0
i2s = {i: s for s, i in s2i.items()}
vocab_size = len(i2s)
print(i2s)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [5]:
block_size = (
    3  # context length: how many characters do we take to predict the next one?
)


def build_dataset(words):
    x, y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + ".":
            ix = s2i[ch]
            x.append(context)
            y.append(ix)
            context = context[1:] + [ix]  # crop and append

    x = torch.tensor(x)
    y = torch.tensor(y)
    print(x.shape, y.shape)
    return x, y


import random

random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

x_trn, y_trn = build_dataset(words[:n1])  # 80%
x_val, y_val = build_dataset(words[n1:n2])  # 10%
x_tst, y_tst = build_dataset(words[n2:])  # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [6]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(
        f"{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}"
    )

In [7]:
n_embd = 10  # the dimensionality of the character embedding vectors
n_hidden = 64  # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)  # for reproducibility
C = torch.randn((vocab_size, n_embd), generator=g)
# layer 1
W1 = (
    torch.randn((n_embd * block_size, n_hidden), generator=g)
    * (5 / 3)
    / ((n_embd * block_size) ** 0.5)
)
b1 = (
    torch.randn(n_hidden, generator=g) * 0.1
)  # using b1 just for fun, it's useless because of BN
# layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# batchnorm parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

# note: i am initializating many of these parameters (e.g., biases) in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))  # number of parameters in total
for p in parameters:
    p.requires_grad = True

4137


In [8]:
batch_size = 32
n = batch_size  # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, x_trn.shape[0], (batch_size,), generator=g)
xb, yb = x_trn[ix], y_trn[ix]  # batch X,Y

Below is the forward pass, "chunkated" into smaller steps that are possible to backward one at a time.

In [9]:
emb = C[xb]  # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1)  # concatenate the vectors
# linear layer 1
hprebn = embcat @ W1 + b1  # hidden layer pre-activation
# batchnorm layer
bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = (
    1 / (n - 1) * (bndiff2).sum(0, keepdim=True)
)  # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# non-linearity
h = torch.tanh(hpreact)  # hidden layer
# linear layer 2
logits = h @ W2 + b2  # output layer
# cross entropy loss (same as F.cross_entropy(logits, yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes  # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = (
    counts_sum**-1
)  # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), yb].mean()

# PyTorch backward pass
for p in parameters:
    p.grad = None

for t in [
    logprobs,
    probs,
    counts,
    counts_sum,
    counts_sum_inv,  # afaik there is no cleaner way
    norm_logits,
    logit_maxes,
    logits,
    h,
    hpreact,
    bnraw,
    bnvar_inv,
    bnvar,
    bndiff2,
    bndiff,
    hprebn,
    bnmeani,
    embcat,
    emb,
]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3092, grad_fn=<NegBackward0>)

### Exercise 1

Backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one.

A few notes before we start:
* As a naming convention, we are going to name each variable storing the partial derivative of the loss w.r.t. each parameter group, `d` + `<parameter group>`. For instance, the partial derivative of the loss w.r.t. `logprobs`, $ \frac{\partial J(w, b)}{\partial \text{logprobs}} $, will be named `dlogprobs`.
* We are also omitting $ \frac{\partial J(w, b)}{\partial J(w, b} $ as that is equal to 1.

In [10]:
# create a matrix of all zero of the same shape as the `logprobs` array, to store the 
# gradients `dlogprobs` 
dlogprobs = torch.zeros_like(logprobs)
# update the gradient only of the elements corresponding to the ground truth predictions 
# with the partial derivative of the loss w.r.t. logprobs
dlogprobs[range(n), yb] = -1.0 / n
cmp("logprobs", dlogprobs, logprobs)

dprobs = (1.0 / probs) * dlogprobs # multiply by dlogprobs as we are using the chain rule
cmp("probs", dprobs, probs)

# in this case, we need to keep in mind that `counts`, `dprobs`, and `count_sum_inv` have
# different shapes:
# >>> counts.shape, dprobs.shape, counts_sum_inv.shape
# (torch.Size([32, 27]), torch.Size([32, 27])), torch.Size([32, 1])) 
# we need `dcounts_sum_inv` to be of shape (32, 1), thus accumulating the gradients of 
# each row
# d11 d12 d13     b1 (= d11 + d12 + d13)
# d21 d22 d23 --> b2 (= d21 + d22 + d23)
# d31 d32 d33     b3 (= d31 + d32 + d33)
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True)
cmp("counts_sum_inv", dcounts_sum_inv, counts_sum_inv)

# `dcounts` is a trickier example, as `counts` is used in two places in the computational
# graph:
# 1. probs = counts * counts_sum_inv, AND...
# 2. counts_sum = counts.sum(1, keepdims=True)
# we are going to work initially on the first contribution, and later add the second part 
# to it to compute the total gradient
dcounts = counts_sum_inv * dprobs

dcounts_sum = (-counts_sum ** -2) * dcounts_sum_inv
cmp("counts_sum", dcounts_sum, counts_sum)

# let's now resume working on `dcounts` by processing the second contribution:
# `counts_sum = counts.sum(1, keepdims=True)`
# also in this case, we need to keep an eye on the dimension of the variables:
# >>> counts.shape, dcounts_sum.shape
# (torch.Size([32, 27]), torch.Size([32, 1]))
# what we want to accomplish is to add to all elements in `dcounts` the `dcounts_sum`
# contribution. since `dcounts_sum`
# a11 a12 a13   b1 --> a11 a12 a13   b1 b1 b1    
# a21 a22 a23 + b2 --> a21 a22 a23 + b2 b2 b2  
# a31 a32 a33   b3 --> a31 a32 a33   b3 b3 b3    
# this is a simple broadcasting operation, where `dcounts_sum` expands from (32, 1) 
# to (32, 27)
# NOTE: karpathy implements it in a slightly move convoluted way:
# >>> dcounts += torch.ones_like(counts) * dcounts_sum 
dcounts += dcounts_sum # the += operation is because we are summing the contribution of both branches
cmp('counts', dcounts, counts)
# NOTE: the max() operation fans out the gradient to all elements of the activations that were
# included in the operation equally! >>> risk of accidental gradient explosion??

dnorm_logits = counts * dcounts  # counts = norm_logits.exp(); in this way we save some FLOPs :-)
cmp('norm_logits', dnorm_logits, norm_logits)

# `logits` is used in two places in the computational graph:
# >>> logit_maxes = logits.max(1, keepdim=True).values
# >>> norm_logits = logits - logit_maxes  # subtract max for numerical stability
# the shape of `logits` and `logit_maxes` is not the same, as there is an implicit 
# broadcasting operation in the `logits - logit_maxes` operation:
# >>> logits.shape, logit_maxes.shape
# (torch.Size([32, 27]), torch.Size([32, 1]))
# `logit_maxes` is broadcasted from (32, 1) to (32, 27) 
# c11 c12 c13   a11 a12 a13   c11
# c21 c22 c23 = a21 a22 a23 - c21
# c31 c32 c33   a31 a32 a33   c31
dlogits = dnorm_logits.clone()  # 1st contribution

dlogit_maxes = (-dnorm_logits).sum(1, keepdims=True)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
if not torch.allclose(dlogit_maxes, torch.zeros_like(dlogit_maxes)):
    # NOTE: `dlogit_maxes` should be very close to 0, because that offset we use to normalize the logit vector before passing it to torch.exp()
    # should have no impact on the loss! In fact, we could have used any constant to normalize the logit vector without impacting the value of
    # loss. The only reason why we are using max() is to guarantee that the max value of the "normalized" logit vector before passing it to 
    # torch.exp() is zero, to avoid any potential overflow.
    print("Error: some elements of `dlogit_maxes` are not close to zero!")

# here we want the derivative to flow through where those maximum values occurred in the logit vector
# NOTE: karphathy implements it with a different syntax, but the outcome is the same:
# >>> dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dlogits[range(n), logits.max(1).indices] += dlogit_maxes.view(-1, )  # 2nd contribution
cmp('logits', dlogits, logits)

# check notes in reMarkable for the full manual derivation
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

# thanks to google search:
# a = tanh(x)
# d/dx tanh(x) = 1-a**2
dhpreact = (1.0 - h**2) * dh
cmp('hpreact', dhpreact, hpreact)

# forward pass: hpreact = bngain * bnraw + bnbias
# here we have an element-wise multiplication & broadcasting
# >>> bngain.shape, bnraw.shape, bnbias.shape
# (torch.Size([1, 64]), torch.Size([32, 64]), torch.Size([1, 64]))
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim=True)
cmp('bngain', dbngain, bngain)
cmp('bnraw', dbnraw, bnraw)
cmp('bnbias', dbnbias, bnbias)

# forward pass: bnraw = bndiff * bnvar_inv
dbndiff = bnvar_inv * dbnraw  # 1st contribution
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

# forward pass: bnvar_inv = (bnvar + 1e-5) ** -0.5
dbnvar = (-0.5 * (bnvar + 1e-5) ** -1.5) * dbnvar_inv
cmp('bnvar', dbnvar, bnvar)

# forward pass: bnvar = (
#                   1 / (n - 1) * (bndiff2).sum(0, keepdim=True)
#               )  # note: Bessel's correction (dividing by n-1, not n)
# Note: karpathy writes it in an equivalent, but different, way: dbndiff2 = (1.0/(n - 1.0)) *  torch.ones_like(bndiff2) * dbnvar
dbndiff2 = (1.0 / (n - 1.0)) * dbnvar
cmp('bndiff2', dbndiff2, bndiff2)

# forward pass: bndiff2 = bndiff**2
dbndiff += 2 * bndiff * dbndiff2  # 2nd contribution
cmp('bndiff', dbndiff, bndiff)

# forward pass: bndiff = hprebn - bnmeani
dbnmeani = (-dbndiff).sum(0, keepdim=True)
dhprebn = dbndiff.clone()  # 1st contribution
cmp('bnmeani', dbnmeani, bnmeani)

# forward pass: bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
# Note: karpathy writes it in the equivalent expression: dhprebn += (1.0 / n) * torch.ones_like(hprebn) * dbnmeani 
#  We prefer to keep the broadcasting operation implicit.
dhprebn += (1.0 / n) * dbnmeani  # 2nd contribution
cmp('hprebn', dhprebn, hprebn)

# forward pass: hprebn = embcat @ W1 + b1  # hidden layer pre-activation
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)

# forward pass: embcat = emb.view(emb.shape[0], -1)  # concatenate the vectors
demb = dembcat.view(emb.shape)
cmp('emb', demb, emb)

# forward pass: emb = C[xb]  # embed the characters into vectors
# print(emb.shape, C.shape, xb.shape)
# print(xb[:5])
# torch.Size([32, 3, 10]) torch.Size([27, 10]) torch.Size([32, 3])
# tensor([[ 1,  1,  4],
#         [18, 14,  1],
#         [11,  5,  9],
#         [ 0,  0,  1],
#         [12, 15, 14]])
# emb[bs, cntx_len, emb_sz]: 
# C[n_chars, emb_sz]: embedding matrix
# xb[bs, char_idx]:
dC = torch.zeros_like(C)
for k in range(xb.shape[0]): # bs
    for j in range(xb.shape[1]): # char_idx
        ix = xb[k, j]
        dC[ix] += demb[k, j, :]
cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approximate: True  | maxdiff: 2.3283064365386963e-09
bnraw           | exact: False | approximate: True  | maxdiff: 6.984919309616089e-10
bnbias 

Notes:
1. Always check the size of the tensors. The parameter group and the derivative of the loss w.r.t. that parameter group must have the same shape. At times we will need to undo a broadcasting operation or aggregating operation (e.g., sum or average over certain axes). 
1. If a parameter group appears in multiple branches of the computational graph, we must accumulate the gradient of both branches.

**TODO**: Reduce `maxdiff` of `hpreact`. According to some prelimanary reseach in YouTube and GitHub, this seems to be caused by the PyTorch version being used. Is there a bug in PyTorch? Our math matches what Karpathy's.

### Exercise 2

Backprob through `cross_entropy` but all in one go. To complete this challenge look at the mathematical expression of the loss, take the derivative, simplify the expression, and just write it out.

In [11]:
# forward pass

# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes  # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = (
#     counts_sum**-1
# )  # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), yb].mean()

# now:
loss_fast = F.cross_entropy(logits, yb)
print(f"Loss: {loss_fast.item():.8f} (diff: {(loss_fast - loss).item()})")

Loss: 3.30922031 (diff: -2.384185791015625e-07)


In [12]:
# backward  pass
# dlogits = ???
# cmp('logits', dlogits, logits)