In [3]:
import torch
import random
import torch.nn.functional as F
import matplotlib as plt

%matplotlib inline

In [4]:
words = open('names.txt', 'r').read().splitlines()

In [7]:
chars = sorted(list(set(''.join(words))))

In [11]:
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [19]:
# block size for training and inference
block_size = 3
# construct data
def construct_data(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)

            context = context[1:] + [ix]
    
    return torch.tensor(X), torch.tensor(Y)


In [None]:
import random
random.seed(21476891001)
random.shuffle(words)

In [40]:
# split the data into batches (training, dev, test)

n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

Xtr, Ytr = construct_data(words[:n1])
Xdev, Ydev = construct_data(words[n1:n2])
Xte, Yte = construct_data(words[n2:])

In [26]:
g = torch.Generator().manual_seed(21476891001)

In [28]:
# initialize values 
vocab_size = 27
n_emb = 10
n_hidden = 100

In [30]:
C = torch.randn((vocab_size, n_emb), generator=g)
W1 = torch.randn((n_emb * block_size, n_hidden), generator=g) * (5/3) / ((n_emb * block_size) ** 0.5) # add kaiming 
b1 = torch.randn(n_hidden, generator=g) * 0.1
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

bngain = torch.randn((1, n_hidden)) * 0.1 + 0.1
bnbias = torch.randn((1, n_hidden)) * 0.1

In [31]:
parameters = [C, W1, b1, W2, b2]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True


6097


In [45]:
batch_size = 32
ix = torch.randint(0, Xtr.shape[0], (batch_size, ), generator=g)

Xb, Yb = Xtr[ix], Ytr[ix]
Xb.shape

torch.Size([182778, 3])


torch.Size([32, 3])

In [115]:
emb = C[Xb] # [27, 10] [32, 3] -> [32, 3, 10]

embcat = emb.view(emb.shape[0], emb.shape[1] * emb.shape[2])

# linear layer 01
# pre activation
bnprebn = embcat @ W1 + b1
# batch normalization
bnmeani =  1 / batch_size * (bnprebn.sum(0, keepdim=True))
bndiff = bnprebn - bnmeani
bndiff2 = bndiff ** 2
bnvari = 1 / (batch_size - 1) * bndiff2.sum(0, keepdim=True) # Bessels correction
bnnorm = bndiff / ((bnvari + 1e-5) ** 0.5)

hpreact = bnnorm * bngain + bnbias

# non linear layer
h = torch.tanh(hpreact)

# linear layer 02
logits = h @ W2 + b2
print(logits.shape)
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()

# normalize counts
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1 # more exact

probs = counts * counts_sum_inv
log_probs = probs.log()
loss = -log_probs[range(batch_size), Yb].mean()


# pytorch
for p in parameters:
    p.grad = None
for t in (log_probs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnnorm, bnvari, bndiff, bndiff2, bnmeani, embcat, emb):
    t.retain_grad()
loss.backward()
loss

torch.Size([32, 27])


tensor(3.2633, grad_fn=<NegBackward0>)