## Imports and setup

In [1]:
import sys
assert any("deep_learning_curriculum" in p for p in sys.path), sys.path

In [2]:
from __future__ import annotations

from collections import Counter
from dataclasses import dataclass, field
import itertools as it
import random
from frozendict import frozendict
import os
import pathlib
import re
from urllib.request import urlopen

import numpy as np
import torch as t
from torch import nn, optim
from tqdm import tqdm

from config import Config
from model import Transformer

In [3]:
PATH = pathlib.Path(os.getcwd())
while not str(PATH).endswith("_curriculum"):
    PATH = PATH.parent
print(f"{PATH = }")

PATH = PosixPath('/home/matthewbaggins/code/deep_learning_curriculum')


## Train the model on Shakespeare's works

In [4]:
data_path = PATH / "data"

def load_corpus_text() -> str:
    if not data_path.exists():
        data_path.mkdir()

    shakespeare_path = data_path / "shakespeare.txt"

    if shakespeare_path.exists():
        print("Loading Shakespeare...")
        with open(shakespeare_path, "r", encoding="utf-8") as f:
            text = f.read()
    else:
        print("Fetching Shakespeare..")
        url = "https://www.gutenberg.org/files/100/100-0.txt"
        text = urlopen(url).read().decode("utf-8")
        with open(shakespeare_path, "w", encoding="utf-8") as f:
            f.write(text)
    return text
    

def tokenize(text: str) -> list[str]:
    return re.split(r"\b", text)

EOS_TOKEN_STR = "<EOS>"
EOS_TOKEN_INT = 0

@dataclass(frozen=True, slots=True)
class Corpus:
    text: str
    tokens_str: list[str]
    tokens_int: list[int]
    tok_int2str: frozendict[int, str]
    tok_str2int: frozendict[str, int]
    token_counts: frozendict[str, int]
    
    @classmethod
    def load(cls) -> Corpus:
        text = load_corpus_text()
        tokens_str = tokenize(text)
        token_counts = Counter(tokens_str)
        tok_int2str: dict[int, str] = {EOS_TOKEN_INT: EOS_TOKEN_STR}
        tok_str2int: dict[str, int] = {EOS_TOKEN_STR: EOS_TOKEN_INT}
        for tok_int, (tok_str, _tok_count) in enumerate(
            sorted(token_counts.items(), key=lambda x: x[1], reverse=True),
            start=1
        ):
            assert tok_int not in tok_int2str
            assert tok_str not in tok_str2int
            tok_int2str[tok_int] = tok_str
            tok_str2int[tok_str] = tok_int
        tokens_int = [tok_str2int[tok_str] for tok_str in tokens_str]
        corpus = cls(
            text=text,
            tokens_str=tokens_str,
            tokens_int=tokens_int,
            tok_int2str=frozendict(tok_int2str),
            tok_str2int=frozendict(tok_str2int),
            token_counts=frozendict(token_counts)
        )
        print(f"Shakespeare text: {len(text)} characters, {len(tokens_str)} tokens")
        return corpus
    
    def get_corpus_subsequences(self, subseq_len: int, *, pad_with_eos: bool = True) -> t.Tensor: # [batch pos]
        n_subseqs = len(self) // subseq_len
        if pad_with_eos:
            subseq_len -= 1
        subseqs = t.tensor(
            [
                self.tokens_int[i * subseq_len : (i + 1) * subseq_len]
                for i in range(n_subseqs)
            ],
            dtype=t.int64
        )
        if pad_with_eos:
            subseqs = t.cat(
                [
                    t.tensor(list(it.repeat(EOS_TOKEN_INT, n_subseqs))).reshape(n_subseqs, 1),
                    subseqs
                ],
                dim=1
            )
        return subseqs
    
    def __len__(self) -> int:
        return len(self.tokens_str)
    
    @property
    def vocab_size(self) -> int:
        return len(self.token_counts)

In [5]:
corpus = Corpus.load()

Loading Shakespeare...
Shakespeare text: 5392638 characters, 1991703 tokens


In [6]:
def loss_fn(logits: t.Tensor, tokens: t.Tensor) -> t.Tensor:
    logits = logits[:, :-1]
    tokens = tokens[:, 1:].unsqueeze(-1)
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens)[..., 0]
    return -correct_log_probs.mean()

def acc_fn(logits: t.Tensor, tokens: t.Tensor) -> float:
    logits = logits[:, :-1]
    tokens = tokens[:, 1:]
    preds = logits.argmax(-1)
    acc = (tokens == preds).mean(dtype=t.float64).item()
    return acc

In [7]:
@dataclass(frozen=True, slots=True)
class TrainingHistory:
    losses: list[list[float]] = field(default_factory=list)
    accuracies: list[list[float]] = field(default_factory=list)

    def __post_init__(self) -> None:
        assert len(self.losses) == len(self.accuracies)
    
    def __len__(self) -> int:
        return len(self.losses)

def stop_early(accuracies: list[float], early_stopping_epochs: int | None, min_acc: float = 1.0) -> bool:
    if early_stopping_epochs is None:
        return False
    return all(acc >= min_acc for acc in accuracies)

def train(
    model: nn.Module,
    batches: list[t.Tensor],
    optimizer: optim.Optimizer,
    *,
    n_epochs: int = 20,
    scheduler: optim.lr_scheduler.LRScheduler | optim.lr_scheduler.ReduceLROnPlateau | None = None, # type: ignore
) -> TrainingHistory:
    th = TrainingHistory()
    n_batches = len(batches)
    for epoch_i in range(1, n_epochs + 1):
        print(f"Epoch {epoch_i} / {n_epochs}")
        
        batch_losses: list[float] = []
        batch_accs: list[float] = []
        th.losses.append(batch_losses)
        th.accuracies.append(batch_accs)
        
        for batch_i, batch in enumerate(batches, 1):
            optimizer.zero_grad()
            batch_logits = model(batch)
            batch_loss = loss_fn(batch_logits, batch)
            batch_loss.backward()
            optimizer.step()
            with t.no_grad():
                batch_acc = acc_fn(batch_logits, batch)
            batch_losses.append(batch_loss.item())
            batch_accs.append(batch_acc)
            
            msg = f"\t Batch {batch_i} / {n_batches}: loss={batch_losses[-1]:.3f}, acc={batch_acc:.3%}"
            
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(batch_loss)
                if hasattr(scheduler, "_last_lr"):
                    last_lr: float = getattr(scheduler, "_last_lr")[-1]
                    msg += f", {last_lr=}"
            elif scheduler is not None:
                scheduler.step()
                last_lr: float = scheduler.get_last_lr()[-1]
                msg += f", {last_lr=}"
            
            print(msg)
                
    return th

In [8]:
cfg = Config(
    d_model=128,
    d_vocab=corpus.vocab_size + 1,
    n_layers=2,
    n_heads=4,
    n_ctx=128,
)
subseqs = corpus.get_corpus_subsequences(subseq_len=cfg.n_ctx - 1)
BATCH_SIZE = 32
N_BATCHES = len(subseqs) // BATCH_SIZE
batches: list[t.Tensor] = [subseqs[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] for i in range(N_BATCHES)]
random.shuffle(batches)

In [9]:
model = Transformer(cfg)

LR = 1e-3
optimizer = optim.AdamW(model.parameters(), lr=LR)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

N_EPOCHS = 2
LOG_EACH_EPOCHS = 1

In [10]:
th = train(
    model,
    batches,
    optimizer=optimizer,
    n_epochs=N_EPOCHS,
    scheduler=scheduler
)

Epoch 1 / 2
	 Batch 1 / 490: loss=10.400, acc=0.000%, last_lr=[0.001]
	 Batch 2 / 490: loss=10.393, acc=7.515%, last_lr=[0.001]
	 Batch 3 / 490: loss=10.381, acc=36.458%, last_lr=[0.001]
	 Batch 4 / 490: loss=10.364, acc=39.658%, last_lr=[0.001]
	 Batch 5 / 490: loss=10.345, acc=35.615%, last_lr=[0.001]
	 Batch 6 / 490: loss=10.313, acc=35.789%, last_lr=[0.001]
	 Batch 7 / 490: loss=10.277, acc=33.036%, last_lr=[0.001]
	 Batch 8 / 490: loss=10.215, acc=34.077%, last_lr=[0.001]
	 Batch 9 / 490: loss=10.125, acc=36.905%, last_lr=[0.001]
	 Batch 10 / 490: loss=10.031, acc=34.598%, last_lr=[0.001]
	 Batch 11 / 490: loss=9.877, acc=36.607%, last_lr=[0.001]
	 Batch 12 / 490: loss=9.705, acc=37.153%, last_lr=[0.001]
	 Batch 13 / 490: loss=9.522, acc=33.631%, last_lr=[0.001]
	 Batch 14 / 490: loss=9.188, acc=37.227%, last_lr=[0.001]
	 Batch 15 / 490: loss=8.923, acc=35.045%, last_lr=[0.001]
	 Batch 16 / 490: loss=8.494, acc=35.293%, last_lr=[0.001]
	 Batch 17 / 490: loss=8.000, acc=34.449%, la

KeyboardInterrupt: 