In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

In [6]:
words = open('../data/names.txt').read().splitlines()

In [18]:
vocab = sorted(set(''.join(words)))

stoi = {s:i+1 for i, s in enumerate(vocab)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

In [133]:
block_size = 5 # context length: how many characters do we take to predict the next one?
X, Y = [], []
for w in words:
  
  #print(w)
  context = [0] * block_size
  for ch in w + '.':
    ix = stoi[ch]
    X.append(context)
    Y.append(ix)
    context = context[1:] + [ix] # crop and append
  
x = torch.tensor(X)
y = torch.tensor(Y)

In [134]:
# Embed x into a 2 dimensional feature vector

C = torch.randn((27, 2))
w1 = torch.randn((2 * block_size, 100))
b1 = torch.randn(100)
w2 = torch.randn((100, 27))
b2 = torch.randn(27)

parameters = [C, w1, b1, w2, b2]

sum(p.nelement() for p in parameters)

3881

In [135]:
epochs = 200000
bs = 64


for p in parameters:
    p.requires_grad = True

for epoch in range(epochs):
    batch = torch.randint(0, x.shape[0], (bs,))
    # forward pass
    x_in = C[x[batch]]
    x_in = x_in.view(x_in.shape[0], -1) @ w1 + b1
    x_in = F.tanh(x_in)
    logits = x_in @ w2 + b2

    # backward pass
    loss = F.cross_entropy(logits, y[batch])
    print(loss.item())

    for p in parameters:
        p.grad = None

    loss.backward()

    lr = 0.1 if epoch < 100000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad


18.547077178955078
15.163475036621094
16.200468063354492
13.8974609375
14.048439025878906
10.066896438598633
12.072203636169434
13.080248832702637
10.739436149597168
9.979900360107422
11.879819869995117
10.519529342651367
10.339436531066895
11.811232566833496
10.251545906066895
8.939838409423828
9.633538246154785
7.984615802764893
8.763277053833008
8.660447120666504
7.101648330688477
7.619985103607178
7.415341377258301
5.940225124359131
6.5482306480407715
7.505264759063721
5.8265533447265625
6.166872978210449
6.178145408630371
6.318881988525391
6.2144036293029785
6.04902458190918
6.126405239105225
6.256136417388916
5.316727638244629
6.69320011138916
6.319162845611572
5.107863903045654
5.112107276916504
5.366641521453857
6.459537982940674
5.086668491363525
5.081257343292236
5.454922199249268
4.832509994506836
4.925400257110596
5.261524200439453
5.473140716552734
4.894525051116943
4.064647674560547
4.552493572235107
5.583275318145752
4.704357147216797
4.576202392578125
5.261595249176025


In [136]:
# Sample from the model 

for _ in range(20):
    res = ""
    ix = [0] * block_size
    while True:
        inp = C[torch.tensor(ix).unsqueeze(0)]
        inp = inp.view(1, -1) @ w1 + b1
        inp = F.tanh(inp)

        logits = inp @ w2 + b2
        counts = torch.exp(logits)
        probs = counts / counts.sum(1, keepdim=True)

        out = torch.multinomial(probs, 1, replacement=True)
        curr_ix = out.item()

        if curr_ix == 0:
            break

        ix = ix[1:] + [curr_ix]
        res += itos[curr_ix]

    print(res)    
    

mieiar
brasesrangy
ratk
diallu
adyr
abefelle
janat
dawn
memecerrango
amnie
daeziorsa
jilre
konces
lintir
ameymani
dhumcan
mioad
annalla
anli
lwior
