# Imports

In [1]:
import torch

In [2]:
import torch.nn.functional as F


# loading the dataset

In [10]:
def get_data():
    with open('names.txt', 'r') as f:
        names = f.read().splitlines()
    return names
    
names = get_data()

In [7]:
# creating "stoi" and "itos"
chars = sorted(list(set(''.join(names))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}



# creating the training data

In [26]:
xs, ys = [], []

for n in names:
    chs = ['.', '.'] + list(n) + ['.','.']
    for chs1, chs2, chs3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[chs1]
        ix2 = stoi[chs2]
        ix3 = stoi[chs3]
        xs.append((ix1, ix2))
        ys.append(ix3)
    

xs = torch.tensor(xs)
ys = torch.tensor(ys)


In [32]:
#one-hot encoded of xs
xenc = F.one_hot(xs, num_classes=27).float()
xenc[0]

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [28]:
xenc[0][0]

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [34]:
xenc.shape

torch.Size([260179, 2, 27])

In [31]:
test = torch.cat([xenc[0][0], xenc[0][1]])
test

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [35]:
xenc_cat = xenc.view(-1, 2*27)
xenc_cat[0]

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

# NN

In [37]:
# initializing W
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((2*27, 27), generator=g, requires_grad=True)

In [47]:
num = xenc.shape[0]
num

260179

In [50]:
for i in range(100):

    logits = xenc_cat @ W #equivalent to count's logarithm
    counts = logits.exp()
    probs = counts/counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(num), ys].log().mean()
    print("loss: ", loss.item())

    W.grad = None
    loss.backward()

    #update
    W.data += -10 * W.grad

    

loss:  2.2992334365844727
loss:  2.295475482940674
loss:  2.2918198108673096
loss:  2.2882637977600098
loss:  2.2848024368286133
loss:  2.28143310546875
loss:  2.2781527042388916
loss:  2.2749581336975098
loss:  2.271845579147339
loss:  2.268812417984009
loss:  2.265856981277466
loss:  2.2629752159118652
loss:  2.2601656913757324
loss:  2.25742506980896
loss:  2.254751682281494
loss:  2.252143383026123
loss:  2.2495977878570557
loss:  2.24711275100708
loss:  2.2446866035461426
loss:  2.2423171997070312
loss:  2.2400026321411133
loss:  2.237741470336914
loss:  2.2355315685272217
loss:  2.2333714962005615
loss:  2.23125958442688
loss:  2.2291946411132812
loss:  2.2271745204925537
loss:  2.225198268890381
loss:  2.2232649326324463
loss:  2.221372365951538
loss:  2.21951961517334
loss:  2.217705488204956
loss:  2.215928316116333
loss:  2.2141876220703125
loss:  2.212482213973999
loss:  2.210810899734497
loss:  2.209172487258911
loss:  2.207566261291504
loss:  2.205991268157959
loss:  2.204