In [44]:
import torch
import torch.nn.functional as tfunc
import matplotlib.pyplot as plt


torch.set_default_device("cuda")

# Build Dataset

In [45]:
words = open("../names.txt", "r").read().splitlines()
chars = sorted(list(set("".join(words))))
char_to_indx = {char: indx+1 for indx, char in enumerate(chars)}
char_to_indx["."] = 0
indx_to_char = {indx: char for char, indx in char_to_indx.items()}

BLOCK_SIZE = 3
inputs, labels = [], []
for word in words:
    # print(word)
    context = [0] * BLOCK_SIZE
    for label in word + ".":
        label_indx = char_to_indx[label]
        inputs.append(context)
        labels.append(label_indx)
        # print("".join(indx_to_char[indx] for indx in context), "------>", label)
        context = context[1:] + [label_indx]
    # print()

inputs = torch.tensor(inputs)
labels = torch.tensor(labels)

In [46]:
inputs.shape, labels.shape

(torch.Size([228331, 3]), torch.Size([228331]))

In [47]:
# Params
gen = torch.Generator(device="cuda").manual_seed(2147483647)
C = torch.randn((27, 2), generator=gen, requires_grad=True) # Lookup table
weights1 = torch.randn((6, 100), generator=gen, requires_grad=True)
bias1 = torch.randn(100, generator=gen, requires_grad=True)
weights2 = torch.randn((100, 27), generator=gen, requires_grad=True)
bias2 = torch.randn(27, generator=gen, requires_grad=True)
parameters = [C, weights1, bias1, weights2, bias2]
print(f"TOTAL PARAMS: {sum([param.nelement() for param in parameters])}")

TOTAL PARAMS: 3481


In [48]:
EPOCHS = 1000
LEARNING_RATE = 0.1

for _ in range(EPOCHS):
    embed = C[inputs]
    joined_embed = embed.view(embed.shape[0], embed.shape[1] * embed.shape[2])
    layer1_out = torch.tanh(joined_embed @ weights1 + bias1)
    logits = layer1_out @ weights2 + bias2
    loss = tfunc.cross_entropy(logits, labels)
    print(f"{loss.item()=}")

    for param in parameters:
        param.grad = None

    loss.backward()
    for param in parameters:
        param.data += LEARNING_RATE * -param.grad

loss.item()=14.744425773620605
loss.item()=13.678918838500977
loss.item()=12.868958473205566
loss.item()=12.26665210723877
loss.item()=11.80759334564209
loss.item()=11.410181045532227
loss.item()=11.042360305786133
loss.item()=10.694540023803711
loss.item()=10.362454414367676
loss.item()=10.043938636779785
loss.item()=9.73803424835205
loss.item()=9.444653511047363
loss.item()=9.16425895690918
loss.item()=8.897695541381836
loss.item()=8.645363807678223
loss.item()=8.407538414001465
loss.item()=8.183372497558594
loss.item()=7.971580505371094
loss.item()=7.770554065704346
loss.item()=7.578848361968994
loss.item()=7.395312786102295
loss.item()=7.219094753265381
loss.item()=7.049475193023682
loss.item()=6.886199951171875
loss.item()=6.729057312011719
loss.item()=6.578070163726807
loss.item()=6.433620929718018
loss.item()=6.295994281768799
loss.item()=6.16542911529541
loss.item()=6.041723728179932
loss.item()=5.924273490905762
loss.item()=5.812203407287598
loss.item()=5.704646587371826
loss.