<a href="https://colab.research.google.com/github/ashishvinodkumar/GPT2/blob/main/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# !pip3 install tiktoken
# !pip3 install torch
# !pip3 install transformers

# Training Workflow

### Import Packages

In [6]:
from gpt2 import GPT2, GPT2Config, DataLoaderLite
import torch
from torch.nn import functional as F
import time
import math

### Set Device & Args

In [7]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
print(f'Using Device: {device}')

num_return_sequences = 5
max_length = 30

Using Device: cuda


In [8]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

### Prepare Dataset

In [9]:
total_batch_size = 524288 # 2**19, ~0.5M, as per GPT paper.
B = 16 # Micro Batch Size
T = 1024 # Max Sequence Length
assert total_batch_size % (B*T) == 0, "Total Batch Size must be divisible by B*T"
grand_accum_steps = total_batch_size // (B*T)

print(f'Total Desired Batch Size: {total_batch_size}')
print(f'Grand Accumulate Steps: {grand_accum_steps}')

input_text = './data/input.txt'
train_loader = DataLoaderLite(B=B, T=T, input_text=input_text)

Total Desired Batch Size: 524288
Grand Accumulate Steps: 32
Loaded 338025 tokens
1 Epoch = 20 batches


### Initialize Model

In [10]:
# Set precision to TF32 when available. Will speed up total performance.
# TF32 will reduce the decimal precision.
torch.set_float32_matmul_precision('high')

In [11]:
# Initialize model
model = GPT2(GPT2Config(vocab_size=50304)) # Initializing with random weights. Not using HF model.
model.to(device)
model = torch.compile(model)

# Cosine decay learning rate with warm-up.
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50

def get_lr(it):
  # Linear warmp for warm_iter steps
  if it < warmup_steps:
    return max_lr * (it+1) / warmup_steps
  if it > max_steps:
    return min_lr
  decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
  assert 0.0 <= decay_ratio <= 1.0
  coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
  return min_lr + coeff * (max_lr - min_lr)

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, betas=(0.9, 0.95), device_type=device)

for step in range(max_steps):
    t0 = time.time()
    optimizer.zero_grad()
    loss_accum = 0.0

    for micro_step in range(grand_accum_steps):
        x, y = train_loader.next_batch()
        x = x.to(device)
        y = y.to(device)

        with torch.autocast(device_type=device, dtype=torch.bfloat16):
          logits, loss = model(x, y)
        loss = loss / grand_accum_steps
        loss_accum += loss.detach()
        loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # determine and set the learning rate for this iteration
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    optimizer.step()
    torch.cuda.synchronize() # Wait for gpu to finish work.
    t1 = time.time()
    dt = round((t1 - t0)*1000, 3) # time difference in ms.
    tokens_per_second = round((train_loader.B * train_loader.T * grand_accum_steps) / (t1-t0), 3)
    print(f'step: {step} | loss: {loss_accum.item():.4e} | lr: {lr:.4e} | norm: {norm:.4f} | dt: {dt} | tokens/sec: {tokens_per_second}')

num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step: 0 | loss: 1.0939e+01 | lr: 6.0000e-05 | norm: 27.0126 | dt: 35394.091 | tokens/sec: 14812.868
step: 1 | loss: 9.6493e+00 | lr: 1.2000e-04 | norm: 9.5178 | dt: 2783.526 | tokens/sec: 188353.879
step: 2 | loss: 9.2256e+00 | lr: 1.8000e-04 | norm: 5.7292 | dt: 2781.923 | tokens/sec: 188462.421
step: 3 | loss: 9.8131e+00 | lr: 2.4000e-04 | norm: 8.2066 | dt: 2781.04 | tokens/sec: 188522.266
step: 4 | loss: 9.1916e+00 | lr: 3.0000e-04 | norm: 4.2994 | dt: 2781.193 | tokens/sec: 188511.891
step: 5 | loss: 8.6780e+00 | lr: 3.6000e-04 | norm: 3.6285 | dt: 2783.095 | tokens/sec: 188383.069
step: 6 | loss: 8.2950e+00 | lr: 4.2000e-04 | norm: 1.9535 | dt: 2781.343 | tokens/sec: 188501.759
step: 7 | loss: 8.0680e+00 | lr: 4.8000e-04 | norm: 2.8519 | dt: 2781.697 | tokens/sec: 188477.783
step: 8 | loss: 7.7142e+00 | lr: 5.4000e-04 | norm: 1.9108