In [34]:
import matplotlib.pyplot as plt
%matplotlib inline

In [35]:
with open('names.txt') as f:
    content = f.read()
    words = content.splitlines()

In [36]:
len(words)

32033

# Preprocess the data

In [37]:
chars = sorted(list(set(''.join(words))))
stoi = { s: i+1 for i, s in enumerate(chars) }
stoi['.'] = 0
itos = { i: s for s, i in stoi.items() }
vocab_size = len(itos)

In [38]:
import torch

# Build dataset splits
def build_dataset(words: list[str], block_size: int) -> (torch.tensor, torch.tensor):
    x, y = [], []
    for w in words:
        #print(w)
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            #print(''.join(itos[i] for i in context), "--->", ch)
            # Advance the rolling window of context
            context = context[1:] + [ix]
    X = torch.tensor(x)
    Y = torch.tensor(y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

CONTEXT_SZ = 5
Xtr, Ytr = build_dataset(words[:n1], CONTEXT_SZ)
Xdev, Ydev = build_dataset(words[n1:n2], CONTEXT_SZ)
Xte, Yte = build_dataset(words[n2:], CONTEXT_SZ)

torch.Size([182625, 5]) torch.Size([182625])
torch.Size([22655, 5]) torch.Size([22655])
torch.Size([22866, 5]) torch.Size([22866])


# Model

In [45]:
EMBED_SZ = 32
HIDDEN_LAYER_SZ = 96

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, EMBED_SZ), generator=g)

W1 = torch.randn((CONTEXT_SZ * EMBED_SZ, HIDDEN_LAYER_SZ), generator=g)
b1 = torch.randn(HIDDEN_LAYER_SZ, generator=g)

W2 = torch.randn((HIDDEN_LAYER_SZ, vocab_size), generator=g)
b2 = torch.randn(vocab_size, generator=g)

parameters = [C, W1, b1, W2, b2]

# Reset parameters for training
for p in parameters:
    p.requires_grad = True
    
# Total number of parameters
sum(p.nelement() for p in parameters)

18939

# Training

In [50]:
def forward_pass(x: torch.tensor) -> torch.tensor:
    """Forward pass."""
    emb = C[x] # (N, CONTEXT_SZ, 2)
    h = relu(emb.view(-1, CONTEXT_SZ * EMBED_SZ) @ W1 + b1) # (N, HIDDEN_LAYER_SZ)
    return h @ W2 + b2 # (HIDDEN_LAYER_SZ, 27)

In [46]:
from statistics import mean
from tqdm import tqdm
from torch import nn
from torch.nn import functional as F

MINI_BATCH_SZ = 128

# It doesn't make sense to step through the learning rate linearly. Instead, we'll
# step through the exponents linearly.
lre = torch.cat([-1 * torch.ones(100000), torch.linspace(-1, -5, steps=50000)])
lrs = 10**lre
#plt.plot(range(0, len(lrs)), lrs)

lri = []
lossi = []
loglossi = []
loglossiavg = []
stepi = []

relu = nn.ReLU()
pbar = tqdm(range(len(lrs)))
for i in pbar:
    # Mini batch
    ix = torch.randint(0, Xtr.shape[0], (MINI_BATCH_SZ,)) # (MINI_BATCH_SZ)
    logits = forward_pass(Xtr[ix])
    loss = F.cross_entropy(logits, Ytr[ix])

    # Update parameters
    lr = lrs[i]
    sgd = torch.optim.SGD(parameters, lr=lr, weight_decay=0.001)
    sgd.zero_grad()
    loss.backward()
    sgd.step()

    # Track stats
    stepi.append(i)
    lossi.append(loss.item())
    loglossi.append(loss.log10().item())
    if i < 500:
        loglossiavg.append(loss.log10().item())
    else:
        loglossiavg.append(mean(loglossi[-500:]))
    if i % 1000 == 0:
        pbar.set_description("LR: %0.4f Loss %0.4f" % (lr, mean(lossi[-500:])))
    
print(loss.item())

LR: 0.0000 Loss 2.0648: 100%|██████████| 150000/150000 [01:14<00:00, 2025.33it/s]

1.9859323501586914





# Evaluate Loss

In [51]:
print({
    'MINI_BATCH_SZ': MINI_BATCH_SZ,
    'CONTEXT_SZ': CONTEXT_SZ,
    'EMBED_SZ': EMBED_SZ,
    'HIDDEN_LAYER_SZ': HIDDEN_LAYER_SZ,
    'num_parameters': sum(p.nelement() for p in parameters),
})

with torch.no_grad():
    xtr_loss = F.cross_entropy(forward_pass(Xtr), Ytr)
    print(f"Xtr loss: {xtr_loss:0.4f}")

    xdev_loss = F.cross_entropy(forward_pass(Xdev), Ydev)
    print(f"Xdev loss: {xdev_loss:0.4f}")

{'MINI_BATCH_SZ': 128, 'CONTEXT_SZ': 5, 'EMBED_SZ': 32, 'HIDDEN_LAYER_SZ': 96, 'num_parameters': 18939}
Xtr loss: 2.0658
Xdev loss: 2.0889


# Sample from the Model

In [54]:
g = torch.Generator().manual_seed(2147483647 + 10)

for _ in range(20):
    out = []
    context = [0] * CONTEXT_SZ
    while True:
        x = torch.tensor([context])
        logits = forward_pass(x)

        # Make logits sum to 1
        probs = F.softmax(logits, dim=1)

        # Sample from probabilities
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))

carmah.
amille.
khirmin.
rehty.
salayane.
rahnen.
deliyah.
kaqui.
nellaiah.
maiiv.
kaleigh.
hamon.
cadessan.
shoilea.
jadbi.
watthon.
jaryni.
jaceenissa.
medde.
iil.
