Context:

In previous notebook (4-activations...ipynb), we utilized PyTorch's autograd(loss.backwards()) for backpropogation. It's bad to use autograd from frameworks without learning it's internals, becuase we won't know why it's performing well or not. We've implemented our own backpropogation for scalars in micrograd but implementing backpropogation instead of frameworks autograd will help to improve debugging neural nets.
As we'll learn the internals of backpropgation it will help more on our undersanding.

# Makemore 5: Becoming a backprop ninja

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as ply
%matplotlib inline

In [2]:
# read in all the words
words = open("names.txt").read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build dataset
block_size = 3 # contet length: how many characters do we take to predict the next one?

def build_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)

    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
# boilerplate done,to the action

In [6]:
# Utility function to compare manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [7]:
n_embed = 10 # dimensionality of character embedding vectors
n_hidden = 64 # number of neurons in hidden layer of MLP
torch_seed = 2147483647

g = torch.Generator().manual_seed(torch_seed) # for reproducability
C = torch.randn(vocab_size, n_embed)

# Layer 1
W1 = torch.randn((n_embed * block_size, n_hidden), generator=g) * (5/3)/((n_embed * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1 # just for understanding, useless because of batch normalization
#Layer 2 
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
# Batch norm paramters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1 

# Instead of zeros, retaining a samll number, 
# because sometimes initializing with all zeros could mask an incorrect implementation of backward pass

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4137


In [8]:
batch_size = 32
n = batch_size # shorter variable for conveniance
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [9]:
# Forward pass, "chunkated" into smaller steps that are possible to backward one at a time
emb = C[Xb] # embed chars into vectos
embcat = emb.view(emb.shape[0], -1) # concatenat the vectors

# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation

# BatchNorm layer
bnmeani = 1 / n*hprebn.sum(0, keepdim=True) # equivalvelnt of torch.mean(0, keepdim=True)

# hprebn - hprebn_mean
bndiff = hprebn - bnmeani

# Variance - average squared deviations from mean
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction dividing by n-1 not n
bnvar_inv = (bnvar + 1e-5)**-0.5 # 1 /square roor -> -0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non-linearity
h = torch.tanh(hpreact) # hidden layer

# Linear layer 2
logits = h @ W2 + b2 # output layer

# Cross entropy loss (same as F.cross_entropy(logits))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # Subrac max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          norm_logits, logit_maxes, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.5082, grad_fn=<NegBackward0>)

## Excercis 1: Backpropogating atomic compute graph

### dlogprobs

In [10]:
# dlogprobs
# dlogprobs is logprobs derivate with respect to loss
# how loss is influenced by dlogprobs

In [11]:
logprobs.shape, Yb.shape

(torch.Size([32, 27]), torch.Size([32]))

In [12]:
# Plucking out correct index for character out of (27) for each input in the batch based on Yb(correct index)
# doing a mean of these values and negative
logprobs[range(n), Yb]

tensor([-3.5725, -3.6400, -4.4233, -3.3555, -3.9925, -4.2959, -2.5485, -3.5424,
        -2.6516, -2.5485, -3.9843, -4.1128, -3.5419, -4.1123, -4.1636, -4.2203,
        -3.3555, -3.3250, -4.2712, -3.8505, -2.5782, -2.6619, -3.2985, -4.3270,
        -3.1013, -4.2558, -3.1564, -2.8716, -3.6606, -1.8493, -3.1537, -3.8404],
       grad_fn=<IndexBackward0>)

In [13]:
# loss = -(a + b + c)/3
# We've 32 characters so which is batch
# loss = -(a + b + ....)/32
# loss = -a/32 + -b/32 +......
# dloss/da = -1/32
# -1/n

The derivative of logprobs where indexes are plucked out is -1/n. What about the other indexes which are not plucked out. Since they don't participate in loss. The derivative of those indices will be zero.

In [14]:
dlogprobs = torch.zeros_like(logprobs)
# Setting those indices to 1/n
dlogprobs[range(n), Yb] = -1.0/n
cmp('logprobs', dlogprobs, logprobs)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0


### probs

In [15]:
# probs
# how probs is affecting logprobs
# logprobs is log of probs

In [16]:
# logprobs = probs.log()
# dlogprobs/respect to probs == 1/probs ln(probs)
# ln(probs) = log_e(probs) where e = 2.71288
# torch.log() - uses natural log
# So dlogprobs/probs = 1 / probs * local_derivative(by chain rule
# dlogprobs/probs = 1 / probs * dlogporbs

Above derivative, i initially assume torch.log() is base 10 and concluded the derivative of torch.log(x) as 1 / x ln (10).
After reading [torch.log](https://pytorch.org/docs/stable/generated/torch.log.html#torch.log) the implementation itself is natural log. Derivate of log(x) will simply be 1/x

In [17]:
probs.shape

torch.Size([32, 27])

In [18]:
dprobs = 1 / probs * dlogprobs

In [19]:
dprobs.shape

torch.Size([32, 27])

In [20]:
cmp('dprobs', dprobs, probs)

dprobs          | exact: True  | approximate: True  | maxdiff: 0.0


### counts_sum_inv

In [21]:
# how probs is affected by counts_sum_inv
# probs = counts * counts_sum_inv
# dprobs / counts_sum_inv = counts * local_gradient

In [22]:
dcounts_sum_inv = counts * dprobs

In [23]:
counts.shape, counts_sum_inv.shape

(torch.Size([32, 27]), torch.Size([32, 1]))

If we look at the shapes, in forward pass implict tensor broadcasting is done by PyTorch to perform matrix multiplication.

In [24]:
counts_sum_inv.shape, dcounts_sum_inv.shape

(torch.Size([32, 1]), torch.Size([32, 27]))

In [25]:
counts_sum_inv.grad.shape

torch.Size([32, 1])

In [26]:
# Doing a sum at dim 1 to match shape to accomodate tensor broadcasting
dcounts_sum_inv = dcounts_sum_inv.sum(1, keepdims=True)
dcounts_sum_inv.shape

torch.Size([32, 1])

In [27]:
cmp('dcounts_sum_inv', dcounts_sum_inv, counts_sum_inv)

dcounts_sum_inv | exact: True  | approximate: True  | maxdiff: 0.0


### counts_sum

In [28]:
# counts_sum with respect to counts_sum_inv
# dcounts_sum_inv / dcounts_sum = ??
# counts_sum_inv = counts_sum**-1
# derivative of x**-1 -> -1/x**2
# dcounts_sum_inv / dcounts_sum = -1 / counts_sum ** 2 * dcounts_sum_inv

In [29]:
counts_sum.shape, counts_sum_inv.shape

(torch.Size([32, 1]), torch.Size([32, 1]))

In [30]:
dcounts_sum = (-1.0/counts_sum**2) * dcounts_sum_inv
dcounts_sum_inv.shape

torch.Size([32, 1])

Shapes hold good.

In [31]:
cmp('counts_sum', dcounts_sum, counts_sum)

counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0


### counts

In [71]:
# dcounts
# dcounts has two gradients as it influences probs and counts_sum
# we'll need dprobs/dcounts and dcounts_sum/dcounts
# dprobs/dcounts It's multiplication
# dprobs/dcounts = counts_sum_inv * local_gradient

In [72]:
dcounts = counts_sum_inv * dprobs

In [73]:
dcounts.shape, counts.shape

(torch.Size([32, 27]), torch.Size([32, 27]))

In [74]:
# dcounts_sum/dcounts
# counts_sum = counts.sum(1, keepdims=True)
# Derivative of addition is 1, so gradients just passes through
# to keep shapes, we'll create ones of counts shape and multiply local gradient with it

In [75]:
# += to add previous gradinet dprobs/dcount
dcounts1 = torch.ones_like(counts) * dcounts_sum

In [76]:
dcounts.shape, dcounts1.shape

(torch.Size([32, 27]), torch.Size([32, 27]))

In [77]:
dcounts = dcounts + dcounts1

In [78]:
cmp('dcounts', dcounts, counts)

dcounts         | exact: True  | approximate: True  | maxdiff: 0.0


### Overall

In [81]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = 1/probs * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True)
dcounts_sum = (-1.0/counts_sum**2) * dcounts_sum_inv
dcounts = counts_sum_inv * dprobs
dcounts += torch.ones_like(counts) * dcounts_sum

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
# cmp('norm_logits', dnorm_logits, norm_logits)
# cmp('logit_maxes', dlogit_maxes, logit_maxes)
# cmp('logits', dlogits, logits)
# cmp('h', dh, h)
# cmp('W2', dW2, W2)
# cmp('b2', db2, b2)
# cmp('hpreact', dhpreact, hpreact)
# cmp('bngain', dbngain, bngain)
# cmp('bnbias', dbnbias, bnbias)
# cmp('bnraw', dbnraw, bnraw)
# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
# cmp('bnvar', dbnvar, bnvar)
# cmp('bndiff2', dbndiff2, bndiff2)
# cmp('bndiff', dbndiff, bndiff)
# cmp('bnmeani', dbnmeani, bnmeani)
# cmp('hprebn', dhprebn, hprebn)
# cmp('embcat', dembcat, embcat)
# cmp('W1', dW1, W1)
# cmp('b1', db1, b1)
# cmp('emb', demb, emb)
# cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
