In [1]:
from src.model import GPT,Config
from src.trainer import Trainer
import pandas as pd
import os
import torch
import numpy as np
import tiktoken
import multiprocessing as mp
import time

  _C._set_float32_matmul_precision(precision)


In [2]:
logpath = './log'
DATASET_PATH = './data/tinystories'
SEED = 42

In [3]:

class DataLoaderLite:

    def __init__(self, B, T, process_rank, num_processes, split='train'):
        super().__init__()
        self.B, self.T = B, T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}
        
        # get the shard filenames
        data_root = "./data/tinystories"
        shard_filenames = os.listdir(data_root)
        shard_filenames = sorted([filename for filename in shard_filenames if split in filename])
        self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames]
        assert len(self.shard_filepaths) > 0, f'no shards found for split {split}'
        master_process = process_rank == 0
        if master_process:
            print(f'found {len(self.shard_filepaths)} shards for split {split}')
        self.reset()

    def load_tokens(self, filepath):
        tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long)
        return tokens

    def reset(self):
        # state, init at shard 0
        self.curr_shard = 0
        self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
        self.curr_pos = self.B * self.T * self.process_rank

    def next_batch(self):
        B, T = self.B, self.T
        batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1]
        x_batch = batch[:-1].view(B, T)
        y_batch = batch[1:].view(B, T)
        self.curr_pos += B * T * self.num_processes
        if self.curr_pos + (B * T + 1) > len(self.tokens):
            self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths)
            self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
            self.curr_pos = self.B * self.T * self.process_rank
        return x_batch, y_batch

In [4]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

device_type = 'cuda' if device.startswith('cuda') else 'cpu'
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

master_process = True

In [5]:
MINI_BATCH_SIZE = 4
CTX_LENGTH = 2048
NUM_HEADS = 8
NUM_LAYERS = 10
EMBED_DIM = 768
WEIGHT_DECAY =0.1
MAX_LR = 1e-3
MIN_LR = 1e-3*0.1
EVAL_FREQ = 250
MAX_STEPS = 1000
WARMUP_STEPS = 715

In [6]:
grad_accum_steps = 32

In [7]:
train_loader = DataLoaderLite(B=MINI_BATCH_SIZE, T=CTX_LENGTH, process_rank=0, num_processes=1, split='train')
val_loader = DataLoaderLite(B=MINI_BATCH_SIZE, T=CTX_LENGTH, process_rank=0, num_processes=1, split='val')

found 4 shards for split train
found 1 shards for split val


In [8]:
gpt_config = Config(vocab_size=50304,  # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token = 50257, 
                    # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number)
                           context_length=CTX_LENGTH, 
                           num_layers=NUM_LAYERS, 
                           num_heads=NUM_HEADS, 
                           embedding_dim=EMBED_DIM
                           )

model = GPT(gpt_config)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of trainable parameters: {total_params:,}')
model.to(device)
# model = torch.compile(model)
optimizer = model.configure_optimizer(weight_decay=WEIGHT_DECAY,lr=MAX_LR,device_type=device_type,master_process=master_process)
token_encoder = tiktoken.get_encoding('gpt2')


Total number of trainable parameters: 111,086,592
num decay parameter tensors: 42 with 110,985,216 parameters
num nodecay parameter tensors: 82 with 101,376 parameters
using fused AdamW optimizer: True


In [9]:
start_time = time.time()
trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, EVAL_FREQ, grad_accum_steps, device,master_process, logpath)
trainer.train(MAX_STEPS, WARMUP_STEPS, MAX_LR, MIN_LR)
dt = (time.time() - start_time) / (60*60)
print(f"Total training time: {dt:.4f}hr")

Val loss: 10.9669
step    0 | loss: 10.965103 | lr: 1.40e-06 | norm: 21.0113 | dt: 9939.6815ms | tok/sec: 26.3735
step    1 | loss: 10.846848 | lr: 2.80e-06 | norm: 21.3282 | dt: 7599.0839ms | tok/sec: 34.4968
step    2 | loss: 10.654166 | lr: 4.20e-06 | norm: 19.0643 | dt: 7562.1984ms | tok/sec: 34.6651
step    3 | loss: 10.432758 | lr: 5.59e-06 | norm: 14.6875 | dt: 7711.6253ms | tok/sec: 33.9934
step    4 | loss: 10.152704 | lr: 6.99e-06 | norm: 11.8371 | dt: 7781.1313ms | tok/sec: 33.6897
step    5 | loss: 9.975197 | lr: 8.39e-06 | norm: 9.1772 | dt: 7794.3084ms | tok/sec: 33.6327


KeyboardInterrupt: 