In [187]:
import torch
import matplotlib.pyplot as plt
import math


In [188]:
words = open('names.txt', 'r').read().splitlines()


In [189]:
# create data set
xs, ys = [], []
for name in words:
    name = '..' + name + '.'
    for ch1, ch2, ch3 in zip(name, name[1:], name[2:]):
        xs.append(ch1+ch2)
        ys.append(ch3)

In [190]:
# mappings
str_to_inx_pairs = {str:inx for inx, str in enumerate(sorted(set(xs)))}
str_to_inx_letter = {str:inx for inx, str in enumerate(sorted(set(ys)))}
str_to_inx_pairs['..'] = 0
str_to_inx_letter['.'] = 0
inx_to_str_pairs = {inx:str for str, inx in str_to_inx_pairs.items()}
inx_to_str_letter = {inx:str for str, inx in str_to_inx_letter.items()}
print(str_to_inx_pairs)
str_to_inx_letter

{'..': 0, '.a': 1, '.b': 2, '.c': 3, '.d': 4, '.e': 5, '.f': 6, '.g': 7, '.h': 8, '.i': 9, '.j': 10, '.k': 11, '.l': 12, '.m': 13, '.n': 14, '.o': 15, '.p': 16, '.q': 17, '.r': 18, '.s': 19, '.t': 20, '.u': 21, '.v': 22, '.w': 23, '.x': 24, '.y': 25, '.z': 26, 'aa': 27, 'ab': 28, 'ac': 29, 'ad': 30, 'ae': 31, 'af': 32, 'ag': 33, 'ah': 34, 'ai': 35, 'aj': 36, 'ak': 37, 'al': 38, 'am': 39, 'an': 40, 'ao': 41, 'ap': 42, 'aq': 43, 'ar': 44, 'as': 45, 'at': 46, 'au': 47, 'av': 48, 'aw': 49, 'ax': 50, 'ay': 51, 'az': 52, 'ba': 53, 'bb': 54, 'bc': 55, 'bd': 56, 'be': 57, 'bh': 58, 'bi': 59, 'bj': 60, 'bl': 61, 'bn': 62, 'bo': 63, 'br': 64, 'bs': 65, 'bt': 66, 'bu': 67, 'by': 68, 'ca': 69, 'cc': 70, 'cd': 71, 'ce': 72, 'cg': 73, 'ch': 74, 'ci': 75, 'cj': 76, 'ck': 77, 'cl': 78, 'co': 79, 'cp': 80, 'cq': 81, 'cr': 82, 'cs': 83, 'ct': 84, 'cu': 85, 'cx': 86, 'cy': 87, 'cz': 88, 'da': 89, 'db': 90, 'dc': 91, 'dd': 92, 'de': 93, 'df': 94, 'dg': 95, 'dh': 96, 'di': 97, 'dj': 98, 'dk': 99, 'dl': 100

{'.': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26}

In [191]:
xs = [str_to_inx_pairs[x] for x in xs]
ys = [str_to_inx_letter[y] for y in ys]

In [192]:
xs, ys = torch.tensor(xs), torch.tensor(ys)

In [193]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=len(inx_to_str_pairs)).float()
xenc.shape

torch.Size([228146, 602])

In [194]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((len(inx_to_str_pairs), len(inx_to_str_letter)), generator=g, requires_grad=True)
W.shape

torch.Size([602, 27])

In [195]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce MX130


In [196]:
# data set splits
train_range = math.ceil(len(xs) * 0.8)
dev_range = (len(xs) - train_range) // 2

training_set = xs[:train_range]
dev_set = xs[train_range:train_range+dev_range]
test_set = xs[train_range+dev_range:]

y_training_set = ys[:train_range]
y_dev_set = ys[train_range:train_range+dev_range]
y_test_set = ys[train_range+dev_range:]

assert training_set.nelement() + dev_set.nelement() + test_set.nelement() == xs.nelement(), "Bad split"

In [197]:
# train the NN
input_num = training_set.nelement()
xenc = F.one_hot(training_set, num_classes=len(str_to_inx_pairs)).float()

print(input_num)
for _ in range(50):
    # forward pass
    logits = torch.matmul(xenc, W)
    # counts = logits.exp()
    # probs = counts / counts.sum(1, keepdim=True)
    # loss = -probs[torch.arange(input_num), y_training_set].log().mean() + 0.01*(W**2).mean()
    loss = F.cross_entropy(logits, y_training_set) + 0.01*(W**2).mean()
    # print(loss)

    # backward pass
    W.grad = None
    loss.backward()

    # update
    W.data += -150 * W.grad
loss

182517


tensor(2.3981, grad_fn=<AddBackward0>)

In [198]:
# dev set evaluation
input_num = dev_set.nelement()
xenc = F.one_hot(dev_set, num_classes=len(str_to_inx_pairs)).float()

logits = xenc @ W
# -------- EQUIVALENT TO CROSS_ENTROPY --------
# counts = logits.exp()
# probs = counts / counts.sum(dim=1, keepdim=True)
# loss = -probs[torch.arange(input_num), y_dev_set].log().mean()
# -------- EQUIVALENT TO CROSS_ENTROPY --------
print(xenc.shape, W.shape, logits.shape)
loss = F.cross_entropy(logits, y_dev_set) # nll + softmax
loss

torch.Size([22814, 602]) torch.Size([602, 27]) torch.Size([22814, 27])


tensor(2.6422, grad_fn=<NllLossBackward0>)

In [199]:
# test set evaluation
input_num = test_set.nelement()
xenc = F.one_hot(test_set, num_classes=len(str_to_inx_pairs)).float()

logits = xenc @ W
loss = F.cross_entropy(logits, y_test_set)
loss

tensor(2.6578, grad_fn=<NllLossBackward0>)

In [200]:
# tune regularization using dev set
input_num = dev_set.nelement()
xenc = F.one_hot(dev_set, num_classes=len(str_to_inx_pairs)).float()

losses = []
for reg_term in [0.001, 0.01, 0.1, 1.0, 10]:
    W_tune = torch.randn((len(inx_to_str_pairs), len(inx_to_str_letter)), generator=g, requires_grad=True)
    print(W_tune[3,3])
    for _ in range(50):
        # forward pass
        logits = torch.matmul(xenc, W_tune)
        loss = F.cross_entropy(logits, y_dev_set) + reg_term*(W_tune**2).mean()
        # print(loss)

        # backward pass
        W_tune.grad = None
        loss.backward()

        # update
        W_tune.data += -100 * W.grad
    losses.append(loss)
losses

tensor(1.1870, grad_fn=<SelectBackward0>)
tensor(-0.5672, grad_fn=<SelectBackward0>)
tensor(-0.3240, grad_fn=<SelectBackward0>)
tensor(1.0157, grad_fn=<SelectBackward0>)
tensor(0.2222, grad_fn=<SelectBackward0>)


[tensor(6.7359, grad_fn=<AddBackward0>),
 tensor(6.4630, grad_fn=<AddBackward0>),
 tensor(6.4656, grad_fn=<AddBackward0>),
 tensor(7.3441, grad_fn=<AddBackward0>),
 tensor(18.0238, grad_fn=<AddBackward0>)]