In [226]:
import torch
import matplotlib.pyplot as plt

In [227]:
words = open('names.txt', 'r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [228]:
# create trigrams
letters = sorted(list(set(''.join(words))))
words = ('..' + word + '.' for word in words) # .. - beginning, . - end
trigrams = [(word[i:i+2], word[i+2]) for word in words for i in range(len(word)) if i+2 < len(word)]
letter_pairs = sorted(set(tup[0] for tup in trigrams))
trigrams

[('..', 'e'),
 ('.e', 'm'),
 ('em', 'm'),
 ('mm', 'a'),
 ('ma', '.'),
 ('..', 'o'),
 ('.o', 'l'),
 ('ol', 'i'),
 ('li', 'v'),
 ('iv', 'i'),
 ('vi', 'a'),
 ('ia', '.'),
 ('..', 'a'),
 ('.a', 'v'),
 ('av', 'a'),
 ('va', '.'),
 ('..', 'i'),
 ('.i', 's'),
 ('is', 'a'),
 ('sa', 'b'),
 ('ab', 'e'),
 ('be', 'l'),
 ('el', 'l'),
 ('ll', 'a'),
 ('la', '.'),
 ('..', 's'),
 ('.s', 'o'),
 ('so', 'p'),
 ('op', 'h'),
 ('ph', 'i'),
 ('hi', 'a'),
 ('ia', '.'),
 ('..', 'c'),
 ('.c', 'h'),
 ('ch', 'a'),
 ('ha', 'r'),
 ('ar', 'l'),
 ('rl', 'o'),
 ('lo', 't'),
 ('ot', 't'),
 ('tt', 'e'),
 ('te', '.'),
 ('..', 'm'),
 ('.m', 'i'),
 ('mi', 'a'),
 ('ia', '.'),
 ('..', 'a'),
 ('.a', 'm'),
 ('am', 'e'),
 ('me', 'l'),
 ('el', 'i'),
 ('li', 'a'),
 ('ia', '.'),
 ('..', 'h'),
 ('.h', 'a'),
 ('ha', 'r'),
 ('ar', 'p'),
 ('rp', 'e'),
 ('pe', 'r'),
 ('er', '.'),
 ('..', 'e'),
 ('.e', 'v'),
 ('ev', 'e'),
 ('ve', 'l'),
 ('el', 'y'),
 ('ly', 'n'),
 ('yn', '.'),
 ('..', 'a'),
 ('.a', 'b'),
 ('ab', 'i'),
 ('bi', 'g'),
 ('ig'

In [229]:
# mappings
str_to_inx_pairs = {str:inx for inx, str in enumerate(letter_pairs, start=0)}
str_to_inx_letter = {str:inx for inx, str in enumerate(letters, start=1)}
str_to_inx_pairs['..'] = 0
str_to_inx_letter['.'] = 0
inx_to_str_pairs = {inx:str for str, inx in str_to_inx_pairs.items()}
inx_to_str_letter = {inx:str for str, inx in str_to_inx_letter.items()}
print(str_to_inx_pairs)
str_to_inx_letter

{'..': 0, '.a': 1, '.b': 2, '.c': 3, '.d': 4, '.e': 5, '.f': 6, '.g': 7, '.h': 8, '.i': 9, '.j': 10, '.k': 11, '.l': 12, '.m': 13, '.n': 14, '.o': 15, '.p': 16, '.q': 17, '.r': 18, '.s': 19, '.t': 20, '.u': 21, '.v': 22, '.w': 23, '.x': 24, '.y': 25, '.z': 26, 'aa': 27, 'ab': 28, 'ac': 29, 'ad': 30, 'ae': 31, 'af': 32, 'ag': 33, 'ah': 34, 'ai': 35, 'aj': 36, 'ak': 37, 'al': 38, 'am': 39, 'an': 40, 'ao': 41, 'ap': 42, 'aq': 43, 'ar': 44, 'as': 45, 'at': 46, 'au': 47, 'av': 48, 'aw': 49, 'ax': 50, 'ay': 51, 'az': 52, 'ba': 53, 'bb': 54, 'bc': 55, 'bd': 56, 'be': 57, 'bh': 58, 'bi': 59, 'bj': 60, 'bl': 61, 'bn': 62, 'bo': 63, 'br': 64, 'bs': 65, 'bt': 66, 'bu': 67, 'by': 68, 'ca': 69, 'cc': 70, 'cd': 71, 'ce': 72, 'cg': 73, 'ch': 74, 'ci': 75, 'cj': 76, 'ck': 77, 'cl': 78, 'co': 79, 'cp': 80, 'cq': 81, 'cr': 82, 'cs': 83, 'ct': 84, 'cu': 85, 'cx': 86, 'cy': 87, 'cz': 88, 'da': 89, 'db': 90, 'dc': 91, 'dd': 92, 'de': 93, 'df': 94, 'dg': 95, 'dh': 96, 'di': 97, 'dj': 98, 'dk': 99, 'dl': 100

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '.': 0}

In [230]:
N = torch.zeros(len(letter_pairs), len(letters) + 1, dtype=torch.int32)
for trigram in trigrams:
    j = str_to_inx_letter[trigram[1]]
    i = str_to_inx_pairs[trigram[0]]
    N[i, j] += 1

print(N.shape)
# plt.figure(figsize=(19,19))
# plt.imshow(N, cmap='Blues')
# for j in range(N.shape[0]):
#     for i in range(N.shape[1]):
#         chstr = inx_to_str_pairs[j] + inx_to_str_letter[i]
#         plt.text(i, j, chstr, ha="center", va="bottom", color='gray')
#         plt.text(i, j, N[j, i].item(), ha="center", va="top", color='gray')
# plt.axis('off')
# print(N[:, -1])

torch.Size([602, 27])


In [231]:
# plt.imshow(N)
plt.show()
P = (N+1).float() # avoid getting inf log
P /= P.sum(dim=1, keepdim=True)  # normalize


In [232]:
generator = torch.Generator().manual_seed(2147483647)
for _ in range(10):
    out = '..'
    inx = 0
    while True:
        last_two_chars = out[-2:]
        inx = str_to_inx_pairs[last_two_chars]
        prob_vector = P[inx]
        inx = torch.multinomial(prob_vector, num_samples=1, replacement=True, generator=generator).item()
        next_char = inx_to_str_letter[inx]
        out += next_char
        if inx == 0:
            break
    print(out)

..miq.
..axx.
..mereyannyaar.
..knooraen.
..el.
..marviovania.
..odarimalabelon.
..hamirelslen.
..elyn.
..rae.
