In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# get all words
words = open('names.txt', 'r').read().splitlines()
words[:15]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn',
 'abigail',
 'emily',
 'elizabeth',
 'mila',
 'ella']

In [6]:
chars = sorted(list(set(''.join(words))))
s_to_i = {s:i+1 for i, s in enumerate(chars)}
s_to_i['.'] = 0
i_to_s = {i:s for s, i in s_to_i.items()}
print(f's to i: {s_to_i}\n\ni to s: {i_to_s}')

s to i: {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}

i to s: {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [11]:
# build dataset

block_size = 3 # set context length - char used to predict the next char
X, Y = [], []
for word in words[:5]:
    print(word)
    context = [0] * block_size

    for char in word + '.':
        ix = s_to_i[char]
        X.append(context)
        Y.append(ix)
        print(''.join(i_to_s[i] for i in context), '->', i_to_s[ix])
        context = context[1:] + [ix] # crop first character and append next

X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... -> e
..e -> m
.em -> m
emm -> a
mma -> .
olivia
... -> o
..o -> l
.ol -> i
oli -> v
liv -> i
ivi -> a
via -> .
ava
... -> a
..a -> v
.av -> a
ava -> .
isabella
... -> i
..i -> s
.is -> a
isa -> b
sab -> e
abe -> l
bel -> l
ell -> a
lla -> .
sophia
... -> s
..s -> o
.so -> p
sop -> h
oph -> i
phi -> a
hia -> .


In [12]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [14]:
# embed characters in an n dimensional space
C = torch.randn((27, 2))
C

tensor([[-0.0245,  1.9602],
        [ 2.9721, -2.3247],
        [-1.0816,  0.4286],
        [-0.0668, -1.8977],
        [-0.4481, -0.4141],
        [-0.9444, -0.1084],
        [-1.1674,  0.3169],
        [ 0.2859, -1.5466],
        [-1.7160, -2.0809],
        [ 1.7888, -0.4585],
        [ 0.3697, -0.3948],
        [-0.8490, -1.0591],
        [ 0.2968,  0.2380],
        [-1.7551, -2.2854],
        [-0.3268, -0.3873],
        [-1.1596,  0.5439],
        [-1.0180,  0.4369],
        [ 0.0159,  1.5863],
        [-0.7533, -0.5934],
        [-0.7389,  0.1808],
        [ 0.9034,  0.0109],
        [ 0.5795,  0.2692],
        [-0.1713,  3.0976],
        [ 1.0953,  0.9038],
        [ 1.7837,  0.4909],
        [ 2.3491,  0.4712],
        [-0.3091, -1.2805]])