## Train the model on Shakespeare's works

In [None]:
import os
import pathlib
import re
from urllib.request import urlopen


shakespeare_path = pathlib.Path("../data/shakespeare.txt")

if shakespeare_path.exists():
    print("Loading Shakespeare...")
    with open(shakespeare_path, "r", encoding="utf-8") as f:
        text = f.read()
else:
    print("Fetching Shakespeare..")
    url = "https://www.gutenberg.org/files/100/100-0.txt"
    text = urlopen(url).read().decode("utf-8")
    with open(shakespeare_path, "w", encoding="utf-8") as f:
        f.write(text)

tokens = tokenize(text)

print(f"Shakespeare text: {len(text)} characters, {len(tokens)} tokens")


Loading Shakespeare...
Shakespeare text: 5392638 characters, 1991703 tokens


In [None]:
from collections import Counter
from typing import Iterable, TypeVar

@dataclass(frozen=True, slots=True)
class Tokenizer:
    d_vocab: int
    tok2int: dict[str, int]
    int2tok :dict[int, str]
    
    @staticmethod
    def split_into_tokens(text: str) -> list[str]:
        return re.split(r"\b", text)
    
    @classmethod
    def make(cls, text: str) -> Tokenizer:
        tokens = cls.split_into_tokens(text)
        token_counts = Counter(tokens)
        d_vocab = len(token_counts) + 1 # plus BOS/EOS
        tok2int = {tok: i for i, (tok, _) in enumerate(sorted(token_counts.items(), key=lambda x: x[1], reverse=True))}
        int2tok = {i: tok for tok, i in tok2int.items()}
        assert len(tok2int) == d_vocab - 1
        return cls(d_vocab, tok2int, int2tok)
    
    @property
    def eos(self) -> int:
        return self.d_vocab - 1
    @property
    def token_set(self) -> set[str]:
        return set(self.tok2int)
    
    def tokenize(self, text: str) -> tuple[list[str], list[int]]:
        tokens = self.split_into_tokens(text)
        assert set(tokens) <= self.token_set
        token_ids = [self.tok2int[tok] for tok in tokens]
        return tokens, token_ids
    
    def decode(self, inds: Iterable[int]) -> list[str]:
        assert all(i < self.d_vocab for i in inds)
        return [self.int2tok[i] for i in inds]
        


T = TypeVar("T")
def split_into_pieces(xs: list[T], n_pieces: int, piece_length: int) -> list[list[T]]:
    assert n_pieces * piece_length < len(xs)
    # max_start_ind = len(xs) - piece_length
    # start_inds = [random.randint(0, max_start_ind) for _ in range(n_pieces)]
    pieces = [
        xs[i * piece_length : (i + 1) * piece_length] 
        for i in range(n_pieces)
    ]
    assert len(pieces) == n_pieces
    assert all(len(p) == piece_length for p in pieces)
    return pieces
    

tokenizer = Tokenizer.make(text)
tokens, token_ids = tokenizer.tokenize(text)

cfg = Config(
    d_model=128,
    d_vocab=tokenizer.d_vocab,
    n_layers=4,
    n_heads=8,
    n_ctx=256,
)

random.seed(42)
pieces = split_into_pieces(token_ids, n_pieces=256, piece_length=cfg.n_ctx)
token_tensor = t.tensor(pieces)#.unsqueeze(0)

print(tokenizer.decode(pieces[0])[:20])

['\ufeff', 'The', ' ', 'Project', ' ', 'Gutenberg', ' ', 'eBook', ' ', 'of', ' ', 'The', ' ', 'Complete', ' ', 'Works', ' ', 'of', ' ', 'William']


In [None]:
def loss_fn(logits: t.Tensor, tokens: t.Tensor) -> t.Tensor:
    logits = logits[:, :-1]
    tokens = tokens[:, 1:].unsqueeze(-1)
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens).squeeze(-1)
    return -correct_log_probs.mean()

In [None]:
N_BATCHES = 8
assert token_tensor.size(0) % N_BATCHES == 0
batches = token_tensor.reshape(N_BATCHES, -1, token_tensor.size(-1))
batches.shape

torch.Size([8, 32, 256])

In [None]:
model = Transformer(cfg)

N_EPOCHS = 100
LR = 1e-4

optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.1, patience=4)

loss_history = []


for epoch_i in range(N_EPOCHS):
    epoch_losses = []

    for batch_i, batch in enumerate(batches):
        # Forward
        logits = model(batch)
        # Loss
        loss = loss_fn(logits, batch)
        # Backward and update
        loss.backward()
        optimizer.step()
        # Append
        epoch_losses.append(loss.item())
    
    # Measure
    epoch_loss = t.tensor(epoch_losses).mean().item()
    loss_history.append(epoch_loss)
    
    print(f"[{epoch_i}]: loss = {epoch_loss:.3f}")

[0]: loss = 497.613
[1]: loss = 475.608
[2]: loss = 456.961
[3]: loss = 438.606
[4]: loss = 417.623
[5]: loss = 394.853
[6]: loss = 370.869
[7]: loss = 356.367
[8]: loss = 350.797
[9]: loss = 343.557
[10]: loss = 333.555
[11]: loss = 320.975
[12]: loss = 308.607
[13]: loss = 300.223
[14]: loss = 294.453
[15]: loss = 289.125
[16]: loss = 283.711
[17]: loss = 278.533
[18]: loss = 274.038
[19]: loss = 269.739
[20]: loss = 265.175
[21]: loss = 261.107
[22]: loss = 257.682
[23]: loss = 253.756
[24]: loss = 248.863
[25]: loss = 243.134
[26]: loss = 237.000
[27]: loss = 231.162
[28]: loss = 226.516
[29]: loss = 223.017
[30]: loss = 220.411
[31]: loss = 218.557
[32]: loss = 216.727
[33]: loss = 213.388
[34]: loss = 207.911
[35]: loss = 201.511
[36]: loss = 195.534
[37]: loss = 190.740
[38]: loss = 187.545
[39]: loss = 185.531
[40]: loss = 183.982
[41]: loss = 182.198
[42]: loss = 179.854
[43]: loss = 177.266
[44]: loss = 174.394
[45]: loss = 171.092
[46]: loss = 167.486
[47]: loss = 163.795
[4

KeyboardInterrupt: 

In [None]:
from datetime import datetime
import pickle

dt_str = datetime.now().isoformat().replace(":", "-").split(".")[0]
model_filepath = f"../models/model-1-{dt_str}.pkl"
with open(model_filepath, "wb") as f:
    pickle.dump(model, f)

# TODO

- add BOS token
- retrain model with max number of splits
- generate shakespeare or sth