In [44]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [45]:
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(words[:8])

32033
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [46]:
chars = sorted(list(set(''.join(words))))
stoi = {c: i+1 for i, c in enumerate(chars)}
stoi['.'] = 0
itos = {i+1: c for i, c in enumerate(chars)}
vocab_size = len(stoi)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}
27


In [47]:
def buildDataset(words, block_size):
    x, y = [], []
    for w in words:
        #print(w)
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            #print(''.join ([itos[i] for i in context]), '-->', itos[ix])
            context = context[1:] + [ix]
            
    x = torch.tensor(x)
    y = torch.tensor(y)
    print(x.shape, y.shape)
    return x, y
import random
block_size = 3
random.seed(42)
random.shuffle(words)
n1 = int (len(words)*.8)    
n2 = int (len(words)*.9)    
x_train, y_train = buildDataset(words[:n1], block_size)
x_val, y_val = buildDataset(words[n1:n2], block_size)
x_test, y_test = buildDataset(words[n2:], block_size)


torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [48]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    max_diff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | max_diff: {max_diff:.5f}')

In [49]:
n_embed = 10
n_hidden = 64

C = torch.randn(vocab_size, n_embed)
#Layer one
W1 = torch.randn(n_embed * block_size, n_hidden) * (5/3) /((n_embed * block_size)**.5)
B1 = torch.randn(n_hidden) * .1
#Layer two
W2 = torch.randn(n_hidden, vocab_size) * .1
B2 = torch.randn(vocab_size) * .1
#Batch normalization
bngain = torch.randn(1, n_hidden) * 0.1 + 1
bnbias = torch.randn(1, n_hidden) * 0.1

params = [C, W1, B1, W2, B2, bngain, bnbias]
for p in params:
    p.requires_grad = True

In [50]:
batch_size = 32
n = batch_size
ix = torch.randint(0, x_train.shape[0], (n,))
Xb, Yb = x_train[ix], y_train[ix]

In [78]:
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hprebn = embcat @ W1 + B1
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
h = torch.tanh(hpreact)
logits = h @ W2 + B2
logitmaxes = logits.max(1, keepdim=True)[0]
normlogits = logits - logitmaxes
counts = normlogits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = 1/counts_sum
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

for p in params:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, normlogits,logitmaxes,  logits, h, hpreact, hprebn, bnraw, bnvar, bnvar_inv, bndiff, bndiff2, bnmeani ,embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3998, grad_fn=<NegBackward0>)

In [79]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1/n

cmp('logprobs', dlogprobs, logprobs)

logprobs        | exact: True  | approx: True  | max_diff: 0.00000


In [80]:
dprobs = dlogprobs / probs

cmp('probs', dprobs, probs)

probs           | exact: True  | approx: True  | max_diff: 0.00000


In [81]:
counts.shape, counts_sum_inv.shape
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)

cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

counts_sum_inv  | exact: True  | approx: True  | max_diff: 0.00000


In [82]:
dcounts_sum = -1/(counts_sum**2) * dcounts_sum_inv

cmp('counts_sum', dcounts_sum, counts_sum)

counts_sum      | exact: False | approx: True  | max_diff: 0.00000


In [83]:
dcounts = dprobs * counts_sum_inv
dcounts += torch.ones_like(dcounts) * dcounts_sum

cmp('counts', dcounts, counts)

counts          | exact: False | approx: True  | max_diff: 0.00000


In [84]:
dnormlogits = counts * dcounts

cmp('normlogits', dnormlogits, normlogits)

normlogits      | exact: False | approx: True  | max_diff: 0.00000


In [85]:
dlogits = dnormlogits
dlogit_maxes = (-dnormlogits).sum(1, keepdim=True)

cmp('logit_maxes', dlogit_maxes, logitmaxes)

logit_maxes     | exact: False | approx: True  | max_diff: 0.00000


In [86]:
dlogits += F.one_hot(logits.max(1).indices, logits.shape[1]) * dlogit_maxes

cmp('logits', dlogits, logits)

logits          | exact: False | approx: True  | max_diff: 0.00000


In [87]:
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
dB2 = dlogits.sum(0)

cmp('h', dh, h), cmp('W2', dW2, W2), cmp('B2', dB2, B2)

h               | exact: False | approx: True  | max_diff: 0.00000
W2              | exact: False | approx: True  | max_diff: 0.00000
B2              | exact: False | approx: True  | max_diff: 0.00000


(None, None, None)

In [88]:
dhpreact = dh * (1 - h**2)

cmp('hpreact', dhpreact, hpreact)

hpreact         | exact: False | approx: True  | max_diff: 0.00000


In [89]:
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)

cmp('bngain', dbngain, bngain)

bngain          | exact: False | approx: True  | max_diff: 0.00000


In [90]:
dbnraw = bngain * dhpreact

cmp('bnraw', dbnraw, bnraw)

bnraw           | exact: False | approx: True  | max_diff: 0.00000


In [91]:
dbnbias = dhpreact.sum(0, keepdim=True)

cmp('bnbias', dbnbias, bnbias)

bnbias          | exact: False | approx: True  | max_diff: 0.00000


In [92]:
dbndiff = dbnraw * bnvar_inv

cmp('bndiff', dbndiff, bndiff)

bndiff          | exact: False | approx: False | max_diff: 0.00101


In [93]:
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True)

cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

bnvar_inv       | exact: False | approx: True  | max_diff: 0.00000


In [94]:
dbnvar = -1/2 * (bnvar + 1e-5)**-1.5 * dbnvar_inv

cmp('bnvar', dbnvar, bnvar)

bnvar           | exact: False | approx: True  | max_diff: 0.00000


In [95]:
dbndiff2 = 1/(n-1) * dbnvar

cmp('bndiff2', dbndiff2, bndiff2)

bndiff2         | exact: False | approx: True  | max_diff: 0.00000


In [96]:
dbndiff += 2 * bndiff * dbndiff2

cmp('bndiff', dbndiff, bndiff)

bndiff          | exact: False | approx: True  | max_diff: 0.00000


In [97]:
dhrpebn = dbndiff.clone()

cmp('hprebn', dhrpebn, hprebn)

hprebn          | exact: False | approx: False | max_diff: 0.00105


In [98]:
dbnmeani = (-dbndiff).sum(0, keepdim=True)

cmp('bnmeani', dbnmeani, bnmeani)

bnmeani         | exact: False | approx: True  | max_diff: 0.00000


In [99]:
dhrpebn += 1/n * torch.ones_like(hprebn) * dbnmeani

cmp('hprebn', dhrpebn, hprebn)

hprebn          | exact: False | approx: True  | max_diff: 0.00000


In [100]:
dembcat = dhrpebn @ W1.T

cmp('embcat', dembcat, embcat)

embcat          | exact: False | approx: True  | max_diff: 0.00000


In [101]:
dW1 = embcat.T @ dhrpebn

cmp('W1', dW1, W1)

W1              | exact: False | approx: True  | max_diff: 0.00000


In [103]:
dB1 = dhrpebn.sum(0)

cmp('B1', dB1, B1)

B1              | exact: False | approx: True  | max_diff: 0.00000
