The aim is to understand Backprop for tensors pretty well. Andrej argues that it's very important to avoiding shooting yourself in the foot. He wrote a [blogpost](https://karpathy.medium.com/yes-you-should-understand-backprop-e2f06eab496b).

In [18]:
### content in this cell is the same as previous notebook Makemore MLP Better than ever
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline 

# Read data
names = open("names.txt", "r").read().splitlines()

# Build vocabulary
characters = sorted(list(set(''.join(names))))
str_to_int = {s:i+1 for i, s in enumerate(characters)}
str_to_int['.'] = 0.0
int_to_str = {i:s for s, i in str_to_int.items()}
vocab_size = len(int_to_str)

block_size = 3  # context size

def build_dataset(names):
    X, Y = [], []
    for name in names:
        context = [0] * block_size
        for character in name + ".":
            ix = str_to_int[character]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    X = torch.tensor(X)
    Y = torch.tensor(Y, dtype=torch.long)
    print(X.shape, Y.shape)
    return X, Y

import random 
random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))

Xtr, Ytr = build_dataset(names[:n1])
Xdev, Ydev = build_dataset(names[n1:n2])
Xte, Yte = build_dataset(names[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


# Function to Compare Gradients

In [19]:
def cmp(s, dt, t):
    """Compares manual gradients to Pytorch gradients."""
    ex = torch.all(dt == t.grad).item()            # Exact gradient
    app = torch.allclose(dt, t.grad)               # Approximate gradient
    maxdiff = (dt - t.grad).abs().max().item()     # Maximum difference
    print(f"{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}")

# Neural Network Initialization

In [20]:
n_emb = 10
n_hidden = 64

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_emb), generator=g)
# Layer 1
W1 = torch.randn((n_emb*block_size, n_hidden), generator=g) * (5/3) / ((n_emb*block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g)  # Just for fun, it is useless due to BatchNorm
# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# Batch norm
bngain = torch.randn((1, n_hidden), generator=g)*0.1 + 1.0
bnbias = torch.randn((1, n_hidden), generator=g)*0.1 

# Some of these parameters are initialized in non-standard ways because sometimes initializing them 
# with all zeros can mask an incorrect implementation of the backward pass

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4137


# Batching

In [21]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size, ), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

# Forward Pass

In [22]:
emb = C[Xb]                                       # Embed characters into vectors
embcat = emb.view(emb.shape[0], -1)               # Concatenate embedding for the context
# First linear layer
hprebn = embcat @ W1 + b1                         # pre-activation
# Batch norm
bnmeani = (1 / n)*hprebn.sum(0, keepdim=True)    # compute the mean of the batch
bndiff = hprebn - bnmeani                       
bndiff2 = bndiff**2
bnvar = (1/(n-1)) * bndiff2.sum(0, keepdim=True)  # Bessel's correction: divide by n-1 not n
bnvar_inv = (bnvar + 1e-5)**(-0.5)
bnraw = bndiff * bnvar_inv                        # Divide by the variance
hpreact = bngain * bnraw + bnbias                 # Shift and rescale the Batch Norm
# Non-linearity
h = torch.tanh(hpreact)                           # Hidden layer
# Second Linear layer
logits = h @ W2 + b2
# Cross-entropy loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes        
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1                   # If we use (1./counts_sum) instead then we can't get backprop to be bit exact
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# Backward Pass

In [23]:
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, 
          norm_logits, logit_maxes, logits, h, hpreact, bnraw, 
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani, 
          embcat, emb]:
    t.retain_grad()  # retain the gradient, meaning that all these intermediary values keep i
loss.backward()
loss

tensor(3.5571, grad_fn=<NegBackward0>)

In [68]:
# dloss/dlogprob
# The expression is `loss = -logprobs[range(n), Yb].mean()`. It is the negative of the mean
# The derivative of this is an element with all -1/n. The other elements in logprobs do not
# contribute to the loss, so their derivative is zero
dlogprobs = torch.zeros_like(logprobs)  # to avoid hard-coding numbers
dlogprobs[range(n), Yb] = -1./n

# dloss/dprobs = (dloss / dlogprob) * (dlogprob / dprob)
# Since logprob = probs.log() we have
dprobs = (1.0 / probs) * dlogprobs

# dloss/dcounts_sum_inv = (dloss/dprobs) * (dprobs / dcounts_sum_inv)
# To understand the second term `(dprobs / dcounts_sum_inv)`, notice that the expression `probs = counts * counts_sum_inv`
# actually contains two operations because `counts` is `(32, 27)` but `counts_sum_inv` is `(32, 1)` so it has been replicated.
# Replication means that the same value has been used 27 times, one time per column. When a value gets used multiple times, we just sum 
# up the contributions. Since each of these contributions would be `dprobs * counts` we just sum them up on axis 1
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)  # Keepdim=True to keep it at `(32, 1)` rather than `(32,)`

# `dloss/dcounts` is more tricky because ` counts_sum_inv` depends on `counts`.
# For now we only compute `(dloss/dprobs) * (dprobs / dcounts)` but we need the second branch
dcounts = dprobs * counts_sum_inv

# dcounts_sum = -(counts_sum)^(-2) * chain_rule...
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv

# Second branch of dcounts
dcounts += torch.ones_like(counts) * dcounts_sum 

# the rest..
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()    # for safety
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes  
dh = dlogits @ W2.T 
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1.0 - h**2) * dh
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim=True)
dbndiff = bnvar_inv *dbnraw   # first branch
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar = (-0.5*(bnvar + 1e-5)**(-1.5)) * dbnvar_inv
dbndiff2 = (1.0 / (n-1)) * torch.ones_like(bndiff2) * dbnvar
dbndiff += (2*bndiff) * dbndiff2
dhprebn = dbndiff.clone() 
dbnmeani = (-dbndiff).sum(0)
dhprebn += (1.0/n) * torch.ones_like(hprebn) * dbnmeani
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0, keepdim=True)
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k, j]
        dC[ix] += demb[k, j]


cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts', dcounts, counts)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes) 
cmp('logits', dlogits, logits) 
cmp('dh', dh, h)
cmp('dW2', dW2, W2)
cmp('db2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnraw', dbnraw, bnraw)
cmp('bnbias', dbnbias, bnbias)
cmp('bndiff', dbndiff, bndiff)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('dhprebn', dhprebn, hprebn)
cmp('dbnmeani', dbnmeani, bnmeani)
cmp('dembcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
dh              | exact: True  | approximate: True  | maxdiff: 0.0
dW2             | exact: True  | approximate: True  | maxdiff: 0.0
db2             | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff:

A quick note: in the original batch norm paper they use $1/m$ for the sample variance at test time, but $1/(m-1)$ at training time (or the other way around). This is confusing and typically one should avoid train-test discrepancies, especially when unmotivated.

In [69]:
# simpler version for dlogits
dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1.0
dlogits /= n


cmp('logits', dlogits, logits)

logits          | exact: False | approximate: True  | maxdiff: 6.51925802230835e-09
