In [1]:
import torch
from utils import MLP

In [2]:
raw_names = open("./names.txt").read()
names = raw_names.split('\n')

In [3]:
X = []
Y = []
block_size = 3

letters_n = {v:i+1 for i,v in enumerate(sorted(list(set(raw_names.replace('\n','')))))}
letters_n['.'] = 0
n_letters = {v:i for i,v in letters_n.items()}

for name in names:
    context = [0] * block_size
    for char in name:
        n_char = letters_n[char]
        X.append(context)
        Y.append(n_char)
        
        context = context[1:] + [n_char]

X = torch.Tensor(X)
Y = torch.Tensor(Y)

X.dtype, Y.dtype

(torch.float32, torch.float32)

In [4]:
C = torch.randn((27,3))
C.dtype

torch.float32

In [5]:
embeddings = C[X.type(dtype=torch.int64)]

In [6]:
embeddings[1,:,:], embeddings.shape

(tensor([[ 1.0253,  0.6084, -0.0179],
         [ 1.0253,  0.6084, -0.0179],
         [-1.5227, -0.6673,  0.2682]]),
 torch.Size([196490, 3, 3]))

In [7]:
W1 = torch.randn((9, 100))
b1 = torch.randn(100)

In [8]:
embeddings.view(-1,9)[1], embeddings.view(-1, 9).shape

(tensor([ 1.0253,  0.6084, -0.0179,  1.0253,  0.6084, -0.0179, -1.5227, -0.6673,
          0.2682]),
 torch.Size([196490, 9]))

In [9]:
hypothesis = (embeddings.view(-1,9) @ W1 + b1).relu()

In [10]:
hypothesis

tensor([[0.0000, 0.0000, 0.0000,  ..., 3.4565, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 2.2855, 0.0000, 0.0000],
        [4.7100, 0.0000, 0.0000,  ..., 0.0000, 2.7183, 0.3967],
        ...,
        [3.8638, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.3029],
        [4.0175, 0.5541, 0.0000,  ..., 1.6632, 1.2080, 3.2167],
        [2.8860, 0.0000, 0.0000,  ..., 1.3694, 0.0000, 0.0000]])

In [11]:
hypothesis.shape

torch.Size([196490, 100])

In [12]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [13]:
logits = hypothesis @ W2 + b2
counts = logits.exp()
counts.shape

torch.Size([196490, 27])

In [14]:
prob = counts / counts.sum(1, keepdim=True)
prob[0].sum()

tensor(1.0000)

In [15]:
model = MLP(9, [100, 27])

In [23]:
model(embeddings.view(-1,9)[0])

[tensor([9.9156], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([13.9064], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([15.2431], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([4.3940], grad_fn=<ReluBackward0>),
 tensor([18.5980], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([28.8150], grad_fn=<ReluBackward0>),
 tensor([8.6418], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([3.8244], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([1.3330], grad_fn=<ReluBackward0>),
 tensor([5.9164], grad_fn=<ReluBackward0>),
 tensor([11.5193], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([15.6921], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([0.], grad_fn=<ReluBackward0>),
 tensor([6.0311], grad_fn=