# Bigram Model with NN

In [1]:
import torch
from torch.nn.functional import one_hot

# Data

In [2]:
with open('names.txt') as f:
    data = f.read().split('\n')
data[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
characters = list(set([i for i in ''.join(data)]))
characters.insert(1, '.')
hmap = {}
for n in range(0,27):
    hmap[sorted(characters)[n]] = n
print(hmap)

{'.': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}


In [4]:
# Create dataset
X, y = [],[]

for name in data:
    name = '.' + name + '.'
    for ch1, ch2 in zip(name, name[1:]):
        X.append(hmap[ch1])
        y.append(hmap[ch2])
    # break

In [5]:
X = one_hot(torch.tensor(X), num_classes=27).float()
y = torch.tensor(y).int()
X.shape, y.shape

(torch.Size([228146, 27]), torch.Size([228146]))

# Model

In [7]:
# Test
W = torch.randn((27,27), requires_grad=True)
logits = X @ W # these are log-counts
cts = logits.exp()
probs = cts/cts.sum(1, keepdim=True) # softmax
probs

tensor([[0.0110, 0.0053, 0.0664,  ..., 0.0534, 0.0099, 0.0346],
        [0.0147, 0.0612, 0.0393,  ..., 0.2099, 0.0072, 0.0102],
        [0.0528, 0.0196, 0.0082,  ..., 0.0146, 0.0169, 0.0365],
        ...,
        [0.0217, 0.0232, 0.0047,  ..., 0.1152, 0.0065, 0.0663],
        [0.0100, 0.1927, 0.0139,  ..., 0.0519, 0.0343, 0.0657],
        [0.1069, 0.0825, 0.0111,  ..., 0.0325, 0.1544, 0.0102]],
       grad_fn=<DivBackward0>)

In [8]:
probs[0].sum() # Sum along a row should be 1

tensor(1., grad_fn=<SumBackward0>)

This would have given the output, if W was correct. So, we optimise W

# Training

In [9]:
num_epochs = 300
lr = 50
n = len(X)

In [10]:
X, y

(tensor([[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 tensor([ 5, 13, 13,  ..., 26, 24,  0], dtype=torch.int32))

In [11]:
for i in range(num_epochs):
    
    # forward pass
    logits = X @ W 
    cts = logits.exp()
    probs = cts/cts.sum(1, keepdim=True)

    # loss function
    loss = -probs[torch.arange(n), y].log().mean() + .01*(W**2).mean() # Negative log likelyhood loss; Last term is for regularisation
    if not i%100:
        print(f'Epoch {i}: {loss.item()}')
    
    
    # backward pass
    W.grad = None
    loss.backward()
    W.data += lr * -W.grad


Epoch 0: 3.7964553833007812
Epoch 100: 2.4904351234436035
Epoch 200: 2.4831576347351074


# Inference

In [12]:
rev_hmap = {value:key for key,value in hmap.items()}
print(rev_hmap)

{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [13]:
for i in range(10):

    next_index = torch.tensor(0)
    l = []
    while True:
        x = one_hot(next_index,num_classes=27,).float().reshape(1,27)
        logits = x @ W 
        cts = logits.exp()
        probs = cts/cts.sum(1, keepdim=True)
        # print(probs)
        # next_index = torch.argmax(probs) # This chooses the most likely one
        next_index = torch.multinomial(probs,num_samples=1, replacement=True) # This chooses according to probability
        l.append(rev_hmap[next_index.item()])
        
        if next_index == 0:
            break

    print(''.join(l))  

ryninadann.
e.
sesennele.
mbn.
ve.
te.
mayson.
marata.
ka.
banenandedly.
