In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import math
from model import Model_args, GPT
import time
import lzma


In [2]:
#Model parameters
block_size = 128
batch_size = 32
n_layer = 12
n_head = 12
n_embed = 768
bias = False
dropout = 0.1
dataset_path = '/path/to/your/data'
init_from = 'scratch'
checkpoint_save_dir = '/path/to/checkpoints'
eval_iters = 200
eval_interval = 10000

#Learning rate decay
learning_rate = 2e-4
warmup_iters = 2000
lr_decay_iters = 100000  #num of iterations to decay the learning rate
min_lr = 6e-5

#Optimizer parameters
max_iters = 300000
weight_decay = 1e-1
betas = (0.9, 0.95)
grad_clip = 1.0


In [3]:
#System settings
device = 'cuda'
device_type = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'

ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)

print(device)

cuda


In [None]:
#Aggregate data from all subfolders
'''
This block is used when your data is in different folders.
I use this because when I tried to use openwebtext data, 
putting all the data in one folder caused kernel to shut down unexpectedly.
'''
def aggregate_data(split):
    data = []
    file_count = 0
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file == f'{split}.bin':
                file_path = os.path.join(root, file)
                file_data = np.memmap(file_path, dtype=np.uint16, mode='r')
                data.append(file_data)
                file_count += 1
    concatenated_data = np.concatenate(data)
    print(f"Aggregated {file_count} files for {split} with total size: {concatenated_data.shape[0]}")
    return concatenated_data

train_data = aggregate_data('train')
val_data = aggregate_data('val')

def get_batch(split):
    if split == 'train':
        data = train_data
    else:
        data = val_data

    ix = torch.randint(len(data) - block_size, (batch_size,))
            #Randomly select start indices for each sequence
    x = torch.stack([torch.from_numpy(data[i:i + block_size].astype(np.int64)) for i in ix])
            #Create input sequences
    y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + block_size].astype(np.int64)) for i in ix])
            #Create target sequences
    x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
            #Pass to GPU
    return x, y

model_args = dict(n_layer=n_layer, n_head=n_head, n_embed=n_embed, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout)

iter_num = 0
best_val_loss = 1e9

assert init_from == 'scratch' or init_from == 'resume'    #two diffrent task, train from scratch or finetune
if init_from == 'scratch':
    print("Training model from scratch")
    model_args['vocab_size'] = 50304
    gpt_args = Model_args(**model_args)
    model = GPT(gpt_args)
elif init_from == 'resume':
    print("Resuming training")
    ckpt_path = os.path.join(checkpoint_save_dir, 'checkpoint.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    for k in ['n_layer', 'n_head', 'n_embed', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    gpt_args = Model_args(**model_args)
    model = GPT(gpt_args)
    state_dict = checkpoint['model']
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

model.to(device)
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas, device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None

def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def get_lr(now_iter):
    if now_iter < warmup_iters:
        return learning_rate * now_iter / warmup_iters
    elif now_iter > lr_decay_iters:
        return min_lr
    else:
        rate = (now_iter - warmup_iters) / (lr_decay_iters - warmup_iters)
        return min_lr + 0.5 * (1.0 + math.cos(math.pi * rate)) * (learning_rate - min_lr)

X, Y = get_batch('train')
t_before = time.time()

while True:
    lr = get_lr(iter_num)     #get current learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    if iter_num > 0 and iter_num % eval_interval == 0:
        loss_dict = estimate_loss()
        print(f"Iteration {iter_num}, train loss: {loss_dict['train']}, val loss: {loss_dict['val']}")
        best_val_loss = min(loss_dict['val'], best_val_loss)
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'model_args': model_args,
            'iter_num': iter_num,
            'best_val_loss': best_val_loss
        }
        torch.save(checkpoint, os.path.join(checkpoint_save_dir, 'checkpoint.pt'))   #save checkpoint
        print(f"Checkpoint saved at {checkpoint_save_dir}/checkpoint.pt")
    
    with ctx:
        logits, loss = model(X, Y)
        #print(f"Iteration {iter_num}, loss: {loss.item()}")
            #to notify and stop training when occurs vanishing/exploding gradients.
        if torch.isnan(loss) or torch.isinf(loss):       
            print("Loss is NaN or Inf. Stopping training.")
            break
        if iter_num % 1000 == 0:
            print(f"Iteration {iter_num}, loss: {loss.item()}")
        scaler.scale(loss).backward()

    if grad_clip > 0.0:
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    
    scaler.step(optimizer)
    scaler.update()

    optimizer.zero_grad(set_to_none=True)

    t_after = time.time()
    dt = t_after - t_before
    t_before = t_after

    iter_num += 1
    if iter_num > max_iters:
        break
