In [1]:
import torch
import torch.nn.functional as F
from torch import Generator
import matplotlib.pyplot as plt

In [2]:
names=open('names.txt','r').read().splitlines()
names[:4]

['emma', 'olivia', 'ava', 'isabella']

In [3]:
chars=sorted(list(set(''.join(names))))
stoi={s:i+1 for i,s in enumerate(chars)}
stoi['.']=0
itos={i:s for s,i in stoi.items()}

In [4]:
block_size=4
embedding_size=10
hidden_neurons=200
vocab_size=27

In [5]:
import random
def build_tensor(words):
    X,y=[],[]
    for word in words:
        word = list(word) + ['.']
        context=[0] * block_size
        for w in word:
            X.append(context)
            idx=stoi[w]
            y.append(idx)
            context=context[1:]+[idx]
    X=torch.tensor(X)
    y=torch.tensor(y)
    return X,y
random.seed(42)
random.shuffle(names)
n1=int(0.8 * len(names))
n2=int(0.9 * len(names))

x_train,y_train=build_tensor(names[:n1])
x_val,y_val=build_tensor(names[n1:n2])
x_test,y_test=build_tensor(names[n2:])

print("Training : ",x_train.shape,y_train.shape)
print("Validation : ",x_val.shape,y_val.shape)
print("Testing : ",x_test.shape,y_test.shape)


Training :  torch.Size([182625, 4]) torch.Size([182625])
Validation :  torch.Size([22655, 4]) torch.Size([22655])
Testing :  torch.Size([22866, 4]) torch.Size([22866])


In [6]:
x_train[:10],y_train[:10]

(tensor([[ 0,  0,  0,  0],
         [ 0,  0,  0, 25],
         [ 0,  0, 25, 21],
         [ 0, 25, 21,  8],
         [25, 21,  8,  5],
         [21,  8,  5, 14],
         [ 8,  5, 14,  7],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  4],
         [ 0,  0,  4,  9]]),
 tensor([25, 21,  8,  5, 14,  7,  0,  4,  9, 15]))

In [73]:
g=Generator().manual_seed(27654839)
C=torch.randn((vocab_size,embedding_size),generator=g)

# For first net (Sigmoid function)
W11=torch.randn((embedding_size * block_size,hidden_neurons),generator=g) # * (1/(embedding_size * block_size))**0.5
b11=torch.randn((hidden_neurons,),generator=g)# * (2/(hidden_neurons))**0.5

W12=torch.randn((hidden_neurons,embedding_size * block_size),generator=g)# * (2/(hidden_neurons))**0.5
b12=torch.randn((embedding_size * block_size,),generator=g) #* (2/(embedding_size * block_size))**0.5


W21=torch.randn((embedding_size * block_size * 2,hidden_neurons),generator=g)# * 0.5
b21=torch.randn((hidden_neurons,),generator=g)# * (2/hidden_neurons)**0.5

W22=torch.randn((hidden_neurons,vocab_size),generator=g) * 0.01
b22=torch.randn((vocab_size,),generator=g) * 0

nbias=torch.zeros((hidden_neurons))
ngains=torch.ones((hidden_neurons))


In [74]:
parameters=[C,W11,b11,W12,b12,W21,b21,W22,b22,nbias,ngains]
batch=100
for p in parameters:
    p.requires_grad=True

In [75]:
for i in range (60000):
    ix=torch.randint(0,x_train.shape[0],(batch,))
    Ct=torch.cat(C[x_train[ix]].unbind(1),1)
    lr=0.1 if i<30000 else 0.01
    
    z11 = Ct @ W11 + b11
    z11= ngains * ((z11 - z11.mean(0))/z11.std(0)) + nbias
    a11 = z11.tanh()
    z12 = a11 @ W12 + b12

    a12sig = z12.sigmoid()
    a12tanh=z12.tanh()

    It = a12sig * a12tanh

    Ct= (Ct * a12sig) + It
    Ot= torch.cat((Ct.tanh() , a12sig),dim=1)
    
    z21 = Ot @ W21 + b21
    a21=z21.tanh()
    z22 = a21 @ W22 + b22
    loss=F.cross_entropy(z22,y_train[ix])
    
    for p in parameters:
        p.grad=None
    loss.backward()
    for p in parameters:
        p.data -= lr * p.grad

        
print(loss.item())
    


2.0406246185302734


In [86]:
with torch.no_grad():
    emb=torch.cat(C[x_train].unbind(1),1)
    zemb=emb @ W11 + b11
    bnstd=zemb.std(0)
    bnmean=zemb.mean(0)

In [88]:
Ct=torch.cat(C[x_train].unbind(1),1)
z11 = Ct @ W11 + b11
z11= ngains * ((z11 - bnmean)/bnstd) + nbias
a11 = z11.tanh()
z12 = a11 @ W12 + b12
a12sig = z12.sigmoid()
a12tanh=z12.tanh()
It = a12sig * a12tanh
Ct= (Ct * a12sig) + It
Ot= torch.cat((Ct.tanh() , a12sig),dim=1)
z21 = Ot @ W21 + b21
a21=z21.tanh()
z22 = a21 @ W22 + b22
loss=F.cross_entropy(z22,y_train)
print(loss.item())

2.0890700817108154


In [89]:
Ct=torch.cat(C[x_test].unbind(1),1)
z11 = Ct @ W11 + b11
z11= ngains * ((z11 - bnmean)/bnstd) + nbias
a11 = z11.tanh()
z12 = a11 @ W12 + b12
a12sig = z12.sigmoid()
a12tanh=z12.tanh()
It = a12sig * a12tanh
Ct= (Ct * a12sig) + It
Ot= torch.cat((Ct.tanh() , a12sig),dim=1)
z21 = Ot @ W21 + b21
a21=z21.tanh()
z22 = a21 @ W22 + b22
loss=F.cross_entropy(z22,y_test)
print(loss.item())

2.1376121044158936


In [90]:
Ct=torch.cat(C[x_val].unbind(1),1)
z11 = Ct @ W11 + b11
z11= ngains * ((z11 - bnmean)/bnstd) + nbias
a11 = z11.tanh()
z12 = a11 @ W12 + b12
a12sig = z12.sigmoid()
a12tanh=z12.tanh()
It = a12sig * a12tanh
Ct= (Ct * a12sig) + It
Ot= torch.cat((Ct.tanh() , a12sig),dim=1)
z21 = Ot @ W21 + b21
a21=z21.tanh()
z22 = a21 @ W22 + b22
loss=F.cross_entropy(z22,y_val)
print(loss.item())

2.147047758102417


In [None]:
context = [0] * block_size
Ct = torch.zeros((1, embedding_size * block_size))
name = ''

while True:
    Z = torch.cat(C[context].unbind(1), dim=0).unsqueeze(0)
    z11 = Z @ W11 + b11
    z11= ngains * ((z11 - bnmean)/bnstd) + nbias
    a11 = z11.tanh()
    z12 = a11 @ W12 + b12
    a12sig = z12.sigmoid()
    a12tanh = z12.tanh()
    It = a12sig * a12tanh
    Ct = (Ct * a12sig) + It
    Ot = torch.cat((Ct.tanh(), a12sig), dim=1)
    z21 = Ot @ W21 + b21
    a21 = z21.tanh()
    z22 = a21 @ W22 + b22
    a22 = z22.softmax(dim=1)
    idx = torch.multinomial(a22, num_samples=1).item()
    if idx == 0:
        break
    name += itos[idx]
    context = context[1:] + [idx]
print(name)


torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
torch.Size([1, 27])
eisoaaeaotyyaeieeeyoizioyeaiaiiyeoaott
