# Branchy GPT

In this notebook we will try to train a custom BranchyGPT for experiment on Shakespeare_char dataset for experimental purposes, It might scale further to openwebtext after.

First please run 

    python data/shakespeare_char/prepare.py


In [1]:
import torch
import os
import numpy as np
import time
from contextlib import nullcontext

from model import GPTConfig, GPT
torch.manual_seed(1337)

<torch._C.Generator at 0x7f10c4426b30>

In [2]:
# Setting up checkpoint saving directory
out_dir = "./BranchyGPT_save"
dataset = "shakespeare_char"
dtype = torch.float16

# Get device between GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=dtype)
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

#Prepare dataset
gradient_accumulation_steps = 1 # used to simulate larger batch sizes
batch_size = 128
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')


In [3]:
# Get default conf, model is GPT2
gptconf = GPTConfig()
gptconf.block_size = 256
gptconf.n_layer = 6
gptconf.n_head = 6
gptconf.n_embd = 384
model = GPT(gptconf)
model.to(device)
print(gptconf)

number of parameters: 29.96M
GPTConfig(block_size=256, vocab_size=50304, n_layer=6, n_head=6, n_embd=384, dropout=0.0, bias=True)


In [4]:
# adamw optimizer
learning_rate = 1e-3 # max learning rate
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.99
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0

optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device)

if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model) # requires PyTorch 2.0


num decayed parameter tensors: 26, with 30,031,872 parameters
num non-decayed parameter tensors: 50, with 30,720 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)


In [5]:
def get_batch(split, block_size, batch_size, device):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

X, Y = get_batch('train', gptconf.block_size, batch_size, device) # fetch the very first batch

eval_iters = 200 # how many iterations to average loss over when evaluating
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, gptconf.block_size, batch_size, device)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [6]:
# Training loop
iter_num = 0
eval_interval = 100
best_val_loss = 0
max_iters = 2000
log_interval = 10
t0 = time.time()
while True:
    
    lr = learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': gptconf,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
                # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train', gptconf.block_size, batch_size, device)
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)
    
    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")

    iter_num += 1

    if iter_num > max_iters:
        break


step 0: train loss 10.9045, val loss 10.8965
iter 0: loss 10.9042, time 34231.67ms
iter 10: loss 6.6493, time 174.91ms
iter 20: loss 3.9184, time 174.92ms
iter 30: loss 3.4536, time 176.65ms
iter 40: loss 3.4101, time 175.59ms
iter 50: loss 3.3761, time 177.68ms
iter 60: loss 3.3678, time 175.01ms
iter 70: loss 3.3611, time 178.31ms
iter 80: loss 3.3904, time 174.75ms
iter 90: loss 3.3788, time 175.37ms
step 100: train loss 3.3606, val loss 3.3982
iter 100: loss 3.3670, time 20155.28ms
iter 110: loss 3.3591, time 176.92ms
iter 120: loss 3.3660, time 176.46ms
iter 130: loss 3.3970, time 177.02ms
iter 140: loss 3.3918, time 175.55ms
iter 150: loss 3.3847, time 175.97ms
iter 160: loss 3.3573, time 175.44ms
iter 170: loss 3.4239, time 176.10ms
iter 180: loss 3.3341, time 176.35ms
iter 190: loss 3.3289, time 176.28ms
step 200: train loss 3.5250, val loss 3.5963
iter 200: loss 3.5595, time 20357.84ms
iter 210: loss 3.3558, time 176.36ms
iter 220: loss 3.3312, time 175.43ms
iter 230: loss 3.3

KeyboardInterrupt: 

In [19]:
import pickle

start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster

meta_path = "./data/shakespeare_char/meta.pkl"
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')


Loading meta from ./data/shakespeare_char/meta.pkl...

se n tnaee t ayy nhphh n rhaeoi ante eehneesa yttnhsa hh tthaa    nae i a enetah  ypaop eh r sae ntht etstyot  he  o aoeaetpeayh  yteehtt  n  hppot ie peoat arhrao enie  at  aheh a rarpey  aeyh htarahpho iois ho ye hyhy nyeo  eiep e oetehaahseoo ec n te   ta sthtsoe io eae   e  s  re te  eea s tetes ie crco   oe  ot e ttptttp ihsoeaoteht  c pet hs hryee eo scate  ocry h chntattte   hoo  iosetaett e anoot   hioheoe tt   ceahhe achp  a ci    sett  oh ie i   etaaapai hh  r    oeeh  na opeeoah   eot
---------------

nie  e cet y   y  tpao erh orea ep     n    eo aaaoo  she   a    i yhech ee eoenhcsr tshaeieseorea  pttaiyy soesys h    h ie aseestshae pocretsho s e ceo iort    cne t  tah i    yetthn    orhhhos   a eh ttteee    hhcethnsyho nryneopnheisi e e a  o rrs tenoe a  h iho  acr cea   atooe etche oh ityanecsh    eea  ee tir aiocn  hc o a  a apoa oepeas osett o yehhoah th optsyet r ee ati  n s iee  h a rhtei pt eh  htooetrsh tones  

KeyboardInterrupt: 

In [22]:
import torch
a = {0:0, 1:0, 2:0, 3:0}
for i in range (10000):
    a[torch.multinomial(torch.tensor([0.25, 0.25, 0.25, 0.25]),1).item()] += 1
print(a)

{0: 2330, 1: 2856, 2: 2461, 3: 2353}


In [3]:
import torch
device_type = 'cuda'
batch_size = 64
device = torch.device(device_type)

train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
calib_data = train_data[:1000]

def get_batch(split):
    if split == 'train':
        data = train_data
    elif split == 'val':
        data = val_data
    elif split == 'calib':
        data = calib_data
    else:
        raise ValueError(f"invalid split: {split}")
    ix = torch.randint(len(data) - 128, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+128]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+128]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [5]:
from rejectOption import RejectOption
from model import BranchyGPT, GPTConfig
import numpy as np
import os
import torch

data_dir = "data/shakespeare_char"

X,Y = get_batch('calib')

# baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2
block_size = 256 # context of up to 256 previous characters
bias = False # do we use bias inside LayerNorm and Linear layers?

model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=65, dropout=dropout)
gptconf = GPTConfig(**model_args)
model = BranchyGPT(gptconf).to(torch.device('cuda:0'))
model = torch.compile(model)
print(X.device)
print(next(model.parameters()).device)
print(model(X)[0].shape)
reject_option = RejectOption(dataset=X, model=model)



number of parameters: 10.65M
cuda:0
cuda:0


  Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(seed_cuda_0, 0)
    tmp1 = index_expr(i2 + 384 * i1 + 49152 * i0, torch.int32)
    tmp2 = rand(tmp0, tmp1, torch.float32)
    tmp3 = constant(0.2, torch.float32)
    tmp4 = tmp2 > tmp3
    tmp5 = to_dtype(tmp4, torch.float32)
    tmp6 = load(primals_51, i1 + 128 * i0)
    tmp7 = load(primals_14, i2 + 384 * (tmp6))
    tmp8 = load(buf0, i1)
    tmp9 = load(primals_15, i2 + 384 * (tmp8))
    tmp10 = tmp7 + tmp9
    tmp11 = tmp5 * tmp10
    return tmp11
    ,
    ranges=[64, 128, 384],
    origins={convert_element_type, philox_rand_like, primals_14, primals_51, embedding_1, add, primals_15, philox_seed_like, mul, gt, embedding, unsqueeze}
  )
))
  Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(seed_cuda_0, 0)
    tmp1 = index_expr(i2 + 384 * i1 + 49152 * i0, torch.int32)
    tmp2 = rand(tmp0, tmp1, torch.float32)
    tmp3 = constant(0.2, torch.float32)
    tmp4 = tmp2 > tmp3
    tmp5 = to_dtype(tmp4, torch.float32)
   

torch.Size([6, 64, 1, 65])


  Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(arg50_1, i1 + 128 * i0)
    tmp1 = load(arg13_1, i2 + 384 * (tmp0))
    tmp2 = index_expr(i1, dtype=torch.int64)
    tmp3 = load(arg14_1, i2 + 384 * (tmp2))
    tmp4 = tmp1 + tmp3
    tmp5 = load(buf1, i1 + 128 * i0)
    tmp6 = tmp4 - tmp5
    tmp7 = load(buf2, i1 + 128 * i0)
    tmp8 = index_expr(384, torch.float32)
    tmp9 = tmp7 / tmp8
    tmp10 = constant(1e-05, torch.float32)
    tmp11 = tmp9 + tmp10
    tmp12 = rsqrt(tmp11)
    tmp13 = tmp6 * tmp12
    return tmp13
    ,
    ranges=[64, 128, 384],
    origins={embedding_1, unsqueeze, arg50_1, var_mean, add, sub, mul, iota, add_1, embedding, arg13_1, arg14_1, rsqrt}
  )
))
  Pointwise(
    'cuda',
    torch.float32,
    tmp0 = load(arg50_1, i1 + 128 * i0)
    tmp1 = load(arg13_1, i2 + 384 * (tmp0))
    tmp2 = index_expr(i1, dtype=torch.int64)
    tmp3 = load(arg14_1, i2 + 384 * (tmp2))
    tmp4 = tmp1 + tmp3
    tmp5 = load(buf1, i1 + 128 * i0)
    tmp6 = tmp4 - tmp5
   

AttributeError: 'tuple' object has no attribute 'shape'

In [5]:
import torch
test_tensor = torch.tensor([1,2,3,4,5,6,7,8,9,10], dtype=torch.float32)
print(test_tensor)
test_tensor = test_tensor.to(torch.device('cuda:0'))
print(test_tensor)
test_tensor = test_tensor.cpu()
print(test_tensor)

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], device='cuda:0')
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
