In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [249]:
words = open('data/names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {c:i+1 for i,c in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [330]:
block_size = 5

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X,Y

import random
random.shuffle(words)
train_split = int(.8*len(words))
test_split = int(.9*len(words))

train_X, train_Y = build_dataset(words[:train_split])
val_X, val_Y = build_dataset(words[train_split:test_split])
test_X, test_Y = build_dataset(words[test_split:])

train_X.shape, val_X.shape, test_X.shape

(torch.Size([182367, 5]), torch.Size([22860, 5]), torch.Size([22919, 5]))

In [334]:
emb_dim = 20
hidden_units = 300

C = torch.randn((27, emb_dim),)
W1 = torch.randn((emb_dim*block_size, hidden_units),)
B1 = torch.randn((hidden_units),)
W2 = torch.randn((hidden_units, 27),)
B2 = torch.randn((27),)
parameters = [C, W1, B1, W2, B2]
for p in parameters:
    p.requires_grad = True

In [None]:
epochs = 300000
batch_size = 64

for epoch in range(epochs):
    batch = torch.randint(0, train_X.shape[0], (batch_size,))
    emb = C[train_X[batch]]
    h = torch.tanh(emb.view(-1, emb_dim*block_size) @ W1 + B1)
    logits = h @ W2 + B2
    loss = F.cross_entropy(logits, train_Y[batch])

    for p in parameters:
        p.grad = None
    loss.backward()
    lr = 0.1 if epoch < 200000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

print(loss.item())

2.159745693206787


In [341]:
emb = C[train_X]
h = torch.tanh(emb.view(-1, emb_dim*block_size) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, train_Y)
print(loss.item())

2.061523199081421


In [340]:
emb = C[val_X]
h = torch.tanh(emb.view(-1, emb_dim*block_size) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, val_Y)
print(loss.item())

2.1416330337524414
