In [50]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math

In [3]:
words = open('names.txt', 'r').read().splitlines()

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)} #mapping of a character to an index
stoi['.'] = 0 # replacing special end and start characters with a single . representing the space between words
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)

In [5]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?
def build_dataset (words) :
    X, Y = [],[]
    for w in words:
        context = [0] * block_size
        for ch in w+ '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append
    X = torch. tensor (X)
    Y = torch. tensor (Y)
    print(X. shape, Y.shape)
    return X, Y

import random 
random.seed(42)
random.shuffle(words)

n1 = int(0.8*len(words) )
n2 = int (0.9*len(words) )
Xtr, Ytr = build_dataset (words[:n1])
Xdev, Ydev = build_dataset (words[n1:n2])
Xte, Yte = build_dataset(words[n2:])


torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [6]:
#function that checks our backprop implementation to pytorch 
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt-t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff:{maxdiff}')

In [None]:
n_emb = 10
n_hidden = 64

C = torch.randn((vocab_size,n_emb)) 

W1 = torch.randn((n_emb*block_size,n_hidden)) * (5/3)/(n_emb*block_size)**0.5 
b1 = torch.randn(n_hidden) * 0.1 # just for funsies to see the gradients

W2 = torch.randn((n_hidden,vocab_size)) * 0.1 
b2 = torch.randn(vocab_size) * 0.1 
# initializing with non-zero parameters helps debug incorrect implementation of backprop
bngain = torch.rand((1, n_hidden))*0.1 + 1
bnbias = torch.rand((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]

for p in parameters:
    p.requires_grad = True

print(sum(p.nelement() for p in parameters))

4137


In [8]:
#minibatch
batch_size = 32
n = batch_size

ix = torch.randint(0,Xtr.shape[0],(batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix]

In [99]:
#forward pass
emb = C[Xb] 
embcat = emb.view(emb.shape[0],-1) 
pre_bn = embcat @ W1 + b1 

bnmean = pre_bn.sum(0, keepdim=True)/n
bndiff = pre_bn - bnmean
bndiff2 = bndiff**2
bnvar = bndiff2.sum(0, keepdim=True)/(n-1)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
pre_act = bngain*bnraw + bnbias 

h= torch.tanh(pre_act)

logits = h @ W2 + b2

logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts*counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts_sum_inv, counts_sum,counts,norm_logits,logit_maxes,
          logits,h, pre_act, bnraw, bnvar, bnvar_inv, bndiff, bndiff2, bnmean,pre_bn, emb, embcat]:
    t.retain_grad()

loss.backward()
loss

tensor(3.5622, grad_fn=<NegBackward0>)

In [87]:
pre_bn.shape

torch.Size([32, 64])

Backprop by-hand implementation 

In [57]:
#starting with the derivative of all variables
#dlogprorbs is defined as the derivative of the loss wrt all elements of logprobs which has 32x27 elements
# loss = -logprobs[range(n), Yb].mean() - iterating over n rows and choosing number at column Yb
# loss = -(a + b + c + ...)/n 
# dloss/dlogprobs[a] = -1/n
dlogprobs = torch.zeros_like(logprobs) 
dlogprobs[range(n),Yb] = -1.0/n
cmp('logprobs', dlogprobs, logprobs)

logprobs        | exact: True  | approximate: True  | maxdiff:0.0


In [78]:
bngain.shape, bnraw.shape

(torch.Size([1, 64]), torch.Size([32, 64]))

In [98]:
pre_bn.grad

  pre_bn.grad


In [106]:
dprobs = dlogprobs * 1/probs
dcounts_sum_inv = (counts*dprobs).sum(1, keepdim=True) # sum due to broadcasting - each row is multiplied by the same counts
dcounts_sum = -1 * counts_sum**-2 * dcounts_sum_inv
dcounts = torch.ones_like(counts)*dcounts_sum + counts_sum_inv * dprobs# branch of counts is used twice, in probs and counts_sum
dnorm_logits = counts*dcounts
dlogit_maxes = (-1 * torch.ones_like(logit_maxes) * dnorm_logits).sum(1, keepdim=True)  # sum for braodcasting, dlogit_maxes and logit_maxers have to have the same dimention
dlogits = (torch.ones_like(logits) * dnorm_logits) + F.one_hot(logits.max(1).indices, num_classes=logits.shape[1])*dlogit_maxes # max chooses the max number in each column, that index will be 1 and rest is zero - one_hot
dW2 = h.T @ dlogits
db2 = (torch.ones_like(b2)*dlogits).sum(0, keepdim=True)
dh = dlogits @ W2.T

dpre_act = (1-(h)**2) *dh#(1-(((2*pre_act).exp()-1)/((2*pre_act).exp()+1))**2) *dh
dbnraw = (bngain * dpre_act)
dbngain = (bnraw * dpre_act).sum(0, keepdim=True)
dbnbias = (dpre_act * torch.ones_like(bnbias)).sum(0, keepdim=True)

dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar=  -0.5 * (bnvar + 1e-5)**-1.5 * dbnvar_inv
dbndiff2 = dbnvar.sum(0, keepdim=True)/(n-1)
dbndiff = bnvar_inv * dbnraw + 2*bndiff*dbndiff2

dbnmean =( -torch.ones_like(bndiff) * dbndiff).sum(0, keepdim=True)
dpre_bn = torch.ones_like(pre_bn) * dbndiff + (torch.ones_like(pre_bn)/n * dbnmean)


dembcat = dpre_bn @ W1.T
dW1 = embcat.T @ dpre_bn
db1 = (dpre_bn * torch.ones_like(b1)).sum(0, keepdim=True)
demb = dembcat.view(emb.shape)# re-represent the shape of the original tensor before concatanation
dC = torch. zeros_like(C)
for k in range(Xb. shape [0]): 
    for j in range (Xb.shape [1]):
        ix = Xb[k,j]
        dC [ix] += demb[k,j]
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('pre_act', dpre_act, pre_act)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmean', dbnmean, bnmean)
cmp('hprebn', dpre_bn, pre_bn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

probs           | exact: True  | approximate: True  | maxdiff:0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff:0.0
counts_sum      | exact: True  | approximate: True  | maxdiff:0.0
counts          | exact: True  | approximate: True  | maxdiff:0.0
norm_logits     | exact: True  | approximate: True  | maxdiff:0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff:0.0
logits          | exact: True  | approximate: True  | maxdiff:0.0
h               | exact: True  | approximate: True  | maxdiff:0.0
W2              | exact: True  | approximate: True  | maxdiff:0.0
b2              | exact: True  | approximate: True  | maxdiff:0.0
pre_act         | exact: True  | approximate: True  | maxdiff:0.0
bngain          | exact: True  | approximate: True  | maxdiff:0.0
bnbias          | exact: True  | approximate: True  | maxdiff:0.0
bnraw           | exact: True  | approximate: True  | maxdiff:0.0
bnvar_inv       | exact: True  | approximate: True  | maxdiff:0.0
bnvar     

In [107]:
# Exercise 2: backprop through cross_entropy but all in one go
# to complete this challenge look at the mathematical expression of the loss,
# take the derivative, simplify the expression, and just write it out

# forward pass

# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()

# now:
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())

3.5622336864471436 diff: 2.384185791015625e-07


In [None]:
# backward pass

# -----------------
# YOUR CODE HERE :)
dlogits = None # TODO. my solution is 3 lines
# -----------------

#cmp('logits', dlogits, logits) # I can only get approximate to be true, my maxdiff is 6e-9