In [1]:
import torch
import torchinfo
import numpy as np
import math
import os
import time
from contextlib import nullcontext

## Config

In [2]:
data_path = f'data/tiny_shakespeare_char'

# I/O
eval_only = False # if True, script exits right after the first eval


# system
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 
compile = True

# evaluation and output
out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

# wandb logging
wandb_log = False
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

# optimization hyperparams
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
decay_lr = True # whether to decay the learning rate
lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = 1e-1
min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
warmup_iters = 100
gradient_accumulation_steps = 1 # accumulate gradients over this many steps. simulates larger batch size

# batch size and block size
batch_size = 64
block_size = 256

# DDP (distributed data parallel) training
ddp = False
master_process = True

# TODO: set up DDP for future experiments

In [3]:
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

## Data Loader

In [4]:
# data loader
train_data = np.memmap(f'{data_path}_train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap(f'{data_path}_val.bin', dtype=np.uint16, mode='r')
test_data = np.memmap(f'{data_path}_test.bin', dtype=np.uint16, mode='r')

def get_batch(split, batch_size=batch_size, block_size=256):
    data = train_data if split == 'train' else val_data if split == 'val' else test_data
    idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in idx])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in idx])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

## Model

In [5]:
import torch
from torch import nn
from blocks import EncoderBlock

class TransformerLM(nn.Module):
    """Transformer Language Model"""
    def __init__(self, vocab_size, d_model, n_layers, n_heads, dff, dropout_rate, activation, norm_first, max_block_size, bias=True):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dff = dff
        self.dropout_rate = dropout_rate
        self.activation = activation
        self.norm_first = norm_first
        self.max_block_size = max_block_size
        self.bias = bias

        self.layers = nn.ModuleDict(dict(
            token_embedder = nn.Embedding(vocab_size, d_model),
            positional_embedder = nn.Embedding(max_block_size, d_model),
            dropout = nn.Dropout(dropout_rate),
            blocks = nn.ModuleList([EncoderBlock(d_model, n_heads, dff, dropout_rate,
                activation, norm_first, bias, causal=True) for _ in range(n_layers)]),
            final_out = nn.Linear(d_model, vocab_size)
        ))

    def forward(self, x, targets=None):
        device = x.device
        b, t = x.size()
        assert t <= self.max_block_size, f'Input sequence length {t} exceeds maximum block size {self.max_block_size}'

        token_embedding = self.layers.token_embedder(x)
        positions = torch.arange(0, t, dtype=torch.long, device=device)
        positional_embedding = self.layers.positional_embedder(positions)
        x = token_embedding + positional_embedding

        for enc_block in self.layers.blocks:
            x = enc_block(x)

        if targets is not None:
            # compute loss if given targets
            logits = self.layers.final_out(x)
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.layers.final_out(x[:, [-1], :])
            loss = None

        return logits, loss

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        L, H, Q, T = self.n_layers, self.n_heads, self.d_model//self.n_heads, self.max_block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.layers.positional_embedder.weight.numel()
        return n_params

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

def configure_optimizers(model, weight_decay, learning_rate, betas, device_type):
    # start with all of the candidate parameters
    param_dict = {pn: p for pn, p in model.named_parameters()}
    # filter out those that do not require grad
    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
    # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
    # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
    optim_groups = [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': nodecay_params, 'weight_decay': 0.0}
    ]
    num_decay_params = sum(p.numel() for p in decay_params)
    num_nodecay_params = sum(p.numel() for p in nodecay_params)
    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
    # Create AdamW optimizer and use the fused version if it is available
    use_fused = (device_type == 'cuda')
    optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
    print(f"using fused AdamW: {use_fused}")

    return optimizer

In [6]:
import pickle
meta_data = pickle.load(open(f'data/shakespeare_char_meta.pkl', 'rb'))
vocab_size = meta_data['vocab_size']

In [7]:
model_args = dict(
    vocab_size=vocab_size, d_model=384, n_layers=6, n_heads=6, dff=None,
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True)
model = transformer_lm = TransformerLM(**model_args).to(device)

In [8]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,256)), device='cuda')

Layer (type:depth-idx)                        Output Shape              Param #
TransformerLM                                 [1, 1, 65]                --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [1, 256, 384]             24,960
│    └─Embedding: 2-2                         [256, 384]                98,304
│    └─ModuleList: 2-3                        --                        --
│    │    └─EncoderBlock: 3-1                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-2                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-3                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-4                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-5                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-6                 [1, 256, 384]             1,774,464
│    └─Linear: 2-4                           

## Training Set up

In [17]:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

In [20]:
scaler?

[0;31mType:[0m        GradScaler
[0;31mString form:[0m <torch.cuda.amp.grad_scaler.GradScaler object at 0x7f1d3ac328d0>
[0;31mFile:[0m        ~/miniconda3/envs/abstract_transformer/lib/python3.11/site-packages/torch/cuda/amp/grad_scaler.py
[0;31mDocstring:[0m   <no docstring>

In [10]:
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 27, with 10,765,056 parameters
num non-decayed parameter tensors: 49, with 30,017 parameters
using fused AdamW: True


In [11]:
if compile:
    print('compiling model...')
    unoptimized_model = model
    model = torch.compile(model)
    print('done.')

compiling model...


done.


In [12]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val', 'test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [13]:
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [14]:
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

## Training Loop

In [21]:
# training loop
iter_num = 0
best_val_loss = 1e9

X, Y = get_batch('train', batch_size, block_size) # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:

    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train/loss": losses['train'],
                "val/loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu*100, # convert to percentage
            })
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    # 'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train', batch_size, block_size)
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()

    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

step 0: train loss 4.5337, val loss 4.5311


BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

In [None]:
# TODO: transform training loop above to re-usable training loop for any model/dataset