Rewriting loss.backward() manually.

In [14]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [15]:
words = open('names.txt', 'r').read().splitlines()

In [16]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [17]:
block_size = 3
def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            idx = stoi[ch]
            X.append(context)
            Y.append(idx)
            context = context[1:] + [idx]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [18]:
# utility function to compare manual gradients to PyTorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item() # true if all elements in both tensors are the same
    app = torch.allclose(dt, t.grad) # true if all elements are approximately close
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [19]:
n_embd = 10 # dimensionality of the character embedding vectors
n_hidden = 64 # number of neurons in the hidden layer

g = torch.Generator().manual_seed(456789096543)
C = torch.randn((vocab_size, n_embd), generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1 # useless but let's keep it
# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0 # all zeros could mask an incorrect implementation of the backward pass
bnbias = torch.randn((1, n_hidden)) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4137


In [20]:
batch_size = 32
# construct a minibatch
idx = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[idx], Ytr[idx]

In [52]:
# forward pass

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre activation
# BatchNorm layer
bnmeani = 1/batch_size*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(batch_size-1)*(bndiff2).sum(0, keepdim=True) # Bessel's correction, dividing by n-1
bnvar_inv = (bnvar + 1e05)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Nonlinearity
h = torch.tanh(hpreact)
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (F.cross_entropy(logits, Yb) functionality)
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability (so that the highest number is 0 and exp doesn't overflow)
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(batch_size), Yb].mean()

# PyTorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv,
             bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3039, grad_fn=<NegBackward0>)

Manual backprop, finding derivatives, etc. in the order given in the forward pass

---
First we find dlogprobs. We see that loss = -logprobs\[range(batch_size), Yb].mean(). Since we're taking the mean and the derivative with respect to variables at a specific index, the derivative is -1/batch_size.

Simpler example: logprobs: a + b + c \
loss = -1/3a - 1/3b - 1/3c \
dloss/da = -1/3, or -1/n

Therefore,\
dlogprobs = -1/n

---
We get logprobs by taking the logarithm of probs (logprobs depends on probs through a log). Therefore, dprobs = dlogprobs * 1/probs.

---
dcounts_sum_inv = counts must be true because when we take the derivative of probs with respect to dcounts_sum_inv, we're left with counts. But we also need to make sure the dimensions match, so we multiply counts * dprobs (by chain rule) and sum horizontally across rows:\
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)

During the forward pass, for each example in the batch, we have probs\[i, j] = counts\[i, j] * counts_sum_inv\[i].\
counts_sum_inv\[i] is a scalar that gets multiplied by all elements in row i of counts.\
During the backward pass, dloss/dcounts_sum_inv\[i] is given by:

$$\frac{\partial \text{loss}}{\partial \text{counts-sum}_i^{-1}} = \sum_j \frac{\partial \text{loss}}{\partial \text{probs}_{i,j}} \cdot \frac{\partial \text{probs}_{i,j}}{\partial \text{counts-sum}_i^{-1}}$$

We know that$\frac{\partial \text{probs}_{i,j}}{\partial \text{counts-sum}_i^{-1}} = \text{counts}_{i, j}$, so:

$$\frac{\partial \text{loss}}{\partial \text{counts-sum}_i^{-1}} = \sum_j \frac{\partial \text{loss}}{\partial \text{probs}_{i,j}} \cdot \text{counts}_{i, j}$$

Basically, the gradient is dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True) because counts_sum_inv\[i] is broadcast and used in computing every row i in probs. We must accumulate gradients, so we just sum.

---
For counts_sum, use the same chain rule and recognize that derivative of counts_sum^{-1} is -counts_sum^{-2}

---
dcounts = counts_sum_inv * dprobs. No additional summation required because counts_sum_inv is 32 x 1 and dprobs is 32 x 27. As a result we'll get a single vector. That's the first gradient.

counts.shape is 32 x 27 and counts_sum.shape is 32 x 1.\
$$\frac{\partial \text{counts-sum}_i}{\text{counts}_{i, j}} = 1$$
because counts_sum is about summing counts along rows. The gradient along rows other than row $i$ would be 0 (no contribution). Therefore, when we evaluate the gradient of the loss:
$$\frac{\partial \text{loss}}{\partial \text{counts}_{i,j}} = \frac{\partial \text{loss}}{\partial \text{counts-sum}_{i}} \cdot \frac{\partial \text{counts-sum}_i}{\partial \text{counts}_{i, j}} = \frac{\partial \text{loss}}{\partial \text{counts-sum}_{i}} \cdot 1$$
dcounts_sum is the partial derivative in the result.\
torch.ones_like(counts) * dcounts_sum broadcasts dcounts_sum\[i] to all elements in row i.

---
Now we take the derivative of the loss with respect to norm_logits. Chain rule + recognizing the derivative of the exponential is unchanged.\
As a result, we get dnorm_logits = e^{norm_logits} * dcounts = counts * dcounts

---
norm_logits depends on logits and logit_maxes. Check the shapes again:\
norm_logits is 32 x 27, logits is 32 x 27, and logit_maxes is 32 x 1.\
norm_logits = logits - logit_maxes\
When we find dlogits, we use the chain rule and the fact that the derivative with respect to logits is 1 to get dlogits = dnorm_logits.
For logit_maxes, it is -1 * dnorm_logits, but we also need to sum along the rows. Lesssgoooo, I'm finally understanding this.

---
Second branch of dlogits:\
logit_maxes = logits.max(1, keepdim=True).values\
The local derivative is 1 (derivative of logit_maxes with respect to logits). We only need to mask one entry per row to be 1 (the index of the maximum logit). One solution is to use F.one_hot(logits.max(1).indices, num_classes=logits.shape\[1]). There are 27 categories -> num_classes. Using the chain rule, we multiply by dlogit_maxes, which is a column vector, so it will broadcast. Whichever of the bits is turned on using F.one_hot() will also be multiplied by the gradient with respect to logit_maxes.

---
We have logits = h @ W2 + b2. Taking the derivative with respect to h, we get W2 * dlogits. However, let's look at the shapes:\
dlogits -> 32 x 27, h -> 32 x 64, W2 -> 64 x 27, b2 -> 27

bias vector will become a row vector in the broadcasting and replicate vertically.\
Consider a simple example d = a @ b + c:
$$\begin{bmatrix} d_{11} & d_{12} \\ d_{21} & d_{22} \end{bmatrix} = \begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix} \begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \end{bmatrix} + \begin{bmatrix} c_{1} & c_{2} \\ c_{1} & c_{2} \end{bmatrix}$$
=>
$$d_{11} = a_{11}b_{11} + a_{12}b_{21} + c_1 \qquad d_{12} = a_{11}b_{12} + a_{12}b_{22} + c_2$$
$$d_{21} = a_{21}b_{11} + a_{22}b_{21} + c_1 \qquad d_{22} = a_{21}b_{12} + a_{22}b_{22} + c_2$$
Taking the derivative of the loss with respect to $a_{11}, a_{12}, a_{21}, a_{22}$, we find that it's a simple matrix multiplication:
$$\frac{\partial L}{\partial a} = \frac{dL}{dd} \times b^T \qquad \frac{\partial L}{\partial b} = a^T \times \frac{dL}{dd} \qquad \frac{\partial L}{\partial c} = \frac{dL}{dd} \cdot \text{sum(0)}$$
sum(0) means sum along the columns.\
Or just make sure that the dimensions work out.

---
h = torch.tanh(hpreact)
d/dx tanh(x) = sech^2 x <- not very helpful. If $a = \tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}}$, then a simpler formula is $\frac{da}{dz} = 1 - a^2$. Then, using the chain rule, dhpreact = (1.0 - h**2) * dh.

---
Then we want to backpropagate into bngain, bnraw, and bnbias.\
hpreact: 32 x 64, bngain: 1 x 64, bnraw: 32 x 64, bnbias: 1 x 64 \
We're just shifting and scaling, so it makes sense that bnraw has the same dimensions as hpreact.

Use the chain rule: dbngain = bnraw * dhpreact, but dimensions must be 1 x 64 and currently it is 64 x 64, so we need to sum across the columns. Same for dbnraw, but we don't need to sum because of broadcasting. For dbnbias, we need dimension 1 x 64, so we sum dhpreact across the columns

---
Backpropagate through bndiff and bnvar_inv. Similar to other expressions, use the chain rule and match dimensions.\
Dimensions are: bnraw - 32 x 64, bndiff - 32 x 64, bnvar_inv - 1 x 64\
dbndiff = bnvar_inv * dbnraw (dimensions preserved)\
dbnvar_inv = (bndiff * dbnraw).sum(0)

---
Similarly:\
bnvar_inv = (bnvar + 1e05)^-0.5\
dbnvar = -0.5(bnvar + 1e05)^-1.5 * dbnvar_inv # Yayy, correct

bnvar = 1/(batch_size-1)*(bndiff2).sum(0)\
dbndiff2 = 1/(batch_size-1)*torch.ones_like(bndiff2) * dbnvar because the local derivative is just an array of 1s scaled by 1/n-1

bndiff2 = bndiff**2\
dbndiff = 2*bndiff * dbndiff2

bndiff = hprebn - bnmeani\
Shapes: bndiff - 32 x 64, hprebn - 32 x 64, bnmeani - 1 x 64\
dhprebn = 1 * dbndiff\
dbnmeani = (-dbndiff).sum(0) # sum across columns to match dimensions

bnmeani = 1/batch_size*hprebn.sum(0, keepdim=True)\
dhprebn = 1/batch_size * dbnmeani

hprebn = embcat @ W1 + b1\
Shapes: W1 - 30 x 64, hprebn - 32 x 64, embcat - 32 x 30, b1 - 64\
dembcat = dhprebn @ W1.T \
dW1 = embcat.T @ dhprebn \
db1 = dhprebn.sum(0

embcat = emb.view(emb.shape\[0], -1)\
Shapes: embcat - 32 x 30, emb - 32 x 3 x 10\
To revert, we rerepresent it in the original shape: dembcat.view(emb.shape

---
Indexing in the forward pass: emb = C\[Xb]\
Shapes: emb - 32 x 3 x 10 (32 examples, 3 characters, 10d embedding), C - 27 x 10 (lookup table with 27 chatracters, each 10d), Xb - 32 x 3 (32 examples, context length of 3)\
dC must accumulate the gradients. We iterate through all examples in the batch and each position in the context window, get the vocabulary index used at that position, and add the gradient to dC\[idx].

In [65]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Yb] = -1.0/batch_size

cmp('logprobs', dlogprobs, logprobs)

dprobs = 1.0/probs * dlogprobs
cmp('probs', dprobs, probs)

dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
cmp('dcounts_sum_inv', dcounts_sum_inv, counts_sum_inv)

dcounts_sum = -counts_sum**-2 * dcounts_sum_inv
cmp('dcounts_sum', dcounts_sum, counts_sum)

dcounts = counts_sum_inv * dprobs # first gradient with respect to dcounts
dcounts += torch.ones_like(counts) * dcounts_sum # second gradient
cmp('dcounts', dcounts, counts)

dnorm_logits = norm_logits.exp() * dcounts
cmp('dnorm_logits', dnorm_logits, norm_logits)

dlogits = dnorm_logits.clone() # this is no the final derivative for the logits because there's more

dlogit_maxes = -dnorm_logits.sum(1, keepdim=True)
cmp('dlogit_maxes', dlogit_maxes, logit_maxes)

dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes # second branch of dlogits
cmp('dlogits', dlogits, logits)

dh = dlogits @ W2.T
cmp('dh', dh, h)

dW2 = h.T @ dlogits
cmp('dW2', dW2, W2)

db2 = dlogits.sum(0)
cmp('db2', db2, b2)

dhpreact = (1.0 - h**2) * dh
cmp('dhpreact', dhpreact, hpreact)

dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
cmp('dbngain', dbngain, bngain)

dbnraw = bngain * dhpreact # dimension preserved
cmp('dbnraw', dbnraw, bnraw)

dbnbias = dhpreact.sum(0, keepdim=True)
cmp('dbnbias', dbnbias, bnbias)

dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# cmp('dbndiff', dbndiff, bndiff)
cmp('dbnvar_inv', dbnvar_inv, bnvar_inv)

dbnvar = -0.5*(bnvar + 1e05)**-1.5 * dbnvar_inv
cmp('dbnvar', dbnvar, bnvar)

dbndiff2 = 1/(batch_size-1)*torch.ones_like(bndiff2) * dbnvar
cmp('dbndiff2', dbndiff2, bndiff2)

dbndiff += 2*bndiff * dbndiff2 # second branch of dbndiff
cmp('dbndiff', dbndiff, bndiff)

dhprebn = dbndiff.clone()
dbnmeani = (-dbndiff).sum(0)
dhprebn += (1/batch_size) * torch.ones_like(hprebn) * dbnmeani # second branch of dhprebn
cmp('dhprebn', dhprebn, hprebn)
cmp('dbnmeani', dbnmeani, bnmeani)

dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0, keepdim=True)
cmp('dembcat', dembcat, embcat)
cmp('dW1', dW1, W1)
cmp('db1', db1, b1)

demb = dembcat.view(emb.shape)
cmp('demb', demb, emb)

dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        idx = Xb[k, j]
        dC[idx] += demb[k, j]

cmp('dC', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
dcounts_sum_inv | exact: True  | approximate: True  | maxdiff: 0.0
dcounts_sum     | exact: True  | approximate: True  | maxdiff: 0.0
dcounts         | exact: True  | approximate: True  | maxdiff: 0.0
dnorm_logits    | exact: True  | approximate: True  | maxdiff: 0.0
dlogit_maxes    | exact: True  | approximate: True  | maxdiff: 0.0
dlogits         | exact: True  | approximate: True  | maxdiff: 0.0
dh              | exact: True  | approximate: True  | maxdiff: 0.0
dW2             | exact: True  | approximate: True  | maxdiff: 0.0
db2             | exact: True  | approximate: True  | maxdiff: 0.0
dhpreact        | exact: True  | approximate: True  | maxdiff: 0.0
dbngain         | exact: True  | approximate: True  | maxdiff: 0.0
dbnraw          | exact: True  | approximate: True  | maxdiff: 0.0
dbnbias         | exact: True  | approximate: True  | maxdiff:

Exercise 2: simplify the loss calculation



$$\text{loss} = -\log P_y = -\log \frac{e^{l_y}}{\sum_j e^{l_j}}$$
where the last expression arises from applying softmax (raising to the power of the exponential and normalizing)

$$\frac{\partial \text{loss}}{\partial l_i} = \frac{\partial}{\partial l_i}\left[-\log \frac{e^{l_y}}{\sum_j e^{l_j}}\right] = -\frac{\sum_j e^{l_j}}{e^{l_y}} \cdot \frac{\partial}{\partial l_i}\left[ \frac{e^{l_y}}{\sum_j e^{l_j}} \right]$$

When $i \neq y$:
$$\frac{\partial \text{loss}}{\partial l_i} = -\frac{\sum_j e^{l_j}}{e^{l_y}} \left[0 \cdot \frac{1}{\sum_j e^{l_j}} - e^{l_y} \cdot \frac{e^{l_i}}{(\sum_j e^{l_j})^2} \right] = \frac{e^{l_i}}{\sum_j e^{l_j}} = P_i$$
When $i = y$:
$$\frac{\partial \text{loss}}{\partial l_i} = -\frac{\sum_j e^{l_j}}{e^{l_y}} \left[e^{l_y} \cdot \frac{1}{\sum_j e^{l_j}} - e^{l_y} \cdot \frac{e^{l_i}}{(\sum_j e^{l_j})^2} \right] = \frac{e^{l_i}}{\sum_j e^{l_j}} - 1 = P_i - 1$$

Simplifies to $P_i$ and $P_i - 1$

In [67]:
# backward pass
dlogits = F.softmax(logits, 1) # apply softmax along the rows
dlogits[range(batch_size), Yb] -= 1  # subtract 1 from the gradient
dlogits /= batch_size # divide the gradient by the batch size (take the average)

cmp('logits', dlogits, logits) # only approximate is true

logits          | exact: False | approximate: True  | maxdiff: 8.381903171539307e-09


Exercise 3: backprop through batchnorm but all in one go

$$\mu = \frac{1}{m} \sum_i x_i \qquad \sigma^2 = \frac{1}{m-1}\sum_i (x_i-\mu)^2$$
$$\hat{x}_i=\frac{x_i - \mu}{\sqrt{\sigma^2 + \varepsilon}} \qquad y_i = \gamma \hat{x}_i + \beta$$
We know $\partial L/\partial y_i$ and need to find $\partial L/\partial x_i$
$$\frac{\partial L}{\partial \hat x_i}=\gamma \cdot \frac{\partial L}{\partial y_i}$$
$$\frac{\partial L}{\partial \hat \sigma^2}=\sum_i \frac{\partial L}{\partial \hat x_i}\frac{\partial \hat x_i}{\partial \sigma^2} = \gamma \sum_i \frac{\partial L}{\partial y_i} \cdot \frac{\partial}{\partial \sigma^2}\left[ (x_i - \mu)(\sigma^2 + \varepsilon)^{-1/2} \right] = -\frac{1}{2} \gamma \sum_i \frac{\partial L}{\partial y_i} \cdot (x_i - \mu)(\sigma^2 + \varepsilon)^{-3/2}$$
The next one has two arrows because it appears in two equations: one from $\hat x$ and one from $\sigma^2$:
$$\frac{\partial L}{\partial \mu} = \sum_i \frac{\partial L}{\partial \hat x_i}\frac{\partial \hat x_i}{\partial \mu} + \frac{\partial L}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial \mu}= \gamma \sum_i \frac{\partial L}{\partial y_i} \cdot (-(\sigma^2 + \varepsilon)^{-1/2}) + \frac{\partial L}{\partial \sigma^2}\cdot 0 = -\gamma \sum_i \frac{\partial L}{\partial y_i} \cdot (\sigma^2 + \varepsilon)^{-1/2}$$
Finally, to derive $\partial L/\partial x_i$, we first need to recognize that there are three arrows eminating from $x_i$: to $\mu, \sigma^2$, and $\hat x$. Use the chain rule and add up the contributions from those three:
$$\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat x_i} \frac{\partial \hat x_i}{x_i} + \frac{\partial L}{\partial \mu} \frac{\partial \mu}{x_i} + \frac{\partial L}{\partial \sigma^2} \frac{\partial \sigma^2}{x_i} = \gamma \frac{\partial L}{\partial y_i}(\sigma^2 + \varepsilon)^{-1/2}  -\frac{\gamma}{m} \sum_j \frac{\partial L}{\partial y_i} \cdot (\sigma^2 + \varepsilon)^{-1/2} -\frac{1}{2} \gamma \sum_j \frac{\partial L}{\partial y_i} \cdot (x_i - \mu)(\sigma^2 + \varepsilon)^{-3/2} \cdot \frac{2}{m-1} (x_i - \mu) =$$
$$=\frac{\gamma(\sigma^2 + \varepsilon)^{-1/2}}{m} \left[ m\frac{\partial L}{\partial y_i} - \sum_j \frac{\partial L}{\partial y_j} - \frac{m}{m-1} \hat x_i \sum_j \frac{\partial L}{\partial y_j} \hat{x_j} \right]$$

I skipped some derivation steps, but implementing this expression in code, we get:

In [68]:
# m is batch_size
dhprebn = bngain * bnvar_inv/batch_size * (batch_size*dhpreact - dhpreact.sum(0) - batch_size/(batch_size-1) * bnraw * (dhpreact * bnraw).sum(0))
cmp('hprebn', dhprebn, hprebn) # approximation

hprebn          | exact: False | approximate: True  | maxdiff: 3.637978807091713e-12
