In [1]:
# helper function ugh
def load_py(filename):
    import requests, importlib
    if not filename.endswith(".py"): filename += ".py"
    url = f"https://raw.githubusercontent.com/BenAF002/data_science/refs/heads/main/colab_checkpoints/{filename}"
    code = requests.get(url).text

    # write to /content for colab
    destination = f"/content/{filename}"
    with open(destination, "w") as f:
        f.write(code)

    # handle module reloading
    module_name = filename[:-3] # Strip the .py extension
    if module_name in globals():
        importlib.reload(globals()[module_name])
    else:
        globals()[module_name] = importlib.import_module(module_name)

load_py("gpt_homebrew")  # will need to recall each time script is updated/saved

In [2]:
import time
from dataclasses import dataclass
import numpy as np
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken
enc = tiktoken.get_encoding('gpt2')

In [3]:
device = 'cpu'
if torch.cuda.is_available():
    device='cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device='mps'
print(f'using device: {device}')

using device: cuda


In [4]:
from torch.utils.data import Dataset, DataLoader

# data manager for tinyshakespeare data
class ShakeMgr:
    def __init__(self, seq_len: int = 32):
        if 'tinyshakespeare_data' not in globals():
            import urllib.request
            url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

            # read text in from url
            with urllib.request.urlopen(url) as response:
                tinyshakespeare_data = response.read().decode('utf-8')

        self.seq_len = seq_len
        
        self.text = tinyshakespeare_data
        self.tokens = enc.encode(self.text)

        # stack token sequences into data buffer for dataloader
        buf_len  = len(self.tokens) - (len(self.tokens) % (seq_len + 1))
        self.buffer = self.tokens[:buf_len]
        self.buffer = torch.tensor(self.buffer).reshape(-1, seq_len + 1)

    def train_test_split(self, train_pct: float = 0.8):
        idx = np.random.permutation(data_len := len(self.buffer))
        train_data = self.buffer[idx[:int(data_len * train_pct)]]
        test_data = self.buffer[idx[int(data_len * train_pct):]]

        return ShakeDS(train_data), ShakeDS(test_data)  # return train-test split

# dataset class for tinyshakespeare data, compatible with pytorch dataloader
class ShakeDS(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx: int):
        x = self.data[idx, :-1]  # shifted input sequence
        y = self.data[idx, 1:]   # shifted target sequence
        return x, y

In [5]:
# set up new training data with maximum batch-size sequences
data = ShakeMgr(seq_len=1024)  # max batch-size
train, test = data.train_test_split(train_pct = 0.8)

In [5]:
# gpt = gh.GPT.from_pretrained('gpt2', lora_rank=0, lora_alpha=16)
gpt = gpt_homebrew.GPT(gpt_homebrew.GPTConfig())
gpt.to(device)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
dl = DataLoader(train, batch_size=8, shuffle=True)

# try adjusting precision of pytorch float32 operations
# setting to 'high' uses TensorFlow32 (TF32) which subtley adjusts dtypes in operations
# increases TFLOPS on GPUs with tensorcores (like A100) by ~8X -- may improve other GPUs too
# TF32 is available on GPUs after the "Ampere" series (of which A100 is one)
torch.set_float32_matmul_precision('high')  # 'highest' is default

optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4)

import time
for i in range(10):
    t0 = time.time()

    xb, yb = next(iter(dl))

    # send to device during training / inference instead of during dataset creation
    # this is much more efficient for memory usage and dataloader speed
    xb, yb = xb.to(device), yb.to(device)  

    logits, loss = gpt(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
    # causing us to print time EARLY!
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000  # time in miliseconds
    tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
    print(f"step {i}, loss: {loss.item():,.6f}, dt: {dt:,.2f}ms, tok/sec: {tokens_per_sec:,.2f}")

  _C._set_float32_matmul_precision(precision)


step 0, loss: 10.992279, dt: 1,298.27ms, tok/sec: 6,309.92
step 1, loss: 9.610030, dt: 927.51ms, tok/sec: 8,832.28
step 2, loss: 9.131944, dt: 926.60ms, tok/sec: 8,840.93
step 3, loss: 9.067685, dt: 934.72ms, tok/sec: 8,764.16
step 4, loss: 8.617716, dt: 936.50ms, tok/sec: 8,747.47
step 5, loss: 8.445677, dt: 937.33ms, tok/sec: 8,739.75
step 6, loss: 8.125593, dt: 931.83ms, tok/sec: 8,791.29
step 7, loss: 7.868392, dt: 933.69ms, tok/sec: 8,773.79
step 8, loss: 7.779512, dt: 932.62ms, tok/sec: 8,783.84
step 9, loss: 7.629560, dt: 938.43ms, tok/sec: 8,729.48


Adjusting to use TF32 in matmul operations increased approximate tokens per second from ~15.7k --> ~42.7k.\
This is less than the ~8k improvement suggested by TFLOPS because of memory constraints.

We can do better by improving memory efficiency - i.e. using a floating-point format for the parameter data itself, not just the operation (a bit arcane......). The trade off is reduced precision and potentially a reduced range of value expressions. The format BF16 is a good middle-ground. It maintains the range of expression of FP32 (the default for `float32` dtype), but reduces the precision to reduce the bit-cost in memory. 

[This tutorial](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html) explains the pytorch implementation for adjusting float formats, particularly using `torch.autocast`. This can be used to adust the formats during the forward pass automatically.\
This only converts some datatypes. The implementation used here won't convert the embedding weights. It sounds like the `nn.Linear` layers will be changed which is good since they tend to be the most expensive here.

**Again**: this is only possible to do in Ampere GPUs, older GPUs don't support BF16 (I guess)

In [None]:
# add BF16 autocast
for i in range(10):
    t0 = time.time()

    xb, yb = next(iter(dl))

    # send to device during training / inference instead of during dataset creation
    # this is much more efficient for memory usage and dataloader speed
    xb, yb = xb.to(device), yb.to(device)  

    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = gpt(xb, yb)
    
    loss.backward()
    optimizer.step()

    # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
    # causing us to print time EARLY!
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000  # time in miliseconds
    tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
    print(f"step {i}, loss: {loss.item():,.6f}, dt: {dt:,.2f}ms, tok/sec: {tokens_per_sec:,.2f}")

step 0, loss: 7.383627, dt: 943.68ms, tok/sec: 8,680.94
step 1, loss: 7.230249, dt: 749.67ms, tok/sec: 10,927.51
step 2, loss: 7.142229, dt: 746.48ms, tok/sec: 10,974.18
step 3, loss: 6.902237, dt: 748.69ms, tok/sec: 10,941.82
step 4, loss: 6.789080, dt: 749.39ms, tok/sec: 10,931.58
step 5, loss: 6.764679, dt: 751.12ms, tok/sec: 10,906.43
step 6, loss: 6.723163, dt: 749.41ms, tok/sec: 10,931.24
step 7, loss: 6.539473, dt: 750.65ms, tok/sec: 10,913.22
step 8, loss: 6.486143, dt: 751.10ms, tok/sec: 10,906.62
step 9, loss: 6.443795, dt: 750.19ms, tok/sec: 10,919.83


## torch.compile

This is a *compiler* for pytorch code which can be used improve efficiency a lot.\
[Docs](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)

Testing out by just calling compile on the model:

In [9]:
gpt = torch.compile(gpt)

In [12]:
dl = DataLoader(train, batch_size=8, shuffle=True)
import time

torch.set_float32_matmul_precision('high')  # 'highest' is default

optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4)
for i in range(10):
    t0 = time.time()

    xb, yb = next(iter(dl))

    # send to device during training / inference instead of during dataset creation
    # this is much more efficient for memory usage and dataloader speed
    xb, yb = xb.to(device), yb.to(device)  

    # add BF16 autocast
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = gpt(xb, yb)
    
    loss.backward()
    optimizer.step()

    # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
    # causing us to print time EARLY!
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000  # time in miliseconds
    tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
    print(f"step {i}, loss: {loss.item():,.6f}, dt: {dt:,.2f}ms, tok/sec: {tokens_per_sec:,.2f}")

  _C._set_float32_matmul_precision(precision)
W1210 02:15:23.297000 1460 torch/_inductor/utils.py:1558] [0/0] Not enough SMs to use max_autotune_gemm mode


step 0, loss: 10.983019, dt: 37,981.59ms, tok/sec: 215.68
step 1, loss: 9.558714, dt: 336.32ms, tok/sec: 24,357.79
step 2, loss: 9.116148, dt: 337.41ms, tok/sec: 24,279.05
step 3, loss: 9.207137, dt: 342.12ms, tok/sec: 23,944.77
step 4, loss: 8.644671, dt: 343.04ms, tok/sec: 23,880.37
step 5, loss: 8.542477, dt: 338.34ms, tok/sec: 24,212.33
step 6, loss: 8.274139, dt: 339.65ms, tok/sec: 24,118.90
step 7, loss: 8.107955, dt: 342.17ms, tok/sec: 23,941.45
step 8, loss: 7.832716, dt: 339.43ms, tok/sec: 24,134.72
step 9, loss: 7.663104, dt: 341.04ms, tok/sec: 24,020.36


So, this added some time in compilation, but it DOUBLED the approximate tokens per second!!!

It sounds like these improvements mainly come from circumventing the Python overhead (compiling the code entirely in C?) and reducing GPU "read/write". On this second point, basically each time a computation happens in the code, we need to bus data from the GPU memory to the GPU itself for processing and then back to the memory.\
This can become really expensive even when we have very high bandwidth memory because we may end up needing to do a very large amount of transfers. For example, a fairly simple statement like the following could have five separate passes:\
`x = 0.5 * x * (1 - torch.tanh(math.sqrt(2 / math.pi) * x))`

The compiler batches these operations so that many or all of them are applied at once on the GPU without the need to bus back and forth to the memory for each operation. Apparently this is an example of ***Kernel Fusion***... may be worth reading about

It sounds like the gist of kernel fusion is using the small amount of memory (SRAM) that lives in the GPU to do sets of the operations for us instead of streaming between the HBM that is off of the GPU chip for each operation. So, maybe it's not too far off to think of this as chunkating the stream from HBM to GPU chip, processing a set of operations on the GPU chip, then streaming back - rather than streaming one chunk per operation...

There are some operations that `torch.compile()` **does not find**

Enter FlashAttention

## FlashAttention

From the [2022 paper](https://arxiv.org/pdf/2205.14135), FlashAttention fundamentally works pretty similarly to `torch.compile()` it seems. Basically, it reduces the read/writes between the GPU and HBM. It goes a step further than `torch.compile` in that it actually rewrites teh attention alogirithm. It explicitly separates $K$, $Q$, $V$ matrices, iteratively loading them to SRAM and writing the matmul outputs to HBM, completely avoiding reading/writing the *full $T\times T$ attention matrix* to HBM. This entire process actually *costs more TFLOPS* than ordinary attention, but it is *dramatically* faster because it manages the memory read/write operations much better.

Interestingly, FlashAttention built upon online softmax first proposed in a [2018 paper from Nvidia](https://arxiv.org/pdf/1805.02867)

The implementation of FlashAttention in pytorch is super simple: `F.scaled_dot_product_attention()`, that's it...

In [13]:
# reload after adding FlashAttention
load_py("gpt_homebrew")  # will need to recall each time script is updated/saved
gpt = gpt_homebrew.GPT(gpt_homebrew.GPTConfig())
gpt.to(device)

gpt = torch.compile(gpt)  # recompile after reloading

In [14]:
dl = DataLoader(train, batch_size=8, shuffle=True)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4)
for i in range(10):
    t0 = time.time()

    xb, yb = next(iter(dl))

    # send to device during training / inference instead of during dataset creation
    # this is much more efficient for memory usage and dataloader speed
    xb, yb = xb.to(device), yb.to(device)  

    # add BF16 autocast
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = gpt(xb, yb)
    
    loss.backward()
    optimizer.step()

    # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
    # causing us to print time EARLY!
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000  # time in miliseconds
    tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
    print(f"step {i}, loss: {loss.item():,.6f}, dt: {dt:,.2f}ms, tok/sec: {tokens_per_sec:,.2f}")

step 0, loss: 10.932065, dt: 25,433.62ms, tok/sec: 322.09
step 1, loss: 9.635462, dt: 228.31ms, tok/sec: 35,880.57
step 2, loss: 9.057121, dt: 235.18ms, tok/sec: 34,832.91
step 3, loss: 8.782685, dt: 249.91ms, tok/sec: 32,780.10
step 4, loss: 8.509007, dt: 244.93ms, tok/sec: 33,445.99
step 5, loss: 8.315421, dt: 233.01ms, tok/sec: 35,156.71
step 6, loss: 8.074077, dt: 236.29ms, tok/sec: 34,669.51
step 7, loss: 7.929491, dt: 243.47ms, tok/sec: 33,646.60
step 8, loss: 7.672122, dt: 244.01ms, tok/sec: 33,571.91
step 9, loss: 7.453282, dt: 239.39ms, tok/sec: 34,219.81


So, that's about a 40% increase in speed!

## Powers of 2

Powers of 2 are nice numbers for NNs. This is because most low-level operations happen in powers of twos (like if we partition matmul operations between tensorcores which do $4\times 4$ matrix operations). When we don't have powers of twos, then more leg-work is done at a low-level.

In our case, a bad number is the ***vocab size*** of 50,257. Very baaadddd.\
The hackneyed solution which actually works is just to find ways to increase numbers to the nearest nice number (i.e. a number that can be factorized into many powers of two). So, we can just override the vocabulary size to be 50,304, a nice number, and get better efficiency. This will just add input dimensions to the embedding and output dimensions to the final output classification head, but that's fine since it will be more efficient. Since no input tokens will use these additional embeddings, we don't really need to care about them.

In [15]:
gpt = gpt_homebrew.GPT(gpt_homebrew.GPTConfig(vocab_size=50304))  # nice number
gpt.to(device)

gpt = torch.compile(gpt)  # recompile after reloading

In [16]:
dl = DataLoader(train, batch_size=8, shuffle=True)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4)
for i in range(10):
    t0 = time.time()

    xb, yb = next(iter(dl))

    # send to device during training / inference instead of during dataset creation
    # this is much more efficient for memory usage and dataloader speed
    xb, yb = xb.to(device), yb.to(device)  

    # add BF16 autocast
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = gpt(xb, yb)
    
    loss.backward()
    optimizer.step()

    # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
    # causing us to print time EARLY!
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000  # time in miliseconds
    tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
    print(f"step {i}, loss: {loss.item():,.6f}, dt: {dt:,.2f}ms, tok/sec: {tokens_per_sec:,.2f}")

step 0, loss: 10.969835, dt: 23,063.76ms, tok/sec: 355.19
step 1, loss: 9.600793, dt: 222.01ms, tok/sec: 36,899.41
step 2, loss: 9.291704, dt: 221.79ms, tok/sec: 36,935.15
step 3, loss: 8.853453, dt: 239.39ms, tok/sec: 34,220.43
step 4, loss: 8.704416, dt: 239.32ms, tok/sec: 34,230.86
step 5, loss: 8.334323, dt: 227.06ms, tok/sec: 36,077.91
step 6, loss: 8.144485, dt: 228.03ms, tok/sec: 35,925.67
step 7, loss: 7.994779, dt: 229.42ms, tok/sec: 35,706.67
step 8, loss: 7.777615, dt: 237.81ms, tok/sec: 34,447.79
step 9, loss: 7.580957, dt: 233.12ms, tok/sec: 35,140.74


So that's like a ~4-5% improvement...

## Hyperparameter Tuning

Here, Andrej tries to copy the training hyperparameters from GPT-3

In [18]:
# add betas used in GPT-3
optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4, betas=(0.9, 0.95))

Next, they limit the $l2$-norm of ***all gradients*** for *all parameters* to be no more than $1.0$. Pytorch makes this easy with `torch.nn.utils.clip_grad_norm_`. Basically, the $l2$-norms for each layer of gradients are computed, squared, summed, and rooted - this is equivalent to collecting all of the gradients into a single vector and computing the $l2$-norm of that vector:
$$
G = \sqrt{\sum_{i=1}^k \| \nabla L(w_i)\|_2^2}
$$
Then, if the total gradient norm exceeds a `max_norm` $C$, the gradients are simply scaled down by:
$$ \alpha = \frac{C}{G} $$
Yielding:
$$\nabla L(w_i)_{\text{clipped}} = \alpha \cdot \nabla L(w_i)$$

This basically just helps to reduce the risk of exploding gradients.

In [None]:
def train_loop(gpt, dl, optimizer, device, steps=10):
    for i in range(steps):
        t0 = time.time()

        xb, yb = next(iter(dl))

        # send to device during training / inference instead of during dataset creation
        # this is much more efficient for memory usage and dataloader speed
        xb, yb = xb.to(device), yb.to(device)  

        # add BF16 autocast
        optimizer.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = gpt(xb, yb)
        
        loss.backward()

        # add gradient clipping
        norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), max_norm=1.0)

        optimizer.step()

        # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
        # causing us to print time EARLY!
        torch.cuda.synchronize()
        t1 = time.time()
        dt = (t1 - t0) * 1000  # time in miliseconds
        tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
        print(
            f"step {i}, loss: {loss.item():,.6f}, dt: {dt:,.2f}ms, tok/sec: {tokens_per_sec:,.2f}"
            + f", grad norm: {norm:,.4f}"
        )

train_loop(gpt, dl, optimizer, device, steps=10)

step 0, loss: 7.428323, dt: 462.63ms, tok/sec: 17,707.33, grad norm: 1.5343868732452393
step 1, loss: 8.012007, dt: 229.21ms, tok/sec: 35,739.91, grad norm: 14.025282859802246
step 2, loss: 7.422108, dt: 227.86ms, tok/sec: 35,951.23, grad norm: 3.2190914154052734
step 3, loss: 7.007826, dt: 229.67ms, tok/sec: 35,669.00, grad norm: 2.499519109725952
step 4, loss: 6.743257, dt: 237.70ms, tok/sec: 34,463.96, grad norm: 1.6739602088928223
step 5, loss: 6.769724, dt: 243.08ms, tok/sec: 33,701.45, grad norm: 3.9348113536834717
step 6, loss: 6.550073, dt: 241.84ms, tok/sec: 33,873.68, grad norm: 1.6428518295288086
step 7, loss: 6.547624, dt: 237.44ms, tok/sec: 34,501.40, grad norm: 1.2088583707809448
step 8, loss: 6.518197, dt: 237.17ms, tok/sec: 34,540.18, grad norm: 1.3318595886230469
step 9, loss: 6.389904, dt: 239.82ms, tok/sec: 34,158.65, grad norm: 1.3799618482589722


## Learning Rate Scheduler

In GPT-3 they use a cosine-decay learning-rate scheduler. This gradually decreases the learning rate over the course of training. Learning-rate decay is useful for allowing us to learn more quickly early in training whil reducing the risks of unstable gradients later in training. One of the neat things about the scheduler that the GPT-3 paper uses is that it starts with a very low learning-rate, presumably this promotes stability in the very early phases of training, where getting bad gradients (or learning the wrong things) early on can be very costly. It very rapidly increases the learning rate to the maximum, and then begins decreasing from there down to 10% of the maximum.

In [14]:
def get_lr(step):
    max_lr = 3e-4
    min_lr = max_lr * 0.1
    warmup_steps = 10
    max_steps = 50

    # warm up phase -- linearly increase to max_lr
    if step < warmup_steps:
        return max_lr * (step + 1) / warmup_steps
    # if past max_steps, do nothing
    if step > max_steps:
        return min_lr
    # in-between, do annealing
    decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)  # interpolate
    assert 0 <= decay_ratio <= 1, "Invalid decay_ratio"
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # cosine decay coefficient
    return min_lr + coeff * (max_lr - min_lr)  

In [16]:
def train_loop(gpt, dl, optimizer, device, steps=10):
    for i in range(steps):
        t0 = time.time()

        xb, yb = next(iter(dl))

        # send to device during training / inference instead of during dataset creation
        # this is much more efficient for memory usage and dataloader speed
        xb, yb = xb.to(device), yb.to(device)  

        # add BF16 autocast
        optimizer.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = gpt(xb, yb)
        
        loss.backward()

        # add gradient clipping
        norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), max_norm=1.0)

        # lr cosine annealing
        lr = get_lr(i)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        optimizer.step()

        # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
        # causing us to print time EARLY!
        torch.cuda.synchronize()
        t1 = time.time()
        dt = (t1 - t0) * 1000  # time in miliseconds
        tokens_per_sec = (dl.batch_size * xb.shape[1]) / (t1 - t0)
        print(
            f"step {i} | loss: {loss.item():,.6f} | dt: {dt:,.2f}ms | tok/sec: {tokens_per_sec:,.2f}"
            + f" | grad norm: {norm:,.4f} | lr: {lr:e}"
        )

In [18]:
# reload after adding custom weight decay -- weight decay of 0.1 added to all weights (i.e. not to biases/norms)
load_py("gpt_homebrew")  # will need to recall each time script is updated/saved
gpt = gpt_homebrew.GPT(gpt_homebrew.GPTConfig())
gpt.to(device)

gpt = torch.compile(gpt)  # recompile after reloading

In [19]:
dl = DataLoader(train, batch_size=8, shuffle=True)

torch.set_float32_matmul_precision('high')  # 'highest' is default

# weight decay from paper
optimizer = gpt.configure_optimizers(weight_decay=0.1, learning_rate=3e-4, betas=(0.9, 0.95), device_type='cuda')
# optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4, betas=(0.9, 0.95))

train_loop(gpt, dl, optimizer, device, steps=50)

num decayed parameter tensors: 50, with 124,318,464 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step 0 | loss: 10.996070 | dt: 24,698.77ms | tok/sec: 331.68 | grad norm: 25.7995 | lr: 3.000000e-05
step 1 | loss: 9.807817 | dt: 220.69ms | tok/sec: 37,119.82 | grad norm: 10.7909 | lr: 6.000000e-05
step 2 | loss: 9.378258 | dt: 228.02ms | tok/sec: 35,926.57 | grad norm: 10.1761 | lr: 9.000000e-05
step 3 | loss: 9.805191 | dt: 245.63ms | tok/sec: 33,350.45 | grad norm: 11.4488 | lr: 1.200000e-04
step 4 | loss: 9.313437 | dt: 241.86ms | tok/sec: 33,870.98 | grad norm: 4.9837 | lr: 1.500000e-04
step 5 | loss: 8.875429 | dt: 221.94ms | tok/sec: 36,910.16 | grad norm: 3.3518 | lr: 1.800000e-04
step 6 | loss: 8.916994 | dt: 223.19ms | tok/sec: 36,704.97 | grad norm: 4.7965 | lr: 2.100000e-04
step 7 | loss: 8.658592 | dt: 234.16ms | tok/sec: 34,984.46 | grad norm: 2.0000 | lr: 2.400000e-04
step 8 | loss: 8.479709 | dt: 241.24ms | tok/sec: 33,

## Gradient Accumulation

GPT-3 Small was trained with batches of ~0.5M tokens, or ~488 sequences each of length 1,024.\
We obviously cannot directly replicate this batch size because it won't fit in our available memory. But we can simulate this batch size using ***Gradient Accumulation***

This sounds fancy, but it actually just amounts to getting gradients for forward passes through multiple batches, tracking them all, then updating at once!

This is pretty simple to implement because the `.backwards()` method already does gradient accumulation - i.e. it just ADDS gradients (this is why we always call optimizer.zero_grad() ofc). 

In [12]:
# testing with total_batch_size < 0.5M for time and since tinyshakespeare may be too small for that...
def accum_train_loop(gpt, dl, optimizer, device, steps: int = 10, total_batch_size: int = 65536):
    """Training loop with gradient accumulation"""
    assert total_batch_size % (dl.batch_size * gpt.config.block_size) == 0, "total_batch_size not evenly divisible by minibatches"
    accum_steps = total_batch_size // (dl.batch_size * gpt.config.block_size)

    for i in range(steps):
        t0 = time.time()
        loss_accum = 0.0
        for acc_step in range(accum_steps):
            xb, yb = next(iter(dl))

            # send to device during training / inference instead of during dataset creation
            # this is much more efficient for memory usage and dataloader speed
            xb, yb = xb.to(device), yb.to(device)  

            # add BF16 autocast
            optimizer.zero_grad()
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, loss = gpt(xb, yb)
            
            loss /= accum_steps  # <-- normalize to account for smaller minibatch size (since cross_entropy has mean reduction by default)
            loss_accum += loss.detach()
            loss.backward()  # this AUTOMATICALLY accumulates gradients in the gradient tensors - just adds them

        # add gradient clipping
        norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), max_norm=1.0)

        # lr cosine annealing
        lr = get_lr(i)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        optimizer.step()

        # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
        # causing us to print time EARLY!
        torch.cuda.synchronize()
        t1 = time.time()
        dt = (t1 - t0) * 1000  # time in miliseconds
        tokens_per_sec = (dl.batch_size * xb.shape[1] * accum_steps) / (t1 - t0)
        print(
            f"step {i} | loss: {loss_accum.item():,.6f} | dt: {dt:,.2f}ms | tok/sec: {tokens_per_sec:,.2f}"
            + f" | grad norm: {norm:,.4f} | lr: {lr:e}"
        )

In [16]:
# reload after adding custom weight decay -- weight decay of 0.1 added to all weights (i.e. not to biases/norms)
load_py("gpt_homebrew")  # will need to recall each time script is updated/saved
gpt = gpt_homebrew.GPT(gpt_homebrew.GPTConfig())
gpt.to(device)

gpt = torch.compile(gpt)  # recompile after reloading

dl = DataLoader(train, batch_size=8, shuffle=True)

torch.set_float32_matmul_precision('high')  # 'highest' is default

# weight decay from paper
optimizer = gpt.configure_optimizers(weight_decay=0.1, learning_rate=3e-4, betas=(0.9, 0.95), device_type='cuda')

accum_train_loop(gpt, dl, optimizer, device, steps=10, total_batch_size=524288) # // 8)

num decayed parameter tensors: 50, with 124,318,464 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step 0 | loss: 10.896723 | dt: 15,714.33ms | tok/sec: 33,363.68 | grad norm: 0.4496 | lr: 3.000000e-05
step 1 | loss: 9.660885 | dt: 14,404.30ms | tok/sec: 36,398.01 | grad norm: 0.1724 | lr: 6.000000e-05
step 2 | loss: 9.324739 | dt: 14,082.58ms | tok/sec: 37,229.53 | grad norm: 0.2252 | lr: 9.000000e-05
step 3 | loss: 9.688659 | dt: 13,727.82ms | tok/sec: 38,191.65 | grad norm: 0.1484 | lr: 1.200000e-04
step 4 | loss: 9.312862 | dt: 13,746.38ms | tok/sec: 38,140.08 | grad norm: 0.0870 | lr: 1.500000e-04
step 5 | loss: 9.016571 | dt: 13,710.54ms | tok/sec: 38,239.78 | grad norm: 0.0662 | lr: 1.800000e-04
step 6 | loss: 8.772388 | dt: 13,764.61ms | tok/sec: 38,089.56 | grad norm: 0.0467 | lr: 2.100000e-04
step 7 | loss: 8.618524 | dt: 13,873.31ms | tok/sec: 37,791.13 | grad norm: 0.0336 | lr: 2.400000e-04
step 8 | loss: 8.483671 | dt: 13,

## Parallelizing

Can parallelize across GPUs in pytorch by using `DistributedDataParallel`.\
Each GPU will do forward and backwards passes in parallel. But the gradients that they use to update parameters will be the average of the gradients computed in the parallel backwards passes.

In [18]:
# copy pasta
import os
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

# set up DDP (distributed data parallel).
# torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
    # use of DDP atm demands CUDA, we set the device appropriately according to rank
    assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])              # int ID of each process across all nodes
    ddp_local_rank = int(os.environ['LOCAL_RANK'])  # local rank (int ID) of process on this node - same as rank then
    ddp_world_size = int(os.environ['WORLD_SIZE'])  # number of processes across all nodes - usually = number of GPUs
    device = f'cuda:{ddp_local_rank}'  # set device to cuda:<local_rank> -- get correct GPU
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
else:
    # vanilla, non-DDP run -- this is single-GPU scenario
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    # attempt to autodetect device
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"using device: {device}")

using device: cuda


Need to adjust accumulation for parallelization *and* dataloader.\
Using pytorch `DataLoader`, we can adjust for DDP by adding a `DistributedSampler` sampler. This partitions the dataset across the GPUs (ranks).

In [23]:
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(train, num_replicas=ddp_world_size, rank=ddp_rank, shuffle=True)
dl = DataLoader(train, batch_size=8, sampler=sampler)

In [26]:
# init model with GPT
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)  # set common seed for all GPUs so they all init the SAME model
gpt = gpt_homebrew.GPT(gpt_homebrew.GPTConfig()).to(device)
gpt.compile()

# wrap model in DDP
if ddp:
    gpt = DDP(gpt, device_ids=[ddp_local_rank])

raw_gpt = gpt.module if ddp else gpt  # unwrap DDP for optimizer

DDP will synchronize gradients for us by averaging them across the GPUs before updating. So, we have separate but identical models living in each GPU, all processing separate batches of the training data but receiving *the same parameter updates*.

In [27]:
total_batch_size = 524288
steps = 10

# adjust accumulation steps for DDP
assert total_batch_size % (dl.batch_size * gpt.config.block_size * ddp_world_size) == 0, \
    "total_batch_size not evenly divisible by minibatches"
accum_steps = total_batch_size // (dl.batch_size * gpt.config.block_size * ddp_world_size)  # distribute
if master_process:
    print(f"accum_steps: {accum_steps} (total_batch_size: {total_batch_size}, "
          f"dl.batch_size: {dl.batch_size}, block_size: {gpt.config.block_size}, "
          f"ddp_world_size: {ddp_world_size})")


torch.set_float32_matmul_precision('high')  # 'highest' is default
optimizer = raw_gpt.configure_optimizers(weight_decay=0.1, learning_rate=3e-4, betas=(0.9, 0.95), device_type='cuda')
for i in range(steps):
    t0 = time.time()
    loss_accum = 0.0
    for acc_step in range(accum_steps):
        xb, yb = next(iter(dl))

        # send to device during training / inference instead of during dataset creation
        # this is much more efficient for memory usage and dataloader speed
        xb, yb = xb.to(device), yb.to(device)  

        # add BF16 autocast
        optimizer.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = gpt(xb, yb)
        
        loss /= accum_steps  # <-- normalize to account for smaller minibatch size (since cross_entropy has mean reduction by default)
        loss_accum += loss.detach()

        # under ddp, only sync gradients on LAST accumulation step
        # without this, ddp would automatically sync after EVERY backward()
        if ddp:
            gpt.requires_backward_grad_sync = (acc_step == accum_steps - 1)

        loss.backward()  # this AUTOMATICALLY accumulates gradients in the gradient tensors - just adds them

    # this averages loss_accum across all DDP processes
    if ddp:
        dist.allr_reduce(loss_accum, op=dist.ReduceOp.AVG)

    # add gradient clipping
    norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), max_norm=1.0)

    # lr cosine annealing
    lr = get_lr(i)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    optimizer.step()

    # this waits for the GPU to finish operations - otherwise CPU REPL stuff may finish before GPU
    # causing us to print time EARLY!
    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000  # time in miliseconds
    tokens_per_sec = (dl.batch_size * xb.shape[1] * accum_steps * ddp_world_size) / (t1 - t0)

    # only master process does logging
    if master_process:
        print(
            f"step {i} | loss: {loss_accum.item():,.6f} | dt: {dt:,.2f}ms | tok/sec: {tokens_per_sec:,.2f}"
            + f" | grad norm: {norm:,.4f} | lr: {lr:e}"
        )

# clean up after training
if ddp:
    destroy_process_group()

accum_steps: 64 (total_batch_size: 524288, dl.batch_size: 8, block_size: 1024, ddp_world_size: 1)
num decayed parameter tensors: 50, with 124,318,464 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True
step 0 | loss: 11.011212 | dt: 14,185.72ms | tok/sec: 36,958.85 | grad norm: 0.4259 | lr: 3.000000e-05
step 1 | loss: 9.682494 | dt: 14,096.88ms | tok/sec: 37,191.78 | grad norm: 0.1992 | lr: 6.000000e-05
step 2 | loss: 9.236091 | dt: 13,718.60ms | tok/sec: 38,217.31 | grad norm: 0.1986 | lr: 9.000000e-05
step 3 | loss: 9.449409 | dt: 13,486.06ms | tok/sec: 38,876.28 | grad norm: 0.1612 | lr: 1.200000e-04
step 4 | loss: 9.027285 | dt: 13,539.00ms | tok/sec: 38,724.28 | grad norm: 0.0805 | lr: 1.500000e-04
step 5 | loss: 8.736452 | dt: 13,558.33ms | tok/sec: 38,669.07 | grad norm: 0.0601 | lr: 1.800000e-04
step 6 | loss: 8.529930 | dt: 13,681.62ms | tok/sec: 38,320.61 | grad norm: 0.0401 | lr: 2.100000e-04
step 7 | loss: 8.378812 | dt: 13,779.

Were we to actually run this in a distributed way, we could use:\
`torchrun --standalone --nproc_per_node=8 train_gpt2.py` or similar