In [2]:
import torch
import torch.nn.functional as F
import matplotlib as plt
%matplotlib inline

In [3]:
words = open('names.txt', 'r').read().splitlines()

In [4]:
chars = sorted(list(set(''.join(words))))

In [5]:
stoi = { s:i+1 for i, s in enumerate(chars) }
stoi['.'] = 0
itos = { i:s for s, i in stoi.items() }

In [6]:
context_size = 3
# construct dataset
def construct_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * context_size
        for ch in w:
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)

            context = context[1:] + [ix]

    return torch.tensor(X), torch.tensor(Y)

In [7]:
# split the data into batches
n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

Xtr, Ytr = construct_dataset(words[:n1])
Xdev, Ydev = construct_dataset(words[n1:n2])
Xte, yte = construct_dataset(words[n2:])

In [8]:
# init
vocab_size = 27
n_embd = 10
n_hidden = 100 

In [9]:
# define parameterse
C = torch.randn((vocab_size, n_embd)) # 27 characters embedded into a 10 dimensional feature vector
W1 = torch.randn((context_size * n_embd, n_hidden)) * (5/3) / ((context_size * n_embd) ** 0.5)
b1 = torch.randn(n_hidden) * 0.1
W2 = torch.randn(n_hidden, vocab_size) * 0.1
b2 = torch.randn(vocab_size) * 0.1

bngain = torch.randn((1, n_hidden)) * 0.1
bnbias = torch.randn((1, n_hidden)) * 0.1 + 0.1

In [10]:
parameters = [C, W1, b1, W2, b2, bngain, bnbias]
sum(p.nelement() for p in parameters)
for p in parameters: 
    p.requires_grad = True

In [11]:
g = torch.Generator().manual_seed(20000100000)
# create a minibatch
batch_size = 32

ix = torch.randint(0, Xtr.shape[0], (batch_size, ), generator=g)
print(Ytr[ix].shape)
Xb, Yb = Xtr[ix], Ytr[ix]

torch.Size([32])


In [12]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

torch.Size([27, 10])


In [14]:

# Initialize training batch as a representation of C
emb = C[Xb]

embcat = emb.view(emb.shape[0], emb.shape[1] * emb.shape[2]) # [32, 3, 10] -> [32, 3] -> [32, 30]

# linear layer 01
hprebn = embcat @ W1 + b1
# batch normalization layer
bnmeani = hprebn.sum(0, keepdim=True) * 1 / batch_size
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvari = 1/(batch_size - 1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvari + 1e-5)**-0.5
bnnorm = bndiff * bnvar_inv

# bngain bnbias
hpreact = bnnorm * bngain + bnbias
# non linearity
h = torch.tanh(hpreact) # bring to -1 and 1
# linear layer 02
logits = h @ W2 + b2
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
# cross entropy implementaiton
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum ** -1

probs = counts * counts_sum_inv # multiply expontiated logit by corresponding inverse sum(each prob adjusted relative to total prob sum, so they sum up to one)
logprobs = probs.log()
loss = -logprobs[range(batch_size), Yb].mean()

for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact,bngain, bnnorm, bnvar_inv, bnvari, bndiff2, bndiff, bnmeani, embcat, emb, hprebn]:
    t.retain_grad()

loss.backward()
loss

tensor(3.3382, grad_fn=<NegBackward0>)

In [15]:
# Exercise 1: backprop through the whole thing manually,
# backpropagating through exactly all of the variables
# as they are defined in the forward pass above, one by one

# -----------------
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Yb] = -1.0 / batch_size
dprobs = (1 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts_sum = -(1/counts_sum**2) * dcounts_sum_inv
dcounts = counts_sum_inv * dprobs
dcounts += dcounts_sum
dnorm_logits = dcounts * counts
dlogit_maxes = -dnorm_logits.sum(1, keepdim=True)
dlogits = dnorm_logits.clone()
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1.0 - h**2) * dh
dbngain = (bnnorm * dhpreact).sum(0, keepdim=True)
dbnbias = dhpreact.sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvari = (-0.5 * (bnvari + 1e-5) ** -1.5) * dbnvar_inv
dbndiff2 = (1.0 / (batch_size-1) *torch.ones_like(bndiff2)) * dbnvari
dbndiff += 2*(bndiff) * dbndiff2
dbnmeani = (-dbndiff).sum(0)
dhprebn = dbndiff.clone()
dhprebn += 1.0 / batch_size * (torch.ones_like(hprebn) * dbnmeani)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
demb =  dembcat.view(emb.shape)
dC = torch.zeros_like(C)
print(Xb.shape)
for k in range(Xb.shape[0]):
    for i in range(Xb.shape[1]):
        ix = Xb[k, i] # index of embeddings
        dC[ix] += demb[k, i]
print(demb.shape)
cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnnorm)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvari, bnvari)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

torch.Size([32, 3])
torch.Size([32, 3, 10])
logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | 

In [16]:
# Exercise 2: backprop through cross_entropy but all in one go
# to complete this challenge look at the mathematical expression of the loss,
# take the derivative, simplify the expression, and just write it out

# forward pass

# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()

# now:
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())

3.3381776809692383 diff: 0.0


In [17]:
# backward pass

# -----------------
dlogits = F.softmax(logits, 1) # softmax along rows
dlogits[range(batch_size), Yb] -= 1
dlogits /= batch_size
# -----------------

cmp('logits', dlogits, logits)

#cmp('logits', dlogits, logits) # I can only get approximate to be true, my maxdiff is 6e-9

logits          | exact: False | approximate: True  | maxdiff: 6.170012056827545e-09


max diff: tensor(1.1921e-07, grad_fn=<MaxBackward1>)


In [230]:
# Exercise 3: backprop through batchnorm but all in one go
# to complete this challenge look at the mathematical expression of the output of batchnorm,
# take the derivative w.r.t. its input, simplify the expression, and just write it out
# BatchNorm paper: https://arxiv.org/abs/1502.03167

# forward pass

# before:
# bnmeani = 1/n*hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias

# now:
# print(hprebn.shape) # -> [32, 100]

hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias
print('max diff:', (hpreact_fast - hpreact).abs().max())

max diff: tensor(1.1921e-07, grad_fn=<MaxBackward1>)


In [236]:
# backward pass

# before we had:
# dbnraw = bngain * dhpreact
# dbndiff = bnvar_inv * dbnraw
# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv
# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar
# dbndiff += (2*bndiff) * dbndiff2
# dhprebn = dbndiff.clone()
# dbnmeani = (-dbndiff).sum(0)
# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)

# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)
# (you'll also need to use some of the variables from the forward pass up above)
 
# -----------------
# YOUR CODE HERE :)

dbnraw = (dhpreact - dhpreact.sum(0))
dbninv = ()
print(dbnraw)
# -----------------

cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10

tensor([[-0.0348,  0.0204,  0.0031,  ...,  0.0341,  0.0131,  0.0067],
        [-0.0330,  0.0123, -0.0030,  ...,  0.0315,  0.0124,  0.0087],
        [-0.0328,  0.0124, -0.0029,  ...,  0.0314,  0.0124,  0.0084],
        ...,
        [-0.0401,  0.0162, -0.0004,  ...,  0.0343,  0.0201,  0.0102],
        [-0.0346,  0.0172, -0.0033,  ...,  0.0343,  0.0152,  0.0098],
        [-0.0325,  0.0171, -0.0038,  ...,  0.0270,  0.0145,  0.0123]],
       grad_fn=<SubBackward0>)
hprebn          | exact: False | approximate: True  | maxdiff: 2.3283064365386963e-10


In [26]:
print(logits.shape)


torch.Size([32, 27])
