In [1]:
import torch
import math
import time
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass
import tiktoken
import inspect
import os
import numpy as np
import json
torch.manual_seed(42)
torch.cuda.manual_seed(42)


In [2]:
@dataclass
class GPTConfig:
    vocab_size: int = 50304
    block_size: int = 1024
    n_head: int = 12
    n_layer: int = 12
    n_embd: int = 768


In [3]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.c_proj = nn.Linear(self.n_embd, self.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size)).view(1,1,config.block_size, config.block_size))

    def forward(self, x):
        B,T,C = x.size() #x.shape B,T,C
        qkv = self.c_attn(x) # B,T, C*3
        q, k, v = qkv.split(self.n_embd, dim = 2)
        
        q = q.view(B,T, self.n_head, C // self.n_head).transpose(1, 2) # B, n_head, T, C
        k = k.view(B,T, self.n_head, C // self.n_head).transpose(1, 2) # B, n_head, T, C
        v = v.view(B,T, self.n_head, C // self.n_head).transpose(1, 2) # B, n_head, T, C
        # attn_weights = (q @ k.transpose(-2, -1)) * (1.0/math.sqrt(q.shape[-1]))
        # attn_weights = attn_weights.masked_fill(self.bias[:,:, :T, :T] == 0, float('-inf'))
        # attn_weights = F.softmax(attn_weights, dim = -1)

        # attn_out = attn_weights @ v
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
        
        attn_out = attn_out.transpose(1,2).contiguous().view(B,T,C)
        proj_out = self.c_proj(attn_out)
        return proj_out
        

In [4]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd)
        self.gelu = nn.GELU(approximate = 'tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        
    def forward(self, x):
        #x.shape is B,T,C
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x        
        

In [5]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        #x.shape is B,T,C
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x        
        

In [6]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias = False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02*(2*(self.config.n_layer**-0.5)))
            else:
                torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)
    
    def forward(self, idx, targets = None):
        #idx is of shape B,T
        T = idx.shape[1] if idx.shape[1] <= self.config.block_size else self.config.block_size
        word_embd = self.transformer.wte(idx)     
        pos_embd = self.transformer.wpe(torch.arange(T, dtype = torch.long, device = idx.device))
        x = word_embd + pos_embd
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def configure_optimizer(self, model, learning_rate, weight_decay, device):
        param_dict = {name: param for name, param in model.named_parameters()}
        trainable_param_dict = {name: param for name, param in model.named_parameters() if param.requires_grad}

        decay_params = [param for name, param in trainable_param_dict.items() if param.dim() >= 2]
        non_decay_params = [param for name, param in trainable_param_dict.items() if param.dim()   < 2]
        optim_groups = [
            {
                "params": decay_params,
                "weight_decay": weight_decay
            },
            {
                "params": non_decay_params,
                "weight_decay": 0.0
            }
        ]
        total_decayed_params = sum([p.numel() for p in decay_params])
        total_non_decayed_params = sum([p.numel() for p in non_decay_params])
        print(f"Decayed parameters : {total_decayed_params} | Non decayed params: {total_non_decayed_params}")
        contains_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        is_cuda = device == 'cuda'
        fused = contains_fused and is_cuda
        optimizer = torch.optim.AdamW(optim_groups, lr = learning_rate, betas = (0.9, 0.95), eps = 1e-8, fused = fused)
        return optimizer

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(100)
if torch.cuda.is_available():
    torch.cuda.manual_seed(100)

In [8]:
#min(self.B, math.ceil((end_index-self.curr_start)/self.T))

class DataLoaderLite:
    def __init__(self, B, T, rank, world_size, num_shards, split = 'train'):
        self.B = B
        self.T = T
        self.rank = rank
        self.world_size = world_size
        self.shard_idx = 1
        self.num_shards = num_shards
        self.file = f"./data/fineweb_10BT/shard_{split}"
        self.enc_text = self.load_shard(self.shard_idx)
        self.curr_start = self.calculate_initial_start()

    def calculate_initial_start(self):
        return self.B * self.T * self.rank

    def load_shard(self, shard_idx):
        return np.load(self.file + str(shard_idx) + ".npy").astype(np.uint16)
    
    def next_batch(self):
        end_index = self.curr_start + self.B * self.T
        data = torch.tensor(self.enc_text[self.curr_start: end_index], dtype = torch.long).view(self.B, self.T)
        labels = torch.tensor(self.enc_text[self.curr_start + 1: end_index + 1], dtype = torch.long).view(self.B,self.T)
        self.curr_start = self.curr_start + self.B * self.T * self.world_size
        if ((self.curr_start + self.B * self.T * self.world_size + 1) > len(self.enc_text)):
            self.shard_idx = (self.shard_idx + 1) % self.num_shards
            self.enc_text = self.load_shard(self.shard_idx)
            self.curr_start = (self.B * self.T) * self.rank
        return data, labels

    def reset(self):
        self.curr_start = self.calculate_initial_start()
                

In [9]:
import os
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist

isDDP = int(os.environ.get("RANK", -1)) != -1
if isDDP:
    dist.init_process_group(backend = 'nccl')
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    device = 'cuda:' + str(local_rank)
    isMaster = rank == 0
else:
    rank = 0
    local_rank = 0
    world_size = 1
    isMaster = True
    

In [10]:
model = GPT(GPTConfig())
model.train()
model.to(device)
model = torch.compile(model)
if isDDP:
    model = DDP(model, device_ids = [local_rank])
    raw_model = model.module
else:
    raw_model = model

In [11]:
max_steps = 100
warmup_steps = 10
max_lr = 6e-4
min_lr = 0.1 * max_lr
def get_lr(epoch):
    if(epoch <= warmup_steps):
        return max_lr * ep/warmup_steps
    if(epoch > max_steps):
        return min_lr
        
    curr_ratio = (epoch - warmup_steps)/(max_steps - warmup_steps)
    coeff = 0.5 * (1 + math.cos(math.pi * curr_ratio))
    lr = min_lr + coeff * (max_lr - min_lr)
    return lr
    
    

In [12]:

def calculate_and_propagate_loss(x, y, acc_loss):
    with torch.autocast(device_type = device, dtype = torch.bfloat16):
        logits, loss = model(x,y)
        loss = loss / acc_steps
    acc_loss += loss.detach() 
    loss.backward()
    return logits, loss, acc_loss

In [13]:
hellaswag_data_url = 'https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl'
hellaswag_data_location = './data/hellaswag/val_data.jsonl'

In [14]:
def download(url, file):
    req = requests.get(url, stream = True)
    with open(file, 'wb') as f:
        for chunk in req.iter_content(chunk_size = 16):
            f.write(chunk)
            

In [15]:
if not os.path.isfile(hellaswag_data_location):
    download(hellaswag_data_url, hellaswag_data_location)

with open(hellaswag_data_location, 'r') as f:
    hellaswag_val_data = f.readlines()

In [16]:
def create_tensor_example_batch(example):
    ctx = example['ctx']
    label = example['label']
    tokenizer = tiktoken.get_encoding('cl100k_im')
    ctx_enc = tokenizer.encode(ctx)
    endings_enc = [tokenizer.encode(ending) for ending in  example['endings']]
    examples_enc = [ctx_enc + end_enc for end_enc in endings_enc]
    max_len = max(len(ex_enc) for ex_enc in examples_enc)
    padding_lens = [max_len - len(ex_enc) for ex_enc in examples_enc]
    padded_examples_enc = [ex_enc + [0]*pad_len for ex_enc, pad_len in zip(examples_enc, padding_lens)]
    examples_tensor = torch.tensor(padded_examples_enc, dtype = torch.long)
    
    padding_mask = torch.ones((4, max_len))
    for i, pad_len in enumerate(padding_lens):
        padding_mask[i, -pad_len:] = 0

    ending_mask = torch.zeros(4, max_len)
    for i, end_len in enumerate([len(end_enc) for end_enc in endings_enc]):
        ending_mask[i, len(ctx_enc): len(ctx_enc)+end_len] = 1

    return examples_tensor, padding_mask, ending_mask, label

In [None]:
epochs = 19074 #(10^10 / 2**19 i.e. dataset size/batch size)
weight_decay = 0.1
B = 64
T = 1024
total_batch_size = 2**19
acc_steps = total_batch_size // (B*T*world_size)
train_dataloader = DataLoaderLite(16, 1024, rank, world_size, 9, 'train')
val_dataloader = DataLoaderLite(16, 1024, rank, world_size, 9, 'val')
model_checkpoints_dir = 'checkpoints/'

#optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)
optimizer = raw_model.configure_optimizer(6e-4, weight_decay = 0.1)

torch.set_float32_matmul_precision('high')


encoder = tiktoken.get_encoding("gpt2")
generator_feed_tokens = torch.tensor(encoder.encode("Hello, I'm a language model,"), dtype = torch.long)
num_rand_generation = 5
generator_feed_tokens = generator_feed_tokens.unsqueeze(0).repeat(num_rand_generation,1).to(device)
generator_max_length = 50
checkpoint_dir = 'checkpoints/'
for ep in range(epochs):
    last_step = ep == epochs - 1
    if (ep > 0) and (last_step or (ep % 500) == 0):
        eval_steps = 10
        model.eval()
        val_dataloader.reset()
        with torch.no_grad():
            for val_step in range(eval_steps):
                val_x, val_y = val_dataloader.next_batch()
                val_x, val_y = val_x.to(device), val_y.to(device)
                val_loss_acc = 0.0
                with torch.autocast(device_type = device, dtype = torch.bfloat16):
                    _, val_loss = model(val_x, val_y)   
                    val_loss = val_loss / eval_steps
                val_loss_acc += val_loss.detach()
            if isDDP:
                torch.distributed.all_reduce(val_loss_acc, op = dist.ReduceOp.AVG)

        if ((not isDDP) or (isMaster)):
            print(f"Epoch : {ep + 1} | Val Loss: {val_loss_acc.item():.4f}")

    if (last_step or (ep % 500) == 0):
        hl_total_correct = 0
        hl_total_elem = 0
        for i, example in enumerate(hellaswag_val_data):
            if ((i - rank) % world_size == 0) :
                json_ex = json.loads(example)
                example_tensor, padding_mask, ending_mask, label = create_tensor_example_batch(json_ex)
                example_tensor, padding_mask, ending_mask = example_tensor.to('cuda'), padding_mask.to('cuda'), ending_mask.to('cuda')
                logits,_ = model(example_tensor)
                probs = F.softmax(logits, dim = 2)
                probs = probs[:, :-1, :].contiguous()
                y = example_tensor[:, 1:].contiguous()
                ending_mask = ending_mask[:, 1:].contiguous()
                probs = probs.view(-1, probs.shape[-1])
                y = y.view(-1)
                loss = F.cross_entropy(probs, y, reduction = 'none')
                loss = loss.view(4, -1)
                loss_endings = loss * ending_mask
                loss_sum = torch.sum(loss_endings, dim = 1)
                loss_avg = loss_sum / torch.sum(ending_mask, dim = 1)
                pred_label = torch.argmin(loss_avg).item()
                hl_total_correct += int(pred_label == label)
                hl_total_elem += 1
        if isDDP:
            torch.distributed.all_reduce(hl_total_correct, op = dist.ReduceOp.SUM)
            torch.distributed.all_reduce(hl_total_elem, op = dist.ReduceOp.SUM)
        
        if isMaster:
            print(f"Hellaswag accurancy: {(hl_total_correct / hl_total_elem)* 100}")

    if (last_step or (ep % 500) == 0):
        generated_tokens = torch.clone(generator_feed_tokens)
        for i in range(generator_max_length):
            with torch.no_grad():
                gen_logits, _ = model(generated_tokens)
                gen_logits = gen_logits[:, -1, :] # B, C
                gen_probs = F.softmax(gen_logits, dim = -1)
                topk_probs, topk_indices = torch.topk(gen_probs, 50, dim = -1) # B,50 B,50
                ix = torch.multinomial(topk_probs, 1) #B,1
                tokens_ix = torch.gather(topk_indices, 1, ix)
                generated_tokens = torch.cat((generated_tokens, tokens_ix), dim = -1)
                
        for i in range(num_rand_generation):
            print(f"Epoch : {ep + 1} | Rank: {rank} | Generated Text: {encoder.decode(generated_tokens.detach().cpu().numpy()[i])}")

    t0 = time.time()
    optimizer.zero_grad()
    lr = get_lr(ep+1)
    for param in optimizer.param_groups:
        param['lr'] = lr
    acc_loss = 0.0
    total_tokens_processed = 0
    for step in range(acc_steps):
        x,y = train_dataloader.next_batch()
        B,T = x.shape
        total_tokens_processed += (B*T)
        x,y = x.to(device), y.to(device)
        if isDDP and (step != (acc_steps - 1)):
            with ddp.no_sync():
                logits, loss, acc_loss = calculate_and_propagate_loss(x,y, acc_loss)
        else:
            logits, loss, acc_loss = calculate_and_propagate_loss(x,y, acc_loss)
    if isDDP:
        torch.distributed.all_reduce(acc_loss, op = dist.ReduceOp.AVG)
        torch.distributed.all_reduce(total_tokens_processed, op = dist.ReduceOp.SUM)
    optimizer.step()
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000
    tk_sec = total_tokens_processed/(t1 - t0)
    if ((not isDDP) or (isMaster)):
        print(f"Epoch : {ep + 1} | Loss: {acc_loss.item():.4f} | Time taken: {dt:.2f} milli sec | Tokens per sec: {tk_sec:.2f}")

    if (ep > 0) and (last_step or (ep % 5000) == 0):
        checkpoint_dict = {
            'model_dict': raw_model.state_dict(),
            'config': raw_model.config,
            'val_loss': val_loss_acc.item(),
            'step': ep
        }
        checkpoint_file_name = checkpoint_dir + f"model_epoch_{ep:05d}.pt"
        torch.save(checkpoint_dict, checkpoint_file_name) 
    
    if isDDP:
        destroy_process_group()
