In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()

In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)

In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch] # num of char
      X.append(context) # after context, predict ix
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80% of data to train set
Xdev, Ydev = build_dataset(words[n1:n2])   # 10% of data to dev set
Xte,  Yte  = build_dataset(words[n2:])     # 10% of data to test set

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / (n_embd * block_size)**0.5 #0.2
b1 = torch.randn(n_hidden,                        generator=g) * 0.1

W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1

bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden)) * 0.1

bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [8]:
batch_size = 32
n = batch_size
# minibatch construct
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [9]:
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1) # flatten
# Lin layer 1
hprebn = embcat @ W1 + b1
# BN
bnmeani = 1/n * hprebn.sum(dim=0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(dim=0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**(-0.5)
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# non lin
h = torch.tanh(hpreact)
# Lin layer 2
logits = h @ W2 + b2
# loss
logits_maxes = logits.max(dim=1, keepdim=True).values
norm_logits = logits - logits_maxes # for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(dim=1, keepdim=True)
counts_sum_inv = counts_sum**(-1)
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# backward
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logits, h, hpreact, bnraw, bndiff, bndiff2, bnvar, bnvar_inv, bnmeani, hprebn, embcat, emb]:
    t.retain_grad()

loss.backward()
loss

tensor(3.3482, grad_fn=<NegBackward0>)

In [12]:
logprobs[range(n), Yb]

tensor([-4.0580, -3.0728, -3.6750, -3.2631, -4.1653, -3.5406, -3.1162, -4.0795,
        -3.2095, -4.3294, -3.1081, -1.6111, -2.8121, -2.9719, -2.9798, -3.1644,
        -3.8541, -3.0233, -3.5830, -3.3694, -2.8526, -2.9453, -4.3805, -4.0618,
        -3.5177, -2.8368, -2.9712, -3.9312, -2.7585, -3.4454, -3.3162, -3.1384],
       grad_fn=<IndexBackward0>)

In [None]:
loss.grad = 1.0
logprobs[range(n), Yb].grad = 1.0 / n

