# Becoming a backprop ninja

This is for the 'part 4 - becoming a backprop ninja' lecture. He took issue w/ us (i think he's feeling 'blindly') calling 'loss.backward' and so using PyTorch's autograd functionality to get our weights. He thinks it's important and useful for us to understand what's going on, as he writes about in https://karpathy.medium.com/yes-you-should-understand-backprop-e2f06eab496b. 

We did do micrograd already, but micrograd only thinks about scalars. This lecture is about tensors.

Historically, interesting to know that back in just 2012, people wrote their backward pass by hand (or used other algorithms than backprop entirely), while now everyone just calls loss.backward(), and 'we've lost something'. They'd use Matlab! (Since it had a convenient tensor class.)

# Overall plan

We'll do the same multilayer perceptron network, and same training loop, but we'll do the backprop by hand.

# Setup

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
len(words)

32033

In [4]:
chars = sorted(list(set(''.join(words))))

stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}

vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [5]:
# build the dataset
block_size = 3 # context length, how many chars to predict next char?

def build_dataset(words):
    X, Y = [], []
    
    for w in words:
        context = [0] * block_size
        # training row for each set of block_size chars (with '.' at end)
        for ch in w + '.':
            ix = stoi[ch] # 'next' char, after context
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # new context shifts right and adds prev 'next' char
            
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words)) # 80% for train
n2 = int(0.9*len(words)) # 10% for dev/validation and test

Xtr,  Ytr = build_dataset(words[:n1])    # 80%
Xdev, Ydev = build_dataset(words[n1:n2]) # 10%
Xte,  Yte = build_dataset(words[n2:])    # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [6]:
# new stuff

In [7]:
# util function to use later to compare manual gradients to PyTorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}') 

In [8]:
n_embd = 10 # dimensionality of the character vectors
n_hidden = 64 # neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd),             generator=g)
# Layer one
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # doesn't matter I think because of batch norm
# Layer two
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm params
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# note: many of the params are initialized in non-standard ways beacuse
# sometimes initializing with e.g. all zeros could mask an incorrect
# impl of the backward pass (because the mult by 0 can simplify)

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # total # of params
for p in parameters:
    p.requires_grad = True

4137


In [9]:
batch_size = 32
n = batch_size # a shorter var also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X, y

In [10]:
# forward pass, manually, 'chunkated' into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed chars into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# linear layer one
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # Bessel's correction, dividing by n-1 not n
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# non-linearity
h = torch.tanh(hpreact) # hidden layer
# linear layer two
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if use (1.0 / counts_sum) can't get backprop to be 'a bit exact'
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
          embcat, emb]:
    # we'll retain the gradient from pytorch so we can compare/check our work
    t.retain_grad()
loss.backward()
loss

tensor(3.3363, grad_fn=<NegBackward0>)

Terminology wise, just to remember... we take the derivative of something with respect to something. For example, we take the derivative of the loss (function) with respect to the logprobs tensor - we get a 'dlogprobs' var with the result. This number tells us how changing - increasing/decreasing - the contents of the values in logprobs by a tiny bit changes the value of the loss function (i.e., it tells us 'at what rate does the loss change as logprobs changes'). 

Also, when we have more than one variable, then if I remember correctly we're actually taking the _partial_ derivative of something, while holding the other variables constant. (And since the other variables are constant, they fall out of the derivative, since the derivative of a constant is 0?)

# Exercise one: backprop through everything manually 

We start with 'what do we need to do to calculate the gradient of the loss with respect to all of the elements of the logprobs tensor'? This gives us dlogprobs.

In [12]:
logprobs.shape

torch.Size([32, 27])

logprobs is a 32, 27 tensor, so dlogprobs will be the same size, since each element of logprobs has its own derivative.

In [15]:
print(len(Yb))
Yb # an array of all of the correct indices

32


tensor([ 8, 14, 15, 22,  0, 19,  9, 14,  5,  1, 20,  3,  8, 14, 12,  0, 11,  0,
        26,  9, 25,  0,  1,  1,  7, 18,  9,  3,  5,  9,  0, 18])

logprobs has a row for each of the 32 characters in Yb (which are the correct next character), where each row has 27 values, each of which is the logarithm of a probability. The '-logprobs[range(n), Yb].mean()' plucks out the log probability of each correct character and takes the mean of the 32 log probabilities, and the negative of the mean is the loss. 

In [16]:
logprobs[range(n), Yb]

tensor([-3.9821, -3.0808, -3.6683, -3.2411, -4.0954, -3.5242, -3.1875, -4.1198,
        -3.1483, -4.2542, -3.1482, -1.6404, -2.7770, -2.9951, -2.9868, -3.1542,
        -3.7326, -3.0231, -3.6101, -3.4038, -2.8735, -3.0179, -4.3650, -4.0523,
        -3.3816, -2.8623, -2.9835, -3.9094, -2.6846, -3.3958, -3.2910, -3.1718],
       grad_fn=<IndexBackward0>)

In [None]:
# we have 32 numbers, but simplified, it's 
# loss = -(a + b + c) / 3
# loss = -1/3a + -1/3b + -1/3c
# dloss/da = -1/3

# generally, when we have n (like 32), then
# dloss/da = -1/n

There are other numbers in the logprobs tensor - other than the 32 we're showing/talking about above, but these don't matter - in the forward pass, they're ignored because we only pluck out the single logprobs value (from the 27) for each element in the minibatch. Conceptually, again, what a derivative is saying is 'if i ever so slightly tweak the number associated with this variable - logprobs here - then how does the (loss) function change its value?' Things that don't affect the loss function don't change its value, so they can't be in the gradient - that is, they have a gradient/derivative of zero.

You can type in 'd/dx log(x)' to Wolfram Alpha and it'll tell you the result, which is 1/x.

In [19]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n

# (1/probs) is the local derivative, *dlogprobs is the chain rule
dprobs = (1.0 / probs) * dlogprobs


cmp('logprobs', dlogprobs, logprobs)
cmp('dprobs', dprobs, probs)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
dprobs          | exact: True  | approximate: True  | maxdiff: 0.0


In [18]:
dlogprobs

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.0312,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0312,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0312,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000

I got here, with a discussion of the intuitive effect of the dprobs calc, at ~20m into the video. At least for now I'm not going to type/enter everything... I may come back and do it after watching the video, but w/ the specific content in this lecture - the math/going deeper into the guts - while I'm curious to hear what he wants to say and emphasize, I'm not as interested in typing out every last thing/reinforcing it that way. Or at least I'll tell myself that because I don't want to do that particular work right now.

Generally, each step/line in the forward pass getting the derivative is understanding the actual 'local' derivative (which you can remember or lookup via something like Wolfram Alpha), AND keeping in mind additional things, like:

- remember to apply the chain rule
- when you have multiple places/calcs where a given value is used, you sum the effects of the nodes; this is also discussed in the micrograd lecture - he also calls this a 'routing' function
- understanding how broadcasting works is important so you can keep the dimensions of the item and its derivative the same/lined up, and also because in some cases (I think he's saying) the broadcasting kind of acts like an additional calc that you need to do the derivative of 
- he has a cool, bit-by-bit demonstration of how he figures out the derivative of a matrix multiply with bias column, on paper (that he then scanned in) starting around 40m or a bit after
- there's a part around 1hr where he's doing the derivative of a bias variable and has to talk about broadcasting and summing the gradients in a particular direction/orientation, which I think is needed because he implemented the network/did the forward pass in this case with what I think is a (1,n) tensor and he talks about how in a diff part of the same network he defined a bias var as just a vector; they're both holding the same amount of information but the math operations are slightly different and we can see how specifically as he defines the derivatives

I got to 59:30 at this point, finishing the derivatives for bngain, bnbias, bnraw.