# Bigram Model with NN

In [1]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from torch.nn.functional import one_hot

# Data

In [2]:
with open('names.txt') as f:
    data = f.read().split('\n')
data[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
characters = list(set([i for i in ''.join(data)]))
characters.insert(1, '.')
hmap = {}
for n in range(0,27):
    hmap[sorted(characters)[n]] = n
print(hmap)

{'.': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}


In [4]:
# Create dataset
X, y = [],[]

for name in data:
    name = '.' + name + '.'
    for ch1, ch2 in zip(name, name[1:]):
        X.append(hmap[ch1])
        y.append(hmap[ch2])
    # break

X = one_hot(torch.tensor(X), num_classes=27).float()
y = torch.tensor(y).int()
X.shape, y.shape

(torch.Size([228146, 27]), torch.Size([228146]))

# Model

In [5]:
# Test
W = torch.randn((27,27), requires_grad=True)
logits = X @ W # these are log-counts
cts = logits.exp()
probs = cts/cts.sum(1, keepdim=True)
probs

tensor([[0.0304, 0.0378, 0.0221,  ..., 0.0694, 0.0070, 0.0258],
        [0.0057, 0.0741, 0.0158,  ..., 0.0146, 0.0288, 0.0456],
        [0.0252, 0.0454, 0.0151,  ..., 0.0687, 0.0098, 0.0449],
        ...,
        [0.0151, 0.1854, 0.0145,  ..., 0.2090, 0.0120, 0.0304],
        [0.0382, 0.0196, 0.0634,  ..., 0.0104, 0.0141, 0.0374],
        [0.0188, 0.0067, 0.0024,  ..., 0.0084, 0.0164, 0.2399]],
       grad_fn=<DivBackward0>)

In [6]:
probs[0].sum() # Sum along a row should be 1

tensor(1., grad_fn=<SumBackward0>)

This would have given the output, if W was correct. So, we optimise W

# Training

In [38]:
num_epochs = 2000
lr = 50
n = len(X)

In [39]:
for i in range(num_epochs):
    
    # forward pass
    logits = X @ W 
    cts = logits.exp()
    probs = cts/cts.sum(1, keepdim=True)

    # loss function
    loss = -probs[torch.arange(n), y].log().mean() + .01*(W**2).mean() # Last term is for regularisation
    if not i%100:
        print(f'Epoch {i}: {loss.item()}')
    
    
    # backward pass
    W.grad = None
    loss.backward()
    W.data += lr * -W.grad


Epoch 0: 2.50633
Epoch 100: 2.48045
Epoch 200: 2.48043
Epoch 300: 2.48042
Epoch 400: 2.48041
Epoch 500: 2.48041
Epoch 600: 2.48041
Epoch 700: 2.48040
Epoch 800: 2.48040
Epoch 900: 2.48040
Epoch 1000: 2.48040
Epoch 1100: 2.48040
Epoch 1200: 2.48040
Epoch 1300: 2.48040
Epoch 1400: 2.48040
Epoch 1500: 2.48040
Epoch 1600: 2.48040
Epoch 1700: 2.48040
Epoch 1800: 2.48040
Epoch 1900: 2.48040


# Inference

In [9]:
rev_hmap = {value:key for key,value in hmap.items()}
print(rev_hmap)

{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [40]:
for i in range(10):

    next_index = torch.tensor(0)
    l = []
    while True:
        x = one_hot(next_index,num_classes=27,).float().reshape(1,27)
        logits = x @ W 
        cts = logits.exp()
        probs = cts/cts.sum(1, keepdim=True)
        # print(probs)
        # next_index = torch.argmax(probs) # This chooses the most likely one
        next_index = torch.multinomial(probs,num_samples=1, replacement=True) # This chooses according to probability
        l.append(rev_hmap[next_index.item()])
        
        if next_index == 0:
            break

    print(''.join(l))  

fiana.
shanali.
damah.
jaynn.
kanaha.
ovil.
janel.
chi.
sonai.
a.
