In [1]:
import pathlib
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
NAMES_FILE = "names.txt"
TERM_TOK = "."
CONTEXT_SIZE = 2
SEED = 2147483647
LAMBDA = 0.01  # l2 regularization strength
LEARNING_RATE = 50
NUM_EPOCHS = 100  # goes through the entire dataset this many times

In [3]:
words = [line.strip() for line in pathlib.Path(NAMES_FILE).open("r").readlines()]
chars = [TERM_TOK] + sorted(list(set("".join(words))))
stoi = {s: i for i, s in enumerate(chars)}
itos = chars
vocab_size = len(chars)

In [4]:
# create trigram distribution
def build_dataset(words: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
    X, y = [], []
    context = [stoi[TERM_TOK]] * CONTEXT_SIZE
    for word in words:
        for c in word + TERM_TOK:
            ix = stoi[c]
            X.append(context)
            y.append(ix)
            context = context[1:] + [ix]
    X = torch.tensor(X)
    y = torch.tensor(y)
    return X, y


X, y = build_dataset(words)
X.shape, y.shape

(torch.Size([228146, 2]), torch.Size([228146]))

In [5]:
# counting method
trigrams = torch.cat((X, y.view(-1, 1)), dim=1)
N = torch.zeros((vocab_size,) * (CONTEXT_SIZE + 1), dtype=torch.long)
unique_trigrams, counts = torch.unique(trigrams, dim=0, return_counts=True)
N[torch.unbind(unique_trigrams, dim=1)] = counts
N.shape

torch.Size([27, 27, 27])

In [6]:
# probability distribution with add-1 smoothing
P = (N + 1).float()  # add-1 smoothing
P /= P.sum(dim=-1, keepdim=True)
assert torch.allclose(P.sum(dim=-1), torch.ones((vocab_size,) * CONTEXT_SIZE))

In [7]:
# sample
g = torch.Generator().manual_seed(SEED)
for i in range(5):
    out = []
    context = [stoi[TERM_TOK]] * CONTEXT_SIZE
    while True:
        p = P[*context]
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:
            break
    print("".join(out))

cexzm.
zoglkurkicqzktyhwevmzimjttainrlkfukzkktda.
sfcxvpubjtbhrmgotzx.
iczieqctvujkwptedogkkjemkmmsedguenkbvgynywftbspmhwcivgbvtahlvsu.
dsdxxblnwglhpyiw.


In [8]:
# calculate loss
nll = -torch.sum(torch.log(P) * N)
print(f"{nll=}")
print(f"{nll / N.sum()}")

nll=tensor(503516.8750)
2.20699405670166
