In [1]:
# data
names = open("names.txt").read().splitlines()

In [5]:
# vocabulary
import torch

chars = sorted({ch for name in names for ch in name} | {'<s>', '</s>'} )  # set union
stoi = {ch: i for i, ch in enumerate(chars)}  # string to index
itos = {i: ch for ch, i in stoi.items()}  # index to string
vocab_size = len(chars)

vocab_size
# chars

28

In [12]:
# counts tensor
counts = torch.zeros((vocab_size, vocab_size, vocab_size), dtype=torch.float32)  # 28, 28, 28

for name in names:
    seq = ['<s>', '<s>'] + list(name) + ['</s>']
    for i in range(len(seq) - 2):
        i1 = stoi[seq[i]]
        i2 = stoi[seq[i + 1]]
        i3 = stoi[seq[i + 2]]
        counts[i1, i2, i3] += 1  # i saw trigram <s>, <s>, a +1 more time

counts


tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 4.4100e+03,  ..., 1.3400e+02,
          5.3500e+02, 9.2900e+02],
         [0.0000e+00, 0.0000e+00, 2.0700e+02,  ..., 2.7000e+01,
          1.7300e+02, 1.5200e+02],
         ...,
         [0.0000e+00, 0.0000e+00, 5.7000e+01,  ..., 1.0000e+00,
          1.700

In [14]:
# probs
probs = counts / counts.sum(dim=2, keepdim=True)
probs[torch.isnan(probs)] = 0  # handles div by 0

probs

tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.1377,  ..., 0.0042, 0.0167, 0.0290],
         [0.0000, 0.0000, 0.0469,  ..., 0.0061, 0.0392, 0.0345],
         ...,
         [0.0000, 0.0000, 0.4254,  ..., 0.0075, 0.1269, 0.0821],
         [0.0000, 0.0000, 0.4598,  ..., 0.0000, 0.0000, 0.0037],
         [0.0000, 0.0000, 0.4909,  ..., 0.0000, 0.0980, 0.0011]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0719, 0.0000, 0.0000,  ..., 0.0000, 0.0360, 0.

In [15]:
# generate names
import torch.nn.functional as F

def generate_names():
    name = ['<s>', '<s>']
    while True:
        i1 = stoi[name[-2]]  # first char
        i2 = stoi[name[-1]]  # second char
        
        p = probs[i1, i2] # probabilities of next chars

        if p.sum() == 0:
            break

        # next char
        next_idx = torch.multinomial(p, num_samples=1).item()
        next_char = itos[next_idx]

        if next_char == '</s>':
            break

        name.append(next_char)

    return ''.join(name[2:])    
        
        

In [17]:
for _ in range(10):
    print(generate_names())

jadia
aandris
mael
lingeriana
hia
tryannia
zelynor
jurdyna
alia
branna
