In [None]:
import torch
import torch.nn.functional as F

words = open('names.txt','r').read().splitlines()
chars = sorted(list(set(''.join(words))))

stoi = {s : i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i : s for s, i in stoi.items()}

In [None]:
xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print("number of examples: ", num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

In [None]:
xenc = F.one_hot(xs, num_classes=27).float()

for k in range(10000):

    # forward pass:
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01 * (W ** 2).mean()

    if k % 1000 == 0:
        print(loss.item())

    # backward pass:
    W.grad = None
    loss.backward()

    # update:
    W.data += -10 * W.grad


In [None]:
# sample from the network

for _ in range(5):

    out = []
    ix = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)
        
        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()

        if ix == 0: break
        out.append(itos[ix])
    
    print(''.join(out))