# Backpropagation at the backend

# Part 1: Setup boilerplate code

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Get the dataset
!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt

--2024-09-18 22:31:03--  https://raw.githubusercontent.com/karpathy/makemore/master/names.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 228145 (223K) [text/plain]
Saving to: ‘names.txt.2’


2024-09-18 22:31:03 (28.2 MB/s) - ‘names.txt.2’ saved [228145/228145]



In [3]:
# Read all the dataset and show first 8 names
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [4]:
# Building position to character mapping
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [5]:
# Building the dataset with context length
block_size = 3 # character context for next character prediction, so last 3 character is used to predict next character

def build_dataset(words):
  X, Y = [], []
  for w in words:
    context = [0] * block_size # Initially context start with [0,0,0] which is [...]
    for ch in w + '.': # . is the ending character for each word
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # Rolling - the first character of context is removed and the next character is added from word [...]->[..e]
  X = torch.tensor(X) # context for train [[...], [..e], [.em]], obviously X contains pos not characters
  Y = torch.tensor(Y) # next character pred [e, m, m], obviously Y contains pos not characters
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1]) # Train split - 80%
Xdev, Ydev = build_dataset(words[n1:n2]) # Dev/Val split - 10%
Xte, Yte = build_dataset(words[n2:]) # Test split - 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [6]:
# Utility function to compare out gradient calculation to Pytorch implementation - Unit test
def test_gradients_cmp(s, dt, t):
  # s - spaces needed, dt - Our grad implementation, t - Pytorch grad implementation
  ex = torch.all(dt == t.grad).item() # Check if all values match
  app = torch.allclose(dt, t.grad) # Check if values approximately match
  maxdiff = (dt - t.grad).abs().max().item() # What is the max diff of any element between the implementation
  print(f'{s:15s} | Exact: {str(ex):5s} | Approx: {str(app):5s} | Maxdiff: {maxdiff}')

In [7]:
# Lets define the model layers in torch
n_embed = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((vocab_size, n_embed), generator=g)

# Layer 1
W1 = torch.randn((n_embed * block_size, n_hidden), generator=g) * (5/3)/((n_embed * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1

# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

# Batch Normalization parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

# Note: A lot of non standard initialization since all zero initialization may mask incorrect backward pass implementation
# List of parameters to optimize
parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # Number of paramters in total
for p in parameters:
  p.requires_grad = True

4137


In [8]:
# Creating our single iteration mini batch
batch_size = 32
n = batch_size # Shorter variable name for convinence
# Constructing a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # Batch of inputs and targets

In [9]:
# Forward pass "chunkated" so that our manual implementation of backward pass can be easily done one at a time
emb = C[Xb] # Get the embedding for entire minibatch
embcat = emb.view(emb.shape[0], -1) # Flatten the embedding, last 2 dimensions of embedding is flattened so its (train_size, context_length)

# Linear Layer 1
hprebn = embcat @ W1 + b1 # Calculate the hidden layer pre activation

# BatchNorm Layer gamma * (x - mean)/sqrt(std^2+eps) + beta, gamma = gain and beta = bias
bnmean = 1/n*hprebn.sum(0, keepdim=True) # Get the mean
bndiff = hprebn - bnmean
bndiff2 = bndiff**2
bnvar = 1/(n-1)*bndiff2.sum(0, keepdim=True) # Note: Bessel's correction (dividing by n-1, not n) and std^2 is var
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non Linearity
h = torch.tanh(hpreact) # Calculate the hidden layer activation

# Linear Layer 2
logits = h @ W2 + b2 # Calculate the output layer pre activation

# Loss calculation
# For numerical stability we subtract the max logit value, basically normalizing logits
logits_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logits_maxes
# Getting loss
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean() # Taking the mean of all log probs and negating it

# Backward Pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum_inv, counts_sum, # afaik there is no cleaner way
          norm_logits, logits_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, bnmean, hprebn,
          embcat, emb]:
          t.retain_grad()
loss.backward()
loss

tensor(3.3480, grad_fn=<NegBackward0>)

# Part 2: Backpropagation Manual Implementation

In [10]:
# Loss Backprop
dlogprobs = torch.zeros_like(logprobs) # dloss/dlogprobs
dlogprobs[range(n), Yb] = -1.0/n

dprobs = 1.0/probs * dlogprobs # dlogprobs/dprobs, dlogprobs comes from chain rule
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True) # dprobs/dcounts_sum_inv, dprobs comes from chain rule, sum(1,keepdim_true), since counts * counts_sum_inv have shape (32,27) and (32,1) so to take of broadcasting
dcounts = counts_sum_inv * dprobs # dprobs/dcounts, dcounts comes from chain rule
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv # dcounts_sum_inv/dcounts_sum, dcounts_sum_inv comes from chain rule
dcounts += torch.ones_like(counts) * dcounts_sum # dcounts_sum/dcounts, dcounts_sum comes from chain rule, since this is row wise addition the gradient just passes equally to all summing elements, so counts size and grad was 1
dnorm_logits = counts * dcounts # dcounts/dnorm_logits, dcounts comes from chain rule
dlogits = 1.0 * dnorm_logits.clone() # dnorm_logits/dlogits, dnorm_logits comes from chain rule
dlogit_maxes = (-1.0 * dnorm_logits).sum(1, keepdim=True) # dnorm_logits/dlogit_maxes, dnorm_logits comes from chain rule, since dlogit_maxes is 32,1 dim
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes # dlogit_maxes/dlogits, dlogit_maxes comes from chain rule, we get index of max logit and pass through one hot and product with dlogit_max means the max logit will be impact by gradient change

test_gradients_cmp('logprobs', dlogprobs, logprobs)
test_gradients_cmp('probs', dprobs, probs)
test_gradients_cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
test_gradients_cmp('counts_sum', dcounts_sum, counts_sum)
test_gradients_cmp('counts', dcounts, counts)
test_gradients_cmp('norm_logits', dnorm_logits, norm_logits)
test_gradients_cmp('logit_maxes', dlogit_maxes, logits_maxes)
test_gradients_cmp('logits', dlogits, logits)

# Linear Layer 2 Backprop
dh = dlogits @ W2.T # dlogits/dh, dlogits comes from chain rule
dW2 = h.T @ dlogits # dlogits/dW2, dlogits comes from chain rule
db2 = dlogits.sum(0) # dlogits/db2, dlogits comes from chain rule

test_gradients_cmp('h', dh, h)
test_gradients_cmp('W2', dW2, W2)
test_gradients_cmp('b2', db2, b2)

# Non Linearity Backprop
dhpreact = (1.0 - h**2) * dh # dh/dhpreact, dh comes from chain rule

test_gradients_cmp('hpreact', dhpreact, hpreact)

# Batch Norm Backprop
dbngain = (bnraw * dhpreact).sum(0, keepdim=True) # dhpreact/dbngain, dhpreact comes from chain rule, again broadcast bngain - (1,64), [bnraw - (32,64), bnbias - (1,64)], * - element wise multiply
dbnraw = bngain * dhpreact # dhpreact/dbnraw, dhpreact comes from chain rule
dbnbias = dhpreact.sum(0, keepdim=True) # dhpreact/dbnbias, dhpreact comes from chain rule
dbndiff = bnvar_inv * dbnraw # dbnraw/dbdiff, dbnraw comes from chain rule
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True) # dbnraw/dbvar_inv, dbnraw comes from chain rule, bnvar_inv - (1,64)
dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv # dbnvar_inv/dbnvar, dbnvar_inv comes from chain rule
dbndiff2 = 1/(n-1) * torch.ones_like(bndiff2) * dbnvar # dbnvar/dbdiff2, dbnvar comes from chain rule
dbndiff += (2*bndiff) * dbndiff2 # dbndiff2/dbdiff, dbndiff2 comes from chain rule
dbnmean = -1.0*dbndiff.sum(0) # dbndiff/dbnmean, dbndiff comes from chain rule, bnmean - (1,64)
dhprebn = 1.0*dbndiff.clone() # dbndiff/dhprebn, dbndiff comes from chain rule
dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmean) # dbnmean/dhprebn, dbnmean comes from chain rule, hprebn - (32,64)

test_gradients_cmp('bngain', dbngain, bngain)
test_gradients_cmp('bnraw', dbnraw, bnraw)
test_gradients_cmp('bnbias', dbnbias, bnbias)
test_gradients_cmp('bndiff', dbndiff, bndiff)
test_gradients_cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
test_gradients_cmp('bnvar', dbnvar, bnvar)
test_gradients_cmp('bndiff2', dbndiff2, bndiff2)
test_gradients_cmp('bnmean', dbnmean, bnmean)
test_gradients_cmp('hprebn', dhprebn, hprebn)

# Linear Layer 1 Backprop
dembcat = dhprebn @ W1.T # dhprebn/dembcat, dhprebn comes from chain rule
dW1 = embcat.T @ dhprebn # dhprebn/dW1, dhprebn comes from chain rule
db1 = dhprebn.sum(0) # dhprebn/db1, dhprebn comes from chain rule

test_gradients_cmp('embcat', dembcat, embcat)
test_gradients_cmp('W1', dW1, W1)
test_gradients_cmp('b1', db1, b1)

# Embedding Backprop
demb = dembcat.view(emb.shape[0], emb.shape[1], -1) # dembcat/demb, dembcat comes from chain rule emb - (32,3,10), embcat (32,30)
dC = torch.zeros_like(C) # demb/dC, demb comes from chain rule
for k in range(Xb.shape[0]): # Xb contains all the minibatch train indices
  for j in range(Xb.shape[1]):
    ix = Xb[k,j]
    dC[ix] += demb[k,j]

test_gradients_cmp('emb', demb, emb)
test_gradients_cmp('C', dC, C)

logprobs        | Exact: True  | Approx: True  | Maxdiff: 0.0
probs           | Exact: True  | Approx: True  | Maxdiff: 0.0
counts_sum_inv  | Exact: True  | Approx: True  | Maxdiff: 0.0
counts_sum      | Exact: True  | Approx: True  | Maxdiff: 0.0
counts          | Exact: True  | Approx: True  | Maxdiff: 0.0
norm_logits     | Exact: True  | Approx: True  | Maxdiff: 0.0
logit_maxes     | Exact: True  | Approx: True  | Maxdiff: 0.0
logits          | Exact: True  | Approx: True  | Maxdiff: 0.0
h               | Exact: True  | Approx: True  | Maxdiff: 0.0
W2              | Exact: True  | Approx: True  | Maxdiff: 0.0
b2              | Exact: True  | Approx: True  | Maxdiff: 0.0
hpreact         | Exact: False | Approx: True  | Maxdiff: 4.656612873077393e-10
bngain          | Exact: False | Approx: True  | Maxdiff: 3.725290298461914e-09
bnraw           | Exact: False | Approx: True  | Maxdiff: 4.656612873077393e-10
bnbias          | Exact: False | Approx: True  | Maxdiff: 3.725290298461914e-0

# Part 3: Cross Entropy and Batch Norm One Go Backpropagation Manual Implementation

In [11]:
# Cross Entropy short backprop implementation
dlogits = F.softmax(logits, 1) # Getting softmax for each train example across row axis, which is like probability
dlogits[range(n), Yb] -= 1 # Now grad is prob_i if i!=y  and prob_i-1 if i==y
dlogits /= n # Now propagate 1/n (averageness) aross all

test_gradients_cmp('logits', dlogits, logits)

# Batch Norm short backprop implementation
dhprebn = bngain * bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))

test_gradients_cmp('hprebn', dhprebn, hprebn)

logits          | Exact: False | Approx: True  | Maxdiff: 7.2177499532699585e-09
hprebn          | Exact: False | Approx: True  | Maxdiff: 9.313225746154785e-10


# Part 4: Bringing it all together

In [18]:
# Setting up the model
n_embed = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((vocab_size, n_embed), generator=g)

# Layer 1
W1 = torch.randn((n_embed * block_size, n_hidden), generator=g) * (5/3)/((n_embed * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1

# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

# Batch Normalization parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

# Note: A lot of non standard initialization since all zero initialization may mask incorrect backward pass implementation
# List of parameters to optimize
parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # Number of paramters in total
for p in parameters:
  p.requires_grad = True

# Training and Optimization
max_steps = 200000
batch_size = 32
n = batch_size
lossi = [] # Stores the loss after every step

with torch.no_grad():
  for i in range(max_steps):

    # Construct a minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix]

    # Forward pass
    emb = C[Xb] # Get the embedding for entire minibatch
    embcat = emb.view(emb.shape[0], -1) # Flatten the embedding, last 2 dimensions of embedding is flattened so its (train_size, context_length)

    # Linear Layer
    hprebn = embcat @ W1 + b1 # Calculate the hidden layer pre activation

    # BatchNorm Layer
    bnmean = hprebn.mean(0, keepdim=True) # Get the mean
    bnvar = hprebn.var(0, keepdim=True, unbiased=True) # Get the variance
    bnvar_inv = (bnvar + 1e-5)**-0.5
    bnraw = (hprebn - bnmean) * bnvar_inv
    hpreact = bngain * bnraw + bnbias

    # Non Linearity
    h = torch.tanh(hpreact) # Calculate the hidden layer activation

    # Linear Layer
    logits = h @ W2 + b2 # Calculate the output layer preactivation

    # Loss
    loss = F.cross_entropy(logits, Yb)

    # Backward Pass
    for p in parameters:
      p.grad = None
    # loss.backward()

    # Manual Backpropagation
    # Loss Backprop
    dlogits = F.softmax(logits, 1)
    dlogits[range(n), Yb] -= 1
    dlogits /= n

    # Linear Layer 2 Backprop
    dh = dlogits @ W2.T
    dW2 = h.T @ dlogits
    db2 = dlogits.sum(0)

    # Non Linearity Backprop
    dhpreact = (1.0 - h**2) * dh

    # Batch Norm Backprop
    dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
    dbnbias = dhpreact.sum(0, keepdim=True)
    dhprebn = bngain * bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))

    # Linear Layer 1 Backprop
    dembcat = dhprebn @ W1.T
    dW1 = embcat.T @ dhprebn
    db1 = dhprebn.sum(0)

    # Embedding Backprop
    demb = dembcat.view(emb.shape[0], emb.shape[1], -1)
    dC = torch.zeros_like(C)
    for k in range(Xb.shape[0]):
      for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]

    grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]

    # Update
    lr = 0.1 if i < 100000 else 0.01 # Learning rate decay
    for p,grad in zip(parameters,grads):
      # p.data += -lr * p.grad
      p.data += -lr * grad

    # Track stats
    if i % 10000 == 0: # Print every once every 10k steps
      print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())

    # if i >= 100:
    #   break # After Debug: Remove this

12297
      0/ 200000: 3.7880
  10000/ 200000: 2.1641
  20000/ 200000: 2.3662
  30000/ 200000: 2.4217
  40000/ 200000: 1.9684
  50000/ 200000: 2.4886
  60000/ 200000: 2.4104
  70000/ 200000: 2.0820
  80000/ 200000: 2.4035
  90000/ 200000: 2.1479
 100000/ 200000: 1.9805
 110000/ 200000: 2.2685
 120000/ 200000: 1.9944
 130000/ 200000: 2.3934
 140000/ 200000: 2.2971
 150000/ 200000: 2.2150
 160000/ 200000: 1.9154
 170000/ 200000: 1.9014
 180000/ 200000: 2.0973
 190000/ 200000: 1.9360


In [16]:
# Compare the gradients to torch grads [Dont run this after debug]
for p, g in zip(parameters, grads):
  test_gradients_cmp(str(tuple(p.shape)), g, p)

(27, 10)        | Exact: False | Approx: True  | Maxdiff: 1.1175870895385742e-08
(30, 200)       | Exact: False | Approx: True  | Maxdiff: 9.313225746154785e-09
(200,)          | Exact: False | Approx: True  | Maxdiff: 3.725290298461914e-09
(200, 27)       | Exact: False | Approx: True  | Maxdiff: 1.4901161193847656e-08
(27,)           | Exact: False | Approx: True  | Maxdiff: 7.450580596923828e-09
(1, 200)        | Exact: False | Approx: True  | Maxdiff: 2.7939677238464355e-09
(1, 200)        | Exact: False | Approx: True  | Maxdiff: 7.450580596923828e-09


In [19]:
# Calaibrate batch norm paramters at end of training
with torch.no_grad():
  # Pass training set through
  emb = C[Xtr]
  embcat = emb.view(emb.shape[0], -1)
  hpreact = embcat @ W1
  # Measure the mean/std over the entire training set
  bnmean = hpreact.mean(0, keepdim=True)
  bnstd = hpreact.var(0, keepdim=True, unbiased=True)

In [20]:
# Calculate loss for different split after tarining
@torch.no_grad() # this decorator disables gradient tracking
def split_loss(split):
  x,y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]
  emb = C[x] # (split_size, block_size, n_embd)
  embcat = emb.view(emb.shape[0], -1) # concat into (split_size, block_size * n_embd)
  hpreact = embcat @ W1 + b1 # (split_size, n_hidden)
  hpreact = bngain * (hpreact - bnmean) / bnstd + bnbias
  h = torch.tanh(hpreact) # (split_size, n_hidden)
  logits = h @ W2 + b2 # (split_size, vocab_size)
  loss = F.cross_entropy(logits, y) # BTW, negative log likelihood of softmax is cross entropy
  print(split, loss.item())

split_loss('train')
split_loss('val')
split_loss('test')

train 2.376091241836548
val 2.3904130458831787
test 2.3936409950256348


In [None]:
# Make predictions from the model
g = torch.Generator().manual_seed(21474836)
for _ in range(20):

    out = []
    context = [0] * block_size # initialize with all ...
    while True:
      emb = C[torch.tensor([context])] # (1,block_size,d) -> (1,3,10)
      h = torch.tanh(emb.view(1, -1) @ W1 + b1) # There is only example so 1, and (1,-1) can be writen as (1,30)
      logits = h @ W2 + b2
      probs = F.softmax(logits, dim=1)
      ix = torch.multinomial(probs, num_samples=1, generator=g).item()
      context = context[1:] + [ix]
      out.append(ix)
      if ix == 0:
        break

    print(''.join(itos[i] for i in out))