# Bigram Model with NN

In [1]:
import torch
from torch.nn.functional import one_hot

# Data

In [2]:
with open('names.txt') as f:
    data = f.read().split('\n')
data[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
characters = list(set([i for i in ''.join(data)]))
characters.insert(1, '.')
hmap = {}
for n in range(0,27):
    hmap[sorted(characters)[n]] = n
print(hmap)

{'.': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}


In [4]:
# Create dataset
X, y = [],[]

for name in data:
    name = '.' + name + '.'
    for ch1, ch2 in zip(name, name[1:]):
        X.append(hmap[ch1])
        y.append(hmap[ch2])
    # break

In [5]:
X

[0,
 5,
 13,
 13,
 1,
 0,
 15,
 12,
 9,
 22,
 9,
 1,
 0,
 1,
 22,
 1,
 0,
 9,
 19,
 1,
 2,
 5,
 12,
 12,
 1,
 0,
 19,
 15,
 16,
 8,
 9,
 1,
 0,
 3,
 8,
 1,
 18,
 12,
 15,
 20,
 20,
 5,
 0,
 13,
 9,
 1,
 0,
 1,
 13,
 5,
 12,
 9,
 1,
 0,
 8,
 1,
 18,
 16,
 5,
 18,
 0,
 5,
 22,
 5,
 12,
 25,
 14,
 0,
 1,
 2,
 9,
 7,
 1,
 9,
 12,
 0,
 5,
 13,
 9,
 12,
 25,
 0,
 5,
 12,
 9,
 26,
 1,
 2,
 5,
 20,
 8,
 0,
 13,
 9,
 12,
 1,
 0,
 5,
 12,
 12,
 1,
 0,
 1,
 22,
 5,
 18,
 25,
 0,
 19,
 15,
 6,
 9,
 1,
 0,
 3,
 1,
 13,
 9,
 12,
 1,
 0,
 1,
 18,
 9,
 1,
 0,
 19,
 3,
 1,
 18,
 12,
 5,
 20,
 20,
 0,
 22,
 9,
 3,
 20,
 15,
 18,
 9,
 1,
 0,
 13,
 1,
 4,
 9,
 19,
 15,
 14,
 0,
 12,
 21,
 14,
 1,
 0,
 7,
 18,
 1,
 3,
 5,
 0,
 3,
 8,
 12,
 15,
 5,
 0,
 16,
 5,
 14,
 5,
 12,
 15,
 16,
 5,
 0,
 12,
 1,
 25,
 12,
 1,
 0,
 18,
 9,
 12,
 5,
 25,
 0,
 26,
 15,
 5,
 25,
 0,
 14,
 15,
 18,
 1,
 0,
 12,
 9,
 12,
 25,
 0,
 5,
 12,
 5,
 1,
 14,
 15,
 18,
 0,
 8,
 1,
 14,
 14,
 1,
 8,
 0,
 12,
 9,
 12,
 12,
 9,
 1,
 1

In [6]:
X = one_hot(torch.tensor(X), num_classes=27).float()
y = torch.tensor(y).int()
X.shape, y.shape

(torch.Size([228146, 27]), torch.Size([228146]))

In [7]:
X

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.]])

# Model

In [8]:
# Test
W = torch.randn((27,27), requires_grad=True)
logits = X @ W # these are log-counts
cts = logits.exp()
probs = cts/cts.sum(1, keepdim=True) # softmax
probs

tensor([[0.0586, 0.0063, 0.0058,  ..., 0.0277, 0.0167, 0.0213],
        [0.2093, 0.0243, 0.0200,  ..., 0.0057, 0.0183, 0.0815],
        [0.0123, 0.0050, 0.0236,  ..., 0.0239, 0.0205, 0.0568],
        ...,
        [0.0063, 0.0049, 0.0553,  ..., 0.0457, 0.0631, 0.0304],
        [0.0649, 0.0609, 0.0429,  ..., 0.0131, 0.0572, 0.0059],
        [0.0047, 0.1060, 0.0159,  ..., 0.0160, 0.0109, 0.0193]],
       grad_fn=<DivBackward0>)

In [9]:
probs[0].sum() # Sum along a row should be 1

tensor(1.0000, grad_fn=<SumBackward0>)

This would have given the output, if W was correct. So, we optimise W

# Training

In [10]:
num_epochs = 300
lr = 50
n = len(X)

In [11]:
for i in range(num_epochs):
    
    # forward pass
    logits = X @ W 
    cts = logits.exp()
    probs = cts/cts.sum(1, keepdim=True)

    # loss function
    loss = -probs[torch.arange(n), y].log().mean() + .01*(W**2).mean() # Last term is for regularisation
    if not i%100:
        print(f'Epoch {i}: {loss.item()}')
    
    
    # backward pass
    W.grad = None
    loss.backward()
    W.data += lr * -W.grad


Epoch 0: 3.7484958171844482
Epoch 100: 2.4913718700408936
Epoch 200: 2.483287811279297


# Inference

In [12]:
rev_hmap = {value:key for key,value in hmap.items()}
print(rev_hmap)

{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [13]:
for i in range(10):

    next_index = torch.tensor(0)
    l = []
    while True:
        x = one_hot(next_index,num_classes=27,).float().reshape(1,27)
        logits = x @ W 
        cts = logits.exp()
        probs = cts/cts.sum(1, keepdim=True)
        # print(probs)
        # next_index = torch.argmax(probs) # This chooses the most likely one
        next_index = torch.multinomial(probs,num_samples=1, replacement=True) # This chooses according to probability
        l.append(rev_hmap[next_index.item()])
        
        if next_index == 0:
            break

    print(''.join(l))  

de.
micanoro.
llicyalda.
ellara.
souan.
in.
unevyn.
arinigayi.
gahul.
katinthecendeqtle.
