# Makemore using Bigram

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns # for heatmap

In [None]:
with open('names.txt', 'r') as f:
    words = f.read().split()
print(len(words))

In [None]:
characters = sorted(set(''.join(words)+ '.'))
print(characters)

In [None]:
hmap = {c:n for c,n in zip(characters, [i for i in range(27)])}
char_hmap = {n:c for c,n in hmap.items()}
print(hmap,'\n', char_hmap)

In [None]:
arr = torch.ones(27,27, dtype=torch.int32) # ones are used here to avoid inf in MLE.ie, model smoothening using fake counts.

In [None]:
for word in words:
    word = '.' + word + '.'
    for ch1,ch2 in zip(word, word[1:]):
        # print(ch1,ch2)
        arr[hmap[ch1]][hmap[ch2]] += 1 # Row for ch1, Column for ch2

In [None]:
P = arr/arr.sum(dim=1, keepdim=True) # dim=1 means for each row
P[0].sum()

In [None]:
plt.figure(figsize=(15,15))
ax = sns.heatmap(P, annot=True, fmt=".2f", cmap="viridis", cbar=False,annot_kws={"ha": 'center', "va": 'center'},xticklabels=[i for i in hmap], yticklabels=[i for i in hmap])
ax.set_xlabel("Second character")
ax.set_ylabel("First character")
plt.show()

In [None]:
g = torch.Generator()

p = torch.rand(3)
p = p/p.sum()
print(p)
next_index = torch.multinomial(p, num_samples=20, replacement=True, generator=g)
print(next_index)

In [None]:
prob_distribution_of_char = P[0]
print(prob_distribution_of_char)
next_index = torch.multinomial(prob_distribution_of_char, num_samples=1, replacement=True, generator=g).item()
print(next_index)

In [None]:
g = torch.Generator()

for i in range(10):
    out = []
    word = []
    index_of_char = 0 # . initially
    log_likelihood = 0
    while True:
        prob_distribution_of_char = P[index_of_char]
        next_index = torch.multinomial(prob_distribution_of_char, num_samples=1, replacement=True, generator=g).item()
        new_char = char_hmap[next_index]
        out.append(new_char)
        log_likelihood -= torch.log(P[index_of_char,next_index])
        index_of_char = next_index
        if next_index == 0:
            print(''.join(out), end=' -> ')
            n = len(out)
            break
    print(f'Average negative log likelehood: {log_likelihood.item()/(n+1):.2f}\n')
    