## Imports and setup

In [1]:
import sys
assert any("deep_learning_curriculum" in p for p in sys.path)

In [2]:
from __future__ import annotations

from dataclasses import dataclass, field
import math
import random
import re
from typing import cast, Self

import einops
from fancy_einsum import einsum
from matplotlib import pyplot as plt
import numpy as np
import torch as t
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F

def pad_last_dim(x: t.Tensor, max_dim: int) -> t.Tensor:
    assert x.ndim > 0
    assert x.size(-1) <= max_dim
    if x.size(-1) == max_dim:
        return x
    padding = t.zeros(*x.shape[:-1], max_dim - x.size(-1))
    return t.cat((x, padding), dim=-1)

def tokenize(text: str) -> list[str]:
    """Split text into words/tokens based on word boundaries (`\\b`)."""
    return re.split(r"\b", text)

def to_one_hot(x: t.Tensor, max_last_dim: int | None = None) -> t.Tensor:
    assert x.ndim == 2
    x_one_hot = F.one_hot(x)
    return pad_last_dim(x_one_hot, max_last_dim or x_one_hot.size(-1)).to(dtype=t.float32)

def is_one_hot(x: t.Tensor) -> bool:
    assert x.ndim >= 2
    return cast(bool, ((x != 0).sum(-1) == 1).all().item())

@dataclass(frozen=True, slots=True)
class Config:
    d_model: int
    d_vocab: int
    n_layers: int
    n_heads: int
    n_ctx: int
    dropout: float = field(default=0.0, kw_only=True, repr=False)
    epsilon: float = field(default=1e-6, kw_only=True, repr=False)
    
    def __post_init__(self) -> None:
        assert self.d_model % self.n_heads == 0
        assert 0 <= self.dropout <= .9, f"unreasonable dropout: {self.dropout}"
        
    @property
    def d_mlp(self) -> int:
        return self.d_model * 4
    
    @property
    def d_head(self) -> int:
        return self.d_model // self.n_heads
    
    @classmethod
    def dummy(cls, **kwargs) -> Self:
        return cls(**(dict(d_model=20, n_ctx=5000, d_vocab=1000, n_layers=0, n_heads=1) | kwargs))
    
    def random_resid(self, n_batches: int = 16) -> t.Tensor:
        return t.rand(n_batches, self.n_ctx, self.d_model)
    
    def random_tokens(self, n_batches: int = 16, n_ctx: int | None = None, max_token: int | None = None) -> t.Tensor:
        n_ctx = n_ctx or self.n_ctx
        max_token = max_token or self.d_vocab - 1
        tokens = (max_token * t.rand(n_batches, n_ctx)).round().to(dtype=t.int64)
        return tokens
    
    def random_tokens_one_hot(self, n_batches: int = 16, max_token: int | None = None) -> t.Tensor:
        tokens = self.random_tokens(n_batches=n_batches, max_token=max_token)
        tokens_one_hot = to_one_hot(tokens, self.d_vocab)
        assert tokens_one_hot.shape[-2:] == (self.n_ctx, self.d_vocab)
        assert is_one_hot(tokens_one_hot)
        return tokens_one_hot
    
Config.dummy().random_tokens_one_hot().shape

torch.Size([16, 5000, 1000])

### Positional embedding

In [3]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty(cfg.n_ctx, cfg.d_model))
        nn.init.normal_(self.W_pos)
    
    def forward(self, tokens: t.Tensor) -> t.Tensor:
        pe = self.W_pos[:tokens.size(-1), :]
        batch_pe = einops.repeat(pe, "pos d_model -> batch pos d_model", batch=tokens.size(0))
        return batch_pe.clone()

    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        pe = cls(cfg)
        x = cfg.random_tokens()
        y = pe(x)
        assert x.ndim == 2
        assert y.ndim == 3
        assert x.shape[:2] == y.shape[:2]
        assert y.size(-1) == cfg.d_model
        print("PASSED!")
    
PosEmbed.test()

PASSED!


## Embedding

In [4]:
class Embed(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model))
        nn.init.normal_(self.W_E)
        
    def forward(self, tokens: t.Tensor) -> None:
        assert tokens.ndim == 2
        assert tokens.max() <= self.cfg.d_vocab - 1
        assert tokens.size(1) <= self.cfg.n_ctx
        
        tokens_one_hot = to_one_hot(tokens, self.cfg.d_vocab)
        # print(tokens_one_hot.shape, self.cfg.d_vocab)
        return einsum(
            "batch pos d_vocab, d_vocab d_model -> batch pos d_model",
            tokens_one_hot,
            self.W_E
        ) * math.sqrt(self.cfg.d_model)
    
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        embed = cls(cfg)
        tokens = cfg.random_tokens()
        embs = embed(tokens)
        assert embs.isnan().sum().item() == 0
        # print(f"{embs.shape = }")
        print("PASSED!")
        
Embed.test()

embs.shape = torch.Size([16, 5000, 20])
PASSED!


## Unembedding

In [7]:
class Unembed(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty(cfg.d_model, cfg.d_vocab))
        nn.init.normal_(self.W_U)
        self.b_U = nn.Parameter(t.zeros(cfg. d_vocab))
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        return einsum(
            "batch pos d_model, d_model d_vocab -> batch pos d_vocab",
            x,
            self.W_U
        ) + self.b_U
    
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        ue = cls(cfg)
        x = cfg.random_resid()
        y = ue(x)
        print(f"PASSED! {tuple(y.shape)}")
        # print(y.isnan().sum())
        
Unembed.test()

PASSED! (16, 5000, 1000)


## MLP

In [8]:
class MLP(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Linear(cfg.d_model, cfg.d_mlp)
        self.dropout = nn.Dropout(p=cfg.dropout)
        self.W_out = nn.Linear(cfg.d_mlp, cfg.d_model)
    
    def forward(self, x: t.Tensor) -> t.Tensor:
        return self.W_out(self.dropout(F.relu(self.W_in(x))))
    
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        mlp = cls(cfg)
        x = cfg.random_resid()
        y = mlp(x)
        assert x.shape == y.shape
        assert y.isnan().sum().item() == 0
        print(f"PASSED! shape: {tuple(x.shape)}")

MLP.test()

PASSED! shape: (16, 5000, 20)


## Attention

In [9]:
class SelfAttention(nn.Module):
    IGNORE: t.Tensor
    
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        # Q
        self.W_Q = nn.Parameter(t.empty(cfg.d_model, cfg.n_heads, cfg.d_head))
        nn.init.normal_(self.W_Q)
        self.b_Q = nn.Parameter(t.zeros(cfg.n_heads, cfg.d_head))
        # K
        self.W_K = nn.Parameter(t.empty(cfg.d_model, cfg.n_heads, cfg.d_head))
        nn.init.normal_(self.W_K)
        self.b_K = nn.Parameter(t.zeros(cfg.n_heads, cfg.d_head))
        # V
        self.W_V = nn.Parameter(t.empty(cfg.d_model, cfg.n_heads, cfg.d_head))
        nn.init.normal_(self.W_V)
        self.b_V = nn.Parameter(t.zeros(cfg.n_heads, cfg.d_head))
        # O
        self.W_O = nn.Parameter(t.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        nn.init.normal_(self.W_O)
        self.b_O = nn.Parameter(t.zeros(cfg.d_model))
        # buffer
        self.register_buffer("IGNORE", t.tensor(-1e6))
    
    def apply_causal_mask(self, attn_scores: t.Tensor) -> t.Tensor:
        assert attn_scores.ndim == 4
        assert attn_scores.size(2) == attn_scores.size(3)
        return attn_scores.where(
            t.ones_like(attn_scores).triu().flip(-1) == 0,
            self.IGNORE
        )
    
    def forward(self, x: t.Tensor) -> t.Tensor:
        assert x.ndim == 3 # batch pos d_model
        assert x.size(2) == self.cfg.d_model
        assert x.size(1) <= self.cfg.n_ctx
        
        q = einsum(
            "batch pos d_model, d_model n_heads d_head -> batch pos n_heads d_head",
            x, 
            self.W_Q
        ) + self.b_Q
        k = einsum(
            "batch pos d_model, d_model n_heads d_head -> batch pos n_heads d_head",
            x, 
            self.W_K
        ) + self.b_K
        v = einsum(
            "batch pos d_model, d_model n_heads d_head -> batch pos n_heads d_head",
            x, 
            self.W_V
        ) + self.b_V
        
        attn_scores = einsum(
            "batch q_pos n_heads d_head, batch k_pos n_heads d_head -> batch n_heads q_pos k_pos", 
            q,
            k
        ) / math.sqrt(self.cfg.d_head)
        attn_scores_masked = self.apply_causal_mask(attn_scores)
        attn_pattern = attn_scores_masked.softmax(-1)
        
        z = einsum(
            "batch n_heads q_pos k_pos, batch k_pos n_heads d_head -> batch q_pos n_heads d_head",
            attn_pattern,
            v
        )
        
        # print(f"W_O: {tuple(self.W_O.shape)}\nz: {tuple(z.shape)}")
        
        o = einsum(
            "n_heads d_head d_model, batch pos n_heads d_head -> batch pos d_model", 
            self.W_O,
            z
        ) + self.b_O
        return o
        
    
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        x = cfg.random_resid()
        attn = cls(cfg)
        y = attn(x)
        assert x.shape == y.shape
        assert y.isnan().sum().item() == 0
        print(f"Passed! shape: {tuple(x.shape)}")
        
SelfAttention.test()

Passed! shape: (16, 5000, 20)


## LayerNorm

In [10]:
def normalize(x: t.Tensor, epsilon: float = 1e-6) -> t.Tensor:
    return (x - x.mean(-1, keepdim=True)) / (x.std() + epsilon)

class LayerNorm(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.scale = nn.Parameter(t.empty(cfg.d_model))
        nn.init.normal_(self.scale)
        self.translation = nn.Parameter(t.zeros(cfg.d_model))
        
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        return normalize(x, self.cfg.epsilon) * self.scale + self.translation
        
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        ln = cls(cfg)
        x = cfg.random_resid()
        y = ln(x)
        assert x.shape == y.shape
        assert y.isnan().sum().item() == 0
        print("PASSED!")
        
LayerNorm.test()

PASSED!


## TransformerBlock

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = SelfAttention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        return x + self.mlp(self.ln2(self.attn(self.ln1(x))))
    
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy()
        tb = cls(cfg)
        x = cfg.random_resid()
        y = tb(x)
        assert x.shape == y.shape
        assert y.isnan().sum().item() == 0
        print("PASSED!")

TransformerBlock.test()

PASSED!


## Full model

In [12]:
class Transformer(nn.Module):
    def __init__(self, cfg: Config) -> None:
        super().__init__()
        self.cfg = cfg
        self.pos_embed = PosEmbed(cfg)
        self.embed = Embed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.unembed = Unembed(cfg)
        
    def forward(self, tokens: t.Tensor) -> t.Tensor:
        assert tokens.ndim == 2
        assert tokens.size(1) <= self.cfg.n_ctx
        assert tokens.dtype == t.int64
        assert tokens.max() <= self.cfg.d_vocab
        
        pos_embeddings = self.pos_embed(tokens)
        embeddings = self.embed(tokens)
        resid = pos_embeddings + embeddings
        for block_i, block in enumerate(self.blocks):
            resid = block(resid)
        # print("check", resid.isnan().sum().item())
        logits = self.unembed(resid)
        return logits
    
    @classmethod
    def test(cls) -> None:
        cfg = Config.dummy(n_layers=2)
        transformer = cls(cfg)
        tokens = cfg.random_tokens()
        logits = transformer(tokens)
        preds = logits.argmax(-1)
        print(logits.shape)
        print(preds.shape)
        print(cfg)
        nans = logits.isnan().sum().item()
        assert nans == 0, f"{nans = }"
        print(preds[0])
        
Transformer.test()

torch.Size([16, 5000, 1000])
torch.Size([16, 5000])
Config(d_model=20, d_vocab=1000, n_layers=2, n_heads=1, n_ctx=5000)
tensor([919, 993, 407,  ..., 522, 475, 244])


## Train the model on reversing random tokens

IDK what I'm doing wrong but I was unable to get it to more than ~35% accuracy.

In [13]:
cfg = Config(
    d_model=256,
    d_vocab=10,
    n_layers=6,
    n_heads=8,
    n_ctx=10,
)

N = 1024
TRAIN_FRAC = 0.5
BATCH_SIZE = 32
assert int(N * TRAIN_FRAC) % BATCH_SIZE == 0

START_TOKEN = cfg.d_vocab - 1
random_tokens = cfg.random_tokens(N, n_ctx=cfg.n_ctx // 2 - 1, max_token=START_TOKEN - 1)
start_tokens = START_TOKEN * t.ones(N, 1)
data = t.cat((start_tokens, random_tokens), dim=-1).to(dtype=t.int64)
data = t.cat((data, data.flip(-1)), dim=-1)
# random_tokens_reversed = random_tokens.flip(-1)
# data = t.cat((start_tokens, random_tokens, random_tokens_reversed), dim=-1).to(dtype=t.int64)
split_ind = int(TRAIN_FRAC * N)
train_data_flat = data[:split_ind]
train_data = train_data_flat.reshape(-1, BATCH_SIZE, cfg.n_ctx)
test_data_flat = data[split_ind:]
test_data = test_data_flat.reshape(-1, BATCH_SIZE, cfg.n_ctx)

train_data.shape,test_data.shape

# assert train_data.size(1) == cfg.n_ctx # - 1
# assert train_data[:, 1:-1].max() == START_TOKEN - 1
# assert train_data[:, 0].max() == train_data[:, 0].min() == START_TOKEN == train_data[:, -1].max() == train_data[:, -1].min()
# assert train_data.dtype == t.int64

(torch.Size([16, 32, 10]), torch.Size([16, 32, 10]))

In [14]:
def loss_fn(logits: t.Tensor, tokens: t.Tensor) -> t.Tensor:
    # assert logits.shape[:2] == tokens.shape
    # assert logits.size(1) % 2 == 1
    # n_ctx = logits.size(1) // 2
    # logits = logits[:, n_ctx:-1]
    # tokens = tokens[:, n_ctx+1:].unsqueeze(-1)
    # assert logits.ndim == tokens.ndim == 3
    logits = logits[:, :-1]
    tokens = tokens[:, 1:].unsqueeze(-1)
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens)[..., 0]
    return -correct_log_probs.mean()

def acc_fn(logits: t.Tensor, tokens: t.Tensor) -> float:
    n_ctx = logits.size(1) // 2
    logits = logits[:, n_ctx:-1]
    preds = logits.argmax(-1)
    tokens = tokens[:, n_ctx+1:]
    acc = (preds == tokens).mean(dtype=t.float).item()
    return acc
    
# loss_fn = nn.CrossEntropyLoss()

model = Transformer(cfg)
train_logits = model(train_data_flat)
assert train_logits.isnan().sum().item() == 0
loss = loss_fn(train_logits, train_data_flat)
print(f"{loss = }")


# loss = loss_fn(train_preds, train_data.to(dtype=t.float))


# acc = acc_fn(train_logits, train_data)
# print(f"{loss = }; {acc = }")

loss = tensor(422.6483, grad_fn=<NegBackward0>)


In [68]:
N_EPOCHS = 1000
LR = 1e-4
# BETAS = (.9, .99)

log_each_epochs = N_EPOCHS // 50

model = Transformer(cfg)
optimizer = optim.Adam(model.parameters(), lr=LR)
N_MILESTONES = 4
milestones = np.linspace(N_EPOCHS // N_MILESTONES, N_EPOCHS, N_MILESTONES)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.1, patience=10)

train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

for epoch_i in range(N_EPOCHS):
    epoch_train_losses = []
    epoch_test_losses = []
    epoch_train_accuracies = []
    epoch_test_accuracies = []
    for train_batch, test_batch in zip(train_data, test_data):
        # Forward
        train_logits = model(train_batch)
        test_logits = model(test_batch)
        # Loss
        train_loss = loss_fn(train_logits, train_batch)
        test_loss = loss_fn(test_logits, test_batch)
        # Backward and update
        train_loss.backward()
        optimizer.step()
        # Accuracy
        train_acc = acc_fn(train_logits, train_batch)
        test_acc = acc_fn(test_logits, test_batch)
        # Append
        epoch_train_losses.append(train_loss.item())
        epoch_test_losses.append(test_loss.item())
        epoch_train_accuracies.append(train_acc)
        epoch_test_accuracies.append(test_acc)
        
    
    # Measure    
    epoch_train_loss = t.tensor(epoch_train_losses).mean().item()
    epoch_test_loss = t.tensor(epoch_test_losses).mean().item()
    train_losses.append(epoch_train_loss)
    test_losses.append(epoch_test_loss)
    epoch_train_acc = t.tensor(epoch_train_accuracies).mean().item()
    epoch_test_acc = t.tensor(epoch_test_accuracies).mean().item()
    train_accuracies.append(epoch_train_acc)
    test_accuracies.append(epoch_test_acc)
    
    if log_each_epochs and epoch_i % log_each_epochs == 0:
        print(f"[{epoch_i}] \n    loss: train={epoch_train_loss:.3f}; test={epoch_test_loss:.3f};\n    acc:  train={epoch_train_acc:.2%}; test={epoch_test_acc:.2%}")

[0] 
    loss: train=285.656; test=284.506;
    acc:  train=11.72%; test=10.06%
[20] 
    loss: train=80.662; test=88.524;
    acc:  train=26.42%; test=22.31%
[40] 
    loss: train=39.626; test=44.791;
    acc:  train=33.74%; test=30.22%
[60] 
    loss: train=32.910; test=36.502;
    acc:  train=31.74%; test=29.83%
[80] 
    loss: train=36.694; test=39.027;
    acc:  train=29.49%; test=27.54%
[100] 
    loss: train=38.767; test=40.889;
    acc:  train=31.64%; test=30.32%
[120] 
    loss: train=46.558; test=49.009;
    acc:  train=28.42%; test=27.64%
[140] 
    loss: train=52.575; test=53.422;
    acc:  train=30.32%; test=29.20%
[160] 
    loss: train=52.438; test=54.450;
    acc:  train=30.62%; test=28.91%
[180] 
    loss: train=64.006; test=65.967;
    acc:  train=29.98%; test=29.74%
[200] 
    loss: train=60.328; test=62.368;
    acc:  train=28.08%; test=29.05%
[220] 
    loss: train=58.725; test=59.570;
    acc:  train=33.79%; test=32.91%
[240] 
    loss: train=67.123; test=66.127;


In [60]:
batch = train_data[0][:1]
logits = model(batch)
preds = logits.argmax(-1)
print(batch)
print(preds)

tensor([[9, 8, 1, 7, 3, 3, 7, 1, 8, 9]])
tensor([[5, 0, 0, 3, 5, 5, 3, 5, 9, 2]])


In [13]:

logits = model(train_data)
preds = logits.argmax(-1)
print(train_data[0])
print(preds[0])

tensor([11, 11, 15,  5, 11, 29, 10, 19, 19, 10, 29, 11,  5, 15, 11, 11])
tensor([26, 11,  7, 11,  9, 20, 11, 10, 19, 20, 11, 11,  8,  7, 11,  9])


In [14]:
logits = model(test_data)
preds = logits.argmax(-1)
print(test_data[0])
print(preds[0])

tensor([26, 16, 28, 18,  7, 12, 14, 11, 11, 14, 12,  7, 18, 28, 16, 26])
tensor([26,  6, 18, 26, 14,  5, 28, 12, 11,  7,  5, 14, 28, 18, 12, 22])


In [15]:
def generate(model: Transformer, tokens: t.Tensor, max_new_tokens: int | None = None) -> t.Tensor:
    assert tokens.ndim == 2
    assert tokens.size(1) < model.cfg.n_ctx
    max_new_tokens = max_new_tokens or model.cfg.n_ctx - tokens.size(-1)
    new_tokens = tokens.detach().clone()
    for i in range(max_new_tokens):
        logits = model(new_tokens)
        preds = logits.argmax(-1)
        assert preds.ndim == 2
        final_preds = preds[:, -1].unsqueeze(-1)
        new_tokens = t.cat((new_tokens, final_preds), dim=-1)
    return new_tokens

x = train_data[:1]
tokens = x[:, :model.cfg.n_ctx // 2 + 2]
y = generate(model, tokens)
print(x)
print(y)

tensor([[11, 11, 15,  5, 11, 29, 10, 19, 19, 10, 29, 11,  5, 15, 11, 11]])
tensor([[11, 11, 15,  5, 11, 29, 10, 19, 19, 10, 23, 15,  7,  4,  2, 25]])


## Train the model on Shakespeare's works

In [11]:
import os
import pathlib
import re
from urllib.request import urlopen


shakespeare_path = pathlib.Path("../data/shakespeare.txt")

if shakespeare_path.exists():
    print("Loading Shakespeare...")
    with open(shakespeare_path, "r", encoding="utf-8") as f:
        text = f.read()
else:
    print("Fetching Shakespeare..")
    url = "https://www.gutenberg.org/files/100/100-0.txt"
    text = urlopen(url).read().decode("utf-8")
    with open(shakespeare_path, "w", encoding="utf-8") as f:
        f.write(text)

tokens = tokenize(text)

print(f"Shakespeare text: {len(text)} characters, {len(tokens)} tokens")


Loading Shakespeare...
Shakespeare text: 5392638 characters, 1991703 tokens


In [30]:
from collections import Counter
from typing import Iterable, TypeVar

@dataclass(frozen=True, slots=True)
class Tokenizer:
    d_vocab: int
    tok2int: dict[str, int]
    int2tok :dict[int, str]
    
    @staticmethod
    def split_into_tokens(text: str) -> list[str]:
        return re.split(r"\b", text)
    
    @classmethod
    def make(cls, text: str) -> Tokenizer:
        tokens = cls.split_into_tokens(text)
        token_counts = Counter(tokens)
        d_vocab = len(token_counts) + 1 # plus BOS/EOS
        tok2int = {tok: i for i, (tok, _) in enumerate(sorted(token_counts.items(), key=lambda x: x[1], reverse=True))}
        int2tok = {i: tok for tok, i in tok2int.items()}
        assert len(tok2int) == d_vocab - 1
        return cls(d_vocab, tok2int, int2tok)
    
    @property
    def eos(self) -> int:
        return self.d_vocab - 1
    @property
    def token_set(self) -> set[str]:
        return set(self.tok2int)
    
    def tokenize(self, text: str) -> tuple[list[str], list[int]]:
        tokens = self.split_into_tokens(text)
        assert set(tokens) <= self.token_set
        token_ids = [self.tok2int[tok] for tok in tokens]
        return tokens, token_ids
    
    def decode(self, inds: Iterable[int]) -> list[str]:
        assert all(i < self.d_vocab for i in inds)
        return [self.int2tok[i] for i in inds]
        


T = TypeVar("T")
def split_into_pieces(xs: list[T], n_pieces: int, piece_length: int) -> list[list[T]]:
    assert n_pieces * piece_length < len(xs)
    # max_start_ind = len(xs) - piece_length
    # start_inds = [random.randint(0, max_start_ind) for _ in range(n_pieces)]
    pieces = [
        xs[i * piece_length : (i + 1) * piece_length] 
        for i in range(n_pieces)
    ]
    assert len(pieces) == n_pieces
    assert all(len(p) == piece_length for p in pieces)
    return pieces
    

tokenizer = Tokenizer.make(text)
tokens, token_ids = tokenizer.tokenize(text)

cfg = Config(
    d_model=128,
    d_vocab=tokenizer.d_vocab,
    n_layers=4,
    n_heads=8,
    n_ctx=256,
)

random.seed(42)
pieces = split_into_pieces(token_ids, n_pieces=256, piece_length=cfg.n_ctx)
token_tensor = t.tensor(pieces)#.unsqueeze(0)

print(tokenizer.decode(pieces[0])[:20])

['\ufeff', 'The', ' ', 'Project', ' ', 'Gutenberg', ' ', 'eBook', ' ', 'of', ' ', 'The', ' ', 'Complete', ' ', 'Works', ' ', 'of', ' ', 'William']


In [31]:
def loss_fn(logits: t.Tensor, tokens: t.Tensor) -> t.Tensor:
    logits = logits[:, :-1]
    tokens = tokens[:, 1:].unsqueeze(-1)
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens).squeeze(-1)
    return -correct_log_probs.mean()

In [32]:
N_BATCHES = 8
assert token_tensor.size(0) % N_BATCHES == 0
batches = token_tensor.reshape(N_BATCHES, -1, token_tensor.size(-1))
batches.shape

torch.Size([8, 32, 256])

In [33]:
model = Transformer(cfg)

N_EPOCHS = 100
LR = 1e-4

optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.1, patience=4)

loss_history = []


for epoch_i in range(N_EPOCHS):
    epoch_losses = []

    for batch_i, batch in enumerate(batches):
        # Forward
        logits = model(batch)
        # Loss
        loss = loss_fn(logits, batch)
        # Backward and update
        loss.backward()
        optimizer.step()
        # Append
        epoch_losses.append(loss.item())
    
    # Measure
    epoch_loss = t.tensor(epoch_losses).mean().item()
    loss_history.append(epoch_loss)
    
    print(f"[{epoch_i}]: loss = {epoch_loss:.3f}")

[0]: loss = 497.613
[1]: loss = 475.608
[2]: loss = 456.961
[3]: loss = 438.606
[4]: loss = 417.623
[5]: loss = 394.853
[6]: loss = 370.869
[7]: loss = 356.367
[8]: loss = 350.797
[9]: loss = 343.557
[10]: loss = 333.555
[11]: loss = 320.975
[12]: loss = 308.607
[13]: loss = 300.223
[14]: loss = 294.453
[15]: loss = 289.125
[16]: loss = 283.711
[17]: loss = 278.533
[18]: loss = 274.038
[19]: loss = 269.739
[20]: loss = 265.175
[21]: loss = 261.107
[22]: loss = 257.682
[23]: loss = 253.756
[24]: loss = 248.863
[25]: loss = 243.134
[26]: loss = 237.000
[27]: loss = 231.162
[28]: loss = 226.516
[29]: loss = 223.017
[30]: loss = 220.411
[31]: loss = 218.557
[32]: loss = 216.727
[33]: loss = 213.388
[34]: loss = 207.911
[35]: loss = 201.511
[36]: loss = 195.534
[37]: loss = 190.740
[38]: loss = 187.545
[39]: loss = 185.531
[40]: loss = 183.982
[41]: loss = 182.198
[42]: loss = 179.854
[43]: loss = 177.266
[44]: loss = 174.394
[45]: loss = 171.092
[46]: loss = 167.486
[47]: loss = 163.795
[4

KeyboardInterrupt: 

In [35]:
from datetime import datetime
import pickle

dt_str = datetime.now().isoformat().replace(":", "-").split(".")[0]
model_filepath = f"../models/model-1-{dt_str}.pkl"
with open(model_filepath, "wb") as f:
    pickle.dump(model, f)

# TODO

- add BOS token
- retrain model with max number of splits
- generate shakespeare or sth