# Predict name based on first several letters

In [212]:
import torch
import torch.nn.functional as F
import random
print(f"{torch.cuda.is_available()=}")

torch.cuda.is_available()=True


In [51]:
import string
itos = {0:'.'}
for i, c in enumerate(string.ascii_lowercase):
    itos[i+1]=c  
stoi = {s:i for i, s in itos.items()}
print(stoi)

{'.': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}


In [439]:
def encode(ss):
    res = [stoi[c] for c in ss]
    return res

def decode(ii, tilldot=False):
    ch = False
    res = []
    for i in ii:
        if i == 0 and tilldot and ch:
            break
        if not(ch) and i != 0:
            ch = True
        res.append(itos[i])
    return ''.join(res)

In [30]:
names_f = "names.txt"
with open(names_f) as f:
    words = f.read().splitlines()

random.seed(42)
random.shuffle(words)
print(words[:3])
print(len(words))

['yuheng', 'diondre', 'xavien']
32033


In [99]:
aaa=[1,2,3]
aaa.append(4)
aaa

[1, 2, 3, 4]

In [105]:
"."*5


'.....'

In [106]:
def add_word(w, bsz, X, Y):
    x = "."*bsz
    xi = [0]*bsz
    for y in w:
        yi = stoi[y]
        X.append(xi)
        Y.append(yi)
        xi = xi[1:]
        xi.append(yi)
    X.append(xi)
    Y.append(0)

In [408]:
Xa, Ya = [], []
for w in words:
    add_word(w, 4, Xa, Ya)
print(len(Xa))
X = torch.tensor(Xa)
Y = torch.tensor(Ya)
print(f"{X.shape=}")
print(f"{Y.shape=}")

228146
X.shape=torch.Size([228146, 4])
Y.shape=torch.Size([228146])


In [409]:
n1 = int(len(X) * 0.8)
n2 = int(len(X) * 0.9)
print(n1, n2)
X_tr = X[:n1]
Y_tr = Y[:n1]
X_val = X[n1:n2]
Y_val = Y[n1:n2]
X_tst = X[n2:]
Y_tst = Y[n2:]
print(f"{X_tr.shape=}")

182516 205331
X_tr.shape=torch.Size([182516, 4])


In [410]:
att=4
emb=3
hidden = 100

torch.manual_seed(0)

E_w = torch.rand(27, emb)
H0_w = torch.rand(size=(emb*att, hidden))
H0_b = torch.rand(size=(hidden,))
H1_w = torch.rand(size=(hidden, 27))
H1_b = torch.rand(size=(27,))

In [411]:
params = [E_w, H0_w, H0_b, H1_w, H1_b]
nparams = sum([t.numel() for t in params])
print(f"{nparams=}")
for t in params:
    t.requires_grad_()

nparams=4108


In [412]:
def forward(X):
    Ey = E_w[X].flatten(1, 2)
    H0_y = torch.tanh(Ey @ H0_w + H0_b)
    L = H0_y @ H1_w + H1_b
    return L

def calc_loss(L, Y):
    return F.cross_entropy(L, Y)

def backward(loss: torch.Tensor):
    for p in params:
        if p.grad is not None:
            p.grad.zero_()
    loss.backward()

def update_params(step):
    for p in params:
        p.data -= step * p.grad


In [413]:
def get_batch(X0, Y0, n):
    rids = torch.randint(0, n1, (n,))
    return X0[rids], Y0[rids]

In [441]:

batch = 32
X0, Y0 = get_batch(X_tr, Y_tr, batch)

L = forward(X0)
loss = calc_loss(L, Y0)
print(loss)

N = 10000
step = 0.01
for i in range(N):
    backward(loss)
    update_params(step)
    X0, Y0 = get_batch(X_tr, Y_tr, batch)
    L = forward(X0)
    loss = calc_loss(L, Y0)
print(loss)


tensor(2.3403, grad_fn=<NllLossBackward0>)
tensor(2.4568, grad_fn=<NllLossBackward0>)


In [442]:
L = forward(X_val)
loss = calc_loss(L, Y_val)
print("Validation loss", loss)

L = forward(X_tst)
loss = calc_loss(L, Y_tst)
print("Test loss", loss)

Validation loss tensor(2.2963, grad_fn=<NllLossBackward0>)
Test loss tensor(2.3015, grad_fn=<NllLossBackward0>)


In [435]:
s=encode(".ale")

for i in range(100):
    x = torch.tensor([s[-4:]])
    L = forward(x)
    ci = int(torch.argmax(L).item())
    c = itos[ci]
    #print(c)
    if ci == 0:
        break
    s += [ci]
print(decode(s))

.alen


In [445]:
beg = []
batch = words[30:55]
for w in batch:
    if len(w) == 3:
        w = "." + w
    beg.append(encode(w[:4]))
x = torch.tensor(beg)
#print(x)
for i in range(7):
    L = forward(x[:,-4:])
    y = torch.argmax(L, dim=1, keepdim=True)
    x = torch.cat([x,y],dim=-1)
for i, row in enumerate(x.numpy()):
    print(decode(row, True), "   ", batch[i])


kymb     kymberlynn
parrie     parrish
hous     houstyn
jamarie     jamaya
ahmon     ahmod
nivin     nivin
milla     milli
crise     cristiana
jaima     jaimee
mitcen     mitchell
nair     nairah
lore     lorena
gente     gentrie
torrie     torrion
savie     savian
benja     benjamine
aire     airess
knut     knute
sultan     sultana
dana     danai
azzan     azzan
issa     issabelle
abra     abrahim
aisley     aislyn
aery     aerys
