## Imports

In [139]:
import torch
from pathlib import Path
import matplotlib.pyplot as plt

## Load dataset

In [124]:
names = Path("../names.txt").read_text().split("\n")
print(f"There are {len(names)} names in the dataset")
print(f"First 10 are:\n{names[:10]}")

## Prepare the character-level tokenizer

In [125]:
tokenizer = {s:i for (i,s) in enumerate(".abcdefghijklmnopqrstuvwxyz")}
print(tokenizer)
detokenizer = {i:s for (s,i) in tokenizer.items()}
print(detokenizer)

## Prepare the training dataset

In [227]:
training_input = []
training_output = []
for name in names:
    name = "." + name + "."
    for ch1, ch2 in zip(name, name[1:]):
        training_input.append(tokenizer[ch1])
        training_output.append(tokenizer[ch2])

training_input = torch.nn.functional.one_hot(torch.tensor(training_input), num_classes=27).float()
training_output = torch.tensor(training_output)
training_output_onehot = torch.nn.functional.one_hot(training_output, num_classes=27).float()

In [195]:
print(training_input.shape, training_input.dtype)

In [228]:
print(training_output.shape)
print(training_output_onehot.shape)

## Initialize random weights

In [233]:
W = torch.randn((27,27), dtype=torch.float32, requires_grad=True)

## Training

In [234]:
for _ in range(100):
    actual = training_input @ W
    loss = torch.nn.functional.cross_entropy(actual, training_output_onehot)
    loss.backward()
    W.data += -0.1 * W.grad
print(loss)
for _ in range(100):
    actual = training_input @ W
    loss = torch.nn.functional.cross_entropy(actual, training_output_onehot)
    loss.backward()
    W.data += -0.01 * W.grad
print(loss)

## Sampling

In [239]:
for _ in range(10):
    generated = None
    name = "."
    while generated != '.':
        actual = torch.nn.functional.one_hot(torch.tensor(tokenizer[generated or "."]), num_classes=27).float() @ W
        actual = actual.exp()
        P = actual / actual.sum()
        idx = torch.multinomial(P, 1)
        generated = detokenizer[idx.item()]
        name += generated
    print(name)