In [58]:
import torch, math
import matplotlib.pyplot as plt
%matplotlib inline

In [59]:
words = open('data/names.txt').read().splitlines()
words = list(set(w.lower() for w in words))
len(words)

21974

In [60]:
chars = sorted(list(set("".join(words))))
stoi = {s:i for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
vocab_size

29

In [67]:
block_size = 3

def build_dataset(words):
    X,Y = [], []
    context = [0]*block_size
    for w in words:
        for ch in w+'.':
            ix = stoi[ch]
            Y.append(ix)
            X.append(context)
            context = context[1:] + [ix]
    X,Y = torch.tensor(X), torch.tensor(Y)
    print(X.shape, Y.shape)
    return X,Y

In [68]:
import random

random.shuffle(words)

n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,Ytr = build_dataset(words[:n1])
Xdev,Ydev = build_dataset(words[n1:n2])
Xte,Yte = build_dataset(words[n2:])

torch.Size([125808, 3]) torch.Size([125808])
torch.Size([15762, 3]) torch.Size([15762])
torch.Size([15710, 3]) torch.Size([15710])


In [69]:
def cmp(s, dt, t):
    exc = torch.all(dt==t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt-t.grad).abs().max().item()
    print(f"{s:15s} | exact {str(exc):5s} | abs {str(app):5s} | maxdiff {maxdiff:}")

In [70]:
n_emb = 10
n_hidden = 64

C = torch.randn((vocab_size, n_emb))

W1 = torch.randn((n_emb*block_size, n_hidden)) * (5/3)/((n_emb*block_size)**0.5)
b1 = torch.randn(n_hidden) * 0.1

bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

W2 = torch.randn((n_hidden, vocab_size)) * 0.1
b2 = torch.randn(vocab_size) * 0.1


parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4287


In [71]:
batch_size = 32
n = batch_size
bix = torch.randint(0, Xtr.shape[0], (n,))
Xb,Yb = Xtr[bix], Ytr[bix]
Xb.shape, Yb.shape

(torch.Size([32, 3]), torch.Size([32]))

In [106]:
emb = C[Xb]
emb_cat = emb.view(emb.shape[0], -1)
hprebn = emb_cat @ W1 + b1

bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2

bnvar = 1/(n-1)*bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-0.5

bnraw = bndiff * bnvar_inv

hpreact = bngain * bnraw + bnbias

h = torch.tanh(hpreact)

logits = h @ W2 + b2

logits_max = logits.max(1, keepdim=True).values
logits_norm = logits - logits_max
counts = logits_norm.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts*counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb]


for p in parameters:
    p.grad = None
    
    
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          logits_norm, logits_max, logits, h, hpreact, bnraw, bnvar,
          bnvar_inv, bndiff2, bndiff, hprebn, bnmeani,emb_cat, emb
         ]:
    t.retain_grad()