In [6]:
WORDS_FILE = 'names.txt'
words = open(WORDS_FILE, 'r', encoding="utf8").read().splitlines()
words = [w.lower() for w in words]
words = [w for w in words if (len(w) > 0)]
words[:10]
letters = sorted(list(set(''.join(words))))
chtoi = {ch:(i+1) for i,ch in enumerate(letters)}
chtoi['.'] = 0
itoch = {i:ch for ch,i in chtoi.items()}
N = len(letters) + 1

In [7]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

CTX_SZ = 4

def make_dataset(words):
    x, y = [], []
    for w in words:        
        ctx = [0] * CTX_SZ
        for ch in w + '.':
            idx = chtoi[ch]
            x.append(ctx)
            y.append(idx)
            ctx = ctx[1:] + [idx]
    x = torch.tensor(x)
    y = torch.tensor(y)
    return x, y

import random
random.seed(244)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

x_trn, y_trn = make_dataset(words[:n1])
x_dev, y_dev = make_dataset(words[n1:n2])
x_tst, y_tst = make_dataset(words[n2:])

N_SAMPLES_TRN = x_trn.shape[0]
N_SAMPLES_DEV = x_dev.shape[0]
N_SAMPLES_TST = x_tst.shape[0]

In [8]:
HID_L_SZ = 200
EMB_DIM = 3

C = torch.randn((N, EMB_DIM))
W1 = torch.randn((CTX_SZ * EMB_DIM, HID_L_SZ)) * 0.1
b1 = torch.randn(HID_L_SZ)                     * 0.01
W2 = torch.randn((HID_L_SZ, N))                * 0.01
b2 = torch.randn(N)                            * 0
params = [C, W1, b1, W2, b2]
for p in params:
    p.requires_grad = True

In [9]:
ITS = 400000
BATCH_SZ = 128

lre = torch.linspace(-3, 0, ITS)
lrs = 10**lre
LR = 0.1
losses = []

def eval_loss(x, y, batch):
    if batch is not None:
        emb = C[x[batch]].view(-1, CTX_SZ * EMB_DIM)
    else:
        emb = C[x].view(-1, CTX_SZ * EMB_DIM)
    h = torch.tanh(emb @ W1 + b1)
    logits = h @ W2 + b2
    if batch is not None:
        loss = F.cross_entropy(logits, y[batch])
    else:
        loss = F.cross_entropy(logits, y)
    return loss

for i in range(ITS):
    bix = torch.randint(0, N_SAMPLES_TRN, (BATCH_SZ,))
    loss = eval_loss(x_trn, y_trn, bix)
    for p in params:
        p.grad = None
    loss.backward()
    lr = (LR/10) if (i > ITS / 2) else LR
    for p in params:
        p.data -= LR * p.grad
    if i % 40000 == 0:
        print(loss.item())
    losses.append(loss.item())

3.2934956550598145


In [10]:
#plt.figure(figsize=(8,8))
#plt.scatter(C[:,0].data, C[:,1].data, s=200)
#for i in range(C.shape[0]):
#    plt.text(C[i,0].item(), C[i,1].item(), itoch[i], ha='center', va='center', color='white')

In [11]:
print(eval_loss(x_dev, y_dev, None).item())

2.3570163249969482


In [12]:
g = torch.Generator().manual_seed(2147483647)
idx = 0
for k in range(20):
    word = ''
    ctx = [0] * CTX_SZ
    while True:
        emb = C[torch.tensor(ctx)].view(-1, CTX_SZ * EMB_DIM)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        idx = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        ctx = ctx[1:] + [idx]
        if idx == 0:
            print(word)
            break
        word += itoch[idx]

torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
