# Import data utils

In [1]:
import pandas as pd
from tokenise import TextDataset, pad_sequences
from torch.utils.data import DataLoader, Dataset


df = pd.read_parquet("hf://datasets/gamino/wiki_medical_terms/wiki_medical_terms.parquet")
articles = df.iloc[:, 1]

dataset = TextDataset(articles, model="gpt2", seq_length=512)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=pad_sequences)

for data in dataloader:
    x, y, att = data['input_ids'], data['targets'], data['attention_mask']
    print(x,y,att)
    break

tensor([[11736,   377, 20106,  ..., 28836,   425,  3912],
        [   64,  7751,   291,  ...,   291,  5768,  3264]]) tensor([[  377, 20106,   292,  ...,   425,  3912,   543],
        [ 7751,   291,  1741,  ...,  5768,  3264,   777]]) tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])


# Import models

In [2]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size = dataset.vocab_size
    block_size = 1024
    n_layer = 6
    n_head = 8
    n_embd = 512
    dropout = 0.1
    bias = False

config = GPTConfig()

## Transformer blocks

In [3]:
from customGPT import Block, GPT

## Attention

In [4]:
import torch 
from torch import nn
import torch.nn.functional as F 

import math

In [6]:
from attention import CausalSelfAttention

C_GPT = GPT(CausalSelfAttention, config)
C_GPT

Number of parameters: 44.63M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (qkv_proj): Linear(in_features=512, out_features=1536, bias=False)
          (out_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear

In [7]:
from attention import FlashAttention
F_GPT = GPT(FlashAttention, config)
F_GPT

Number of parameters: 44.65M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): FlashAttention(
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=50257, bias=False)
)

In [8]:
from attention import SparseAttention
S_GPT = GPT(SparseAttention, config)
S_GPT

Number of parameters: 44.65M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): SparseAttention(
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=50257, bias=False)
)

# 5 epoch test

In [9]:
#training dependencies
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from contextlib import nullcontext
import torch.distributed as dist
from torch.distributed import init_process_group, destroy_process_group

import os

In [10]:
# Default config values designed to train a GPT-2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False
always_save_checkpoint = True
init_from = 'scratch'  # 'scratch' or 'resume' or 'gpt2*'

# Data
# data_path = 'data/openwebtext.txt'  # Path to your text file
gradient_accumulation_steps = 5 * 8
batch_size = 2
block_size = 512

# Model
n_layer = 6  # Reduce the number of layers
n_head = 8   # Reduce the number of attention heads
n_embd = 512 # Reduce the embedding size
dropout = 0.0
bias = False

# AdamW optimizer
learning_rate = 6e-4
max_iters = 5
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0

# Learning rate decay settings
decay_lr = True
warmup_iters = 200
lr_decay_iters = 1000
min_lr = 6e-5

# DDP settings
backend = 'gloo'

# System
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
dtype = 'float32'  # MPS currently supports only float32
compile = False  # Disable compilation for now as PyTorch 2.0 is not yet stable on MPS
# -----------------------------------------------------------------------------

# Collect configuration keys
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}

# Various initializations and derived attributes
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'mps:{ddp_local_rank}' if torch.backends.mps.is_available() else f'cuda:{ddp_local_rank}'
    torch.mps.set_device(device)
    master_process = ddp_rank == 0
    seed_offset = ddp_rank
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    master_process = True
    seed_offset = 0
    ddp_world_size = 1

tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"Tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
    os.makedirs(out_dir, exist_ok=True)

torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'mps' if 'mps' in device else 'cpu'
ptdtype = torch.float32  # Currently, MPS supports only float32
ctx = nullcontext()

Tokens per iteration will be: 40,960


In [11]:
#configuration
class GPTConfig:
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, dropout, bias=True):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.bias = bias
        
# Initialize iteration number and best validation loss
iter_num = 0
best_val_loss = 1e9

# Model initialization
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=len(dataset.tokens), dropout=dropout)

if init_from == 'scratch':
    print("Initializing a new model from scratch")
    gptconf = GPTConfig(**model_args)

    #select Attentions
    model = GPT(SparseAttention, gptconf)
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)

if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size

model.to(device)

Initializing a new model from scratch
Number of parameters: 19.18M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(512, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): SparseAttention(
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=512, bias=False)
)

In [None]:
# Initialize a GradScaler
scaler = GradScaler(enabled=(dtype == 'float16' and 'cuda' in device))

# Configure the optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None  # Free up memory

# Wrap model into DDP container if needed
if ddp:
    print(f"Starting parallel process with rank {ddp_rank}")
    model = DDP(model, device_ids=[ddp_local_rank])


# Function to estimate loss over splits
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        with tqdm(total=eval_iters, desc=f"Evaluating {split}") as pbar:
            for k in range(eval_iters):
                batch = next(iter(dataloader))
                X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
                with ctx:
                    logits, loss = model(X, Y)
                losses[k] = loss.item()
                pbar.update(1)
        out[split] = losses.mean()
    model.train()
    return out

# Learning rate decay scheduler
def get_lr(it):
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (learning_rate - min_lr)

In [None]:
import time
from tqdm import tqdm


# wandb.init(
#     # set the wandb project where this run will be logged
#     project=project_name, name= run_name, config=config)

raw_model = model.module if ddp else model
local_iter_num = 0
running_mfu = -1.0
t0 = time.time()

while iter_num <= max_iters:
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # if wandb_log:
        #     wandb.log({
        #         "iter": iter_num,
        #         "train_loss": losses['train'],
        #         "val_loss": losses['val'],
        #         "lr": lr,
        #         "mfu": running_mfu * 100,  # convert to percentage
        #     })


        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    with tqdm(total=gradient_accumulation_steps, desc=f"Training step {iter_num}") as pbar:
        total_train_loss = 0.0  # Initialize total training loss
        for micro_step in range(gradient_accumulation_steps):
            if ddp:
                model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
            batch = next(iter(dataloader))
            X, Y = batch['input_ids'].to(device), batch['targets'].to(device)
            with ctx:
                logits, loss = model(X, Y)
                loss = loss / gradient_accumulation_steps
            scaler.scale(loss).backward()
            pbar.update(1)

    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5:
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt * 1000:.2f}ms, mfu {running_mfu * 100:.2f}%")

        # if wandb_log:
        #     wandb.log({
        #         "iter": iter_num,
        #         "train_loss": total_train_loss,
        #         "lr": lr,
        #         "mfu": running_mfu * 100,  # convert to percentage
        #     })

    iter_num += 1
    local_iter_num += 1

if ddp:
    destroy_process_group()

Evaluating train: 100%|██████████| 200/200 [00:21<00:00,  9.28it/s]
Evaluating val: 100%|██████████| 200/200 [00:19<00:00, 10.47it/s]


step 0: train loss 2.9166, val loss 2.9313


Training step 0: 100%|██████████| 40/40 [00:11<00:00,  3.48it/s]


iter 0: loss 2.9253, time 52281.32ms, mfu -100.00%


Training step 1: 100%|██████████| 40/40 [00:11<00:00,  3.48it/s]


iter 1: loss 3.0903, time 11924.31ms, mfu -100.00%


Training step 2: 100%|██████████| 40/40 [00:12<00:00,  3.12it/s]


iter 2: loss 2.5330, time 13175.06ms, mfu -100.00%


Training step 3: 100%|██████████| 40/40 [00:11<00:00,  3.36it/s]


iter 3: loss 2.2273, time 12327.11ms, mfu -100.00%


Training step 4: 100%|██████████| 40/40 [00:10<00:00,  3.73it/s]


iter 4: loss 3.6481, time 11076.78ms, mfu -100.00%


Training step 5: 100%|██████████| 40/40 [00:11<00:00,  3.43it/s]


iter 5: loss 2.1135, time 11989.06ms, mfu 0.15%
