In [186]:
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib  inline

In [187]:
words = open('names.txt', 'r').read().splitlines()
len(words)

32033

In [188]:
chars = sorted(list(set("".join(words))))
chars


['a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [189]:
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [190]:
block_size = 3
X, Y = [], []

for w in words:
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [191]:
X.dtype, Y.dtype, X.shape, Y.shape

(torch.int64, torch.int64, torch.Size([228146, 3]), torch.Size([228146]))

In [192]:
C = torch.randn((27, 2))


In [193]:
emb = C[X]
emb.shape

torch.Size([228146, 3, 2])

In [194]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [195]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h.shape


torch.Size([228146, 100])

In [196]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

logits = h @ W2 + b2
logits.shape

torch.Size([228146, 27])

In [197]:
counts = logits.exp()
prob = counts / counts.sum(1, keepdim=True)
prob.shape


torch.Size([228146, 27])

In [198]:
g = torch.Generator().manual_seed(40)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b2 = torch.randn((100), generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn((27), generator=g)
parameters = [ C, W1, b2, W2, b2]

In [199]:
sum(p.nelement() for p in parameters)

3408

In [200]:
for p in parameters:
    p.requires_grad = True

In [201]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10 ** lre


In [205]:
# lri  = []
# lossi = []
for i in range(10000):
    ix = torch.randint(0, X.shape[0], (32,))
        
    # forward pass
    emb = C[X[ix]]
    h =   torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y[ix])
    print(loss.item())

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # upadate
    lr = 0.1
    for p in parameters:
        p.data += -lr * p.grad
    
    # track stats
    # lri.append(lre[i])
    # lossi.append(loss.item())


print(loss.item())


2.8315601348876953
2.349586248397827
2.663170576095581
2.551936626434326
2.5158851146698
2.306763172149658
2.431781053543091
2.378732442855835
2.252362012863159
2.896754264831543
2.7435686588287354
2.623305082321167
2.441997766494751
2.5565760135650635
2.3181941509246826
2.492682695388794
2.369250535964966
2.5825133323669434
2.493499517440796
2.4614460468292236
2.902031660079956
2.3759539127349854
2.82598614692688
2.8112597465515137
2.6171751022338867
2.6911871433258057
2.481391429901123
2.6028852462768555
2.534791946411133
2.4842894077301025
2.5097484588623047
2.4970908164978027
2.6996445655822754
2.7527713775634766
2.264712333679199
2.4246137142181396
2.7674624919891357
2.3959600925445557
2.4874353408813477
2.4961609840393066
2.1912262439727783
2.2534327507019043
2.299328088760376
2.4648141860961914
2.4762649536132812
2.8454556465148926
2.3609066009521484
2.7846198081970215
2.8053503036499023
2.504202127456665
2.3492774963378906
2.59248423576355
2.5515360832214355
2.37869930267334
2.