In [1]:
import torch
import torchinfo
import numpy as np
import math
import os
import time
from contextlib import nullcontext

import sys; sys.path.append('..')
from models import TransformerLM, AbstractTransformerLM, configure_optimizers

## Config

In [2]:
data_path = f'../data/tiny_shakespeare_char'

# I/O
eval_only = False # if True, script exits right after the first eval


# system
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 
compile = True

# evaluation and output
out_dir = '../out/out-shakespeare-char'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

# wandb logging
wandb_log = False
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

# optimization hyperparams
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
decay_lr = True # whether to decay the learning rate
lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = 1e-1
min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
warmup_iters = 100
gradient_accumulation_steps = 1 # accumulate gradients over this many steps. simulates larger batch size

# batch size and block size
batch_size = 64
block_size = 256

# DDP (distributed data parallel) training
ddp = False
master_process = True

# TODO: set up DDP for future experiments

In [3]:
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

## Data Loader

In [4]:
# data loader
train_data = np.memmap(f'{data_path}_train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap(f'{data_path}_val.bin', dtype=np.uint16, mode='r')
test_data = np.memmap(f'{data_path}_test.bin', dtype=np.uint16, mode='r')

def get_batch(split, batch_size=batch_size, block_size=256):
    data = train_data if split == 'train' else val_data if split == 'val' else test_data
    idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in idx])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in idx])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [5]:
import pickle
meta_data = pickle.load(open(f'../data/shakespeare_char_meta.pkl', 'rb'))
vocab_size = meta_data['vocab_size']

## Transformer Model

In [6]:
model_args = dict(
    vocab_size=vocab_size, d_model=384, n_layers=6, n_heads=6, dff=None,
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True)
model = transformer_lm = TransformerLM(**model_args).to(device)

In [7]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,256)), device='cuda')

Layer (type:depth-idx)                        Output Shape              Param #
TransformerLM                                 [1, 1, 65]                --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [1, 256, 384]             24,960
│    └─Embedding: 2-2                         [256, 384]                98,304
│    └─ModuleList: 2-3                        --                        --
│    │    └─EncoderBlock: 3-1                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-2                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-3                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-4                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-5                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-6                 [1, 256, 384]             1,774,464
│    └─Linear: 2-4                           

### Training Set up

In [8]:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

In [9]:
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 27, with 10,765,056 parameters
num non-decayed parameter tensors: 49, with 30,017 parameters
using fused AdamW: True


In [10]:
if compile:
    print('compiling model...')
    unoptimized_model = model
    model = torch.compile(model)
    print('done.')

compiling model...
done.


In [11]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val', 'test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [12]:
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [13]:
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

### Training Loop

In [14]:
# training loop
iter_num = 0
best_val_loss = 1e9

X, Y = get_batch('train', batch_size, block_size) # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:

    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train/loss": losses['train'],
                "val/loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu*100, # convert to percentage
            })
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    # 'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train', batch_size, block_size)
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()

    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

step 0: train loss 4.5564, val loss 4.5639
iter 0: loss 4.5732, time 33891.68ms, mfu -100.00%
iter 10: loss 3.4819, time 46.14ms, mfu 8.11%
iter 20: loss 2.9775, time 46.12ms, mfu 8.11%
iter 30: loss 2.7137, time 45.98ms, mfu 8.11%
iter 40: loss 2.6052, time 45.87ms, mfu 8.12%
iter 50: loss 2.5684, time 45.92ms, mfu 8.12%
iter 60: loss 2.5493, time 45.89ms, mfu 8.12%
iter 70: loss 2.4961, time 46.19ms, mfu 8.12%
iter 80: loss 2.5003, time 46.07ms, mfu 8.12%
iter 90: loss 2.5080, time 46.19ms, mfu 8.12%
iter 100: loss 2.4887, time 46.21ms, mfu 8.12%
iter 110: loss 2.4660, time 46.40ms, mfu 8.11%
iter 120: loss 2.4318, time 46.19ms, mfu 8.11%
iter 130: loss 2.4129, time 46.16ms, mfu 8.11%
iter 140: loss 2.3724, time 46.22ms, mfu 8.11%
iter 150: loss 2.3654, time 46.19ms, mfu 8.11%
iter 160: loss 2.3205, time 45.86ms, mfu 8.11%
iter 170: loss 2.2785, time 45.93ms, mfu 8.12%
iter 180: loss 2.2060, time 45.97ms, mfu 8.12%
iter 190: loss 2.1525, time 46.01ms, mfu 8.12%
iter 200: loss 2.1044,

In [15]:
char_to_idx = meta_data['stoi']
idx_to_char = meta_data['itos']
prompt = 'Romeo'
prompt_idx = torch.from_numpy(np.array([char_to_idx[c] for c in prompt])).unsqueeze(0).to(device)
prompt_idx

tensor([[30, 53, 51, 43, 53]], device='cuda:0')

In [16]:
sample_gen = model.generate(prompt_idx, max_new_tokens=1000, temperature=1.0, top_k=None)[0]
sample_gen = ''.join([idx_to_char[idx] for idx in np.array(sample_gen.cpu())])
print(sample_gen)

Romeo.

AUTOLYCUS:
I have done syemberding.

Shepherd:
Above brief: here's stripp'd always i' the world's death,
bound here they did, glad one of answer.

Clown:
Nay, it may be so: he may be saying, it
mades a special tradition of it. You'll abefall'd
me a torth than most where may say.
Hear most gracious true, the fearful soul
I do not speak with my vow-fill'd blood,
And so I dreamtly bland and I, in a tongue.
You, lord Northumberland on Hermione,
And may progive a craction and cut a wrongot?
I speak not my rage; and he may live
Can pome of her succession.

KING RICHARD III:
Should I convent notwards me now. If it could please
The life must be cause. Go me to kell the maid
In your country, and does tell the malady stand
of a cready liance, that we intend you
Much us our orisons had been an old more But
Hath the petition of your mother shall affect
That you are to trother.

VOLUMNIA:
Our generally,
You please not four must be smile.

VALERIA:
And, now death; how my heart for from meani

In [17]:
# TODO: transform training loop above to re-usable training loop for any model/dataset
# TODO: add back ddp code for training on multiple gpus (on same node)
# TODO: use tqdm instead? would work well in-notebook but perhaps not for slurm jobs

## Abstract Transformer Model

In [18]:
model_args = dict(
    vocab_size=vocab_size, d_model=384, n_layers=6, n_heads_enc=4, n_heads_abs=2, dff=None,
    symbol_retrieval='sym_attn', symbol_retrieval_kwargs=dict(num_symbols=50, n_heads=4, model_dim=384), # FIXME make names consistent: d_model, model_dim
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True)
model = abstracttransformer_lm = AbstractTransformerLM(**model_args).to(device)

In [19]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,256)), device='cuda')

Layer (type:depth-idx)                        Output Shape              Param #
AbstractTransformerLM                         [1, 1, 65]                --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [1, 256, 384]             24,960
│    └─Embedding: 2-2                         [256, 384]                98,304
│    └─ModuleList: 2-3                        --                        --
│    │    └─AbstractEncoderBlock: 3-1         [1, 256, 384]             2,404,224
│    │    └─AbstractEncoderBlock: 3-14        --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-3         --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-4         [1, 256, 384]             2,404,224
│    │    └─AbstractEncoderBlock: 3-14        --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-6         --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-7

In [20]:
# torchinfo overcounts # of params... something to do with symbolic attention shared across layers
# this is the correct number (similar to TransformerLM)

num_params = model.get_num_params() #sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'# of params {num_params:,}')

# of params 13,544,129


In [21]:
# TODO: can we implement in a way that torchinfo can understand? i.e., without "recursive" and overcounting

### Training Set up

In [22]:
# grad scaler
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 42, with 13,605,120 parameters
num non-decayed parameter tensors: 62, with 37,313 parameters
using fused AdamW: True


In [23]:
if compile:
    print('compiling model...')
    unoptimized_model = model
    model = torch.compile(model)
    print('done.')

compiling model...
done.


In [24]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val', 'test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [25]:
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [26]:
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

### Training Loop

In [27]:
# training loop
iter_num = 0
best_val_loss = 1e9

X, Y = get_batch('train', batch_size, block_size) # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:

    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train/loss": losses['train'],
                "val/loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu*100, # convert to percentage
            })
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    # 'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train', batch_size, block_size)
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()

    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

step 0: train loss 4.6440, val loss 4.6503
iter 0: loss 4.6687, time 33445.83ms, mfu -100.00%
iter 10: loss 3.5447, time 85.15ms, mfu 5.67%
iter 20: loss 3.0576, time 85.24ms, mfu 5.67%
iter 30: loss 2.7734, time 85.23ms, mfu 5.67%
iter 40: loss 2.6157, time 85.09ms, mfu 5.67%
iter 50: loss 2.6064, time 85.49ms, mfu 5.66%
iter 60: loss 2.5422, time 85.40ms, mfu 5.66%
iter 70: loss 2.5275, time 85.19ms, mfu 5.66%
iter 80: loss 2.5181, time 85.28ms, mfu 5.66%
iter 90: loss 2.4984, time 85.11ms, mfu 5.66%
iter 100: loss 2.4784, time 84.96ms, mfu 5.66%
iter 110: loss 2.4882, time 85.08ms, mfu 5.67%
iter 120: loss 2.4310, time 85.10ms, mfu 5.67%
iter 130: loss 2.4124, time 84.74ms, mfu 5.67%
iter 140: loss 2.3439, time 85.27ms, mfu 5.67%
iter 150: loss 2.3239, time 84.71ms, mfu 5.67%
iter 160: loss 2.2252, time 85.74ms, mfu 5.67%
iter 170: loss 2.1697, time 84.98ms, mfu 5.67%
iter 180: loss 2.1408, time 85.28ms, mfu 5.67%
iter 190: loss 2.0582, time 85.38ms, mfu 5.66%
iter 200: loss 2.0210,

In [30]:
char_to_idx = meta_data['stoi']
prompt = 'Romeo'
prompt_idx = torch.from_numpy(np.array([char_to_idx[c] for c in prompt])).unsqueeze(0).to('cuda')
prompt_idx

tensor([[30, 53, 51, 43, 53]], device='cuda:0')

In [31]:
sample_gen = model.generate(prompt_idx, max_new_tokens=1000, temperature=1.0, top_k=None)[0]
sample_gen = ''.join([idx_to_char[idx] for idx in np.array(sample_gen.cpu())])
print(sample_gen)

Romeo is too right
To say that truly, as I said, yet it would unplot
A tied where I loved you there.

BRUTUS:
For shame!
All this is the penural great Apollo's corse,
That is not the thought to scarlet me now
That no hage hath caused me and lament
To my blood of imprembrance. The gates hands revolt to
The glass of mine on.

MARCIUS:
They have they do curse they find the court: they will;
No more, love had they say, they follow the rise,
CaAnd now not.

MENENIUS:
I am so breath of the luke: I'll be a very coll,
Apply in one stripe to be a child, on still-sheep-charge:
That out-peach does thee with thy wealth follow'd,
Were thy wilt to seek against what thou hast
Suspicion as thou a palm thine body.

SICINIUS:
A bloody is cheered at the meaner, and infrint
Thy breck of Corioli and no more-ent air.
The commons' sharp-corn unon thy castle hurt by
the file stone: if give not the wisest warrant mine
I swear from off, something the weariest true,
And fine outward in sucributation thstain;
Onl