In [4]:
import torch
import torch.nn.functional as F
import matplotlib as plt
%matplotlib inline

In [5]:
words = open('names.txt', 'r').read().splitlines()

In [6]:
chars = sorted(list(set(''.join(words))))

In [7]:
stoi = { s:i+1 for i, s in enumerate(chars) }
stoi['.'] = 0
itos = { i:s for s, i in stoi.items() }

In [8]:
context_size = 3
# construct dataset
def construct_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * context_size
        for ch in w:
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)

            context = context[1:] + [ix]

    return torch.tensor(X), torch.tensor(Y)

In [10]:
# split the data into batches
n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

Xtr, Ytr = construct_dataset(words[:n1])
Xdev, Ydev = construct_dataset(words[n1:n2])
Xte, yte = construct_dataset(words[n2:])

torch.Size([157152, 3])

In [12]:
# init
vocab_size = 27
n_embd = 10
n_hidden = 100 

In [21]:
# define parameterse
C = torch.randn((vocab_size, n_embd)) # 27 characters embedded into a 10 dimensional feature vector
W1 = torch.randn((context_size * n_embd, n_hidden)) * (5/3) / ((context_size * n_embd) ** 0.5)
b1 = torch.randn(n_hidden) * 0.1
W2 = torch.randn(n_hidden, vocab_size) * 0.1
b2 = torch.randn(vocab_size) * 0.1

bngain = torch.randn((1, n_hidden)) * 0.1
bnbias = torch.randn((1, n_hidden)) * 0.1 + 0.1

In [72]:
parameters = [C, W1, b1, W2, b2]
sum(p.nelement() for p in parameters)
for p in parameters: 
    p.requires_grad = True

In [73]:
g = torch.Generator().manual_seed(20000100000)
# create a minibatch
batch_size = 32

ix = torch.randint(0, Xtr.shape[0], (batch_size, ), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [182]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [264]:
# Initialize training batch as a representation of C
emb = C[Xb]

embcat = emb.view(emb.shape[0], emb.shape[1] * emb.shape[2]) # [32, 3, 10] -> [32, 3] -> [32, 30]

# linear layer 01
hprebn = embcat @ W1 + b1
# batch normalization layer
bnmeani = hprebn.sum(0, keepdim=True) * 1 / batch_size
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvari = 1 / (batch_size - 1) * (bndiff2.sum(0, keepdim=True)) # bessels correction
bnnorm = bndiff / ((bnvari + 1e-05) ** 0.5)

# bngain bnbias
hpreact = bnnorm * bngain + bnbias

# non linearity
h = torch.tanh(hpreact) # bring to -1 and 1
# linear layer 02
logits = h @ W2 + b2
print(logits.shape)
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
print(logit_maxes.shape)
# cross entropy implementaiton
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum ** -1

probs = counts * counts_sum_inv # multiply expontiated logit by corresponding inverse sum(each prob adjusted relative to total prob sum, so they sum up to one)
logprobs = probs.log()
loss = -logprobs[range(batch_size), Yb].mean()

for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnnorm, bnvari, bndiff2, bndiff, bnmeani, hpreact, embcat, emb]:
    t.retain_grad()

loss.backward()
loss

torch.Size([32, 27])
torch.Size([32, 1])


tensor(3.3277, grad_fn=<NegBackward0>)

In [279]:
# Exercise 1: backprop through the whole thing manually,
# backpropagating through exactly all of the variables
# as they are defined in the forward pass above, one by one

# -----------------
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Yb] = -1.0 / batch_size
dprobs = 1 / probs * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = -(counts_sum**-2) * dcounts_sum_inv
dcounts += dcounts_sum.sum(1, keepdim=True)
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()
dlogit_maxes = -(dnorm_logits.sum(1, keepdim=True))
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
# cmp('h', dh, h)
# cmp('W2', dW2, W2)
# cmp('b2', db2, b2)
# cmp('hpreact', dhpreact, hpreact)
# cmp('bngain', dbngain, bngain)
# cmp('bnbias', dbnbias, bnbias)
# cmp('bnraw', dbnraw, bnraw)
# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
# cmp('bnvar', dbnvar, bnvar)
# cmp('bndiff2', dbndiff2, bndiff2)
# cmp('bndiff', dbndiff, bndiff)
# cmp('bnmeani', dbnmeani, bnmeani)
# cmp('hprebn', dhprebn, hprebn)
# cmp('embcat', dembcat, embcat)
# cmp('W1', dW1, W1)
# cmp('b1', db1, b1)
# cmp('emb', demb, emb)
# cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
