In [1]:
from enumerations.data_splits import DataSplits

import torch
torch.manual_seed(1337)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

context_length = 256
batch_size = 64
embed_size = 512
num_layers = 6
num_heads = 8
forward_expansion = 4
dropout = 0.2

max_iters = 100
eval_iters = 2
eval_interval = 5
learning_rate = 3e-4

In [2]:
with open('datasets/shakespeare.txt') as file:
    text = file.read()

In [3]:
from tokenizer.bpe_tokenizer import BytePairEncodingTokenizer
from pipelines.text_to_tensor import create_pipeline

text = text[:50000]

tokenizer = BytePairEncodingTokenizer(1000)

tokenizer.fit([text])

print(tokenizer.vocab_size)

pipeline = create_pipeline(tokenizer, 0.9)

train_data, test_data = pipeline.transform(text)

1095


In [4]:
from typing import Tuple

import torch

def get_batch(split: DataSplits, batch_size: int, context_length: int, device: str = 'cpu') -> Tuple[torch.tensor, torch.tensor]:
    data = train_data if split == DataSplits.TRAIN else test_data

    assert len(data) - context_length >= 0, 'Length of data is shorter than context_length'

    idx = torch.randint(len(data) - context_length, (batch_size, ))
    x = torch.stack([data[i:i + context_length] for i in idx])
    y = torch.stack([data[i + 1:i + context_length + 1] for i in idx])
    x, y = x.to(device), y.to(device)
    return x, y

In [5]:
from language_model.model import LanguageModel

model = LanguageModel(
    tokenizer,
    embed_size,
    context_length,
    num_layers,
    num_heads,
    forward_expansion,
    dropout,
    device
)

print(sum([p.numel() for p in model.encoder.parameters()]) / 1e6, 'M parameters')

20.159559 M parameters


In [6]:
model.predict(' ', max_new_tokens=50)

" three*state,$ratslegshallNoblebarsyselfherdillLet'scupuefairhe.woreot.him.Officer@tribunesbeforecarmineThatstrangeDpleaseHectorit;heardkindifMes^minKingmarknewelsecruelIt4uchHis"

In [7]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.encoder.eval()
    for split in [DataSplits.TRAIN, DataSplits.TEST]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split, batch_size, context_length)
            logits, loss = model.encoder(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.encoder.train()
    return out

In [8]:
optimizer = torch.optim.AdamW(model.encoder.parameters(), lr=0.05)

for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters -1 :
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses[DataSplits.TRAIN]:.4f}, test loss {losses[DataSplits.TEST]:.4f}")

    x_batch, y_batch = get_batch(DataSplits.TRAIN, batch_size, context_length)

    logits, loss = model.encoder(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

model.predict(' ', 50)

step 0: train loss 7.0785, test loss 7.0803
step 1: train loss 7.1452, test loss 7.0664


' ayadaddidagainstbutguCOMINIUS:speakshowil\nwebeen-s!VIRGILIA:ghtcopRceCaiusghtTotellafme,atshiartthckyouguVIRGILIA:didoffTheyofftrs!CORIOLANUS:becomeshihimfiWaccbu'

In [9]:
model.predict(' ', 50)

" COMINIUS:ce,me,teaveitizenheardblewasthusWeiercosdrwarrmfrimeghtabhim,illupsSentMys!s!LARTIUS:bodi's!WhichayshowbrCORIOLANUS:AUFIDIUS:s!ayEarmIueoffanywork"