In [1]:
words = open("../data/names.txt", "r").read().splitlines()
len(words)

32033

# Bigram implementation

In [2]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [3]:
import torch

N = torch.zeros((27, 27), dtype=torch.int32)


for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    N[ix1, ix2] += 1
    

In [4]:
# import matplotlib.pyplot as plt
# %matplotlib inline

# plt.figure(figsize=(16, 16))
# k = 0
# plt.imshow(N, cmap="Blues")
# for i in range(27):
#   for j in range(27):
#     chstr = itos[i] + itos[j]
#     plt.text(j, i, chstr, ha="center", va="bottom", color='gray')
#     plt.text(j, i, N[i, j].item(), ha="center", va="top", color='gray')

# plt.axis('off')

(-0.5, 26.5, 26.5, -0.5)

In [4]:
g = torch.Generator().manual_seed(2147483647)

result = []
for i in range(10):
    word = ''
    ix = 0
    while True:
        p = N[ix].float()
        p = p / p.sum()
        #p = torch.rand(size=(1, 27), generator=g)
        next = torch.multinomial(p, 1, replacement=True, generator=g).item()
        char = itos[next]
        if char == '.':
            break
        else:
            word += char
        ix = next
    result.append(word)
print(f'bigram: \n{result}')

bigram: 
['mor', 'axx', 'minaymoryles', 'kondlaisah', 'anchshizarie', 'odaren', 'iaddash', 'h', 'jhinatien', 'egushl']


In [5]:
P = (N + 1).float()
P /= P.sum(1, keepdim=True)

In [6]:
log_likelihood = 0.0
n = 0

for w in words[:3]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        prob = P[ix1, ix2]
        logprob = torch.log(prob)
        print(f'{ch1}{ch2} {prob: .4f} {logprob: .4f}')
        log_likelihood += logprob
        n += 1
log_likelihood
nll = -log_likelihood
print(f'{nll}')
print(f'{nll/n}')

.e  0.0478 -3.0410
em  0.0377 -3.2793
mm  0.0253 -3.6753
ma  0.3885 -0.9454
a.  0.1958 -1.6305
.o  0.0123 -4.3965
ol  0.0779 -2.5526
li  0.1774 -1.7293
iv  0.0152 -4.1845
vi  0.3508 -1.0476
ia  0.1380 -1.9807
a.  0.1958 -1.6305
.a  0.1376 -1.9835
av  0.0246 -3.7041
va  0.2473 -1.3971
a.  0.1958 -1.6305
38.80856704711914
2.4255354404449463


# Trigram implementation

In [7]:
N3 = torch.zeros((27, 27, 27), dtype=torch.int32)
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    N3[ix1, ix2, ix3] += 1

In [8]:
g = torch.Generator().manual_seed(2147483647)

result = []
for i in range(30):
    word = ''
    ix = 0
    iy = 1
    while True:
        p = N3[ix, iy].float()
        p = p / p.sum()
        next = torch.multinomial(p, 1, replacement=True, generator=g).item()
        char = itos[next]
        if char == '.':
            break
        else:
            word += char
        ix = iy
        iy = next
    result.append(word)
print(f'trigram: \n{result}')

trigram: 
['rri', 'bry', 'nii', 'yloswais', 'nnaaraen', 'la', 'quetony', 'sid', 'ra', 'rimalaalexiaganilley', 'lia', 'lyn', 's', 'la', 'v', 'dalizleigh', 'h', 'ullia', 'nian', 'da', 'l', 'berendecatrutandenneppalycethon', 'maraivyn', 'yton', 'jdenelaymira', 'nn', 'ltseberrysinlexton', 'd', 'yah', 'nn']


In [9]:
P3 = (N3 + 1).float()
P3 /= P3.sum(2, keepdim=True)

In [10]:
P3[0, 0].sum()

tensor(1.)

In [11]:
log_likelihood3 = 0.0
n3 = 0

for w in words[:3]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        prob = P3[ix1, ix2, ix3]
        logprob = torch.log(prob)
        print(f'{ch1}{ch2}{ch3} {prob: .4f} {logprob: .4f}')
        log_likelihood3 += logprob
        n3 += 1
log_likelihood3
nll3 = -log_likelihood3
print(f'{nll3}')
print(f'{nll3/n3}')

.em  0.1855 -1.6847
emm  0.1269 -2.0645
mma  0.3744 -0.9825
ma.  0.0669 -2.7050
.ol  0.2494 -1.3887
oli  0.1084 -2.2223
liv  0.0219 -3.8195
ivi  0.2669 -1.3209
via  0.1578 -1.8465
ia.  0.3657 -1.0060
.av  0.0550 -2.9006
ava  0.1882 -1.6705
va.  0.1405 -1.9625
25.574190139770508
1.96724534034729
