## Increasing Efficiency by Generating more Heat!

#### Precision Optimisations

Unlike the less important transfomers from physics and the even less important transformers from the movies, our favourite transformer is one that actually grows in efficiency and performance as we generate more heat! The goal of this notebook will be to get the GPU cluster as hot as possible and complete as many computations in as short a time as possible.

The first element of this that we will take advantage of is to drive down the precision of our parameters. Deep learning has progressively begun to use less and less precision in their parameters as the size of models have grown. This is not only because of the improvements in training efficiency from being able to complete magnitudes more FLOPS of computation (up to 16x more TFLOPS if we decide to train in BFLOAT16), but also because the model just becomes a lot smaller during inference time.

Let's begin by pasting our current code base


In [1]:
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

@dataclass # adds all the dunder stuff to the class
class GPTConfig:
    block_size: int = 1024 # Context length
    vocab_size: int = 50257 # Word dictionary size
    n_layer: int = 12 # Number of layers (blocks)
    n_head: int = 12 # Number of heads
    n_embd: int = 768 # Embedding dimension

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"

print(f"Detected device: {device}")

import tiktoken

class DataLoader:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        enc = tiktoken.get_encoding('gpt2')
        with open('data/sample_input.txt', 'r') as f:
            text = f.read()
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        self.tokens = self.tokens.to(device)
        print("num tokens in dataset:", len(self.tokens))
        print("batches to 1 epoch:", len(self.tokens) // (self.B * self.T))

        self.current_index = 0

    def get_next_batch(self):
        buf = self.tokens[self.current_index : self.current_index + self.B * self.T + 1]
        x = buf[:-1].view(self.B, self.T)
        y = buf[1:].view(self.B, self.T)
        self.current_index += self.B * self.T

        # Reset when out of bounds
        if self.current_index + (self.B * self.T + 1) > len(self.tokens):
            self.current_index = 0
        return x, y

    def reset(self):
        self.current_index = 0


class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        # Confirm dimensions will be evenly distributed among all the heads
        assert config.n_embd % config.n_head == 0
        # Create the key, query and value projections in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # Following info will need to be stored since forward pass needs to 
        # separate the abomination above
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        # The masked filter but idk why OpenAI called it bias.
        # Resised to fit the above abomination
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1,1, config.block_size, config.block_size))
        # Linear projection out of the Attention block
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
    def forward(self, x):
        # Batch, time and channel of data (batch size, sequence length & emb dim)
        B, T, C = x.size()
        # Calculate the qkv value combined in the shape (B,T,3 * C)
        qkv = self.c_attn(x)
        # Split n_embd size bits out for k, q, v along the channel dimension
        q, k, v = qkv.split(self.n_embd, dim=2)
        # Reshape each tensor to have the heads in dim=2 and then steal the weights for
        # those heads by taking the values from the embedding dimension.
        # Transpose the sequence length and head size so that the affinity calculation
        # is completed on the sequence length * head size face of the matrices
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(q.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1) # (B, n_heads, T, T)
        y = att @ v # (B, n_heads, T, T) @ (B, n_heads, T, head_size) = (B, n_heads, T, head_sze)
        # Re-orient tensor to shape (B, T, n_heads, head_size), followed by
        # Concatenation via the view method
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        # Contiguous method ensures that the entire data is stored in a nice way
        # such that no memory errors can occur when we do the concatenating.
        # This is more important in our gpt-2 size model because our memory usage
        # is high enough that our OS may split the memory to different places
        
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        # Expanding the dimensions of the data to let some 
        # computation occur
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        # gelu activation uses a tanh to approximate the real function
        # in the GPT2 due to an error in PyTorch but we need the exact
        # same stuff to import properly so we gotta suffer same way
        self.gelu = nn.GELU(approximate="tanh")
        # Projecting the data back into normal dimensions
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        
        # Adding a custom flag to the model to flag our scaling
        self.c_proj.CUSTOM_SCALING_INIT = 1
        
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x
    
class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
        
    def forward(self, x):
        # Residual skip connection for the attention block with pre-norm
        x = x + self.attn(self.ln_1(x))
        # Residual skip connection for the MLP block with pre-norm
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            # dictionary of token embedding weights (weight of token embeddings)
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            # dictionary of positional embedding weights (weight of positional embeddings)
            wpe = nn.Embedding(config.block_size, config.n_embd),
            # Attention blocks
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            # Final Layer norm (since pre-norm doesn't touch final block output)
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        # Linear projection out of the Attention block
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight sharing between wte and lm_head 
        self.transformer.wte.weight = self.lm_head.weight
        
        # Apply weight initialisation
        self.apply(self._init_weights) 
    
    def _init_weights(self, module):
        # Initialise the weights of the model using the same method as in GPT2
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, 'CUSTOM_SCALING_INIT'):
                std = (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
    
    def forward(self, text, targets=None):
        # Raw data input of b context length sized sequences of raw text
        B, T = text.size()
        assert T <= self.config.block_size, "Sequence length too high"
        # Create numbers from 0 to T
        pos = torch.arange(0, T, dtype=torch.long, device=text.device)
        # Fetch positional and token embeddings for all the values in text
        pos_emb = self.transformer.wpe(pos) # of size (T, n_embd)
        tok_emb = self.transformer.wte(text) # of size (B, T, n_embd)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # of size (B, T, vocab_size)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            loss = None
        return logits, loss

Detected device: mps


In [3]:
import time

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
torch.manual_seed(42)

# Enable humongous batch size on GPU
# training_loader = DataLoader(B=16, T=1024)

# Enable small batch size on CPU
training_loader = DataLoader(B=8, T=1024)

model = GPT(GPTConfig())
model.to(device);

optimiser = torch.optim.AdamW(model.parameters(), lr=3e-4)
# scheduler = torch.optim.lr_scheduler.StepLR(optimiser, step_size=25, gamma=0.1)
steps = 2
for i in range(steps):
    t0 = time.time()
    x, y = training_loader.get_next_batch()
    optimiser.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    optimiser.step()
    t1 = time.time()
    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # measure throughput of tokens/sec at each step
    tokens_per_sec = (training_loader.B * training_loader.T) / (t1 - t0)
    # scheduler.step()
    print(f"Step {i+1}: Loss={loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

num tokens in dataset: 338025
batches to 1 epoch: 41
Step 1: Loss=11.0233 Time 18916.78ms Tokens/sec=433.05
Step 2: Loss=9.5772 Time 19765.78ms Tokens/sec=414.45


Let's begin by ensuring all the operaitons are done in a lower precision 32 bit mode. In normal 32 bit precision, the model trains with full 32 bit numbers, using the entire 24 mantissa bits to store the full precision. Using the "high" precision mode, we lower the usage of the mantisssa to only occupy 10 bits.

In [2]:
torch.set_float32_matmul_precision('high')

Now this supposed to speed up the training's throughput by many multiples depending on the GPU you own and how many tensor cores are present inside. In our case on a Macbook, this is very negligible and almost nonexitstent. I will test later on hired GPUs but that's what we're seeing right now.

Despite the above configuration downcasting all the float32's to TF32, the model still contains a bunch of 32 bit parameters flying around all over the place. The movement of all the data from place to place is still being done with the full float32 precision. This is why we were gaining some throughput (net amount of computations being done as seen in tok/s), our overall time per step was still high due to bandwidth limitations. Now let's tackle that by reducing the actual size of each parameter in the model.

The details of how the TF and FP numbers are represented is very detailed and stuff, but to summarise, the FP numbers use all available bits to represent numbers while the TF numbers use slightly less to do their representation. This means that the TF numbers are more dense in a way. Have a look at the image below to see the difference.

<image src="./notebooks/images/fp vs tf.png" width="500">

One thing to note is that the FP16 is on the standard set by the people who invented the idea of a float. The issue with FP16 is that it cuts the range of numbers possible on the float and that caused a lot of issues during training since casting between different precisions was super pain. BF16 was introduced in the Ampere series of GPUs and it's a 16 bit float that has a range of numbers that is similar to FP32 but cuts down on the number of precision bits in the mantissa.

To make this happen in our model, we will use the autocast context manager in PyTorch. This function will automatically cast the parameters and inputs to the appropriate precision based on the operation that is being performed. This is a very powerful feature and it will help us reduce the size of each parameter in the model.

Pytorch themselves reccomend for the autocast to only be applied to the forward pass and loss calculations in their documentation. Let's add that in to our model.

In [None]:
optimiser = torch.optim.AdamW(model.parameters(), lr=3e-4)
steps = 2
for i in range(steps):
    t0 = time.time()
    x, y = training_loader.get_next_batch()
    optimiser.zero_grad()
    # Autocast forward pass parameters to BF16
    with torch.autocast(device_type=device, dtype=torch.bfloat16): 
        logits, loss = model(x, y)
    loss.backward()
    optimiser.step()
    t1 = time.time()
    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # measure throughput of tokens/sec at each step
    tokens_per_sec = (training_loader.B * training_loader.T) / (t1 - t0)
    print(f"Step {i+1}: Loss={loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

This autocast actually doesn't apply to every component of the forawrd pass, such as the wte and wpe. Pytoech decides which things require the conversions based what computational and performance effect they have on the forward pass. It's not actually clear what exactly pytorch decides not to autocast but we can trust them that they will pick what is most optimal (I think).

#### Compiling for speed

There's a torch compilation feature that compiles the neural network models, kind of like gcc does for C code. It applies a bunch of optimisations across the model and caches a bunch of stuff to make the model run faster. I have no idea what it does exactly but it says that it reduces the Python overhead and the GPU reads/ writes to speed up the model. 

For how it reduces python overhead, since we already algorithmically define what functions we will use from PyTorch and python and how they will be used, the compiler can basically hard code those elements into the model. By only taking the most important parts of libraries (things we use) and removing the majority of the python interpreter's role, the model is able to run faster (kind of like compiling multiple files of C code into one executable).

As for how it reduces the number of read/ write operations, recall that GPUs have high bandwidth memory (HBM) in the same way that CPU's have RAM. Currently the GPU basically figures out kernels for what operations need to be completed and then sends those kernels to the GPU cores from the HBM when they need to be calculated. This is done for each step of the computation so something like input + 0.4 * torch.pow(input, 3.0) will make like 4 round trips to complete.

The compile operation basically looks to find every operation compeleted by the model that looks like this which it can complete more efficently by keeping that value in the GPU cores without needing to repeatedly send it to the HBM. It basically finds what parts of the operation can be completed during the duration that the memory is in the GPU cores. This is called **kernel fusion**.

This thing actually doesn't even work on MPS so we can't do it on the mac. Currently fighting between coding this entire thing on the cloud and just not running the code at all until I have all the code written locally. What a pain.

In [None]:
torch.set_float32_matmul_precision('high')

model = GPT(GPTConfig())
model.to(device);
# Compile the model
model = torch.compile(model)

Now within a GPU, there is a little bit of extremely fast memory called **SRAM**. This memory is magnitudes faster than the HBM but as always, this speed demon is extremely small so it isn't used alot or most efficiently until a torch compile is completed.

#### Flash Attention Algorithm
A big breakthrough research paper invented the flash attention algorithm which rewrites the old attention operation softmax((q @ k^T)/sqrt(n_embd)) @ v with kernel fusion into 1 kernel.

<image src="./notebooks/images/flash attention.png" width=200>

We make this change in code by simply replacing our manual attention calculation with the flash attention implementation inside PyTorch.

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        # Confirm dimensions will be evenly distributed among all the heads
        assert config.n_embd % config.n_head == 0
        # Create the key, query and value projections in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # Following info will need to be stored since forward pass needs to 
        # separate the abomination above
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        # The masked filter but idk why OpenAI called it bias.
        # Resised to fit the above abomination
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1,1, config.block_size, config.block_size))
        # Linear projection out of the Attention block
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
    def forward(self, x):
        # Batch, time and channel of data (batch size, sequence length & emb dim)
        B, T, C = x.size()
        # Calculate the qkv value combined in the shape (B,T,3 * C)
        qkv = self.c_attn(x)
        # Split n_embd size bits out for k, q, v along the channel dimension
        q, k, v = qkv.split(self.n_embd, dim=2)
        # Reshape each tensor to have the heads in dim=2 and then steal the weights for
        # those heads by taking the values from the embedding dimension.
        # Transpose the sequence length and head size so that the affinity calculation
        # is completed on the sequence length * head size face of the matrices
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2)

        # Efficient attention calculation using Flash Attention
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # Re-orient tensor to shape (B, T, n_heads, head_size), followed by
        # Concatenation via the view method
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        # Contiguous method ensures that the entire data is stored in a nice way
        # such that no memory errors can occur when we do the concatenating.
        # This is more important in our gpt-2 size model because our memory usage
        # is high enough that our OS may split the memory to different places
        
        y = self.c_proj(y)
        return y

#### Funniest Optimisation
So you know how computers love the number 2 right? Well that's true for every calculation that happens inside the GPU too so using lots of numbers that are related to powers of 2 speeds things up a bit too!

We now literally need to pass over some of those weird numbers in our codebase and change them to nice powers of 2.

The first important number that looks weird is the vocab size of 50257, which is an odd number and can't be decomposed into nice powers of 2. The easiest way to fix numbers like these is to look to increase the number to the cleanest power of 2 number. In this case, we like the number 50304, which is divisible by 2, 4, 8, 16, 32 and 64! Those numbers sound much more computer like~

In [None]:
torch.set_float32_matmul_precision('high')

model = GPT(GPTConfig(vocab_size=50304))
model.to(device);
# Compile the model
model = torch.compile(model)

This will actually increase the number of flops of calculation that is completed but surpisingly it gives around a 5% increase in the speed! This is because the GPU usually breaks down problems into kernels that align to sizes of 2's, and then come back later after calculating those clean chunks to take the final leftover numbers and create a final few kernels to complete those operations. Hence those ugly numbers end up costing unnecessary time just to balance things out.

The change basically introduces an extra few token embedding slots that aren't used. This means we waste a tiny miniscule amount of space for a big jump in the speed and efficiency of the model.

Additionally, the higher vocab size means that the lm head at the end of the model has more parameters too and predicts a few more dimensions of possible tokens, albeit they most likely won't really be used for anything. The model will basically waste a bit of computation during training to learn that those last few tokens need to always have a probability of 0.

By this point, these optimisations alone have sped up the training by about 11x.

Let's move on to algorithmic and hyperparameter opitimisations that will help us speed things up and increase the performance of the model!

#### Hyperparameter Optimisations

Let's begin by setting the parameters of the Adam optimiser to the same one that was used for GPT3. The default betas for the momentum and adaptive scaling is 0.9 and 0.999 by default, but we will set them to 0.9 and 0.95. The deafult epsilon at the bottom of the step calculation is 1*10^-8 and that is also what is used in the GPT3 paper but we will explicitly mention it anyway.

In [None]:
optimiser = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)

Next we want to perform this thing called global norm clipping which basically adds a max limit to the norm of the gradients for the model. It calculates the norm by taking the squared sum of the gradients of every parameter in the model and then square roots that sum. It then makes sure that the final result is no larger than 1.

The purpose of this is to avoid extremely edge case batches where the batch produces a super high loss which produces a really high gradient and then cause big changes to the models parameters which produces an overall "shock" to the models optimisation. In order to avoid this possibility, the gradient is clipped by the norm to a maximum of 1

In [None]:
steps = 4

for step in range(steps):
    t0 = time.time()
    x, y = training_loader.get_next_batch()
    optimiser.zero_grad()
    # Autocast forward pass parameters to BF16
    # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
    # MPS only supports FP16 not BF16 :(
    with torch.autocast(device_type=device, dtype=torch.float16): 
        logits, loss = model(x, y)
    loss.backward()
    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimiser.step()
    t1 = time.time()
    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # measure throughput of tokens/sec at each step
    tokens_per_sec = (training_loader.B * training_loader.T) / (t1 - t0)
    print(f"Step {step+1}: Loss={loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

Next let's add in a learning rate scheduler. In GPT3 they used a cosine decay learning schedule with warmup. This can be seen in the image below.

<image src="./notebooks/images/cosine decay lr.png" width=500>

Let's implement it. In the GPT3 paper they say that they build the model with an initial learning rate around 6e-4 and then cosine decay down to 10% of that and then continue at that learning rate.

In [None]:
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50
def get_lr(step):
    # Linear warmup for warmup steps
    if step < warmup_steps:
        return max_lr * (step + 1) / warmup_steps
    # Min learning rate after max steps
    if step > max_steps:
        return min_lr
    # Cosine decay for the rest of the steps
    decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    # Coeff starts at 1 and goes to 0
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + (max_lr - min_lr) * coeff

In [None]:
for step in range(steps):
    t0 = time.time()
    x, y = training_loader.get_next_batch()
    optimiser.zero_grad()
    # Autocast forward pass parameters to BF16
    # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
    # MPS only supports FP16 not BF16 :(
    with torch.autocast(device_type=device, dtype=torch.float16): 
        logits, loss = model(x, y)
    loss.backward()
    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # Determine and set the learning rate for the current step
    lr = get_lr(step)
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr
    optimiser.step()
    t1 = time.time()
    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # measure throughput of tokens/sec at each step
    tokens_per_sec = (training_loader.B * training_loader.T) / (t1 - t0)
    print(f"Step {step+1}: Loss={loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

Next, the model actually utilises a bit of weight decay in some of the layers for regularisation. Recall that applying weight decay means that weights are penalised for getting too large, which forces the model to better utilise the entire network during optimisation. This is done for most of the parameters except for some layers like layer norms and biases are left alone. This simplifies down to only applying weight decay to parameter sets that have 2 or more dimensions.

Weight decay is not applied to things like layer norms and biases because it affects their performance negatively since we don't particularly need those elements of the model to be regularised. 

In [None]:
import inspect

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            # dictionary of token embedding weights (weight of token embeddings)
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            # dictionary of positional embedding weights (weight of positional embeddings)
            wpe = nn.Embedding(config.block_size, config.n_embd),
            # Attention blocks
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            # Final Layer norm (since pre-norm doesn't touch final block output)
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        # Linear projection out of the Attention block
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight sharing between wte and lm_head 
        self.transformer.wte.weight = self.lm_head.weight
        
        # Apply weight initialisation
        self.apply(self._init_weights) 
    
    def _init_weights(self, module):
        # Initialise the weights of the model using the same method as in GPT2
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, 'CUSTOM_SCALING_INIT'):
                std = (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            
    def set_optimiser(self, weight_decay, learning_rate, device):
        # Fetch all parameters that require grad
        params = {pn: p for pn, p in self.named_parameters()}
        params = {pn: p for pn, p in params.items() if p.requires_grad}
        
        # Create groups for parameters that require weight decay
        decay_params = [p for n, p in params.items() if p.dim() >= 2]
        no_decay_params = [p for n, p in params.items() if p.dim() < 2]
        grouped_params = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': no_decay_params, 'weight_decay': 0.0}
        ]
        
        num_decay_params = sum(p.numel() for p in decay_params)
        num_no_decay_params = sum(p.numel() for p in no_decay_params)
        print(f"Number of decay params: {num_decay_params}")
        print(f"Number of no decay params: {num_no_decay_params}")
        
        # Create AdamW optimiser with fused AdamW if available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and 'cuda' in device
        print(f"Using fused AdamW: {use_fused}")
        optimiser = torch.optim.AdamW(grouped_params,
                                        lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimiser
    
    def forward(self, text, targets=None):
        # Raw data input of b context length sized sequences of raw text
        B, T = text.size()
        assert T <= self.config.block_size, "Sequence length too high"
        # Create numbers from 0 to T
        pos = torch.arange(0, T, dtype=torch.long, device=text.device)
        # Fetch positional and token embeddings for all the values in text
        pos_emb = self.transformer.wpe(pos) # of size (T, n_embd)
        tok_emb = self.transformer.wte(text) # of size (B, T, n_embd)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # of size (B, T, vocab_size)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            loss = None
        return logits, loss

Notice in the new set_optimiser function that we used the fused AdamW if it is available. What this does is that it applies a big optimisation on the backprop for the parameters by bunching the params in to a larger kernel rather than iterating over each param group and creating a kernel for each of those. This also grants us a 3% boost in performance.

In [None]:
optimiser = model.set_optimiser(weight_decay=0.1, learning_rate=6e-4, device=device)

for step in range(steps):
    t0 = time.time()
    x, y = training_loader.get_next_batch()
    optimiser.zero_grad()
    # Autocast forward pass parameters to BF16
    # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
    # MPS only supports FP16 not BF16 :(
    with torch.autocast(device_type=device, dtype=torch.float16): 
        logits, loss = model(x, y)
    loss.backward()
    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # Determine and set the learning rate for the current step
    lr = get_lr(step)
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr
    optimiser.step()
    t1 = time.time()
    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # measure throughput of tokens/sec at each step
    tokens_per_sec = (training_loader.B * training_loader.T) / (t1 - t0)
    print(f"Step {step+1}: Loss={loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

#### Gradient Accumulation

Now all the hyperparameters we set was selected under the assumption that we would have about 0.5 million tokens per batch. Now that would need us to have nearly 500 token sequences simultaneously in the same batch but even if we use 8xH100's, we wouldn't be able to  fit all that in the GPUs. Hence, we will use this thing called gradient accumulation, which allows us to simulate arbitrarily larger batch sizes by sequentially calculating enough mini batches to reach the target batch size before performing backprop on the accumulated gradient from the mini batches.

Let's begin by rounding up that 500,000 token number to a cleaner number that is nicely divisible. Following that, we'll calculate how many mini-mini batches we will need to accumulate before we reach the required mini batch size.

In [2]:
total_batch_tokens = 524288
B = 16
T = 1024

assert total_batch_tokens % (B*T) == 0, "Total batch size must be divisible by batch size"

grad_accum_steps = total_batch_tokens // (B*T)
print(f"Gradient accumulation steps: {grad_accum_steps}")

Gradient accumulation steps: 32


Next let's fix up our training loop to account for gradient accumulation

In [None]:
for step in range(max_steps):
    t0 = time.time()
    optimiser.zero_grad()
    accumulated_loss = 0.0
    # Accumulate gradient prior to optimiser step
    for mini_step in range(grad_accum_steps):
        x, y = training_loader.get_next_batch()
        # Autocast forward pass parameters to BF16 (Enable when CUDA)
        # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
        # MPS only supports FP16 not BF16 :(
        with torch.autocast(device_type=device, dtype=torch.float16): 
            logits, loss = model(x, y)

        # Ensure that loss normalisation is still present after accumulation
        loss = loss / grad_accum_steps
        accumulated_loss += loss.detach()
        loss.backward()

    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Determine and set the learning rate for the current step
    lr = get_lr(step)
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr
    optimiser.step()
    t1 = time.time()

    # measure time taken for each step
    dt = (t1 - t0) * 1000
    tokens_processed = training_loader.B * training_loader.T * grad_accum_steps
    # measure throughput of tokens/sec at each step
    tokens_per_sec = tokens_processed / dt
    print(f"Step {step+1}: Loss={accumulated_loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

Note that the loss is divided by the number of accumulation steps we complete. This is because we usually divide the loss for each step by the number of examples in the batch so that the loss is like batch_loss = 1/batch_size * sum(loss(example_i)). This means that when we accumulate gradients, this sum gets lost, so we have to manually add it back in.

Additionally we added a few variables that will help us track the loss, tok/s and stuff properly over each mini batch of 0.5M tokens.

#### GPU Parallelisation
Time for the big weapons of optimisation. We are gonnna begin with GPU parrallelisation. When we train, we will hire out 8 H100's so we will be able to utilise the entirety of those GPUs to train our model. We will basically make each of the 8 GPUs take a batch of data and compute a batch loss on them. This loss will then be averaged over the 8 GPUs to find the final gradient with which the optimiser step is taken. Let's add the code that will setup the Distributed Data Parallel (DDP) environment if it is available.

In a DDP environment, the rank is the identifier of which GPU in the current environment is being used, while the World Size describes how many GPU's are availaable, and hence the upper limit of what the rank of a DDP environment can be. If we had multiple nodes, we would have also had a local_rank to descripe which node is being referred to.

In [None]:
import os
from torch.distributed import init_process_group

ddp = int(os.environ.get('RANK', -1)) != -1 # Check if multiple GPUs are available
if ddp:
    assert torch.cuda.is_available(), "DDP requires CUDA"
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ('RANK'))
    ddp_world_size = int(os.environ('WORLD_SIZE'))
    ddp_local_rank = int(os.environ('LOCAL_RANK'))
    device = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
    print(f"Running with DDP on {ddp_world_size} GPUs")
else:
    ddp_rank = 0
    ddp_world_size = 1
    ddp_local_rank = 0
    master_process = True
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"Running on device: {device}")


In each GPU instance, at each rank, the GPU will run the exact same piece of code, have it's own identical Python interpretor, running its own training loop.

Since each GPU will run one mini- mini batch, we will need to readjust our gradient accumulation steps too

In [None]:
batch_size, context_length = 8, 1024

# Set up gradient accumulation
total_batch_tokens = 524288

assert total_batch_tokens % (batch_size * context_length * ddp_world_size) == 0, "Total batch size must be divisible by mini batch size"

grad_accum_steps = total_batch_tokens // (batch_size*context_length)

Right now with 8 GPU's this is about 4 accumulating mini batches to reach the full batch size.

Now since there will be a parrallel process running on each GPU, we need to fix up our data loader to make sure that it does not fetch the exact same data batch for each of the process, since we have manual seeding right now.

We will basically just offset the data loader by 1 based on the process number so that each process is like 1 load ahead of it's previous process. We move in chunks but we still traverse the data the same way since our seeding between each process is unchanged.

In [None]:
class DataLoader:
    def __init__(self, B, T, process_rank, num_processes):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes

        enc = tiktoken.get_encoding('gpt2')
        with open('data/sample_input.txt', 'r') as f:
            text = f.read()
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        self.tokens = self.tokens.to(device)
        print("num tokens in dataset:", len(self.tokens))
        print("batches to 1 epoch:", len(self.tokens) // (self.B * self.T))

        # Calculate the starting index based on offset using process rank
        self.current_index = self.B * self.T * self.process_rank

    def get_next_batch(self):
        buf = self.tokens[self.current_index : self.current_index + self.B * self.T + 1]
        x = buf[:-1].view(self.B, self.T)
        y = buf[1:].view(self.B, self.T)
        
        # Update the current index for the next batch based on process
        self.current_index += self.B * self.T * self.process_rank

        # Reset when out of bounds
        if self.current_index + (self.B * self.T * self.num_processes + 1) > len(self.tokens):
            self.current_index = self.B * self.T * self.process_rank
        return x, y

Next we need to wrap our model in the ddp container. This container doesn't do anything different in the forward pass but in the backward pass, it calculates an average gradient for all the parameters across all the GPUs and then deposits that average to all the GPU's.

Infact, the DDP manager actually also dispatches averages for parameters from the other GPUs/ ranks while they are still computing the backward pass. So if one layer has been completed across all GPUs, it may start to dispatch the average gradients for those layers without waiting for every rank to complete their backward pass. 

In [None]:
from torch.nn.parallel import DistributedDataParallel as DDP

model = GPT(GPTConfig(vocab_size=50304))
model.to(device);
# Compile the model (on GPU) to remove Python overhead
# model = torch.compile(model)

if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model

Since we wrap the model with a DDP manager, we will need to make sure that the optimiser makes changes to the raw model inside the DDP wrapped model.

In [None]:
optimiser = raw_model.set_optimiser(weight_decay=0.1, learning_rate=6e-4, device=device)

Now with the dispatching averages, with the current setup, it will dispatch and synchronise across the ranks on every loss.backward() but we don't need syncing that frequently. Doing the sync once on the very last mini batch is enough. Because of this, we will enable a special flag on the loss.backward() that synchronises the gradients only during the final mini batch and disabled at all other times.

Note that we also need to fix up our accumulated loss tracker and the tokens per second to account for all the other GPU's

In [None]:
for step in range(max_steps):
    t0 = time.time()
    optimiser.zero_grad()
    accumulated_loss = 0.0
    # Accumulate gradient prior to optimiser step
    for mini_step in range(grad_accum_steps):
        x, y = training_loader.get_next_batch()
        # Autocast forward pass parameters to BF16 (Enable when CUDA)
        # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
        # MPS only supports FP16 not BF16 :(
        with torch.autocast(device_type=device, dtype=torch.float16): 
            logits, loss = model(x, y)

        # Ensure that loss normalisation is still present after accumulation
        loss = loss / grad_accum_steps
        accumulated_loss += loss.detach()
        # Only sync gradients on the last mini batch
        if ddp:
            model.require_backward_grad_sync = (mini_step == grad_accum_steps - 1)
        loss.backward()
    # Ensure that the displayed loss is averaged across all the GPUs
    if ddp:
        torch.distributed.all_reduce(accumulated_loss, op=torch.distributed.ReduceOp.AVG) 

    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Determine and set the learning rate for the current step
    lr = get_lr(step)
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr
    optimiser.step()
    t1 = time.time()

    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # Calculate tokens processed across all GPUs
    tokens_processed = training_loader.B * training_loader.T * grad_accum_steps * ddp_world_size
    # measure throughput of tokens/sec at each step
    tokens_per_sec = tokens_processed / dt
    print(f"Step {step+1}: Loss={accumulated_loss.item():.4f} Time={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

Now with that, we acutally have the optimised model ready for hardcore training.

I have already copied a fineweb dataset fetcher file. In order to train our model, we just need to run the optimised-gpt2-train.py file in a GPU environment and we're good to go.

We need to make a couple of changes to our data loading process to now be able to work with the shards of data we have. We also additionally add a reset function so that we can run validation on the model during training

In [None]:
import numpy as np

def load_tokens(filename):
    np_tensor = np.load(filename)
    pt_tensor = torch.tensor(np_tensor, dtype=torch.long)
    return pt_tensor

class DataLoader:
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in ["train", "val"]
        
        # Get the shard filenames
        data_root = "edu_fineweb10B"
        shards = os.listdir(data_root)
        shards = [shard for shard in shards if split in shard]
        shards = sorted(shards)
        shards = [os.path.join(data_root, shard) for shard in shards]
        self.shards = shards
        
        assert len(self.shards) > 0, "No shards found"
        if master_process:
            print(f"Found {len(self.shards)} shards for {split} split")

        self.reset()

    def get_next_batch(self):
        buf = self.tokens[self.current_index : self.current_index + self.B * self.T + 1]
        x = buf[:-1].view(self.B, self.T)
        y = buf[1:].view(self.B, self.T)
        
        # Update the current index for the next batch based on process
        self.current_index += self.B * self.T * self.num_processes

        # Reset when out of bounds
        if self.current_index + (self.B * self.T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_index = self.B * self.T * self.process_rank
        return x, y
    
    def reset(self):
        # State init at shard 0
        self.current_shard = 0
        # Calculate the starting index based on offset using process rank
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_index = self.B * self.T * self.process_rank


In [None]:
training_loader = DataLoader(batch_size, context_length, ddp_rank, ddp_world_size, split="train")
validation_loader = DataLoader(batch_size, context_length, ddp_rank, ddp_world_size, split="val")

And then let's fix up our training loop to do validation while training.

In [None]:

for step in range(max_steps):
    t0 = time.time()
    
    # Validate every 100 steps
    if step % 100 == 0:
        model.eval()
        validation_loader.reset()
        with torch.no_grad():
            accumulated_val_loss = 0.0
            val_loss_steps = 20
            for _ in range(val_loss_steps):
                x, y = validation_loader.get_next_batch()
                with torch.autocast(device_type=device, dtype=torch.bfloat16): 
                    logits, loss = model(x, y)
                loss = loss / val_loss_steps
                accumulated_val_loss += loss.detach()
        if ddp:
            torch.distributed.all_reduce(accumulated_val_loss, op=torch.distributed.ReduceOp.AVG)
        if master_process:
            print(f"Validation loss: {accumulated_val_loss.item():.4f}")
            
    # Training loop
    model.train()
    optimiser.zero_grad()
    accumulated_loss = 0.0
    # Accumulate gradient prior to optimiser step
    for mini_step in range(grad_accum_steps):
        x, y = training_loader.get_next_batch()
        # Autocast forward pass parameters to BF16 (Enable when CUDA) !!!
        # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
        # MPS only supports FP16 not BF16 :(
        with torch.autocast(device_type=device, dtype=torch.float16): 
            logits, loss = model(x, y)

        # Ensure that loss normalisation is still present after accumulation
        loss = loss / grad_accum_steps
        accumulated_loss += loss.detach()
        # Only sync gradients on the last mini batch
        if ddp:
            model.require_backward_grad_sync = (mini_step == grad_accum_steps - 1)
        loss.backward()
    # Ensure that the displayed loss is averaged across all the GPUs
    if ddp:
        torch.distributed.all_reduce(accumulated_loss, op=torch.distributed.ReduceOp.AVG) 

    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Determine and set the learning rate for the current step
    lr = get_lr(step)
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr
    optimiser.step()
    t1 = time.time()

    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # Calculate tokens processed across all GPUs
    tokens_processed = training_loader.B * training_loader.T * grad_accum_steps * ddp_world_size
    # measure throughput of tokens/sec at each step
    tokens_per_sec = tokens_processed / dt
    print(f"Step {step+1}: loss={accumulated_loss.item():.4f} norm={norm:.4f} t={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")

Next we'll set up the model configs to match the same ratio as in GPT3. We're shooting for same ratio because our training is on way smaller dataset.

In [None]:
# Set up learning rate scheduler
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 715 # Same ratio of warmup as GPT3
max_steps = 19073 # Exactly 1 epoch over the 10B token dataset

While we're at it, let's also make it so our model samples a little bit every now and then too so that we can visually see how the generation improves as the training proceeds.

In [None]:
enc = tiktoken.get_encoding("gpt2")

for step in range(max_steps):
    t0 = time.time()
    
    # Validate every 100 steps
    if step % 100 == 0:
        model.eval()
        validation_loader.reset()
        with torch.no_grad():
            accumulated_val_loss = 0.0
            val_loss_steps = 20
            for _ in range(val_loss_steps):
                x, y = validation_loader.get_next_batch()
                with torch.autocast(device_type=device, dtype=torch.bfloat16): 
                    logits, loss = model(x, y)
                loss = loss / val_loss_steps
                accumulated_val_loss += loss.detach()
        if ddp:
            torch.distributed.all_reduce(accumulated_val_loss, op=torch.distributed.ReduceOp.AVG)
        if master_process:
            print(f"Validation loss: {accumulated_val_loss.item():.4f}")
    
    # Generate every 250 steps
    if ((step > 0 and step % 250 == 0) or step == max_steps - 1):
        model.eval()
        num_return_sequences = 4
        max_length = 32
        tokens = enc.encode("Hello, I'm a language model,")
        tokens = torch.tensor(tokens, dtype=torch.long)
        tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
        xgen = tokens.to(device)
        sample_rng = torch.Generator(device=device)
        sample_rng.manual_seed(42 + ddp_rank)
        while xgen.size(1) < max_length:
            # forward the model to get the logits
            with torch.no_grad():
                with torch.autocast(device_type=device, dtype=torch.bfloat16):
                    logits, loss = model(xgen) # (B, T, vocab_size)
                # take the logits at the last position
                logits = logits[:, -1, :] # (B, vocab_size)
                # get the probabilities
                probs = F.softmax(logits, dim=-1)
                # do top-k sampling of 50 (huggingface pipeline default)
                # topk_probs here becomes (5, 50), topk_indices is (5, 50)
                topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
                # select a token from the top-k probabilities
                # note: multinomial does not demand the input to sum to 1
                ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
                # gather the corresponding indices
                xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
                # append to the sequence
                xgen = torch.cat((xgen, xcol), dim=1)
        # print the generated text
        for i in range(num_return_sequences):
            tokens = xgen[i, :max_length].tolist()
            decoded = enc.decode(tokens)
            print(f"rank {ddp_rank} sample {i}: {decoded}")
    
    # Training loop
    model.train()
    optimiser.zero_grad()
    accumulated_loss = 0.0
    # Accumulate gradient prior to optimiser step
    for mini_step in range(grad_accum_steps):
        x, y = training_loader.get_next_batch()
        # Autocast forward pass parameters to BF16 (Enable when CUDA) !!!
        # with torch.autocast(device_type=device, dtype=torch.bfloat16): 
        # MPS only supports FP16 not BF16 :(
        with torch.autocast(device_type=device, dtype=torch.float16): 
            logits, loss = model(x, y)

        # Ensure that loss normalisation is still present after accumulation
        loss = loss / grad_accum_steps
        accumulated_loss += loss.detach()
        # Only sync gradients on the last mini batch
        if ddp:
            model.require_backward_grad_sync = (mini_step == grad_accum_steps - 1)
        loss.backward()
    # Ensure that the displayed loss is averaged across all the GPUs
    if ddp:
        torch.distributed.all_reduce(accumulated_loss, op=torch.distributed.ReduceOp.AVG) 

    # Clip the gradient norms to a maximum of 1
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Determine and set the learning rate for the current step
    lr = get_lr(step)
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr
    optimiser.step()
    t1 = time.time()

    # measure time taken for each step
    dt = (t1 - t0) * 1000
    # Calculate tokens processed across all GPUs
    tokens_processed = training_loader.B * training_loader.T * grad_accum_steps * ddp_world_size
    # measure throughput of tokens/sec at each step
    tokens_per_sec = tokens_processed / dt
    if master_process:
        print(f"Step {step+1}: loss={accumulated_loss.item():.4f} norm={norm:.4f} t={dt:.2f}ms tok/sec={tokens_per_sec:.2f}")
        with open(log_file, "a") as f:
            f.write(f"{step} train {accumulated_loss.item():.6f}\n")