In [64]:
import torch
import torch.nn.functional as F

In [65]:
with open('data/names.txt', 'r') as f:
    names_list = f.read().splitlines()

In [66]:
# Get all characters in the name_list
all_chars = sorted(list(set(''.join(names_list))))

# Encoding alphabet using numbering 
# Also using the char '.' to replace the <S> and <E>, the '.' is denoted as 0
stoi = {s:i+1 for i, s in enumerate(all_chars)}
stoi['.'] = 0

# Decoding
itos = {i:s for s, i in stoi.items()}

In [67]:
def built_dataset(text):
    xs, ys = [], []

    for t in text:
        t_str = ['.'] + list(t) + ['.']
        for c1, c2 in zip(t_str, t_str[1:]):
            xs.append(stoi[c1])
            ys.append(stoi[c2])

    return torch.tensor(xs), torch.tensor(ys)

In [68]:
xs, ys = built_dataset(text=names_list)

In [69]:
# Encode the xs by one-hot encoding
xs_encoded = F.one_hot(xs, num_classes=27).float()

In [70]:
# The generator
g = torch.Generator().manual_seed(283839281)

In [71]:
# The weight matrix
W = torch.randn(size=(27, 27), generator=g, requires_grad=True) # The weight matrix

In [72]:
def train(epochs, lr):
    for k in range(epochs):
        # forward pass
        logits = xs_encoded @ W                     
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)

        # The loss function
        loss = -probs[torch.arange(xs.nelement()), ys].log().mean() 
        if k % 10 == 0:
            print("epoch {}: loss = {}".format(k, loss))

        # Backpropagation
        W.grad = None # Set the gradient to zero (using None here)
        loss.backward()

        # Update the network by the gradient
        W.data += -lr * W.grad

In [73]:
train(epochs=500, lr=5)

epoch 0: loss = 3.8166439533233643
epoch 10: loss = 3.428638458251953
epoch 20: loss = 3.190485954284668
epoch 30: loss = 3.0372958183288574
epoch 40: loss = 2.9339752197265625
epoch 50: loss = 2.8602824211120605
epoch 60: loss = 2.804910659790039
epoch 70: loss = 2.761800765991211
epoch 80: loss = 2.7273361682891846
epoch 90: loss = 2.6991896629333496
epoch 100: loss = 2.675816059112549
epoch 110: loss = 2.6561551094055176
epoch 120: loss = 2.639446258544922
epoch 130: loss = 2.6251208782196045
epoch 140: loss = 2.612736940383911
epoch 150: loss = 2.601944923400879
epoch 160: loss = 2.5924651622772217
epoch 170: loss = 2.584073543548584
epoch 180: loss = 2.5765907764434814
epoch 190: loss = 2.569873332977295
epoch 200: loss = 2.56380558013916
epoch 210: loss = 2.558295488357544
epoch 220: loss = 2.55326771736145
epoch 230: loss = 2.548661231994629
epoch 240: loss = 2.5444252490997314
epoch 250: loss = 2.5405187606811523
epoch 260: loss = 2.5369060039520264
epoch 270: loss = 2.53355717

In [75]:
gg = torch.Generator().manual_seed(23879322)

for i in range(20):
    ix = 0 # The start token of the generated name
    generated_name = []
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=gg).item()
        if ix != 0: # If not reach the end token, then keep appending all chars
            generated_name.append(itos[ix])
        else: # The end token of the generated name
            break

    print(''.join(generated_name))

deri
kay
hieprbchahy
wran
iyah
mo
e
men
canty
ph
aya
ery
aghafiauiylae
st
a
bea
zzm
y
brvelans
kos
