In [1]:
# import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

In [432]:
# Read names from file
names = open('first_name.txt', 'r').read().splitlines() 
names = [name.lower() for name in names]
names[:10], len(names)

(['aachal',
  'aadharsh',
  'aadhavi',
  'aadhira',
  'aadidev',
  'aadil',
  'aadita',
  'aaditya',
  'aadiv',
  'aadrik'],
 2195)

In [436]:
import random
random.sample(names, 10)

['chandrashekhar',
 'prayadarshi',
 'devbrat',
 'ravinderpreet',
 'suramya',
 'janaki',
 'anshit',
 'devirupa',
 'avnindra',
 'swasti']

In [23]:
# Create lookup tables
itos = '.abcdefghijklmnopqrstuvwxyz;'
stoi = {c:i for i,c in enumerate(itos)}

In [35]:
import random

# Create dataset
def create_dataset(names, context_size=3):
    X, Y = [], []
    for name in names:
        name = '..' + name + ';'
        for i in range(len(name)-context_size):
            X.append(([stoi[ch] for ch in name[i:i+context_size]]) + [i])
            Y.append(stoi[name[i+context_size]])
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

names = random.sample(names, len(names))
n_train = int(0.8*len(names))
n_val = int(0.9*len(names))
X, Y = create_dataset(names[:n_train], 2)
X_val, Y_val = create_dataset(names[n_train:n_val], 2)
X_test, Y_test = create_dataset(names[n_val:], 2)
for x, y in zip(X[:10], Y[:10]):
    print(x, y)

tensor([0, 0, 0]) tensor(25)
tensor([ 0, 25,  1]) tensor(21)
tensor([25, 21,  2]) tensor(20)
tensor([21, 20,  3]) tensor(9)
tensor([20,  9,  4]) tensor(11)
tensor([ 9, 11,  5]) tensor(1)
tensor([11,  1,  6]) tensor(27)
tensor([0, 0, 0]) tensor(2)
tensor([0, 2, 1]) tensor(8)
tensor([2, 8, 2]) tensor(1)


In [763]:
import random

# Create model
C = torch.randn(28, 2, requires_grad=True)
W1 = torch.randn(2*2 + 1, 100, requires_grad=True)
b1 = torch.randn(100, requires_grad=True)
W2 = torch.randn(100, 28, requires_grad=True)
b2 = torch.randn(28, requires_grad=True)

P = [C, W1, b1, W2, b2]

def forward(X):
    X, XI = X[:, :-1], X[:, -1:]
    X = F.embedding(X, C)
    X = X.view(-1, 2*2)
    X = torch.cat([X, XI.float()], 1)
    X = F.tanh(X @ W1 + b1)
    X = X @ W2 + b2
    return X

def loss(X, Y):
    loss = F.cross_entropy(forward(X), Y)
    for p in P:
        loss += 0.0001 * (p**2).sum()
    return loss

def accuracy(X, Y):
    return (forward(X).argmax(1) == Y).float().mean()

def step(X, Y, lr=0.01, batch_size=32):
    idx = random.sample(range(len(X)), batch_size)
    X = X[idx]
    Y = Y[idx]
    l = loss(X, Y)
    for p in P:
        p.grad = None
    l.backward()
    for p in P:
        p.data -= lr*p.grad
    return l.item()

def train(X, Y, epochs=10000, lr=0.01, batch_size=32):
    for epoch in range(epochs):
        l = step(X, Y, lr, batch_size)
        if epoch % 1000 == 0:
            print(epoch, l, accuracy(X, Y).item())

In [762]:
# total number of params
sum(p.numel() for p in P)

3484

In [764]:
# Train model
train(X, Y, 10000, 0.01, batch_size=2048)

0 17.149017333984375 0.05961377173662186
1000 3.6223037242889404 0.24700404703617096
2000 3.21490216255188 0.2953973114490509
3000 2.980501651763916 0.31051063537597656
4000 2.857954263687134 0.31302952766418457
5000 2.873389482498169 0.3178383409976959
6000 2.708935260772705 0.3234867453575134
7000 2.7502834796905518 0.32440271973609924
8000 2.699195146560669 0.32638728618621826
9000 2.6800663471221924 0.3288298547267914


In [768]:
# Predict
def predict(name=''):
    given_length = len(name)
    context = [0, 0, 0, 0]
    for ch in name:
        i = stoi[ch]
        context.append(i)
    for i in range(given_length, 20):
        x = torch.tensor([context[-2:] + [i]])
        logits = forward(x)
        choice = random.random()
        if choice < 0.8 ** i:
            y = torch.multinomial(F.softmax(logits, 1), 1).item()
        else:
            y = logits.argmax(1).item()
        context.append(y)
        if y == 27:
            break
    name = ''.join([itos[i] for i in context])
    name = name.replace('.', '')
    name = name.replace(';', '')
    return name

for _ in range(5):
    print(predict(''))

nani
aesha
manin
suhun
karin
