## Backprop history and importance: 

It was a common practice as late as 2016 to write your own backward pass instead of the autograd engine we use to call `loss.backward()`, which is why we will spend this lecture trying to get an intuitive sense of backprop by writing own own code to execute it. 

Essentially we will introduce many __intermediate variables__ to track the flow of gradients a bit like we did in autograd. 

We will also __revert__ to our simple model of neural network with only a __single hidden layer__. 

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

from aux_files import Linear, Tanh, BatchNorm1d

In [2]:
words = open('names.txt', 'r').read().splitlines()

In [3]:
allchars = sorted(set(''.join(words)))

stoi = {s:i+1 for i,s in enumerate(allchars) }
stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}
vocab_size = len(stoi)

In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
words_shuffled = words[:]  # shallow copy -- to preserve across runs
random.shuffle(words_shuffled)
n1 = int(0.8*len(words_shuffled))
n2 = int(0.9*len(words_shuffled))

Xtr,  Ytr  = build_dataset(words_shuffled[:n1])     # 80%
Xdev, Ydev = build_dataset(words_shuffled[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words_shuffled[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


Done with boilerplate init code, now to more concrete stuff. 

Lets define a comparing function to check whether analytically calculated gradients are close to those calculated by pytorch.

In [67]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {str(maxdiff):5s}')

Utility of `cmp`: 
- Line 1: compares all elements of dt and t.grad and compresses it to a single True or False boolean rather than a whole matrix
- Line 2: gives come wiggle room during comparison
- Line 3: captures the maximum difference bw the two values across the entire gradient matrix
- Line 4 is a simple print statement

Initializating many of these parameters in non-standard ways because sometimes initializating with e.g. all zeros could mask an incorrect
implementation of the backward pass.

In [6]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5) # kaiming init
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0 # non standard init
bnbias = torch.randn((1, n_hidden))*0.1 # non standard init

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [7]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [75]:
print(hprebn.shape, embcat.shape, W1.shape, b1.shape)
print(embcat.shape, emb.shape)
print(C.shape, emb.shape, Xb.shape)

torch.Size([32, 64]) torch.Size([32, 30]) torch.Size([30, 64]) torch.Size([64])
torch.Size([32, 30]) torch.Size([32, 3, 10])
torch.Size([27, 10]) torch.Size([32, 3, 10]) torch.Size([32, 3])


In [27]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3394, grad_fn=<NegBackward0>)

All cells above this one are just the basics being carried over. Now lets define: 

$\frac{\partial Loss}{\partial {logprobs}}$ and start from here as we go back in the chain to compute derivative of loss wrt each intermediate variable. 

At each step we uncomment each `cmp` call to check the proximity of our analytically calculated grad and pytorch calculated grad.

In [50]:
# c = a * b 
# dc/db = a
print(counts.shape, counts_sum_inv.shape) # broadcasting is implicit
print('--------------------')
print(logits.shape, h.shape, W2.shape, b2.shape)
print('--------------------')
print(hpreact.shape, bnraw.shape, bngain.shape)
print('--------------------')
print(bndiff2.shape)

torch.Size([32, 27]) torch.Size([32, 1])
--------------------
torch.Size([32, 27]) torch.Size([32, 64]) torch.Size([64, 27]) torch.Size([27])
--------------------
torch.Size([32, 64]) torch.Size([32, 64]) torch.Size([1, 64])
--------------------
torch.Size([32, 64])


One tip to deduce formulae of gradients correctly is to check the shape of the original variable; the shape of the grad will also be same, since loss is a scalar. 

i.e. __if__ shape of logprobs = $(30,40)$ shape of dlogprobs also will be $(30,40)$. 

<span style="color:#FF0000; font-family: 'Bebas Neue'; font-size: 01em;">Caution:</span><br>
1. for `probs = counts_sum_inv * counts`, there are __2 steps__: broadcasting counts_sum_inv and then multiplication, which is why in `dcounts_sum_inv` we sum the gradients along dim 1

2. At some places, a `+=` is used to _accumulate_ gradients if a variable is repeated during backpass. 

3. `A*B` is element-wise multiplication and and `A@B` is matrix multiplication. 

In [78]:
# Exercise 1: bacmp('logprobs', dlogprobs, logprobs)

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n

dprobs = 1/probs * dlogprobs # chain rule

dcounts = dprobs * counts_sum_inv # 1st contribution, rest comes from count_sum_inv definition!
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True) # this is tricky -- understand using shape analogy

dcounts_sum =   (-1/(counts_sum)**2) * dcounts_sum_inv

dcounts2 = torch.ones_like(counts) * dcounts_sum
dcounts += dcounts2 

dnorm_logits = dcounts * counts #exponential goes unscathed

dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True) # similar to dcounts_sum_inv -- broadcasting takes place to beware

dh = dlogits @ W2.T  
dW2 = h.T @ dlogits
db2 = dlogits.sum(0, keepdims = True) # trick: dim(b2) = [27] => broadcasting happens across rows, so grad sum must be across rows. 

dhpreact = dh * (1-h**2)

dbnbias = dhpreact.sum(0, keepdims=True) # since dim(bnbias) = 1,64 => sum must be along rows (0) while broadcasting -- same as db2
dbngain = (dhpreact * bnraw).sum(0, keepdims = True) # from dimensional analysis
dbnraw = bngain * dhpreact

dbnvar_inv = (dbnraw * bndiff).sum(0, keepdims = True)
dbndiff = dbnraw * bnvar_inv # will have to add another components from dbndiff2.grad

dbnvar = dbnvar_inv * (-0.5 * ((bnvar + 1e-5)**-1.5)) # simple derivative
dbndiff2 = dbnvar * torch.ones_like(bndiff2) * 1/(n-1) # decrypting broadcasting using math

dbndiff += 2*bndiff*dbndiff2 # add second component 

dbnmeani = -1* dbndiff.sum(0, keepdim=True)

dhprebn = dbndiff.clone()
dhprebn += dbnmeani * 1/n * torch.ones_like(hprebn) # since there were 2 components 

dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0, keepdim = True)

demb = dembcat.view(emb.shape[0], block_size, -1) #since this was just a dim transformation in forward pass too

# proceed based on intuition and dimensions of C, emb and Xb
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
  for j in range(Xb.shape[1]):
    ix = Xb[k,j]
    dC[ix] += demb[k,j]

print('---------GRAD comparison results--------------')
cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

# backpropagating through exactly all of the variables manually
# as they are defined in the forward pass above, one by one




---------GRAD comparison results--------------
logprobs        | exact: True  | approximate: True  | maxdiff: 0.0  
probs           | exact: True  | approximate: True  | maxdiff: 0.0  
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0  
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0  
counts          | exact: True  | approximate: True  | maxdiff: 0.0  
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0  
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0  
logits          | exact: False | approximate: True  | maxdiff: 9.313225746154785e-09
h               | exact: False | approximate: True  | maxdiff: 2.3283064365386963e-09
W2              | exact: False | approximate: True  | maxdiff: 1.210719347000122e-08
b2              | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
hpreact         | exact: False | approximate: True  | maxdiff: 2.3283064365386963e-09
bngain          | exact: False | approxima

It does take some effort to trace it all the way back but matching dimensions of variables is really the trick. Understanding broadcasting, accumulation etc operations. 

### On Bessel's correction 

In the line: <br>
`bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True)` <br>
whether to use $\frac{1}{n-1}$ or $\frac{1}{n}$ is a bit of a confusion in the [original batch norm paper](https://arxiv.org/abs/1502.03167) since they use one during training and other during testing which can, however minutely, give fudged results. Andrej prefers using 
$\frac{1}{n-1}$ uniformly. More debate can be found [here](https://math.oxford.emory.edu/site/math117/besselCorrection/)


## Exercise 2

Backprop through cross_entropy but all in one go to complete this challenge look at the mathematical expression of the loss, take the derivative, simplify the expression, and just write it out.

In [82]:
# forward pass

# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()

# now:
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())

3.3394112586975098 diff: 0.0


Similarly, now in computation of `dlogits` we will derive a mathematical relation to prevent calculation of derivatives of `[logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes]` to reach upto `dlogits`. 

__So the goal is:__ `dlogits = f(logits, Yb)`

Note: an attempt to calculate the loss analytically from the logits matrix, BUT apparently that is not what Andrej meant LOL.

In [105]:
# write a variable logit matrix, subtract, exponentiate, take sum along dim 1 => probs. Calculate derivatives at 
# specific index using chain rule for division. Doesnt seem to work tho. 

den = logits.exp().sum(dim=1, keepdim=True)
den**2 # used in denominator
num1 = ((logits[range(n), Yb]).exp())
num1**2 # first term of numerator

num2 = den * num1.view(32,1) # .view to enforce element wise multiplication
 
print('num and den shape:',den.shape, num1.shape, num2.shape)

entries = ((num1**2).view(32,1) - num2)/ den**2

print('entries shape:',entries.shape)

dlogits_math = torch.zeros_like(logits)

for i in range(n): # n: batch size 32
    dlogits_math[i, Yb[i]] = entries[i]

dlogits_math[:2] # i have f ed up somewhere clearly. 

num and den shape: torch.Size([32, 1]) torch.Size([32]) torch.Size([32, 1])
entries shape: torch.Size([32, 1])


tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.0185,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0441,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000]], grad_fn=<SliceBackward0>)

Andrej's sense of the implementation: 

In [106]:
dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n

cmp('logits', dlogits, logits)

logits          | exact: False | approximate: True  | maxdiff: 9.080395102500916e-09
