In [59]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [60]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [61]:
len(words)

32033

In [62]:
# build the vocab of ch and mappings to/ from the integers

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [63]:
# build the dataset

block_size =  3 # the context length
X, Y = [], []
for w in words:
    #print(w)

    context = [0] *  block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        #print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix] # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)
        

In [64]:
C = torch.randn((27, 2))

In [65]:
C.shape

torch.Size([27, 2])

In [66]:
C[[5,6,7]]

tensor([[ 1.5781, -1.3128],
        [-0.3133, -0.4032],
        [-2.1053,  0.8726]])

In [67]:
emb = C[X]
emb.shape

torch.Size([228146, 3, 2])

In [68]:
w1 = torch.randn((6,100))
b1 = torch.randn(100)

In [69]:
h = torch.tanh(emb.view(-1,6) @ w1 + b1)

In [70]:
h.shape

torch.Size([228146, 100])

In [71]:
w2 = torch.randn((100,27))
b2 = torch.randn(27)

In [72]:
logits = h @ w2 + b2

In [73]:
logits.shape

torch.Size([228146, 27])

In [74]:
counts = logits.exp()
probs = counts/ counts.sum(1, keepdim=True)

In [75]:
probs[0].sum()

tensor(1.0000)

In [76]:
probs.shape

torch.Size([228146, 27])

In [123]:
loss

tensor(3.0532, grad_fn=<NllLossBackward0>)

In [48]:
#torch.cat([emb[:, 0, :], emb[:, 2, :], emb[:, 2, :]], 1).shape

In [49]:
#torch.cat(torch.unbind(emb, 1), 1).shape

In [50]:
#emb.shape

In [47]:
#emb.view(32,6)

In [38]:
X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [156]:
g = torch.Generator().manual_seed(2147483647)
c = torch.randn((27,2), generator=g)
w1 = torch.randn((6,100), generator=g)
b1 = torch.randn(100, generator=g)
w2 = torch.randn((100,27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [c, w1, b1, w2, b2]

In [157]:
sum(p.nelement() for p in parameters)

3481

In [158]:
#forward pass
emb = c[X]
h = torch.tanh(emb.view(-1,6) @ w1 + b1)
logits = h @ w2 + b2
loss = F.cross_entropy(logits, Y)
loss

tensor(19.5052)

In [159]:
for p in parameters:
    p.requires_grad = True

In [160]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10**lre

In [181]:
lri = []
lossi = []
for i in range(10000):
    #minibatch
    ix = torch.randint(0, X.shape[0], (32,))
    #forward pass
    emb = c[X[ix]]
    h = torch.tanh(emb.view(-1,6) @ w1 + b1)
    logits = h @ w2 + b2
    loss = F.cross_entropy(logits, Y[ix])
    
    

    #backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    #update
    lr = 0.01
    for p in parameters:
        p.data += -lr * p.grad


    #track stats
    #lri.append(lre[i])
    #lossi.append(loss.item())

print(loss.item()) 

2.4164788722991943


In [182]:
emb = c[X]
h = torch.tanh(emb.view(-1,6) @ w1 + b1)
logits = h @ w2 + b2
loss = F.cross_entropy(logits, Y)
loss

tensor(2.3936, grad_fn=<NllLossBackward0>)

In [120]:
logits.max(0)

torch.return_types.max(
values=tensor([ 8.5005,  7.7635,  8.3645,  8.5227,  9.1544,  7.5232,  5.7845,  6.9921,
         6.3999, 10.7256,  6.8321,  5.5453,  6.7909,  6.2114,  7.9024,  5.2099,
         3.4723,  3.7873,  7.6372,  6.5402, 10.0558,  5.6155,  6.6399,  6.9849,
         9.6543,  6.0460,  8.4460], grad_fn=<MaxBackward0>),
indices=tensor([226577,  99983, 177314, 213504, 118844,  83287,  44601,  68426, 142168,
          4110,   4111,  25931,  99983,  99983,  41171, 223962, 119214, 215428,
        183814,  18968,   6324, 174542, 119214, 174065,  62301,  35923,  46622]))

In [121]:
Y

tensor([ 5, 13, 13,  ..., 26, 24,  0])