In [4]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [5]:
words = open('names.txt', 'r').read().splitlines()

In [6]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [7]:
# Build the dataset

block_size = 3 # Context Length: How many characters do we take to predict the next one?
X,Y = [], []

for w in words:

    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix] # Crop and Append

X = torch.tensor(X)
Y = torch.tensor(Y)
num = len(Y)


In [8]:
g = torch.Generator().manual_seed(2147483647)

C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 200), generator=g)
b1 = torch.randn(200, generator=g)
W2 = torch.randn((200, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, b1, W2, b2]

In [9]:
sum(p.nelement() for p in parameters)

6881

In [10]:
for p in parameters:
    p.requires_grad = True

In [11]:
for _ in range(200000):

    # Minibatch Construct
    ix = torch.randint(0, X.shape[0], (32,))

    # Forward Pass
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1,6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y[ix])

    # Backward Pass
    for p in parameters:
        p.grad = None
    loss.backward()

    lr = 0.1 if _ < 100000 else 0.01

    # Update
    for p in parameters:
        p.data += -lr * p.grad
    

print(loss.item())

2.249478340148926


In [12]:
emb = C[X]
h = torch.tanh(emb.view(-1,6) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Y)
print(loss.item())

2.2234556674957275


In [13]:
# Testing Output
for k in range(10):
    prev = [0] * block_size
    
    out = []
    while True:
        embs = C[prev]
        h = torch.tanh(embs.view(-1,6) @ W1 + b1)
        logits = (h @ W2 + b2).flatten()
        counts = logits.exp()
        probs = counts / counts.sum()

        idx = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        
        if idx == 0:
            break
        out.append(itos[idx])
        prev = prev[1:] + [idx]
    print(''.join(out))

anel
tah
nyse
brixo
leizia
kalra
maelli
sin
ell
noya
