In [1]:
import torch
import torch.nn.functional as F



In [2]:
names = open('../data/names.txt', 'r').read().splitlines()
names[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [3]:
chs = sorted(set(''.join(names)))
stoi = {s: i+1 for i, s in enumerate(chs)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}

In [4]:
X, y = [], []

block_size = 3


for name in names:
    
    context = [0] * block_size
    
    for ch in name + '.':
        
        idx = stoi[ch]
        
        X.append(context)
        y.append(idx)
        
        context = context[1:] + [idx]
    
X = torch.tensor(X)
y = torch.tensor(y)

In [6]:
emb_size = 2
C = torch.randn(X.shape[0], emb_size)
W1 = torch.randn(C.shape[1] * block_size, 100)
b1 = torch.randn(100)
W2 = torch.randn(100, 27)
b2 = torch.randn(27)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

In [7]:
emb = C[X].view(-1, emb_size*block_size)
emb.shape

torch.Size([228146, 6])

In [8]:
h = torch.tanh(emb @ W1 + b1)
h.shape

torch.Size([228146, 100])

In [9]:
logits = h @ W2 + b2
logits

tensor([[  6.7799,   3.6519,   2.0483,  ...,  10.3901,  -4.2688,   3.4928],
        [  3.5535,  10.7990,  -2.2045,  ...,   9.7832,  -8.4031,   8.0200],
        [ -1.5537,  12.3513,  -8.3341,  ...,   6.6112,   0.6207,  13.7417],
        ...,
        [  1.7672, -10.0364,  -6.9470,  ...,  -5.3411,  -9.3878,  -9.7148],
        [ -1.1193, -10.3435, -13.0041,  ...,   0.4428,  -3.7063,  -5.6216],
        [  3.7308,  -3.2651, -17.6721,  ...,   5.2089,  -6.1043,   1.4450]],
       grad_fn=<AddBackward0>)

In [10]:
loss = F.cross_entropy(logits, y)
loss

tensor(13.3536, grad_fn=<NllLossBackward0>)

In [11]:
for p in parameters:
    p.grad = None

In [12]:
loss.backward()

In [13]:
for p in parameters:
    p.data += -0.1 * p.grad

In [54]:
losses = []

for i in range(100_000):
    
    ix = torch.randint(0, X.shape[0], (1, 32))
    
    emb = C[X[ix]].view(-1, emb_size*block_size)
    h = torch.tanh(emb @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y[ix][0])
    
    if i % 10_000 == 0:
        print(loss.item())
        losses.append(loss.item())
    
    for p in parameters:
        p.grad = None
        
    loss.backward()
    
    for p in parameters:
        p.data += -0.1 * p.grad
        
f"avg loss = {sum(losses) / len(losses)}"

2.207291841506958
2.3573668003082275
2.1506762504577637
2.90224289894104
2.4033026695251465
2.4140212535858154
2.281104803085327
2.1587536334991455
2.129782199859619
2.335639715194702


'avg loss = 2.3340182065963746'

In [84]:
for _ in range(10):

    ch = [0] * block_size
    
    new_name = []
    
    while True:

        emb = C[ch].view(block_size * emb_size)
        h = torch.tanh(emb @ W1 + b1)
        logits = h @ W2 + b2
        counts = logits.exp()
        probs = counts / counts.sum()

        out = torch.multinomial(probs, 1)
        
        if out == 0:
            break
        
        new_name.append(itos[out.item()])
        
        ch = ch[1:] + [out.item()]
        
    print(''.join(new_name))



sur
milismana
tevpel
das
horsie
lon
slavy
quon
elia
jaytana
