# Becoming backprop ninja

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
words = open("names.txt", "r").read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
len(words)

32033

In [4]:
chars = sorted(list(set("".join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [5]:
block_size = 3

def build_dataset(words):
    X, Y = [], []
    
    for w in words:
        context = [0] * block_size
        for ch in w + ".":
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append
            
    X = torch.tensor(X)
    Y = torch.tensor(Y)   
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1])    
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [8]:
# utility function that we will use later to compare manual gradients to PyTorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f"{s:15s} | exact: {ex} | approx: {str(app):5s} | maxdiff: {maxdiff}")

In [10]:
n_emb = 10
n_hidden = 64

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_emb), generator=g) 
# layer 1
W1 = torch.randn((n_emb*block_size, n_hidden), generator=g) * (5/3)/((n_emb*block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1 # just for fun, it's useless because of batch normalization
# layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# BatchNorm Parameters
bngain = torch.ones((1, n_hidden))*0.1 + 1.0
bnbias = torch.zeros((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4137


In [11]:
batch_size = 32
n = batch_size
# consturct minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X, y

In [13]:
# forward pass
emb = C[Xb] # embed characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear Layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre activation
# BatchNorm Layer
bnmeani = 1/n * hprebn.sum(0, keepdim=True) # for ith iteration
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1) * (bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividng by n-1 not n)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non Linearity
h = torch.tanh(hpreact) # hidden layer
logits = h @ W2 + b2  # output layer
# cross entropy loss (same as F.cross_entropy)
logit_maxes = logits.max(1, keepdim=True).values
norm_logits =   logits - logit_maxes # subract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss =  -logprobs[range(n), Yb].mean()

# pytorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,  norm_logits, logit_maxes, logits, h, hpreact, bnraw,bnvar_inv,bnvar, bndiff2,bndiff, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3482, grad_fn=<NegBackward0>)

#### Let the game begin

In [None]:
# dlogprobs = ???  
# dlogprobs will hold dloss wrt to all the values in logprobs

In [14]:
logprobs.shape

torch.Size([32, 27])

In [18]:
logprobs[range(n), Yb].shape

torch.Size([32])

In [None]:
# loss = - (a+b+c) / 3
# dloss / da = -1/3 
# if we have n values, then dloss / da = -1/n, similarly for b and c
# dloss/da = -1/n

logprobs is 32,27  but only 32 elements are plucked out for the loss calculation. Thus gradient of remaining elements is zero because they are not used in the loss calculation.

In [19]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n

cmp("dlogprobs", dlogprobs, logprobs)

dlogprobs       | exact: True | approx: True  | maxdiff: 0.0


now logprobs depends on probs through a log. So all the elements of probs are being elemenwise applied log to .

In [21]:
dprobs = (1.0/probs) * dlogprobs
#         local deri * global deri (the loss of value wrt output of the layer)

cmp("drpobs", dprobs, probs)

drpobs          | exact: True | approx: True  | maxdiff: 0.0


In [None]:
# dcounts_sum_inv

In [22]:
counts.shape, counts_sum_inv.shape

(torch.Size([32, 27]), torch.Size([32, 1]))

So this is doing broadcasting then matrix multiplication.

In [23]:
# c = a * b, but with tensors ===>
# a[3x3] * b[3,1] =>
# a11*b1, a12*b1, a13*b1
# a21*b1, a22*b1, a23*b1
# a31*b1, a32*b1, a33*b1
# c[3x3]

# thus dc/da = b, dc/db = a

#therefore
dcounts_sum_inv = counts # bcz, counts * counts_sum_inv => therefore dcounts_sum_inv = counts (see above)
dcounts_sum_inv = counts * dprobs # this is gradient wrt to replicated b, but we dont have replicated b. We just have b column
#                 local deri * global deri

# so how do we backpropogate through b replication?
# when node is multiple times, we add the gradients (as we did in micrograd)
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True) # we want to keep 32, 1 dims

cmp("dcounts_sum_inv", dcounts_sum_inv, counts_sum_inv)

dcounts_sum_inv | exact: True | approx: True  | maxdiff: 0.0


In [25]:
dcounts = counts_sum_inv * dprobs # local deri * global deri

cmp("dcounts", dcounts, counts) # counts is used multiple times, one while calculating probs and other while calculating counts_sum. SO currently we just calculated the contribution of counts wrt probs, we still need to calculate the contribution of counts wrt counts_sum. That's why it is showing false.

dcounts         | exact: False | approx: False | maxdiff: 0.00623944029211998


In [26]:
dcounts_sum = -1.0 / (counts_sum**2) * dcounts_sum_inv
cmp("dcounts_sum", dcounts_sum, counts_sum)

dcounts_sum     | exact: True | approx: True  | maxdiff: 0.0


In [28]:
# now handling counts_sum = counts.sum(1, keepdim=True)

counts.shape, counts_sum.shape

(torch.Size([32, 27]), torch.Size([32, 1]))

In [29]:
# a11, a12, a13    --------> b1 (= a11 + a12 + a13)
# a21, a22, a23    --------> b2 (= a21 + a22 + a23)
# a31, a32, a33    --------> b3 (= a31 + a32 + a33)

dcounts += torch.ones_like(counts) * dcounts_sum

cmp("dcounts", dcounts, counts)

dcounts         | exact: True | approx: True  | maxdiff: 0.0


In [30]:
# dnorm_logits = norm_logits.exp() * dcounts 
# which is equal to
dnorm_logits = counts * dcounts
cmp("dnorm_logits", dnorm_logits, norm_logits)

dnorm_logits    | exact: True | approx: True  | maxdiff: 0.0


In [31]:
# dlogit_maxes

norm_logits.shape, logits.shape, logit_maxes.shape

(torch.Size([32, 27]), torch.Size([32, 27]), torch.Size([32, 1]))

In [32]:
dlogits = dnorm_logits.clone() # not yet final deri for logits
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
cmp("dlogit_maxes", dlogit_maxes, logit_maxes)

dlogit_maxes    | exact: True | approx: True  | maxdiff: 0.0


In [33]:
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) *  dlogit_maxes

cmp("dlogits", dlogits, logits)

dlogits         | exact: True | approx: True  | maxdiff: 0.0


In [35]:
# logits
dlogits.shape, h.shape, W2.shape, b2.shape

(torch.Size([32, 27]),
 torch.Size([32, 64]),
 torch.Size([64, 27]),
 torch.Size([27]))

In [36]:
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)

cmp("dh", dh, h)
cmp("dW2", dW2, W2)
cmp("db2", db2, b2)

dh              | exact: True | approx: True  | maxdiff: 0.0
dW2             | exact: True | approx: True  | maxdiff: 0.0
db2             | exact: True | approx: True  | maxdiff: 0.0


In [37]:
dhpreact = (1.0 - h**2) * dh
cmp("dhpreact", dhpreact, hpreact)

dhpreact        | exact: False | approx: True  | maxdiff: 9.313225746154785e-10


In [38]:
hpreact.shape, bngain.shape, bnraw.shape, bnbias.shape

(torch.Size([32, 64]),
 torch.Size([1, 64]),
 torch.Size([32, 64]),
 torch.Size([1, 64]))

In [39]:
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnbias = dhpreact.sum(0, keepdim=True)

cmp("dbngain", dbngain, bngain)
cmp("dbnraw", dbnraw, bnraw)
cmp("bnbias", dbnbias, bnbias)

dbngain         | exact: False | approx: True  | maxdiff: 3.725290298461914e-09
dbnraw          | exact: False | approx: True  | maxdiff: 9.313225746154785e-10
bnbias          | exact: False | approx: True  | maxdiff: 3.725290298461914e-09


In [40]:
bnraw.shape, bndiff.shape, bnvar_inv.shape

(torch.Size([32, 64]), torch.Size([32, 64]), torch.Size([1, 64]))

In [51]:
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)

cmp("dbndiff", dbndiff, bndiff)
cmp("dbnvar_inv", dbnvar_inv, bnvar_inv)

dbndiff         | exact: False | approx: False | maxdiff: 0.0011080320691689849
dbnvar_inv      | exact: False | approx: True  | maxdiff: 3.725290298461914e-09


In [46]:
dbnvar = (-0.5 * (bnvar + 1e-5)**(-1.5)) * dbnvar_inv
cmp("dnvar", dbnvar, bnvar)

dnvar           | exact: False | approx: True  | maxdiff: 4.656612873077393e-10


In [47]:
bnvar.shape, bndiff2.shape

(torch.Size([1, 64]), torch.Size([32, 64]))

In [49]:
dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar

cmp("dbndiff2", dbndiff2, bndiff2)

dbndiff2        | exact: False | approx: True  | maxdiff: 1.4551915228366852e-11


In [52]:
dbndiff += (2 * bndiff) * dbndiff2
cmp("dbndiff", dbndiff, bndiff)

dbndiff         | exact: False | approx: True  | maxdiff: 6.984919309616089e-10


In [55]:
bndiff.shape, hprebn.shape, bnmeani.shape

(torch.Size([32, 64]), torch.Size([32, 64]), torch.Size([1, 64]))

In [64]:
dhprebn = dbndiff.clone()
dbnmeani = (-torch.ones_like(bndiff) * dbndiff).sum(0)

cmp("dhprebn", dhprebn, hprebn)
cmp("dhmeani", dbnmeani, bnmeani)

dhprebn         | exact: False | approx: False | maxdiff: 0.0011309990659356117
dhmeani         | exact: False | approx: True  | maxdiff: 3.725290298461914e-09


In [57]:
bnmeani.shape, hprebn.shape

(torch.Size([1, 64]), torch.Size([32, 64]))

In [65]:
dhprebn += 1.0/n *  (torch.ones_like(hprebn) * dbnmeani)
cmp("dhprebn", dhprebn, hprebn)

dhprebn         | exact: False | approx: True  | maxdiff: 6.984919309616089e-10


In [66]:
hprebn.shape, embcat.shape, W1.shape, b1.shape

(torch.Size([32, 64]),
 torch.Size([32, 30]),
 torch.Size([30, 64]),
 torch.Size([64]))

In [67]:
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)

cmp("dembcat", dembcat, embcat)
cmp("dW1", dW1, W1)
cmp("db1", db1, b1)

dembcat         | exact: False | approx: True  | maxdiff: 1.3969838619232178e-09
dW1             | exact: False | approx: True  | maxdiff: 5.587935447692871e-09
db1             | exact: False | approx: True  | maxdiff: 4.6566128730773926e-09


In [68]:
embcat.shape, emb.shape

(torch.Size([32, 30]), torch.Size([32, 3, 10]))

In [69]:
demb = dembcat.view(emb.shape)
cmp("demb", demb, emb)

demb            | exact: False | approx: True  | maxdiff: 1.3969838619232178e-09


In [70]:
emb.shape, C.shape, Xb.shape

(torch.Size([32, 3, 10]), torch.Size([27, 10]), torch.Size([32, 3]))

In [71]:
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]
        
cmp("dC", dC, C)

dC              | exact: False | approx: True  | maxdiff: 5.587935447692871e-09


calculating logits in one go

In [72]:
# this formula is caluclated from by deriving using pen and paper

dlogits = F.softmax(logits, dim=1)
dlogits[range(n), Yb] -= 1
dlogits /= n

cmp("dlogits", dlogits, logits)

dlogits         | exact: False | approx: True  | maxdiff: 8.614733815193176e-09


calculating gradient for batchnorm in one go

In [74]:
dhprebn = bngain * bnvar_inv/n * (n*dhpreact -dhpreact.sum(0) -n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))

cmp("dhprebn", dhprebn, hprebn)

dhprebn         | exact: False | approx: True  | maxdiff: 9.313225746154785e-10
