In [46]:
import torch
from torch import nn
from torch.nn import functional as F
import tokenizers
import math

In [25]:
with open('corpus/tiny_shakespeare.txt', 'r') as file:
    raw_text = file.read()

In [26]:
tokenizer = tokenizers.WordTokenizer(raw_text, vocab_size=3000)

In [27]:
VOCAB_SIZE = len(tokenizer.get_vocab_mapping())
EMB_SIZE = 128
HIDDEN_SIZE = 128
CTX_LEN = 32

In [28]:
encoded = tokenizer.encode(raw_text)
batched = [encoded[i:i + CTX_LEN] for i in range(len(encoded) - CTX_LEN)]
batched = torch.tensor(batched)
sub_batched = batched[:10]

In [36]:
embedding = torch.randn((VOCAB_SIZE, EMB_SIZE)) * math.sqrt(2 / VOCAB_SIZE)
hiddenWx = torch.randn((EMB_SIZE, HIDDEN_SIZE)) * math.sqrt(2 / EMB_SIZE)
hiddenWh = torch.randn((EMB_SIZE, HIDDEN_SIZE)) * math.sqrt(2 / EMB_SIZE)
hiddenB = torch.zeros((HIDDEN_SIZE,))
outputW = torch.randn((HIDDEN_SIZE, VOCAB_SIZE)) * math.sqrt(2 / HIDDEN_SIZE)
outputB = torch.zeros((VOCAB_SIZE,))
params = [embedding, hiddenWx, hiddenWh, hiddenB, outputW, outputB]
for p in params:
    p.requires_grad = True

In [37]:
sum(p.numel() for p in params)

804153

In [38]:
def rnn_forward(x, state):
    emb = F.one_hot(x, num_classes=VOCAB_SIZE).float() @ embedding
    new_state = F.tanh(emb @ hiddenWx + state @ hiddenWh + hiddenB)
    logits = new_state @ outputW + outputB
    return logits, new_state

In [39]:
def predict(prompt: str, max_new_tokens: int):
    ids = tokenizer.encode(prompt)
    state, logits = torch.zeros((1, HIDDEN_SIZE)), None
    for i in ids:
        logits, state = rnn_forward(torch.tensor([i]), state)
    ids.append(torch.argmax(logits).item())
    for i in range(max_new_tokens):
        logits, state = rnn_forward(torch.tensor(ids[-1]), state)
        ids.append(torch.argmax(logits).item())
    # print(ids)
    return tokenizer.decode(ids)

In [40]:
print(predict('lord', 32))

lordthunderdrewoursremainsHastingstakesmissjoysslayhornroguefewforcerunseestBelieveMeratherforknavescontractforswornearsthreadstamptriumphantbeholdingloseBolingbrokestarsfightconsentrt


In [47]:
BATCH_SIZE = 10240
EPOCHS = 1
LR = 0.01

for epoch in range(EPOCHS):
    for b_i in range(0, len(batched), BATCH_SIZE):
        mini_batch = batched[b_i:b_i + BATCH_SIZE]
        state = torch.zeros((len(mini_batch), HIDDEN_SIZE))
        accu_loss = 0.0
        for i in range(CTX_LEN - 1):
            logits, state = rnn_forward(mini_batch[:, i], state.clone().detach().requires_grad_(True))
            logits = torch.softmax(logits, dim=1)
            y = mini_batch[:, i + 1]
            nll = -torch.log(logits[torch.arange(len(logits)), y]).mean()
            for p in params:
                p.grad = None
            nll.backward(retain_graph=False)
            for p in params:
                p.data -= LR * p.grad
            accu_loss += nll.item()
        print(f'Epoch: {epoch} Step: {b_i}, Loss: {accu_loss / (CTX_LEN - 1):.2f}')

Epoch: 0 Step: 0, Loss: 3.15
Epoch: 0 Step: 10240, Loss: 3.11
Epoch: 0 Step: 20480, Loss: 3.11
Epoch: 0 Step: 30720, Loss: 3.07
Epoch: 0 Step: 40960, Loss: 3.10
Epoch: 0 Step: 51200, Loss: 3.12
Epoch: 0 Step: 61440, Loss: 3.07
Epoch: 0 Step: 71680, Loss: 3.02
Epoch: 0 Step: 81920, Loss: 3.03
Epoch: 0 Step: 92160, Loss: 3.02
Epoch: 0 Step: 102400, Loss: 3.02
Epoch: 0 Step: 112640, Loss: 3.04
Epoch: 0 Step: 122880, Loss: 3.10
Epoch: 0 Step: 133120, Loss: 3.08
Epoch: 0 Step: 143360, Loss: 3.08
Epoch: 0 Step: 153600, Loss: 3.07
Epoch: 0 Step: 163840, Loss: 3.07
Epoch: 0 Step: 174080, Loss: 3.08
Epoch: 0 Step: 184320, Loss: 3.07
Epoch: 0 Step: 194560, Loss: 3.01
Epoch: 0 Step: 204800, Loss: 3.06
Epoch: 0 Step: 215040, Loss: 3.11
Epoch: 0 Step: 225280, Loss: 3.06
Epoch: 0 Step: 235520, Loss: 3.03
Epoch: 0 Step: 245760, Loss: 3.03
Epoch: 0 Step: 256000, Loss: 3.05


KeyboardInterrupt: 

In [48]:
predict('First', 32)

'First <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> '

In [44]:
raw_text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'