In [None]:
pip install transformers datasets tiktoken tqdm wandb numpy

In [None]:
pip install  --pre torch torchvision torchaudio   --extra-index-url https://download.pytorch.org/whl/nightly/cpu

In [None]:
# NOTE: Import packages! 
import os,pickle,requests,time,math,torch,tiktoken
import numpy as np
from contextlib import nullcontext
from model import GPTConfig, GPT

In [None]:
# NOTE: Download the tiny shakespeare dataset!
input_file_path = './data/shakespeare_char/input.txt'
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

if not os.path.exists(input_file_path):
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] 
def decode(l):
    return ''.join([itos[i] for i in l]) 

n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile('./data/shakespeare_char/train.bin')
val_ids.tofile('./data/shakespeare_char/val.bin')

meta = {
            'vocab_size': vocab_size,
            'itos': itos,
            'stoi': stoi,
            }
with open('./data/shakespeare_char/meta.pkl', 'wb') as f:
    pickle.dump(meta, f)

In [None]:
# NOTE: Configuration
dataset = 'shakespeare_char'
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

out_dir = './out'
eval_interval = 200
eval_iters = 1
batch_size = 12
block_size = 64
learning_rate = 6e-4
max_iters = 5000
decay_lr = True
warmup_iters = 100
lr_decay_iters = max_iters
min_lr = learning_rate/10.0
tokens_per_iter = batch_size * block_size
local_iter_num = 0
running_mfu = -1.0
iter_num = 0
best_val_loss = 1e9
device = 'mps'
device_type = 'cpu'
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype='float16')

In [None]:
# NOTE: Define some functions 

# NOTE:++++++++++++++
# NOTE: Build a batch loader
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# NOTE:++++++++++++++
# NOTE: Build a loss estimator
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# NOTE:++++++++++++++
# NOTE: Build a learning rate adjustor
def get_lr(it):
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

In [None]:
# NOTE: Build a model and a optimizer
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
model_args = dict(n_layer=4, n_head=4, n_embd=128, block_size=block_size, bias=False, vocab_size=None, dropout=0.0) 
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.to(device)
scaler = torch.cuda.amp.GradScaler(enabled=False)
optimizer = model.configure_optimizers(1e-1, learning_rate, (0.9, 0.95), device_type)

In [None]:
# NOTE: Training phase begins
if not os.path.exists(out_dir): 
    print("make dir=>{}".format(out_dir))
    os.makedirs(out_dir)
t0 = time.time()
X, Y = get_batch('train')
while True:
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    if iter_num == 0: 
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'model_args': model_args,
            'iter_num': iter_num,
            'best_val_loss': best_val_loss,
            }
        print(f"saving initial checkpoint to {out_dir}")
        torch.save(checkpoint, os.path.join(out_dir, 'ckpt_init.pt'))
    
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(iter_num % eval_interval)
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))

    for micro_step in range(1):
        with ctx:
            logits, loss = model(X, Y)
            loss = loss

        X, Y = get_batch('train')
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % 1 == 0:
        lossf = loss.item()
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = model.estimate_mfu(batch_size, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    if iter_num > max_iters:
        break

In [None]:
# NOTE: Test phase begins
start = "Shakespeare \n"
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

In [None]:
#NOTE: Results output by initial models
state_dict_init = torch.load("./out/ckpt_init.pt", map_location=device)['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict_init.items()):
    if k.startswith(unwanted_prefix):
        state_dict_init[k[len(unwanted_prefix):]] = state_dict_init.pop(k)
model.load_state_dict(state_dict_init)
with torch.no_grad():
        y = model.generate(x, 500, temperature=0.8, top_k=200)
        print('------------------------')
        print(decode(y[0].tolist()))
        print('------------------------')

In [None]:
#NOTE: Results output by final models
state_dict = torch.load("./out/ckpt.pt", map_location=device)['model']
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
with torch.no_grad():
        y = model.generate(x, 500, temperature=0.8, top_k=200)
        print('------------------------')
        print(decode(y[0].tolist()))
        print('------------------------')