In [1]:
import math
import time
import os
import wget

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import tiktoken

import torch 
import torch.nn as nn 
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.distributed import init_process_group, destroy_process_group

from dataclasses import dataclass
from datasets import load_dataset

from GPT2.hellaswag import render_example, iterate_examples
from GPT2.gpt2_functions import *


torch.manual_seed(42)
g = torch.Generator().manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Distributed data parallel (DDP)

# torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
ddp = int(os.environ.get('RANK', -1)) != -1 # Is this a ddp run?
if ddp:
    # DDP requieres CUDA
    assert torch.cuda.is_available(), "We need CUDA for DDP"
    init_process_group(backend='nccl')

    ddp_rank       = int(os.environ['RANK'])       # Global GPU index
    ddp_local_rank = int(os.environ['LOCAL_RANK']) # Current node GPU index
    ddp_world_size = int(os.environ['WORLD_SIZE']) # Total number GPUs across nodes
    
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    # Choose process: logging, checkpointing...
    master_process = ddp_rank == 0 
else:
    # vanilla, non-DDP run
    ddp_rank       = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    # attempt to autodetect device
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"using device: {device}")

# added after video, pytorch can be serious about it's device vs. device_type distinction
device_type = "cuda" if device.startswith("cuda") else "cpu"



using device: cuda


In [3]:
# Reproducibility
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

# Define precision
torch.set_float32_matmul_precision("high")

# Hyperparameters
total_batch_size = 524288 # 2**19 (close to 0.5M)
B                = 16     # Micro Batch Size (use gradient accumulation)
T                = 1024   # Sequence Length
vocab_size       = 50304 
max_lr           = 6e-4
min_lr           = max_lr * 0.1
warmup_steps     = 715
max_steps        = 19073 # 19,073 steps is ~1 epoch, if data is 10B tokens and batch size 0.5M tokens
weight_decay     = 0.1
learning_rate    = 6e-4

# Grad accum step
assert total_batch_size % (B * T * ddp_world_size) == 0, "Make sure total batch size is div by B * T * ddp_world_size"
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)

# Encoder
enc = tiktoken.get_encoding("gpt2")

# Loader functions
train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
val_loader   = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val"  )

# Model
model = GPT2(GPT2Config(vocab_size=vocab_size))
model.to(device)    # Move to device
use_compile = False # torch.compile infers with HellaSwag
if use_compile:
    model = torch.compile(model)
if ddp:
    model = DDP(model, device_ids = [ddp_local_rank])
raw_model = model.module if ddp else model 

# Optimize
optimizer = raw_model.configure_optimizers(
    weight_decay  = weight_decay,
    learning_rate = learning_rate,
    device_type   = device_type
)

# print
if master_process: 
    print(f"total desired batch size: {total_batch_size}")
    print(f"-> Calculated grad acc steps is {grad_accum_steps}")



AssertionError: no shards found for split train

In [None]:
# Log directory
log_dir = "log"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"log.txt")
with open(log_file, "w") as f: # open for writing to clear the file
    pass

for step in range(max_steps):
    t0 = time.time()
    last_step = (step == max_steps - 1)

    # Every 250 steps evaluate validation loss for 20 steps
    if step % 250 == 0 or last_step:
        model.eval()
        val_loader.reset()
        with torch.no_grad():
            val_loss_accum = 0.0
            val_loss_steps = 20
            for _ in range(val_loss_steps):
                x, y = val_loader.next_batch()
                x, y = x.to(device), y.to(device)
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(x, y)
                loss = loss / val_loss_steps
                val_loss_accum += loss.detach()
        if ddp:
            dist.all_reduce(val_loss_accum, op = dist.ReduceOp.AVG)
        if master_process:
            print(f"validation loss: {val_loss_accum.item():.4f}")
            with open(log_file, "a") as f:
                f.write(f"{step} val {val_loss_accum.item():.4f}\n")
            if step > 0 and (step % 5000 == 0 or last_step):
                # optionally write model checkpoints
                checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.pt")
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'config': raw_model.config,
                    'step': step,
                    'val_loss': val_loss_accum.item()
                }
                torch.save(checkpoint, checkpoint_path)

    # Every 250 steps evaluate hellaswag
    if (step % 250 == 0 or last_step) and (not use_compile):
        num_correct_norm = 0
        num_total = 0
        for i, example in enumerate(iterate_examples("val")):
            # only process examples where i % ddp_world_size == ddp_rank
            if i % ddp_world_size != ddp_rank:
                continue
            # render the example into tokens and labels
            _, tokens, mask, label = render_example(example)
            tokens = tokens.to(device)
            mask = mask.to(device)
            # get the logits
            with torch.no_grad():
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(tokens)
                pred_norm = get_most_likely_row(tokens, mask, logits)
            num_total += 1
            num_correct_norm += int(pred_norm == label)
        # reduce the stats across all processes
        if ddp:
            num_total = torch.tensor(num_total, dtype=torch.long, device=device)
            num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
            dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
            dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
            num_total = num_total.item()
            num_correct_norm = num_correct_norm.item()
        acc_norm = num_correct_norm / num_total
        if master_process:
            print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
            with open(log_file, "a") as f:
                f.write(f"{step} hella {acc_norm:.4f}\n")

    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        # we have to scale the loss to account for gradient accumulation,
        # because the gradients just add on each successive backward().
        # addition of gradients corresponds to a SUM in the objective, but
        # instead of a SUM we want MEAN. Scale the loss here so it comes out right
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()
    if ddp:
        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # determine and set the learning rate for this iteration
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()
    if device_type == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work
    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
    tokens_per_sec = tokens_processed / dt
    if master_process:
        print(f"step {step:5d} | loss: {loss_accum.item():.6f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
        with open(log_file, "a") as f:
            f.write(f"{step} train {loss_accum.item():.6f}\n")

if ddp:
    destroy_process_group()