# TODO
1. Add start and end tokens to the text. (later only use one special token, because of the redundancy in the bigrams' probability matrix)
2. Create a simple tokenizer that transforms the char.
3. Transform the data into training data set and use one-hot encoding for the characters. (Note that one-hot is exatly the same as seleting the index of the bigram probability matrix)
4. Create a weight matrix of 27 x 27 as a simple one layer NN.
5. Use softmax to transform the logits into probability and use log-likelihood as a loss function to train the model.
6. Use model smoothing to avoid infinite log-likelihood (regularize the w to zero is equivalent to adding a positive value to the counts of the bigram).
7. Update the weight matrix with the gradient descent.

In [12]:
import torch
import torch.nn.functional as F

# Read

In [None]:
words = open('../names.txt').read().splitlines()
print(words[:10])

In [None]:
words_line = '.' + '.'.join(words) + '.'

In [None]:
words_set = set(words_line)
print(len(words_set))

# Tokenize

In [36]:
chars = list(sorted(words_set))
stoi = {char : i for i, char in enumerate(chars)}
itos = {i : char for i, char in enumerate(chars)}

# Dataset

In [None]:
xs = []
ys = []

for ch1, ch2 in zip(words_line, words_line[1:]):
    id1 = stoi[ch1]
    id2 = stoi[ch2]
    xs.append(id1)
    ys.append(id2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

# Train

In [None]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)
lr = 50
epochs = 150

for epoch in range(epochs):
    # forward
    xenc = F.one_hot(xs, num_classes=27).float()
    logits = xenc @ W
    probs = logits.exp() / torch.sum(logits.exp(), dim=1, keepdim=True) # NOTE: keepdim is critical
    loss = -probs[torch.arange(len(ys)), ys].log().mean() + 0.01 * (W ** 2).mean() # NOTE: regularization

    # backward
    W.grad = None
    loss.backward()
    W.data += -lr * W.grad

    print(f"loss{epoch}: {loss.item():.4f}")

loss0: 3.7686
loss1: 3.3788
loss2: 3.1611
loss3: 3.0272
loss4: 2.9345
loss5: 2.8672
loss6: 2.8167
loss7: 2.7771
loss8: 2.7453
loss9: 2.7188
loss10: 2.6965
loss11: 2.6774
loss12: 2.6608
loss13: 2.6464
loss14: 2.6337
loss15: 2.6225
loss16: 2.6125
loss17: 2.6037
loss18: 2.5958
loss19: 2.5887
loss20: 2.5823
loss21: 2.5764
loss22: 2.5711
loss23: 2.5663
loss24: 2.5618
loss25: 2.5577
loss26: 2.5539
loss27: 2.5504
loss28: 2.5472
loss29: 2.5442
loss30: 2.5414
loss31: 2.5387
loss32: 2.5363
loss33: 2.5340
loss34: 2.5318
loss35: 2.5298
loss36: 2.5279
loss37: 2.5261
loss38: 2.5244
loss39: 2.5228
loss40: 2.5213
loss41: 2.5198
loss42: 2.5185
loss43: 2.5172
loss44: 2.5160
loss45: 2.5148
loss46: 2.5137
loss47: 2.5127
loss48: 2.5117
loss49: 2.5108
loss50: 2.5099
loss51: 2.5090
loss52: 2.5082
loss53: 2.5074
loss54: 2.5066
loss55: 2.5059
loss56: 2.5052
loss57: 2.5045
loss58: 2.5039
loss59: 2.5033
loss60: 2.5027
loss61: 2.5021
loss62: 2.5016
loss63: 2.5011
loss64: 2.5006
loss65: 2.5001
loss66: 2.4996
loss6

# Test

In [49]:
num_samples = 5
g = torch.Generator().manual_seed(2147483647)
for i in range(num_samples):
    ix = stoi['.'] # NOTE: starts with '.'
    out = []
    while True:
        out.append(itos[ix])
        logits = W[ix]
        probs = logits.exp() / torch.sum(logits.exp())
        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        if ix == stoi['.']:
            break
    print(''.join(out))

.junide
.janasah
.p
.cfay
.a
