In [227]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [228]:
words = open('names.txt','r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [229]:
len(words)

32033

In [230]:
# build the vocab of chars and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i,s in enumerate(chars)}
stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [231]:
# build the dataset
block_size = 3
X, Y = [], []

for w in words:
    # print(w)
    context = [0] * block_size
    
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join(itos[i] for i in context), '---->', itos[ix])
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [232]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([228146, 3]), torch.int64, torch.Size([228146]), torch.int64)

In [233]:
# Lookup table
C = torch.randn((27,2))

In [234]:
C[X][1,1]

tensor([0.3552, 0.1907])

In [235]:
emb = C[X]
emb.shape

torch.Size([228146, 3, 2])

In [236]:
# Middle layer
W1 = torch.randn((6,100))
b1 = torch.randn(100)

In [237]:
emb[:, 0, :].shape

torch.Size([228146, 2])

In [238]:
torch.cat([emb[:,0,:], emb[:,1,:], emb[:,2,:]], 1).shape

torch.Size([228146, 6])

In [239]:
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([228146, 6])

In [240]:
# more efficient
# -1 torch infers the correct dimension
emb.view(-1, 6).shape

torch.Size([228146, 6])

In [241]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [242]:
# Final layer
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

logits = h @ W2 + b2

In [243]:
logits.shape
counts = logits.exp()
prob = counts / counts.sum(1, keepdims=True)

In [244]:
prob.shape

torch.Size([228146, 27])

In [245]:
prob[0].sum()

tensor(1.)

In [247]:
loss = -prob[torch.arange(228146), Y].log().mean()
loss

tensor(16.6972)

In [None]:
# --- Summary ----

In [None]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

In [None]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27,2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, b1, W2, b2]

In [None]:
sum(p.nelement() for p in parameters)

3481

In [None]:
for p in parameters:
    p.requires_grad = True

In [None]:
for i in range(1000):
    
    # forward pass
    emb = C[X] # (32, 3, 2)
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    print(loss.item())
    
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    
    # update
    for p in parameters:
        p.data += -0.1 * p.grad


17.76971435546875
13.656403541564941
11.298772811889648
9.452458381652832
7.984263896942139
6.891322135925293
6.100015640258789
5.452036380767822
4.8981523513793945
4.4146647453308105
3.985849618911743
3.602830648422241
3.262141704559326
2.961381435394287
2.6982975006103516
2.469712972640991
2.271660804748535
2.101283550262451
1.9571771621704102
1.837485909461975
1.7380971908569336
1.6535120010375977
1.5790897607803345
1.5117669105529785
1.449605107307434
1.3913124799728394
1.335992455482483
1.2830536365509033
1.232192039489746
1.183382511138916
1.1367992162704468
1.0926648378372192
1.0510931015014648
1.012027621269226
0.9752711057662964
0.940557062625885
0.907613217830658
0.8761924505233765
0.8460894227027893
0.8171362280845642
0.7891996502876282
0.7621750831604004
0.7359820008277893
0.7105579972267151
0.6858614087104797
0.6618654727935791
0.6385660767555237
0.6159822344779968
0.594166100025177
0.5732107758522034
0.553256630897522
0.5344884991645813
0.5171172022819519
0.50133180618286