In [1]:
import torch
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

import pickle

## Data prep

In [2]:
words = open("./names.txt", "r").read().splitlines()
train_words, temp_words = train_test_split(words, train_size=0.8, random_state=42)
dev_words, test_words = train_test_split(temp_words, test_size=0.5, random_state=42)

In [3]:
len(train_words), len(dev_words), len(test_words)

(25626, 3203, 3204)

## Trigram

In [4]:
train_chars = sorted(list(set(''.join(train_words))))
two_chars = set()
for c1 in train_chars+["."]:
  for c2 in train_chars+["."]:
    two_chars.add(c1+c2)

two_chars = sorted(list(two_chars))

stoi = {s:i+1 for i,s in enumerate(train_chars)}
stoi["."] = 0
stoi2 = {s:i for i,s in enumerate(two_chars)}
itos2 = {i:s for i,s in enumerate(two_chars)}

In [5]:
xs_t, ys_t = [], []
for w in train_words:
  chs = ["."] + list(w) + ["."]
  for ch1,ch2,ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi2[ch1+ch2]
    ix2 = stoi[ch3]
    xs_t.append(ix1)
    ys_t.append(ix2)

xs_t = torch.tensor(xs_t)
ys_t = torch.tensor(ys_t)

W = torch.empty(0)

In [6]:
def train(reg_factor, epochs=150):
    global W
    g = torch.Generator().manual_seed(2147483647)
    W = torch.randn((729, 27), generator=g, requires_grad=True)
    for i in range(150):
        # forward pass
        logits = W[xs_t]
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)
        loss = -probs[torch.arange(xs_t.nelement()), ys_t].log().mean() + reg_factor*(W**2).mean()
    
        print(f"Epoch: {i}; Loss: {loss.item()}")
    
      # backward pass
        W.grad = None
        loss.backward()
        with torch.no_grad():
            W.data += -75 * W.grad

In [7]:
def get_loss(word_set):
    xs_t, ys_t = [], []
    for w in word_set:
        chs = ["."] + list(w) + ["."]
        for ch1,ch2,ch3 in zip(chs, chs[1:], chs[2:]):
            ix1 = stoi2[ch1+ch2]
            ix2 = stoi[ch3]
            xs_t.append(ix1)
            ys_t.append(ix2)
    
    xs_t = torch.tensor(xs_t)
    ys_t = torch.tensor(ys_t)

    with torch.no_grad():
        logits = W[xs_t]
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)
    
        nll = -probs[torch.arange(xs_t.nelement()), ys_t].log().mean()

    return nll

In [8]:
train(reg_factor=0.01)

Epoch: 0; Loss: 3.7336504459381104
Epoch: 1; Loss: 3.6212451457977295
Epoch: 2; Loss: 3.519554376602173
Epoch: 3; Loss: 3.4285521507263184
Epoch: 4; Loss: 3.3480427265167236
Epoch: 5; Loss: 3.277346134185791
Epoch: 6; Loss: 3.215322494506836
Epoch: 7; Loss: 3.160595417022705
Epoch: 8; Loss: 3.11183762550354
Epoch: 9; Loss: 3.0679633617401123
Epoch: 10; Loss: 3.028154134750366
Epoch: 11; Loss: 2.991804599761963
Epoch: 12; Loss: 2.9584553241729736
Epoch: 13; Loss: 2.927744150161743
Epoch: 14; Loss: 2.8993759155273438
Epoch: 15; Loss: 2.873103380203247
Epoch: 16; Loss: 2.848714590072632
Epoch: 17; Loss: 2.8260245323181152
Epoch: 18; Loss: 2.804870367050171
Epoch: 19; Loss: 2.7851061820983887
Epoch: 20; Loss: 2.766599655151367
Epoch: 21; Loss: 2.749232530593872
Epoch: 22; Loss: 2.732896566390991
Epoch: 23; Loss: 2.7174947261810303
Epoch: 24; Loss: 2.7029411792755127
Epoch: 25; Loss: 2.6891579627990723
Epoch: 26; Loss: 2.676077127456665
Epoch: 27; Loss: 2.663638114929199
Epoch: 28; Loss: 2.

In [9]:
get_loss(test_words)

tensor(2.2715)

In [42]:
reg_factors = torch.arange(0.001, 0.15, 0.009)

In [43]:
reg_factors.shape

torch.Size([17])

In [47]:
reg_factors

tensor([0.0010, 0.0100, 0.0190, 0.0280, 0.0370, 0.0460, 0.0550, 0.0640, 0.0730,
        0.0820, 0.0910, 0.1000, 0.1090, 0.1180, 0.1270, 0.1360, 0.1450])

In [13]:
dev_losses = {}
for reg_factor in reg_factors:
    train(reg_factor=reg_factor)
    dev_losses[reg_factor] = get_loss(dev_words)

Epoch: 0; Loss: 3.7336504459381104
Epoch: 1; Loss: 3.6212451457977295
Epoch: 2; Loss: 3.519554376602173
Epoch: 3; Loss: 3.4285521507263184
Epoch: 4; Loss: 3.3480427265167236
Epoch: 5; Loss: 3.277346134185791
Epoch: 6; Loss: 3.215322494506836
Epoch: 7; Loss: 3.160595417022705
Epoch: 8; Loss: 3.111837863922119
Epoch: 9; Loss: 3.0679633617401123
Epoch: 10; Loss: 3.0281543731689453
Epoch: 11; Loss: 2.991804599761963
Epoch: 12; Loss: 2.9584553241729736
Epoch: 13; Loss: 2.927744150161743
Epoch: 14; Loss: 2.8993759155273438
Epoch: 15; Loss: 2.873103380203247
Epoch: 16; Loss: 2.848714590072632
Epoch: 17; Loss: 2.8260245323181152
Epoch: 18; Loss: 2.804870367050171
Epoch: 19; Loss: 2.7851061820983887
Epoch: 20; Loss: 2.766599655151367
Epoch: 21; Loss: 2.749232530593872
Epoch: 22; Loss: 2.732896327972412
Epoch: 23; Loss: 2.7174947261810303
Epoch: 24; Loss: 2.7029411792755127
Epoch: 25; Loss: 2.6891579627990723
Epoch: 26; Loss: 2.676077127456665
Epoch: 27; Loss: 2.663638114929199
Epoch: 28; Loss: 

In [14]:
dev_losses

{tensor(0.0100): tensor(2.2487),
 tensor(0.0600): tensor(2.2488),
 tensor(0.1100): tensor(2.2495),
 tensor(0.1600): tensor(2.2507),
 tensor(0.2100): tensor(2.2522),
 tensor(0.2600): tensor(2.2542),
 tensor(0.3100): tensor(2.2565),
 tensor(0.3600): tensor(2.2591),
 tensor(0.4100): tensor(2.2619),
 tensor(0.4600): tensor(2.2650),
 tensor(0.5100): tensor(2.2683),
 tensor(0.5600): tensor(2.2718),
 tensor(0.6100): tensor(2.2754),
 tensor(0.6600): tensor(2.2791),
 tensor(0.7100): tensor(2.2830),
 tensor(0.7600): tensor(2.2870),
 tensor(0.8100): tensor(2.2910),
 tensor(0.8600): tensor(2.2951),
 tensor(0.9100): tensor(2.2992),
 tensor(0.9600): tensor(2.3034)}

In [15]:
with open("./dev_losses.pkl", 'wb') as f:
    pickle.dump(dev_losses, f)

In [16]:
test_losses = {}
for reg_factor in reg_factors:
    train(reg_factor=reg_factor)
    test_losses[reg_factor] = get_loss(test_words)

Epoch: 0; Loss: 3.7336504459381104
Epoch: 1; Loss: 3.6212451457977295
Epoch: 2; Loss: 3.519554376602173
Epoch: 3; Loss: 3.4285521507263184
Epoch: 4; Loss: 3.3480427265167236
Epoch: 5; Loss: 3.277346134185791
Epoch: 6; Loss: 3.215322494506836
Epoch: 7; Loss: 3.160595417022705
Epoch: 8; Loss: 3.11183762550354
Epoch: 9; Loss: 3.0679633617401123
Epoch: 10; Loss: 3.028154134750366
Epoch: 11; Loss: 2.991804599761963
Epoch: 12; Loss: 2.9584553241729736
Epoch: 13; Loss: 2.927744150161743
Epoch: 14; Loss: 2.8993759155273438
Epoch: 15; Loss: 2.873103380203247
Epoch: 16; Loss: 2.848714590072632
Epoch: 17; Loss: 2.8260245323181152
Epoch: 18; Loss: 2.804870367050171
Epoch: 19; Loss: 2.7851061820983887
Epoch: 20; Loss: 2.7666001319885254
Epoch: 21; Loss: 2.749232530593872
Epoch: 22; Loss: 2.732896566390991
Epoch: 23; Loss: 2.7174947261810303
Epoch: 24; Loss: 2.7029411792755127
Epoch: 25; Loss: 2.6891579627990723
Epoch: 26; Loss: 2.676077127456665
Epoch: 27; Loss: 2.663638114929199
Epoch: 28; Loss: 2

In [17]:
test_losses

{tensor(0.0100): tensor(2.2715),
 tensor(0.0600): tensor(2.2713),
 tensor(0.1100): tensor(2.2716),
 tensor(0.1600): tensor(2.2725),
 tensor(0.2100): tensor(2.2738),
 tensor(0.2600): tensor(2.2755),
 tensor(0.3100): tensor(2.2775),
 tensor(0.3600): tensor(2.2799),
 tensor(0.4100): tensor(2.2825),
 tensor(0.4600): tensor(2.2854),
 tensor(0.5100): tensor(2.2885),
 tensor(0.5600): tensor(2.2918),
 tensor(0.6100): tensor(2.2952),
 tensor(0.6600): tensor(2.2988),
 tensor(0.7100): tensor(2.3025),
 tensor(0.7600): tensor(2.3063),
 tensor(0.8100): tensor(2.3102),
 tensor(0.8600): tensor(2.3142),
 tensor(0.9100): tensor(2.3182),
 tensor(0.9600): tensor(2.3223)}

In [18]:
with open("./test_losses.pkl", 'wb') as f:
    pickle.dump(test_losses, f)