In [31]:
#LOAD THE DATA OF COURSE
words = open('names.txt', 'r').read().splitlines()

In [32]:
#CHECKS
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [33]:
"""
Create an encoding from all of the chars we see in the data along with start and end symbols, and store the bigram occurances 
of these particular pairs in a tensor. NON neural network implementation
"""
import torch
N = torch.zeros((27,27), dtype=torch.int32)

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0

#Iterate through each word and count up the bigram information for each name
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        N[ix1, ix2] += 1

In [34]:
N[0, :]

tensor([   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
        1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
         134,  535,  929], dtype=torch.int32)

In [35]:
itos = {i:s for s,i in stoi.items()}

In [36]:
"""
Bigram loop for word creating, normalise the vectors for a valid distribution and then just sample values until end character reached.
"""
P = N.float()
P /= P.sum(1, keepdims=True)

for i in range(5):
    out = []
    index = 0
    while True:
        p = P[index]
        index = torch.multinomial(p, num_samples=1, replacement=True).item()
        out.append(itos[index])
        if index == 0:
            break
    print(''.join(out))

maneladhar.
ah.
milir.
titiashanadcyabe.
zisan.


In [40]:
"""
Neural network implementation
"""
import torch.nn.functional as F

xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
num = xs.nelement()
ys = torch.tensor(ys)

W = torch.randn((27,27), requires_grad=True)
"""
Now we must ohe our particular bigram 'features' (first character) and 'targets' (second character), and do that within the gradient descent loop
"""

"\nNow we must ohe our particular bigram 'features' (first character) and 'targets' (second character), and do that within the gradient descent loop\n"

In [44]:
for k in range(100):
    #forward pass
    xenc = F.one_hot(xs, num_classes=27).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01 * (W**2).mean()
    print(loss.item())

    #backward pass
    W.grad = None
    loss.backward()

    #update
    W.data += -10 * W.grad

2.720113515853882
2.7112045288085938
2.7030258178710938
2.6954712867736816
2.6884593963623047
2.6819231510162354
2.6758086681365967
2.670070171356201
2.664670944213867
2.659579277038574
2.6547675132751465
2.6502106189727783
2.6458890438079834
2.6417832374572754
2.637876510620117
2.6341545581817627
2.630603075027466
2.627211332321167
2.623966932296753
2.6208608150482178
2.6178832054138184
2.6150269508361816
2.612283229827881
2.6096460819244385
2.6071085929870605
2.6046650409698486
2.6023101806640625
2.60003924369812
2.5978472232818604
2.5957298278808594
2.5936837196350098
2.5917043685913086
2.589789390563965
2.587935209274292
2.5861382484436035
2.5843966007232666
2.582707166671753
2.5810675621032715
2.5794761180877686
2.577929973602295
2.576427698135376
2.5749664306640625
2.5735456943511963
2.5721631050109863
2.570816993713379
2.5695064067840576
2.5682294368743896
2.5669851303100586
2.56577205657959
2.5645885467529297
2.563434362411499
2.5623083114624023
2.5612082481384277
2.56013464927

In [46]:
for i in range(5):
    out = []
    index = 0
    while True:

        xenc = F.one_hot(torch.tensor([index]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        
        index = torch.multinomial(p, num_samples=1, replacement=True).item()
        out.append(itos[index])
        if index == 0:
            break
    print(''.join(out))

wukilalyll.
letwee.
i.
bgpcty.
thaml.
