In [839]:
block_size = 3
Cdim = 10
seed = 8566354565

In [840]:
with open("most_used_words.txt") as file:
    words = file.read().split()
len(words)

97565

In [841]:
chars = sorted(set(".".join(words)))
atoi = {a:i for i,a in enumerate(chars)}
itoa = {i:a for i,a in enumerate(chars)}

In [842]:
import torch

In [843]:
def prepare_dataset(words):
    X, Y = [], []
    for word in words:
        st = [0] * block_size
        for ch in word+".":
            X.append(st)
            Y.append(atoi[ch])
            st = st[1:] + [atoi[ch]]
    X = torch.tensor(X)
    Y = torch.tensor(Y)

    return X, Y

In [844]:
# Preparing data
import random
random.seed(77)
random.shuffle(words)

n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = prepare_dataset(words[:n1])
Xdev, Ydev = prepare_dataset(words[n1:n2])
Xte, Yte = prepare_dataset(words[n2:])
Xtr.shape, Xdev.shape, Xte.shape

(torch.Size([670727, 3]), torch.Size([83673, 3]), torch.Size([84292, 3]))

In [845]:
g = torch.Generator().manual_seed(seed)
C = torch.randn((27, Cdim), generator=g)
W1 = torch.randn((Cdim*block_size, 200), generator=g)
b1 = torch.randn((200), generator=g)
W2 = torch.randn((200, 27), generator=g)
b2 = torch.randn((27), generator=g)
parameters = [C, W1, b1, W2, b2]
for parm in parameters:
    parm.requires_grad = True
sum(parm.nelement() for parm in parameters)

11897

In [846]:
bg = torch.Generator().manual_seed(seed)
for i in range(100000):
    # batch
    ix = torch.randint(0, Xtr.shape[0], (50,), generator=bg)
    # forward
    h = torch.tanh(C[Xtr[ix]].view(-1, Cdim*block_size) @ W1 + b1)
    logits = h @ W2 + b2
    loss = torch.nn.functional.cross_entropy(logits, Ytr[ix])
    # backward
    for parm in parameters:
        parm.grad = None
    loss.backward()
    # update
    lr = 0.1 if i<50000 else 0.01
    for parm in parameters:
        parm.data += -lr*parm.grad
    if i%5000==0: print(loss.item())
print(loss.item())

25.809968948364258
2.777660608291626
2.6043267250061035
2.6095633506774902
2.464766502380371
2.491333246231079
2.450813055038452
2.2223892211914062
2.6658132076263428
2.8251757621765137
2.679687261581421
2.5435519218444824
2.4230458736419678
2.5283854007720947
2.0144906044006348
2.3614447116851807
2.617199420928955
2.2770345211029053
2.255686044692993
2.448884963989258
2.7411715984344482


In [847]:
h = torch.tanh(C[Xtr].view(-1, Cdim*block_size) @ W1 + b1)
logits = h @ W2 + b2
loss = torch.nn.functional.cross_entropy(logits, Ytr)
print(loss.item())

2.335796356201172


In [848]:
h = torch.tanh(C[Xdev].view(-1, Cdim*block_size) @ W1 + b1)
logits = h @ W2 + b2
loss = torch.nn.functional.cross_entropy(logits, Ydev)
print(loss.item())

2.343080759048462


In [849]:
h = torch.tanh(C[Xte].view(-1, Cdim*block_size) @ W1 + b1)
logits = h @ W2 + b2
loss = torch.nn.functional.cross_entropy(logits, Yte)
print(loss.item())

2.3401174545288086


In [850]:
g = torch.Generator().manual_seed(seed)
for _ in range(10):
    out = ""
    st = [0]*block_size
    while True:
        h = torch.tanh(C[torch.tensor(st)].view(-1, Cdim*block_size) @ W1 + b1)
        logits = h @ W2 + b2
        prob = torch.nn.functional.softmax(logits, dim=1)
        ix = torch.multinomial(prob, num_samples=1, generator=g).item()
        out += itoa[ix]
        st = st[1:] + [ix]
        print(itoa[ix], end="")
        if(ix==0): break
    print()

carted.
valandy.
cotion.
caratived.
usly.
but.
dew.
tarncturwazorye.
rein.
deloneron.
