In [57]:
words = open('names.txt', 'r').read().splitlines()

In [58]:
words[0]

'emma'

In [59]:
# a trigram is split as follows, x is the input to NN, and y is the expected o/p
x = [('.','.'), ('.', 'e')]
y = ['e', 'm']

In [68]:
# helper dictionary
stoi = {s:i+1 for i,s in enumerate(sorted(list(set(''.join(words)))))}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [61]:
import torch
import torch.nn.functional as F

In [62]:
# since we can't pass the characters as is, we will one hot encode inputs and expected outputs
# say we have N examples
# x will have shape (2,27) , each row representing each character in the input bigram
# xs will have shape (N, 2, 27), each row representing each example
# output will be of shape (N, 27), each row is a probability distribution of last character in the trigram

In [88]:
xs = []
ys = []
for w in words[:]:
    chs = ['.','.'] + list(w) + ['.', '.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        x = torch.zeros(27).float()
        x[stoi[ch1]] += 1
        x[stoi[ch2]] += 1
        xs.append(x)
        ys.append(stoi[ch3])
xenc = torch.stack(xs)
xenc.shape

torch.Size([260179, 27])

In [89]:
# we need a weight matrix that takes 
# [N, 2, 27] to [N, 27]
W = torch.randn((27,27), requires_grad=True)
(xenc @ W).shape

torch.Size([260179, 27])

In [90]:
num_examples = len(xs)

In [91]:
for k in range(100):
    # forward pass
    xenc = torch.stack(xs)
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(num_examples), ys].log().mean()
    loss += 0.1*(W**2).mean() # regularization loss
    print(loss)
    
    # backward pass
    W.grad = None
    loss.backward()
    W.data += -5 * W.grad

tensor(4.5456, grad_fn=<AddBackward0>)
tensor(4.2396, grad_fn=<AddBackward0>)
tensor(3.9882, grad_fn=<AddBackward0>)
tensor(3.7883, grad_fn=<AddBackward0>)
tensor(3.6458, grad_fn=<AddBackward0>)
tensor(3.5442, grad_fn=<AddBackward0>)
tensor(3.4626, grad_fn=<AddBackward0>)
tensor(3.3943, grad_fn=<AddBackward0>)
tensor(3.3364, grad_fn=<AddBackward0>)
tensor(3.2867, grad_fn=<AddBackward0>)
tensor(3.2433, grad_fn=<AddBackward0>)
tensor(3.2047, grad_fn=<AddBackward0>)
tensor(3.1699, grad_fn=<AddBackward0>)
tensor(3.1383, grad_fn=<AddBackward0>)
tensor(3.1093, grad_fn=<AddBackward0>)
tensor(3.0825, grad_fn=<AddBackward0>)
tensor(3.0577, grad_fn=<AddBackward0>)
tensor(3.0347, grad_fn=<AddBackward0>)
tensor(3.0133, grad_fn=<AddBackward0>)
tensor(2.9934, grad_fn=<AddBackward0>)
tensor(2.9748, grad_fn=<AddBackward0>)
tensor(2.9574, grad_fn=<AddBackward0>)
tensor(2.9411, grad_fn=<AddBackward0>)
tensor(2.9258, grad_fn=<AddBackward0>)
tensor(2.9114, grad_fn=<AddBackward0>)
tensor(2.8978, grad_fn=<A

In [96]:
for i in range(5):
    out = []
    ix1 = 0
    ix2 = 0
    while True:
        x1 = torch.zeros(1,27)
        x1[0, ix1] += 1
        x1[0, ix2] += 1
        logits =  x1 @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)
        ix3 = torch.multinomial(p, num_samples=1, replacement=True).item()
        out.append(itos[ix3])
        if ix3 == 0:
            break
        ix1, ix2 = ix2, ix3
    print(''.join(out))


zen.
ahi.
slabyrealn.
tpnu.
moan.
