## Section 3: algorithmic improvements to GPT2

The paper for gpt2 or code (inference code) doesn't talk too much about the algorithmic details or hyperparameters being used. So we refer to elements of gpt3 paper, since their architecture is very similar

Gpt2: less details, open weights <br>
Gpt3: more details, no weights

Key differences:<br>
- context length: 1024 vs 2048
- gpt3: trained for lot longer, on bigger dataset, more validation
- 1.6b vs 175b parameters

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import tiktoken
import time
import math

from sec2 import GPT, GPTConfig, DataLoaderLite ,get_device


  from .autonotebook import tqdm as notebook_tqdm


Refer [gpt3 paper](https://arxiv.org/pdf/2005.14165): __Section B - Details of model training__

In actual gpt3: ~__300 billion__ training tokens used, we use much lesser.

1. AdamW parameters: $\beta = $[0.99, 0.95], $\epsilon = 10^{-8}$

2. Gradient clipping at $\text{max norm} = 1.0 $
    - how it works:  scales all grads by $\frac{\text{max norm}}{\text{norm}}$ if norm > 1
    - why: to prevent _shocks_ in model training, in case when batch sampling is too unusual
    - useful to track it during training (is grad norm $\uparrow \text{or} \downarrow$ abnormally etc)

3. Cosine decay learning rate with warmup
    - handwritten _here_, plug and play function also available in pytorch
    - Other schedules can be used (active research area) 

4. Changing batch size (not implemented here)
    - Starts with 32k tokens per batch and increase linearly up to 4-12 billions tokens 
    - Gain is incremental anyway and complicates the arithmatic, so skipped _here_

5. weight decay, Fused AdamW 
    - weight decay for 2D tensors (matmuls and embeddings) and not for biases or layernorms (1D)
    - [Intuition behind weight decay](https://towardsdatascience.com/weight-decay-and-its-peculiar-effects-66e0aee3e7b8/): regularization, prevent individual weight becoming too big
    - [Fused adamw](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html) introduces fused kernels in later versions of pytorch which reduces some overheads when `device = 'cuda'` 

In [2]:
max_lr = 6e-4 # from gpt3 small 
min_lr = max_lr * 0.1
max_steps = 50
warmup_steps = 10

def get_lr(it):
    
    # 1. for warmup phase scale
    if it < warmup_steps:
        return max_lr * (it+1)/max_steps
    
    #2. for it > max_steps
    if it> max_steps:
        return min_lr

    # for in between: cosine decay with linear increment
    decay_ratio = (it - warmup_steps)/ (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)


In [3]:
device = get_device()

model = GPT(GPTConfig(vocab_size=50304)) #- random weights init
model.eval()
model.to(device)
# model = torch.compile(model) # compiles the model 

# B,T = 16,1024
B,T = 4,32
train_loader = DataLoaderLite(B,T)

# torch.set_float32_matmul_precision('high') -- old api, soon to be deprecated 
torch.backends.fp32_precision = "tf32" # new api, use "ieee" to enforce global fp32 precision 


# optimizer with weight decay + fused kernels. 
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device = device)

#training loop
for step in range(max_steps):
    t0 = time.time()
    x,y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()

    # single line autocast
    with torch.autocast(device_type = device, dtype = torch.bfloat16):
        logits, loss = model(x,y)

    loss.backward()

    #grad clipping
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    #determine and set learning rate for this iteration 
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    optimizer.step()
    torch.cuda.synchronize() # allow gpu bandwidth to catch up and clear the queue of operations
    t1 = time.time()
    dt = (t1-t0)*1000 # time difference in milliseconds
    tokens_per_sec = (train_loader.B * train_loader.T )/(t1-t0)
    print(f"step {step}: loss = {loss.item()} | lr = {lr:.2f} | norm = {norm:.4f} | dt: {dt:.2f} ms | {tokens_per_sec:.2f} tokens/sec")


#verify total no of parameters
total_params_M = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Total parameters: {total_params_M:.2f}M")

using device: cuda
Total tokens in dataset = 338025
Max batches in 1 epoch = 2640
num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step 0: loss = 10.84716796875 | lr = 0.00 | norm = 45.0423 | dt: 350.33 ms | 365.36 tokens/sec
step 1: loss = 10.10595703125 | lr = 0.00 | norm = 26.0900 | dt: 175.30 ms | 730.17 tokens/sec
step 2: loss = 9.3475341796875 | lr = 0.00 | norm = 18.2147 | dt: 175.70 ms | 728.53 tokens/sec
step 3: loss = 9.45416259765625 | lr = 0.00 | norm = 15.0661 | dt: 174.02 ms | 735.54 tokens/sec
step 4: loss = 9.084228515625 | lr = 0.00 | norm = 13.4934 | dt: 173.96 ms | 735.82 tokens/sec
step 5: loss = 8.8302001953125 | lr = 0.00 | norm = 8.8941 | dt: 173.82 ms | 736.39 tokens/sec
step 6: loss = 9.38555908203125 | lr = 0.00 | norm = 7.8695 | dt: 173.99 ms | 735.68 tokens/sec
step 7: loss = 9.274658203125 | lr = 0.00 | norm = 5.6511 | dt: 174.72 ms | 732.58 tokens/sec
step 