In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [4]:
words = open('names.txt').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [5]:
len(words)

32658

In [6]:
# build vocab of chars and mappings
chars = sorted(list(set(''.join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 27: 'ó', 28: 'ę', 29: 'ł', 30: 'ń', 31: 'ś', 32: 'ż', 0: '.'}


In [85]:
# build the dataset
block_size = 3 #context len
X, Y = [], []
for w in words:
    # print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
    # print("\n")
X = torch.tensor(X)
Y = torch.tensor(Y)

In [86]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([233052, 3]), torch.int64, torch.Size([233052]), torch.int64)

In [87]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((33, 2), generator=g)

W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)

W2 = torch.randn((100, 33), generator=g)
b2 = torch.randn(33, generator=g)

parameters = [W1, b1, W2, b2]

In [88]:
for p in parameters:
    p.requires_grad = True

In [89]:
learning_rate =  -1
for _ in range(100):
    # forward pass
    emb = C[X]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    print(loss.item())

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data += learning_rate * p.grad


19.38705062866211
14.622990608215332
11.110272407531738
9.312739372253418
8.239185333251953
7.334122180938721
7.979617118835449
7.273725986480713
6.395797252655029
6.205406188964844
5.912548542022705
6.30939245223999
5.678480625152588
5.539745807647705
5.645369529724121
4.902589321136475
4.86363410949707
4.769598960876465
5.071596145629883
4.541385173797607
5.080757141113281
4.390317440032959
4.393413066864014
4.414956569671631
4.904275417327881
4.145936012268066
4.156111717224121
4.191665172576904
4.546576023101807
4.7404608726501465
3.971715211868286
4.000210285186768
4.2430949211120605
4.140524387359619
4.177298545837402
3.861210584640503
4.266016483306885
3.8248376846313477
3.969449996948242
3.933960199356079


KeyboardInterrupt: 