# Deep Dive into Backprop backprop ninja


## Imports and setup


In [171]:
### Imports and setup
import torch
import torch.nn.functional as F 
import matplotlib.pyplot as plt 
from tqdm import tqdm
%matplotlib inline
with open('data/combined_english_names_cleaned.txt','r') as f:
# with open('data/combined_arabic_names_cleaned.txt','r') as f:
    names = [x.strip().lower() for x in f.readlines()]

names[:10]
len(names)
CONTEXT_SIZE = 3

chars = ['.'] + sorted(list(set(''.join(names))))
stoi = { c:i for i,c in enumerate(chars)}
itos = {i:c for c,i in stoi.items()}
print(stoi)
print(itos)
print( sorted(list(set(''.join(names)))))

def build_dataset(names, context_size, p=False):
    X = []
    Y = []

    for name in names:

        if p :
            print(name)
        context = [stoi['.']] * context_size

        for c in name + '.':
            X.append(context)
            Y.append(stoi[c])

            if p:
                print( ''.join(itos[x] for x in context) ,f' --> {c}')

            context = context[1:] + [stoi[c]]

    
    X = torch.tensor(X)
    Y = torch.tensor(Y)


    return X,Y
## MLP Setup


VOCAB_SIZE = len(chars)
EPS  = 1e-5
n1 = int(0.8 * len(names))
n2 = int(0.9 * len(names))
x_train, y_train = build_dataset( names[:n1], CONTEXT_SIZE )
x_val, y_val = build_dataset( names[n1:n2], CONTEXT_SIZE )
x_test, y_test = build_dataset( names[n2:], CONTEXT_SIZE )


{'.': 0, '-': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'y': 25, 'z': 26}
{0: '.', 1: '-', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: 'i', 11: 'j', 12: 'k', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'y', 26: 'z'}
['-', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z']


In [172]:
# Comparing gradients

def compare(s, dt, t):
    exact = torch.all(dt == t.grad).item()
    approx = torch.allclose(dt, t.grad)
    max_diff = (dt - t.grad).abs().max().item()

    print(f'{s:15s} | exact: {str(exact):5s} | approximate: {str(approx):5s} | maxdiff: {max_diff}')


In [173]:
EMBEDDING_SIZE = 10
HIDDEN_SIZE = 64


C = torch.randn((VOCAB_SIZE, EMBEDDING_SIZE))

# Layer 1 
W1 = torch.randn(( EMBEDDING_SIZE * CONTEXT_SIZE,HIDDEN_SIZE)) * (5/3) / ( EMBEDDING_SIZE * CONTEXT_SIZE)**0.5
b1 = torch.randn((HIDDEN_SIZE)) * 0.1

# Layer 2
W2 = torch.randn((HIDDEN_SIZE,VOCAB_SIZE)) * (5/3) / (HIDDEN_SIZE)**0.5
b2 = torch.randn((VOCAB_SIZE)) * 0.1

# Batch Norm
bngain = torch.ones((1,HIDDEN_SIZE))
bnbias = torch.zeros((1,HIDDEN_SIZE))

bnmean_running = torch.zeros((1,HIDDEN_SIZE))
bnstd_running = torch.ones((1,HIDDEN_SIZE))


parameters = [C,W1,b1,W2,b2, bnbias,bngain]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [174]:

batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, x_train.shape[0], (batch_size,))
Xb, Yb = x_train[ix], y_train[ix] # batch X,Y

In [175]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani

bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.5028, grad_fn=<NegBackward0>)

In [176]:
a = - (torch.exp(logits).sum(1) ) 
print(a.shape)
p = F.softmax(logits,1)
# p = torch.exp(logits) /  torch.exp(logits).sum(1, keepdims=True)
dll = p
print(dll[range(dll.shape[0]),Yb].shape)
dll[range(dll.shape[0]),Yb] = p[range(dll.shape[0]),Yb] -1 
dll /= dll.shape[0]
compare('logits', dll, logits)
Yb.shape, logits.shape, dll.shape


torch.Size([32])
torch.Size([32])
logits          | exact: False | approximate: True  | maxdiff: 5.122274160385132e-09


(torch.Size([32]), torch.Size([32, 27]), torch.Size([32, 27]))

In [177]:
dlogprobs =  torch.zeros_like(logprobs)
dlogprobs[range(n),Yb] = -1/logprobs.shape[0]

dprobs = dlogprobs / probs

## This is tricky, because of the sizes aren't the same and the broadcasting the happens
dcounts_sum_inv = (dprobs  * counts).sum(1, keepdim=True)

dcounts_sum = -counts_sum**-2 * dcounts_sum_inv
dcounts =  (dprobs  * counts_sum_inv) + 1 * dcounts_sum
dnorm_logits = norm_logits.exp() * dcounts
dlogit_maxes = -1 * dnorm_logits.sum(1, keepdim=True)

## two ways to calc dlogits
# id = logits.max(1, keepdim=True).indices
# t = torch.zeros_like(logits)
# t[id] = 1
# dlogits = dnorm_logits + t * dlogit_maxes
dlogits = dnorm_logits +  F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes

dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0,keepdim=True)
dhpreact = (1.0 - torch.pow(h,2)) * dh
# dbngain = dhpreact.sum(0, keepdim=True) * bnraw.sum(0, keepdim=True) 

dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
 
dbnbias = dhpreact.sum(0,keepdim=True)
dbnraw = bngain * dhpreact

dbnvar_inv = (dbnraw * bndiff).sum(0,keepdim=True)
dbndiff = bnvar_inv * dbnraw
dbnvar = -.5*(bnvar + 1e-5)**-1.5 * dbnvar_inv


dbndiff2 = torch.ones_like(bndiff2) * (1/(n-1)) * dbnvar
dbndiff += 2*bndiff * dbndiff2


dbnmeani = dbndiff.sum(0, keepdim=True) * -1
dhprebn = dbndiff  + dbnmeani  * torch.ones_like(hprebn) * (1/(n)) 

dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0,keepdim=True)



demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)

for i in range(demb.shape[0]):
    for j in range(demb.shape[1]):
        ix = Xb[i][j]
        dC[ix] += demb[i][j]

compare('logprobs', dlogprobs, logprobs)
compare('probs', dprobs, probs)
compare('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
compare('counts_sum', dcounts_sum, counts_sum)
compare('counts', dcounts, counts)
compare('norm_logits', dnorm_logits, norm_logits)
compare('logit_maxes', dlogit_maxes, logit_maxes)
compare('logits', dlogits, logits)
compare('h', dh, h)
compare('W2', dW2, W2)
compare('b2', db2, b2)
compare('hpreact', dhpreact, hpreact)
compare('bngain', dbngain, bngain)
compare('bnbias', dbnbias, bnbias)
compare('bnraw', dbnraw, bnraw)
compare('bnvar_inv', dbnvar_inv, bnvar_inv)
compare('bnvar', dbnvar, bnvar)
compare('bndiff2', dbndiff2, bndiff2)
compare('bndiff', dbndiff, bndiff)
compare('bnmeani', dbnmeani, bnmeani)
compare('hprebn', dhprebn, hprebn)
compare('embcat', dembcat, embcat)
compare('W1', dW1, W1)
compare('b1', db1, b1)
compare('emb', demb, emb)
compare('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bngain          | exact: False | approximate: True  | maxdiff: 6.51925802230835e-09
bnbias          | exact: False | approximate: True  | maxdiff: 7.450580596923828e-09
bnraw    