# Manual Backprop Exercise

## Setup

In [206]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [207]:
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [208]:
chars = sorted(set(''.join(words)))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [209]:
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):
  X, Y = [], []

  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


## Exercise

In [210]:
def cmp(s, dt, t):
    """Utility function to compare my gradient with PyTorch's"""
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [211]:
# NN setup
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [212]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [284]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3367, grad_fn=<NegBackward0>)

In [214]:
hprebn.shape

torch.Size([32, 64])

d### Exercise 1: Backprop through whole network manually

In [291]:
# Loss is the average negative log likelihood i.e. -1/n ( log(a) + lob(b)...), so dlogprogs is -1/n for each values a,b... and 0 for the other values
# Holds derivative of loss wrt all elements of log probs. The elements not plucked out have 0 derivative while the rest have -1/n
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1/n

# This is dlogprobs/dprobs x dloss/dlogprobs. dlogprobs/dprobs is derivative of an elementwise log, so the derivatives are 1/probs as d(log(x)) = 1/x for each p in probs.
# We have how much tweaking each Xij in dprobs impacts the loss, so using the chain rule tweaking each element in dprobs has impact of  local derivative (effect on logprobs) x error signal (logprobs effect on loss).
dprobs =(1.0/probs) * dlogprobs

# Probs = counts * counts_sum_inv. The elementwise derivative is counts * dprobs, but since counts_sum_inv is broadcasted to match the dimensions of counts, we must sum and collapse it into a column vector of the derivatives wrt each element of sum inv. Essentially the actual derivative is the sum of all the items of the matrix row's derivative
dcounts_sum_inv = (counts * dprobs).sum(1,keepdim=True)

# dcounts_sum is trivial as counts_sum_inv just equals 1/counts_sum so simple power rule application
dcounts_sum = -1 * counts_sum **-2 * dcounts_sum_inv

# Dcounts has 2 components contributing to it, we need to accumulate the gradients.
# First, the gradient directly from counts_sum_inv tweaks
dcounts = dprobs * counts_sum_inv
# Then since counts_sum is a function of counts. dcounts_sum/dcounts * dloss / dcounts. Tweaking each individual element of counts by 1 tweaks corresponding counts_sum entry by 1, hence tweaking loss by that corresponding entry in dcounts_sum.
dcounts += torch.ones_like(counts) * dcounts_sum
# Easy backprop through elementwise exp()
dnorm_logits = dcounts * norm_logits.exp()
# Backprop through simple addition simply passes back the error signal
dlogits = dnorm_logits
# logit_maxes is a 32x1 tensor, it is broadcasted (to 32x27) + elementwise subtracted from 32x27 logits to get norm logits.
# Hence tweaking each element of logit maxes has the cumulative effect of -1 * the sum of each row of dnorm_logits
dlogit_maxes = (-dnorm_logits).sum(1,keepdim=True)
# Then since dlogit_maxes is a function of logits, we accumulate more gradient into dlogits.
# The gradient backpropped is the max components of each row multiplied by dlogit_maxes.
dlogits += F.one_hot(logits.max(1).indices, num_classes=vocab_size) * dlogit_maxes
# For the matrix operations, the derivative is a matrix multiply between transpose and error signal. Just remember dimensions must make sense. The bias is added to each row so tweaking each bias component changes dL by the sum of dlogits for a column
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
# Simple backprop through elementwise tanh.
dhpreact = dh * (1 - h**2)
# Batchnorm layer
# Similar logic re broadcasting as previously
dbngain = (dhpreact * bnraw).sum(0,keepdim = True)
dbnbias = dhpreact.sum(0, keepdim = True)
dbnraw = bngain * dhpreact
dbnvar_inv = (dbnraw * bndiff).sum(0)
dbndiff = dbnraw * bnvar_inv
dbnvar = -0.5 * (bnvar + 1e-5)  ** -1.5 * dbnvar_inv
dbndiff2 = dbnvar * (1/(n-1)) * torch.ones_like(bndiff2)
dbndiff += 2 * bndiff * dbndiff2
dbnmeani = (-dbndiff).sum(0)
dhprebn = torch.ones_like(hprebn) * dbndiff
dhprebn += torch.ones_like(hprebn) * 1/n * dbnmeani
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)

# Finally, backprop the gradients to the embedding vectors.
# Emb is a 3d tensor holding, for each triple of indices in Xb, the 10 dim char features for each character
# All Demb does is hold the gradients from dembcat in a different format, therefore we just view it in necessary dimension.
demb = dembcat.view(-1,3,10)
# Now we have the gradients for each feature vector in the batch, we pass back to C to actually adjust the feature vectors
# For each character in the demb gradients we have, we need to increment the gradients of the actual feature encoding.

# Initialise to 0
dC = torch.zeros_like(C)
# For each training example
for i in range(demb.shape[0]):
    # For each feature vector (e.g. 3 per example). Represents the gradients vector/tweak impact per character
    for j in range(demb.shape[1]):
        # Access the 10 dimensional gradients vector
        gradients = demb[i,j]
        # Get the index of the character
        ix = Xb[i,j]
        # Accumulate the gradient into dC
        dC[ix] += gradients

tensor([[[-4.7125e-01,  7.8682e-01, -3.2844e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4040e-01],
         [-4.7125e-01,  7.8682e-01, -3.2844e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4040e-01],
         [-9.6478e-01, -2.3211e-01, -3.4762e-01,  3.3244e-01, -1.3263e+00,
           1.1224e+00,  5.9641e-01,  4.5846e-01,  5.4011e-02, -1.7400e+00]],

        [[ 1.2815e+00, -6.3182e-01, -1.2464e+00,  6.8305e-01, -3.9455e-01,
           1.4388e-02,  5.7216e-01,  8.6726e-01,  6.3149e-01, -1.2230e+00],
         [ 4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2198e-01,  5.1007e-01,
           1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01],
         [-4.7125e-01,  7.8682e-01, -3.2844e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4040e-01]],

        [[-5.6533e-01,  5.4281e-01,  1.7549e-01, -2.2901e+00, -7.0928e-01,
          -2.92

In [275]:
cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: False | approximate: True  | maxdiff: 4.6566128730773926e-09
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximat

### Exercise 2: Softmax + CrossEntropy simpler backprop analytically

Done pen/paper

In [292]:

# Exercise 2: backprop through cross_entropy but all in one go
# to complete this challenge look at the mathematical expression of the loss,
# take the derivative, simplify the expression, and just write it out

# forward pass

# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())

3.3366892337799072 diff: -2.384185791015625e-07


In [294]:
# backward pass

dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n

cmp('logits', dlogits, logits)

logits          | exact: False | approximate: True  | maxdiff: 4.423782229423523e-09


Since the gradients of the logits per row are just the softmax probs, and p-1 for the correct value, think of each row in logits as being adjusted like a pulley system where the total adjustment is 1 as p1 + p2 + p3 + (py-1) +.. +pn = 0. We are adjusting the logits to make the likely outcome more likely, exactly by the amount we are reducing the rest

### Exercise 3: BatchNorm layer simpler backprop analytically

Done pen/paper

In [297]:
# Exercise 3: backprop through batchnorm but all in one go
# to complete this challenge look at the mathematical expression of the output of batchnorm,
# take the derivative w.r.t. its input, simplify the expression, and just write it out

# forward pass

# before:
# bnmeani = 1/n*hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias

# now:
hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias
print('max diff:', (hpreact_fast - hpreact).abs().max())

max diff: tensor(4.7684e-07, grad_fn=<MaxBackward1>)


In [298]:
# backward pass

# before we had:
# dbnraw = bngain * dhpreact
# dbndiff = bnvar_inv * dbnraw
# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv
# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar
# dbndiff += (2*bndiff) * dbndiff2
# dhprebn = dbndiff.clone()
# dbnmeani = (-dbndiff).sum(0)
# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)

# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)
# (you'll also need to use some of the variables from the forward pass up above)

dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))

cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10

hprebn          | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
