In [2]:
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib  inline

In [3]:
words = open('names.txt', 'r').read().splitlines()
len(words)

32033

In [4]:
chars = sorted(list(set("".join(words))))
chars


['a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [5]:
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [6]:
block_size = 3
X, Y = [], []

for w in words:
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [7]:
X.dtype, Y.dtype, X.shape, Y.shape

(torch.int64, torch.int64, torch.Size([228146, 3]), torch.Size([228146]))

In [8]:
C = torch.randn((27, 2))


In [9]:
emb = C[X]
emb.shape

torch.Size([228146, 3, 2])

In [10]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [11]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h.shape


torch.Size([228146, 100])

In [12]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

logits = h @ W2 + b2
logits.shape

torch.Size([228146, 27])

In [13]:
counts = logits.exp()
prob = counts / counts.sum(1, keepdim=True)
prob.shape


torch.Size([228146, 27])

In [14]:
g = torch.Generator().manual_seed(40)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b2 = torch.randn((100), generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn((27), generator=g)
parameters = [ C, W1, b2, W2, b2]

In [15]:
sum(p.nelement() for p in parameters)

3408

In [16]:
for p in parameters:
    p.requires_grad = True

In [17]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10 ** lre


In [18]:
# lri  = []
# lossi = []
for i in range(10000):
    ix = torch.randint(0, X.shape[0], (32,))
        
    # forward pass
    emb = C[X[ix]]
    h =   torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y[ix])

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # upadate
    lr = 0.1
    for p in parameters:
        p.data += -lr * p.grad
    
    # track stats
    # lri.append(lre[i])
    # lossi.append(loss.item())


print(loss.item())


2.407068967819214
