In [8]:
import torch
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

## Data prep

In [9]:
words = open("./names.txt", "r").read().splitlines()
train_words, temp_words = train_test_split(words, train_size=0.8, random_state=42)
dev_words, test_words = train_test_split(temp_words, test_size=0.5, random_state=42)

In [10]:
len(train_words), len(dev_words), len(test_words)

(25626, 3203, 3204)

## Trigram

In [11]:
train_chars = sorted(list(set(''.join(train_words))))
two_chars = set()
for c1 in train_chars+["."]:
  for c2 in train_chars+["."]:
    two_chars.add(c1+c2)

two_chars = sorted(list(two_chars))

stoi = {s:i+1 for i,s in enumerate(train_chars)}
stoi["."] = 0
stoi2 = {s:i for i,s in enumerate(two_chars)}
itos2 = {i:s for i,s in enumerate(two_chars)}

In [12]:
xs_t, ys_t = [], []
for w in train_words:
  chs = ["."] + list(w) + ["."]
  for ch1,ch2,ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi2[ch1+ch2]
    ix2 = stoi[ch3]
    xs_t.append(ix1)
    ys_t.append(ix2)

xs_t = torch.tensor(xs_t)
ys_t = torch.tensor(ys_t)

W = torch.empty(0)

In [34]:
def train(epochs=150):
    global W
    g = torch.Generator().manual_seed(2147483647)
    W = torch.randn((729, 27), generator=g, requires_grad=True)

    for i in range(150):
        # forward pass
        logits = W[xs_t]
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)
        # loss = -probs[torch.arange(xs_t.nelement()), ys_t].log().mean() + reg_factor*(W**2).mean()
        loss = F.cross_entropy(probs, ys_t) + 0.01*(W**2).mean()
    
        print(f"Epoch: {i}; Loss: {loss.item()}")
    
      # backward pass
        W.grad = None
        loss.backward()
        with torch.no_grad():
            W.data += -75 * W.grad

In [35]:
train()

Epoch: 0; Loss: 3.30454158782959
Epoch: 1; Loss: 3.3042337894439697
Epoch: 2; Loss: 3.303915023803711
Epoch: 3; Loss: 3.3035848140716553
Epoch: 4; Loss: 3.3032443523406982
Epoch: 5; Loss: 3.3028922080993652
Epoch: 6; Loss: 3.302529811859131
Epoch: 7; Loss: 3.302157402038574
Epoch: 8; Loss: 3.3017749786376953
Epoch: 9; Loss: 3.3013832569122314
Epoch: 10; Loss: 3.30098295211792
Epoch: 11; Loss: 3.3005733489990234
Epoch: 12; Loss: 3.3001549243927
Epoch: 13; Loss: 3.29972767829895
Epoch: 14; Loss: 3.2992913722991943
Epoch: 15; Loss: 3.298845052719116
Epoch: 16; Loss: 3.2983877658843994
Epoch: 17; Loss: 3.2979187965393066
Epoch: 18; Loss: 3.2974355220794678
Epoch: 19; Loss: 3.2969377040863037
Epoch: 20; Loss: 3.296422004699707
Epoch: 21; Loss: 3.295886993408203
Epoch: 22; Loss: 3.2953293323516846
Epoch: 23; Loss: 3.2947463989257812
Epoch: 24; Loss: 3.294135570526123
Epoch: 25; Loss: 3.2934932708740234
Epoch: 26; Loss: 3.292816638946533
Epoch: 27; Loss: 3.2921030521392822
Epoch: 28; Loss: 3.

In [30]:
def get_loss(word_set):
    xs_t, ys_t = [], []
    for w in word_set:
        chs = ["."] + list(w) + ["."]
        for ch1,ch2,ch3 in zip(chs, chs[1:], chs[2:]):
            ix1 = stoi2[ch1+ch2]
            ix2 = stoi[ch3]
            xs_t.append(ix1)
            ys_t.append(ix2)
    
    xs_t = torch.tensor(xs_t)
    ys_t = torch.tensor(ys_t)

    with torch.no_grad():
        logits = W[xs_t]
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)
    
        nll = -probs[torch.arange(xs_t.nelement()), ys_t].log().mean()

    return nll

In [31]:
get_loss(dev_words)

tensor(3.5356)