Largely adopted from https://github.com/karpathy/build-nanogpt/

In [1]:
import os
import torch

In [2]:
DATASET_URL = (
    "https://raw.githubusercontent.com/karpathy/char-rnn/master"
    "/data/tinyshakespeare/input.txt"
)
DATA_DIR = os.path.expanduser("~/Data/tinyshakespeare")
DATA_FILENAME = "input.txt"
DATA_FILEPATH = os.path.join(DATA_DIR, DATA_FILENAME)

In [3]:
N_TEST_LINES = 8_000

In [4]:
DEVICE = "cpu"
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = "mps"
DEVICE_TYPE = "cuda" if DEVICE.startswith("cuda") else "cpu"
print(f"using device: {DEVICE}")


SEED = 1337
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

using device: mps


## Load data

In [5]:
from urllib import request

os.makedirs(DATA_DIR, exist_ok=True)
if not os.path.isfile(DATA_FILEPATH):
    content = request.urlopen(DATASET_URL)
    with open(DATA_FILEPATH, "wb") as f:
        f.write(content.read())

with open(DATA_FILEPATH, "r") as f:
    text = f.read()

print(example := text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



## Tokenization examples

In [6]:
import tiktoken

ENCODER = tiktoken.get_encoding("gpt2")
example_tokens = ENCODER.encode(example)
print(example_tokens[:24 + 1])

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13, 198]


In [7]:
example_buf = torch.tensor(example_tokens[:24 + 1])
example_x = example_buf[:-1].view(4, 6)
example_y = example_buf[1:].view(4, 6)  # predict the next token
print(example_x)
print(example_y)

tensor([[ 5962, 22307,    25,   198,  8421,   356],
        [ 5120,   597,  2252,    11,  3285,   502],
        [ 2740,    13,   198,   198,  3237,    25],
        [  198,  5248,   461,    11,  2740,    13]])
tensor([[22307,    25,   198,  8421,   356,  5120],
        [  597,  2252,    11,  3285,   502,  2740],
        [   13,   198,   198,  3237,    25,   198],
        [ 5248,   461,    11,  2740,    13,   198]])


## Train / test split

In [8]:
print(len(lines := text.splitlines()))
train_text = "\n".join(lines[:-N_TEST_LINES])
val_text = "\n".join(lines[-N_TEST_LINES:])
del lines

40000


## Utils

In [9]:
import re

class DataLoaderLite:
    # We are not doing DDP in this project whatsoever, and we're also only training on
    # tinyshakespeare, so this is a significantly downsized version of the DataLoaderLite
    # in the tutorial

    def __init__(
        self,
        B: int,
        T: int,
        text: str,
        loop: bool = False,
    ):
        self.B: int = B
        self.T: int = T
        self.text: str = text
        self.loop: bool = loop
        self.reset()

    @property
    def BT(self) -> int:
        return self.B * self.T

    def reset(self):
        self._buf = []
        self._pos = 0

    def next_tokens(self, add_pred_token: bool = False) -> list[int]:
        BT = self.BT
        n_tokens = BT + int(add_pred_token)
        pos_step = n_tokens * 4

        while len(self._buf) < n_tokens:
            segment = self.text[self._pos : self._pos + pos_step]
            if not(segment):
                raise RuntimeError("no tokens remaining")
            if match := re.search(r"\W+", segment[::-1]):
                segment = segment[: len(segment) - match.end() + 1]
            tokens = ENCODER.encode(segment)
            self._buf.extend(tokens)
            self._pos += len(segment)

        tokens = self._buf[: n_tokens]  # we want this many tokens
        self._buf = self._buf[BT :]  # remove BT tokens (not BT + 1!)
        return tokens

    def next_batch(self) -> tuple[torch.Tensor, torch.Tensor]:
        batch_tokens = torch.as_tensor(self.next_tokens(add_pred_token=True))
        try:
            x = (batch_tokens[:-1]).view(self.B, self.T) # inputs
            y = (batch_tokens[1:]).view(self.B, self.T) # targets
        except RuntimeError:
            raise RuntimeError("no more batches remaining")
        return x, y

    def __iter__(self):
        return self

    def __next__(self) -> tuple[torch.Tensor, torch.Tensor]:
        try:
            return self.next_batch()
        except RuntimeError:
            if self.loop:
                self.reset()
                try:
                    return self.next_batch()
                except RuntimeError:
                    raise StopIteration
            else:
                raise StopIteration

In [10]:
example = (
    "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor"
    " incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam,"
)
dataloader = DataLoaderLite(2, 6, example)

def print_next_tokens(add_pred_token: bool):
    print(dataloader.next_tokens(add_pred_token=add_pred_token))

def print_next_segment(add_pred_token: bool):
    print(ENCODER.decode(dataloader.next_tokens(add_pred_token=add_pred_token)).replace("\n", "\\n"))

dataloader.reset()
print_next_tokens(True)
print_next_tokens(True)
print_next_tokens(True)
print_next_tokens(True)
print()
dataloader.reset()
print_next_segment(True)
print_next_segment(True)
print_next_segment(True)
print_next_segment(True)
print()
dataloader.reset()
print_next_tokens(False)
print_next_tokens(False)
print_next_tokens(False)
print_next_tokens(False)
print()
dataloader.reset()
print_next_segment(False)
print_next_segment(False)
print_next_segment(False)
print_next_segment(False)

[43, 29625, 220, 2419, 388, 288, 45621, 1650, 716, 316, 11, 369, 8831]
[8831, 316, 333, 31659, 271, 2259, 220, 417, 270, 11, 10081, 466, 304]
[304, 3754, 4666, 10042, 753, 312, 312, 2797, 3384, 2248, 382, 2123, 220]
[220, 67, 349, 382, 2153, 2616, 435, 1557, 64, 13, 7273, 551, 320]

Lorem ipsum dolor sit amet, consect
sectetur adipiscing elit, sed do e
 eiusmod tempor incididunt ut labore et 
 dolore magna aliqua. Ut enim

[43, 29625, 220, 2419, 388, 288, 45621, 1650, 716, 316, 11, 369]
[8831, 316, 333, 220, 324, 541, 271, 2259, 1288, 270, 11, 10081]
[466, 304, 3754, 4666, 10042, 220, 1939, 312, 312, 2797, 3384, 2248]
[382, 2123, 288, 349, 382, 2153, 2616, 435, 1557, 64, 13, 7273]

Lorem ipsum dolor sit amet, con
sectetur adipiscing elit, sed
 do eiusmod tempor incididunt ut lab
ore et dolore magna aliqua. Ut


In [11]:
import math
from pydantic import BaseModel

class CosineLRSchedule(BaseModel):
    max_lr: float
    min_lr: float
    warmup_steps: float
    max_steps: float

    def __call__(self, it: int) -> float:
        # 1) linear warmup for warmup_iters steps
        if it < self.warmup_steps:
            return self.max_lr * (it + 1) / self.warmup_steps
        # 2) if it > lr_decay_iters, return min learning rate
        if it > self.max_steps:
            return self.min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (it - self.warmup_steps) / (self.max_steps - self.warmup_steps)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # starts at 1 and goes to 0
        return self.min_lr + coeff * (self.max_lr - self.min_lr)

In [12]:
from chesslm.gpt import GPT
from torch.nn import functional as F

def evaluate_model(model: GPT, dataloader: DataLoaderLite, val_steps: int = 20) -> float:
    model.eval()
    dataloader.reset()
    with torch.no_grad():
        val_loss = 0.0
        for _ in range(val_steps):
            x, y = dataloader.next_batch()
            x, y = x.to(DEVICE), y.to(DEVICE)
            with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16):
                _, loss = model(x, y)
            loss = loss / val_steps
            val_loss += loss.detach().item()
    return val_loss

def write_model_checkpoint(
        model: GPT,
        step: int,
        val_loss: float,
        fpath: str
    ):
    checkpoint = {
        'model': model.state_dict(),
        'config': model.config,
        'step': step,
        'val_loss': val_loss
    }
    # you might also want to add optimizer.state_dict() and
    # rng seeds etc., if you wanted to more exactly resume training
    torch.save(checkpoint, fpath)

def generate_sequences(model: GPT, prefix: str, num_sequences: int, max_length: int) -> list[str]:
    tokens = ENCODER.encode(prefix)
    tokens = torch.tensor(tokens, dtype=torch.long)
    tokens = tokens.unsqueeze(0).repeat(num_sequences, 1)
    xgen = tokens.to(DEVICE)

    sample_rng = torch.Generator(device=DEVICE)
    sample_rng.manual_seed(SEED)

    while xgen.size(1) < max_length:
        # forward the model to get the logits
        with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16):
            logits, _ = model(xgen) # (B, T, vocab_size)

        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)

        # get the probabilities
        probs = F.softmax(logits, dim=-1)

        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)

        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)

        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)

        # append to the sequence
        xgen = torch.cat((xgen, xcol), dim=1)

    return [ENCODER.decode(x[:max_length].tolist()) for x in xgen]

In [13]:
from chesslm.gpt import GPTConfig

MAX_STEPS = 100
WARMUP_STEPS = 10

PRINT_TRAIN_INTERVAL = 5
EVAL_INTERVAL = 25
EVAL_NUM_SEQUENCES = 4
EVAL_SEQUENCE_MAX_LENGTH = 32

MAX_LR = 6e-4
MIN_LR = MAX_LR * 0.1
WEIGHT_DECAY = 0.1
GRAD_ACCUM_STEPS = 2

T = 512
GPT_CONFIG = GPTConfig(
    block_size=T,
    vocab_size=ENCODER.n_vocab,
    n_layer=4,
    n_head=4,
    n_embd=384,
)

B = 4
BT = B * T

In [17]:
import time

model = GPT(GPT_CONFIG).to(DEVICE)

cosine_schedule = CosineLRSchedule(
    max_lr=MAX_LR,
    min_lr=MIN_LR,
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,
)

train_loader = DataLoaderLite(B=B, T=T, text=train_text)
val_loader = DataLoaderLite(B=B, T=T, text=val_text, loop=True)

optimizer = model.configure_optimizers(
    weight_decay=WEIGHT_DECAY,
    learning_rate=MAX_LR,
    device_type=DEVICE_TYPE
)

for step in range(MAX_STEPS):
    t0 = time.time()
    last_step = (step == MAX_STEPS - 1)

    # once in a while:
    # evaluate our validation loss
    # generate from the model and print
    if step % EVAL_INTERVAL == 0 or last_step:
        model.eval()
        with torch.no_grad():
            val_loss = evaluate_model(model, val_loader)
            sequences = generate_sequences(
                model,
                "POMPEY:\n'Twas never merry world since, ",
                EVAL_NUM_SEQUENCES,
                EVAL_SEQUENCE_MAX_LENGTH,
            )

        print("\n--------")
        print(f"Validation loss: {val_loss:.4f}")
        for i, seq in enumerate(sequences, 1):
            seq = seq.replace('\n', '\\n')
            print(f"Sample {i}: {seq}")
        print("--------\n")

    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    train_loss = 0.0
    for micro_step in range(GRAD_ACCUM_STEPS):
        x, y = train_loader.next_batch()
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16):
            logits, loss = model(x, y)

        loss = loss / GRAD_ACCUM_STEPS
        loss.backward()

        train_loss += loss.detach().item()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr = cosine_schedule(step)

    optimizer.step()

    if DEVICE_TYPE == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work

    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = BT * GRAD_ACCUM_STEPS
    tokens_per_sec = tokens_processed / dt

    if step % PRINT_TRAIN_INTERVAL == 0:
        print(
            f"step {step:5d}"
            f" | loss: {train_loss:.6f}"
            f" | lr {lr:.4e}"
            f" | norm: {norm:.4f}"
            f" | dt: {dt*1000:.2f}ms"
            f" | tok/sec: {tokens_per_sec:.2f}"
        )

num decayed parameter tensors: 18, with 26,573,184 parameters
num non-decayed parameter tensors: 34, with 20,736 parameters
using fused AdamW: False

--------
Validation loss: 10.9735
Sample 1: POMPEY:\n'Twas never merry world since,  clumsyハ cruelty attribute embryholflameenaries questions threatens fire Trin communion retard.................. threatensCVE
Sample 2: POMPEY:\n'Twas never merry world since, weeney taxis Hook Scalia pseausibleessorsmorphRGB450 $\ging passing Tackle convince negotiatedCertain
Sample 3: POMPEY:\n'Twas never merry world since,  Pain DeskCVEUST crueltyizu GREEN coff Organ Cerweeney Proced Prep cellsoby BP pounding
Sample 4: POMPEY:\n'Twas never merry world since,  GREENorderingenaries HER Osirisayer KeaneKEY voice UnemploymentUST sophcand Squid Squid trial Herbert
--------

step     0 | loss: 10.974515 | lr 6.0000e-05 | norm: 6.2157 | dt: 1912.00ms | tok/sec: 2142.26
step     5 | loss: 9.492013 | lr 3.6000e-04 | norm: 3.3044 | dt: 542.63ms | tok/sec: 7548.46

RuntimeError: no tokens remaining