In [1]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from torch.nn.functional import one_hot

# Data

In [2]:
with open('names.txt') as f:
    data = f.read().split('\n')
data[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
characters = list(set([i for i in ''.join(data)]))
characters.insert(1, '.')
hmap = {}
for n in range(0,27):
    hmap[sorted(characters)[n]] = n
print(hmap)

{'.': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}


In [4]:
# Create dataset
X, y = [],[]

for name in data:
    name = '.' + name + '.'
    for ch1, ch2 in zip(name, name[1:]):
        X.append(hmap[ch1])
        y.append(hmap[ch2])
    # break

X = one_hot(torch.tensor(X), num_classes=27).float()
y = torch.tensor(y).int()
X.shape, y.shape

(torch.Size([228146, 27]), torch.Size([228146]))

# Model

In [5]:
# Test
W = torch.randn((27,27), requires_grad=True)
logits = X @ W # these are log-counts
cts = logits.exp()
probs = cts/cts.sum(1, keepdim=True)
probs

tensor([[0.2105, 0.0157, 0.1094,  ..., 0.0089, 0.0543, 0.0809],
        [0.0088, 0.0131, 0.0251,  ..., 0.0139, 0.0033, 0.0638],
        [0.0014, 0.0543, 0.0858,  ..., 0.0093, 0.0259, 0.0192],
        ...,
        [0.0095, 0.0804, 0.0090,  ..., 0.0383, 0.0348, 0.0057],
        [0.0838, 0.0186, 0.0157,  ..., 0.0054, 0.0523, 0.0229],
        [0.0309, 0.0433, 0.1325,  ..., 0.0260, 0.0716, 0.0459]],
       grad_fn=<DivBackward0>)

In [6]:
probs[0].sum() # Sum along a row should be 1

tensor(1.0000, grad_fn=<SumBackward0>)

This would have given the output, if W was correct. So, we optimise W

# Training

In [15]:
num_epochs = 200
lr = 100
n = len(X)

In [17]:
for i in range(num_epochs):
    
    # forward pass
    logits = X @ W 
    cts = logits.exp()
    probs = cts/cts.sum(1, keepdim=True)

    # loss function
    loss = -probs[torch.arange(n), y].log().mean() + .001*(W**2).mean()
    if not i%20:
        print(loss.item())
    
    
    # backward pass
    W.grad = None
    loss.backward()
    W.data += lr * -W.grad


2.487898826599121
2.487605571746826
2.4873836040496826
2.487198829650879
2.4870452880859375
2.4869155883789062
2.486804246902466
2.486708641052246
2.4866254329681396
2.4865529537200928


In [9]:
y

tensor([ 5, 13, 13,  ..., 26, 24,  0], dtype=torch.int32)