In [1]:
# model performance from training on bigrams isn't great.
#   - the only context we have is the last character.
#   - evidently, this is not how names or words "work" - we need more context.
# if we extend the technique from our "manually tuned" bigram model, things get out of hand quickly:
#   - we can predict the next character from the last two characters, instead of just the last character.
#   - however, now we need to tune 27 * 27 = 729 parameters instead of 27.
#   - using the last three characters, this becomes 27 * 27 * 27 = 19683 parameters.
#   - O(27^n) parameters - this quickly becomes infeasible.
# what we'll be doing now is solving the problem with an MLP, following Bengio et al. 2003

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.set_default_device("mps")

In [2]:
# read in all the words
words = open('res/names.txt', 'r').read().split()
print(words[:8])
print(len(words))

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']
32033


In [3]:
# build the vocabulary of characters and mappings to/from integer ids
chars = sorted(list(set(''.join(words))))  # tokens a-z

stoi = {s:i+1 for i,s in enumerate(chars)}  # map each token to unique id
stoi['.'] = 0  # add encoding for terminating token

itos = {i:s for s,i in stoi.items()}  # create the reverse mapping
print(stoi)
print(itos)

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [4]:
# define hyperparameters we'll use later
vocab_size = 27   # vocabulary size: how big is our vocabulary of tokens?
ngram_len = 3     # context length: how many characters do we take to predict the next one?
embed_dim = 2     # embedding dimension: what is the length of the embedding vector for each token?
hidden_dim = 100  # how many neurons are in the hidden layer of the MLP?

In [5]:
# build the dataset
X, Y = [], []
for i, w in enumerate(words):
  context = [0] * ngram_len
  for ch in w + '.':
    ix = stoi[ch]
    X.append(context)
    Y.append(ix)
    if i < 3:
      print(''.join(itos[i] for i in context), '--->', itos[ix])
    context = context[1:] + [ix] # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)

... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
... ---> a
..a ---> v
.av ---> a
ava ---> .


In [6]:
# embedding lookup table - squeeze 1 hot encoded token into 2-dimensional space
C = torch.randn((vocab_size, embed_dim), requires_grad=True)

In [7]:
# lookup the embedding for every input in the training set
#   - the shape of our input becomes [batch_size, ngram_len, embed_dim]
emb = C[X]
print(emb)
print(emb.shape)

tensor([[[ 0.6106,  1.8206],
         [ 0.6106,  1.8206],
         [ 0.6106,  1.8206]],

        [[ 0.6106,  1.8206],
         [ 0.6106,  1.8206],
         [ 0.8544,  0.5877]],

        [[ 0.6106,  1.8206],
         [ 0.8544,  0.5877],
         [ 0.3054,  2.0366]],

        ...,

        [[ 1.9114,  1.2000],
         [ 1.9114,  1.2000],
         [ 1.8408,  0.7235]],

        [[ 1.9114,  1.2000],
         [ 1.8408,  0.7235],
         [ 1.9114,  1.2000]],

        [[ 1.8408,  0.7235],
         [ 1.9114,  1.2000],
         [ 1.6874, -1.1933]]], device='mps:0', grad_fn=<IndexBackward0>)
torch.Size([228146, 3, 2])


In [8]:
# hidden layer of MLP
W1 = torch.randn((ngram_len * embed_dim, hidden_dim), requires_grad=True)
b1 = torch.randn(hidden_dim, requires_grad=True)

emb1 = emb.view(emb.shape[0], -1)  # reshape embeddings
a1 = torch.tanh(emb1 @ W1 + b1)  # calculate activation of hidden layer
print(a1.shape)

torch.Size([228146, 100])


In [9]:
# output layer of MLP
W2 = torch.randn((hidden_dim, vocab_size), requires_grad=True)
b2 = torch.randn(vocab_size, requires_grad=True)

logits = a1 @ W2 + b2  # calculate output activation
print(logits.shape)

torch.Size([228146, 27])


In [10]:
# apply softmax to output
counts = logits.exp()
probs = counts / counts.sum(dim=1, keepdim=True)
print(probs.shape)
print("Sum of first row:", probs[0].sum().item())

torch.Size([228146, 27])
Sum of first row: 0.9999998807907104


In [11]:
# calculate NLL loss
Y_pred = probs[torch.arange(0, len(Y)), Y]  # from each sample, select probability that we output the correct next token
loss = -Y_pred.log().mean()  # log, negate and mean the selected probabilities
print(loss)
# alternatively, we can just use pytorch's cross entropy loss
print(F.cross_entropy(logits, Y))

tensor(16.0064, device='mps:0', grad_fn=<NegBackward0>)
tensor(16.0064, device='mps:0', grad_fn=<NllLossBackward0>)


In [12]:
# backward pass
parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.grad = None  # zero grad
loss.backward()

In [13]:
# update parameters
for p in parameters:
    p.data -= 0.1 * p.grad
    print(p.grad)

tensor([[ 0.9688, -0.3917],
        [ 0.3161,  0.0612],
        [ 0.0135,  0.0149],
        [-0.1218, -0.1740],
        [ 0.1167,  0.0534],
        [ 0.1910,  0.1317],
        [ 0.0302,  0.0169],
        [-0.0099,  0.0047],
        [-0.0940, -0.2259],
        [-0.0649, -0.0011],
        [-0.1143, -0.0973],
        [-0.0188, -0.0502],
        [ 0.0030,  0.0919],
        [ 0.1930,  0.0542],
        [-0.0782, -0.4885],
        [ 0.0415,  0.0446],
        [-0.0245, -0.0345],
        [-0.0052,  0.0054],
        [-0.3049, -0.3419],
        [ 0.0788,  0.0575],
        [ 0.0718, -0.0259],
        [-0.0677, -0.0289],
        [-0.0300, -0.1069],
        [ 0.0214,  0.0093],
        [ 0.0051,  0.0044],
        [ 0.1481, -0.0598],
        [ 0.0296, -0.0349]], device='mps:0')
tensor([[-1.3975e-02, -9.2984e-02, -5.6051e-02,  2.7991e-02,  4.9425e-02,
          2.7716e-02,  3.0799e-02, -6.8132e-02,  1.8786e-01,  5.2439e-02,
          8.9711e-03,  1.7737e-02, -6.3090e-02,  2.7649e-02,  4.3962e-02,
     

In [21]:
# putting it all together... full batch gradient descent

# create the layers of the model
C = torch.randn((vocab_size, embed_dim), requires_grad=True)
W1 = torch.randn((ngram_len * embed_dim, hidden_dim), requires_grad=True)
b1 = torch.randn(hidden_dim, requires_grad=True)
W2 = torch.randn((hidden_dim, vocab_size), requires_grad=True)
b2 = torch.randn(vocab_size, requires_grad=True)
parameters = [C, W1, b1, W2, b2]

# define a forward pass for the model
def forward(xs):
    # embedding
    emb = C[xs]
    # hidden layer
    flat_emb = emb.view(emb.shape[0], -1)
    h = torch.tanh(flat_emb @ W1 + b1)
    # output layer
    return h @ W2 + b2

# train the model
for epoch in range(100):
    # forward pass
    logits = forward(X)
    loss = F.cross_entropy(logits, Y)
    print(f'Epoch {epoch}: Loss= {loss.item()}')
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data -= 0.1 * p.grad

# Note: this seems to get numerically unstable sometimes...

Epoch 0: Loss= 17.79652214050293
Epoch 1: Loss= 16.011938095092773
Epoch 2: Loss= 14.524364471435547
Epoch 3: Loss= 13.348714828491211
Epoch 4: Loss= 12.470897674560547
Epoch 5: Loss= 11.78131103515625
Epoch 6: Loss= 11.183290481567383
Epoch 7: Loss= 10.648771286010742
Epoch 8: Loss= 10.170660018920898
Epoch 9: Loss= 9.744149208068848
Epoch 10: Loss= 9.362197875976562
Epoch 11: Loss= 9.01722240447998
Epoch 12: Loss= 8.702160835266113
Epoch 13: Loss= 8.412109375
Epoch 14: Loss= 8.143003463745117
Epoch 15: Loss= 7.8919830322265625
Epoch 16: Loss= 7.656604766845703
Epoch 17: Loss= 7.435487270355225
Epoch 18: Loss= 7.2281575202941895
Epoch 19: Loss= 7.034721851348877
Epoch 20: Loss= 6.855053424835205
Epoch 21: Loss= 6.687909126281738
Epoch 22: Loss= 6.531892776489258
Epoch 23: Loss= 6.385890960693359
Epoch 24: Loss= 6.249068737030029
Epoch 25: Loss= 6.120635509490967
Epoch 26: Loss= 5.999819278717041
Epoch 27: Loss= 5.8858962059021
Epoch 28: Loss= 5.778189182281494
Epoch 29: Loss= 5.676074