In [694]:
# Construct the bigram training set

import torch
import torch.nn.functional as F

words = open("../data/names.txt", "r").read().splitlines()
alphabet = sorted(list(set("".join(words))))
LIMITER="."

atoi = {s:i+1 for i,s in enumerate(alphabet)}
atoi[LIMITER] = 0

itoa = {i:s for s,i in atoi.items()}

C = len(atoi)
N = torch.zeros((C, C), dtype=torch.int32)

xs, ys = [], []

for w in words:
    chs = [LIMITER] + list(w) + [LIMITER]
    for ch1, ch2, in zip(chs, chs[1:]):
        i1 = atoi[ch1]
        i2 = atoi[ch2]
        xs.append(i1)
        ys.append(i2)

xs = torch.tensor(xs, dtype=torch.long)
ys = torch.tensor(ys, dtype=torch.long)

# Total Bigrams
total = xs.nelement()

In [695]:
# Intialize a 27x27 weight matrix. 27 neurons and 27 activations for each letter of the alphabet
W = torch.randn(C, C, requires_grad=True).float()

# Vectorize Onehot
xenc = F.one_hot(xs, num_classes=C).float()

In [696]:
for i in range(500):
    # === Forward ===
    logits = xenc @ W

    c = logits.exp()
    probs = c / c.sum(1, keepdim=True)

    regularization_term = (0.01 * (W**2).mean())
    loss = -(probs[torch.arange(xs.nelement()), ys].log().mean())

    total_loss = loss + regularization_term

    if i % 100 == 0:
        print(total_loss)

    # === Backward ===
    W.grad = None
    total_loss.backward()

    W.data -= 10 * W.grad

tensor(3.8206, grad_fn=<AddBackward0>)
tensor(2.5763, grad_fn=<AddBackward0>)
tensor(2.5190, grad_fn=<AddBackward0>)
tensor(2.5026, grad_fn=<AddBackward0>)
tensor(2.4948, grad_fn=<AddBackward0>)


In [697]:

g = torch.Generator().manual_seed(1234)
for i in range(10):
    out = []
    ix = 0
    while(True):
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        lc = logits.exp()
        probs = lc / lc.sum(1, keepdim=True)

        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()

        ch = itoa[ix]
        out.append(ch)
        if ch == LIMITER:
            break

    print(("".join(out))[:-1])

kyason
kanana
milaveela
ra
drisan
lurtolisyrsh
tobay
n
myaeahare
chucinailelulelone
