In [65]:
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [24]:
words = open('../../assets/names.txt', 'r').read().splitlines()


In [25]:
g = torch.Generator().manual_seed(2147483647)

In [26]:
chars = sorted(list(set(''.join(words))))
stoi = {char: i + 1 for i, char in enumerate(chars)}
stoi['.'] = 0
itos = {i: char for char, i in stoi.items()}

In [81]:
block_size = 3
Inputs, Outputs = [], []

for w in words[:5]:
    context = [0] * block_size

    for ch in w + '.':
        ix = stoi[ch]
        Inputs.append(context)
        Outputs.append(ix)

        context = context[1:] + [ix]

X = torch.tensor(Inputs)
Y = torch.tensor(Outputs)


In [82]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

In [83]:
C= torch.randn(27, 2, generator=g, requires_grad=True)

In [84]:
# Using X as a shopping list of which rows to pick from C!
emb = C[X]


In [85]:
# its ( 32, 3, 2 ) because it looks something like this
# [[num1, num2], [num1, num2], [num1, num2]] -> Index 0
# [[num1, num2], [num1, num2], [num1, num2]] -> Index 1
# [[num1, num2], [num1, num2], [num1, num2]] -> Index 2
# ...
# [[num1, num2], [num1, num2], [num1, num2]] -> Index 32

# Because X is of the shape [32, 3] and C is of the shape [27, 2] --> every value in X is an index to pick from C

emb.shape

torch.Size([32, 3, 2])

In [86]:
# Tanh activation function
W1 = torch.randn(6, 100)
B1 = torch.tensor(100)

h = torch.tanh(emb.view(-1, 6) @ W1 + B1)

In [87]:
# Softmax activation function
W2 = torch.randn(100, 27)
B2 = torch.tensor(27)

In [88]:
# Activation function, with Weight and Bias so we can tune the probablities of desired output more.
logits = h @ W2 + B2
counts = logits.exp()

# Normalize the counts
probs = counts / counts.sum(dim=1, keepdim=True)

In [89]:
loss = -probs[torch.arange(32), Y].log().mean()

In [90]:
loss

tensor(20.2819, grad_fn=<NegBackward0>)

In [92]:
loss = F.cross_entropy(logits, Y)
loss

tensor(20.2819, grad_fn=<NllLossBackward0>)

In [93]:
parameters = [C, W1, B1, W2, B2]

In [94]:
sum(p.nelement() for p in parameters)

3356

In [95]:
for p in parameters:
    p.grad = None
loss.backward()