In [1]:
import torch
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

## Data prep

In [2]:
words = open("./names.txt", "r").read().splitlines()
train_words, temp_words = train_test_split(words, train_size=0.8, random_state=42)
dev_words, test_words = train_test_split(temp_words, test_size=0.5, random_state=42)

In [3]:
len(train_words), len(dev_words), len(test_words)

(25626, 3203, 3204)

## Trigram

In [4]:
train_chars = sorted(list(set(''.join(train_words))))
two_chars = set()
for c1 in train_chars+["."]:
  for c2 in train_chars+["."]:
    two_chars.add(c1+c2)

two_chars = sorted(list(two_chars))

stoi = {s:i+1 for i,s in enumerate(train_chars)}
stoi["."] = 0
stoi2 = {s:i for i,s in enumerate(two_chars)}
itos2 = {i:s for i,s in enumerate(two_chars)}

In [5]:
xs_t, ys_t = [], []
for w in train_words:
  chs = ["."] + list(w) + ["."]
  for ch1,ch2,ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi2[ch1+ch2]
    ix2 = stoi[ch3]
    xs_t.append(ix1)
    ys_t.append(ix2)

xs_t = torch.tensor(xs_t)
ys_t = torch.tensor(ys_t)

W = torch.empty(0)

In [6]:
def train(epochs=150):
    global W
    g = torch.Generator().manual_seed(2147483647)
    W = torch.randn((729, 27), generator=g, requires_grad=True)

    for i in range(150):
        # forward pass
        logits = W[xs_t]
        # counts = logits.exp()
        # probs = counts / counts.sum(1, keepdim=True)
        # loss = -probs[torch.arange(xs_t.nelement()), ys_t].log().mean() + reg_factor*(W**2).mean()
        
        loss = F.cross_entropy(logits, ys_t) + 0.01*(W**2).mean()
    
        print(f"Epoch: {i}; Loss: {loss.item()}")
    
      # backward pass
        W.grad = None
        loss.backward()
        with torch.no_grad():
            W.data += -75 * W.grad

In [7]:
train()

Epoch: 0; Loss: 3.7336502075195312
Epoch: 1; Loss: 3.6212451457977295
Epoch: 2; Loss: 3.519554376602173
Epoch: 3; Loss: 3.4285521507263184
Epoch: 4; Loss: 3.3480429649353027
Epoch: 5; Loss: 3.277346134185791
Epoch: 6; Loss: 3.215322494506836
Epoch: 7; Loss: 3.160595417022705
Epoch: 8; Loss: 3.1118381023406982
Epoch: 9; Loss: 3.067963123321533
Epoch: 10; Loss: 3.0281543731689453
Epoch: 11; Loss: 2.991804838180542
Epoch: 12; Loss: 2.9584555625915527
Epoch: 13; Loss: 2.9277443885803223
Epoch: 14; Loss: 2.8993759155273438
Epoch: 15; Loss: 2.873103380203247
Epoch: 16; Loss: 2.8487143516540527
Epoch: 17; Loss: 2.8260247707366943
Epoch: 18; Loss: 2.804870367050171
Epoch: 19; Loss: 2.7851061820983887
Epoch: 20; Loss: 2.7666003704071045
Epoch: 21; Loss: 2.749232530593872
Epoch: 22; Loss: 2.732896566390991
Epoch: 23; Loss: 2.7174949645996094
Epoch: 24; Loss: 2.7029411792755127
Epoch: 25; Loss: 2.6891579627990723
Epoch: 26; Loss: 2.676076889038086
Epoch: 27; Loss: 2.663638114929199
Epoch: 28; Los

In [8]:
def get_loss(word_set):
    xs_t, ys_t = [], []
    for w in word_set:
        chs = ["."] + list(w) + ["."]
        for ch1,ch2,ch3 in zip(chs, chs[1:], chs[2:]):
            ix1 = stoi2[ch1+ch2]
            ix2 = stoi[ch3]
            xs_t.append(ix1)
            ys_t.append(ix2)
    
    xs_t = torch.tensor(xs_t)
    ys_t = torch.tensor(ys_t)

    with torch.no_grad():
        logits = W[xs_t]
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)
    
        nll = -probs[torch.arange(xs_t.nelement()), ys_t].log().mean()

    return nll

In [9]:
get_loss(dev_words)

tensor(2.2487)