In [None]:
import torch
import torch.nn.functional as F
from helper import *

In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [74]:
words = read_words('names.txt')
stoi, itos = get_mapping(words)

In [75]:
train_split, val_split = 0.8, 0.1
X, Y = build_dataset(words, stoi, block_size=3)
X, Y = torch.tensor(X), torch.tensor(Y)


n = len(X)
n1 = round(n * train_split)
n2 = round(n * val_split)


X_train, Y_train = X[:n1], Y[:n1]
X_val, Y_val = X[n1:n1+n2], Y[n1:n1+n2]
X_test, Y_test = X[n1+n2:], Y[n1+n2:]

In [100]:
nchars = len(stoi.keys())
W = torch.rand(nchars, nchars, nchars) * 0.1
W.requires_grad = True

In [102]:
iterations = 1000
lr = 0.5
reg  = 0.01

for k in range(iterations):

    W.grad = None
    logits = W[X_train[:, 0], X_train[:, 1]]

    loss = F.cross_entropy(logits, Y_train) + reg*torch.mean(W ** 2)

    loss.backward()

    W.data -= lr * W.grad
    pred = logits.argmax(dim = 1)

    acc = (pred == Y_train).float().mean().data
    print(f"iteration {k} loss {loss.data}, acc {acc * 100}")

iteration 0 loss 3.1808459758758545, acc 19.68255043029785
iteration 1 loss 3.180293321609497, acc 19.68255043029785
iteration 2 loss 3.1797478199005127, acc 19.713232040405273
iteration 3 loss 3.1792094707489014, acc 19.718711853027344
iteration 4 loss 3.178678035736084, acc 19.718711853027344
iteration 5 loss 3.1781527996063232, acc 19.718711853027344
iteration 6 loss 3.1776347160339355, acc 19.73185920715332
iteration 7 loss 3.1771228313446045, acc 19.73185920715332
iteration 8 loss 3.176616907119751, acc 19.73185920715332
iteration 9 loss 3.176116943359375, acc 19.77733612060547
iteration 10 loss 3.1756224632263184, acc 19.77733612060547
iteration 11 loss 3.1751341819763184, acc 19.77733612060547
iteration 12 loss 3.1746509075164795, acc 19.77733612060547
iteration 13 loss 3.17417311668396, acc 19.77733612060547
iteration 14 loss 3.1737005710601807, acc 19.82171630859375
iteration 15 loss 3.1732327938079834, acc 19.82171630859375
iteration 16 loss 3.172769784927368, acc 19.82171630

In [103]:
with torch.no_grad():
    p = W[X_test[:, 0], X_test[:, 1]]
    pred = p.argmax(dim = 1)
    loss = F.cross_entropy(logits, Y_train) + reg*torch.mean(W ** 2)
    acc = (pred == Y_test).float().mean().data
    print(f"loss {loss.data}, acc {acc * 100}")

loss 2.975931167602539, acc 21.819934844970703


In [104]:
g = torch.Generator().manual_seed(2147483647)

for k in range(5):

    out = []
    idx0 = 0
    idx1 = 0
    idx2 = None

    while True:
        logits = W[idx0, idx1]
        p = logits.exp() / logits.exp().sum()
        idx2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[idx2])
        if idx2 == 0:
            break
        idx0 = idx1
        idx1 = idx2
    print(''.join(out))

dexzm.
aoglkurkicqzktyhwmvmzimjttainrlkfukzkatda.
rfcxvpubjtbhrmgotzx.
iczixqctvujkwptedogkkjemkmmsidguenkbvgynywftbspmhwcivgbvtahlvsu.
asdxxblnwglhpyiw.


In [98]:
W = torch.rand(nchars * 2, nchars) * 0.1
b = torch.zeros(nchars)

W.requires_grad = True
b.requires_grad = True

In [99]:
iterations = 100
lr = 0.5
reg  = torch.tensor(0.01)
reg.requires_grad = True

for k in range(iterations):

    W.grad = None
    b.grad = None
    reg.grad = None
    
    logits = (W[X_train].sum(1)) + b

    loss = F.cross_entropy(logits, Y_train) + reg*torch.mean(W ** 2)

    # counts = logits.exp()

    # counts_norm = counts * counts.sum(dim = 1, keepdim=True)**-1

    # logs = -counts_norm[torch.arange(x_train.shape[0]), y_train].log()
    # loss = logs.mean() + reg*torch.mean(W ** 2)

    loss.backward()

    W.data -= lr * W.grad
    b.data -= lr * b.grad
    reg.data -= reg * reg.grad

    pred = logits.argmax(dim = 1)

    acc = (pred == Y_train).float().mean().data
    print(f"loss {loss.data}, acc {acc * 100}")


with torch.no_grad():
    p = (W[x_test].sum(1)) + b
    pred = p.argmax(dim = 1)
    acc = (pred == y_test).float().mean().data
    print(f"acc on val {acc * 100}")

loss 3.2963011264801025, acc 3.504331111907959
loss 3.2496845722198486, acc 19.305599212646484
loss 3.20658540725708, acc 20.822717666625977
loss 3.167017698287964, acc 21.366777420043945
loss 3.130943536758423, acc 21.479642868041992
loss 3.098236322402954, acc 21.883987426757812
loss 3.068695545196533, acc 21.921245574951172
loss 3.0420548915863037, acc 21.936038970947266
loss 3.018009662628174, acc 21.96124267578125
loss 2.996251106262207, acc 21.96124267578125
loss 2.976480007171631, acc 22.066438674926758
loss 2.9584341049194336, acc 22.028085708618164
loss 2.941887855529785, acc 22.11246109008789
loss 2.926650285720825, acc 22.079586029052734
loss 2.9125683307647705, acc 22.04178237915039
loss 2.89951229095459, acc 22.03466033935547
loss 2.887378454208374, acc 22.03411102294922
loss 2.876075267791748, acc 22.027536392211914
loss 2.8655264377593994, acc 22.01055335998535
loss 2.8556630611419678, acc 22.006717681884766
loss 2.8464250564575195, acc 22.00562286376953
loss 2.837759017

KeyboardInterrupt: 

In [8]:
g = torch.Generator().manual_seed(2147483647)

for k in range(5):

    out = []
    idx0 = 0
    idx1 = 0
    idx2 = None

    while True:
        logits = (W[idx0] + W[idx1]) + b
        p = logits.exp() / logits.exp().sum()
        idx2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[idx2])
        if idx2 == 0:
            break
        idx0 = idx1
        idx1 = idx2
    print(''.join(out))

dexze.
aoallurailazityhn.
rllimjtnainrlkaan.
ka.
aa.
