In [None]:
!pip install accelerate -U
!pip install numpy==2.0.1

!pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0
!pip install dataclasses==0.6

In [None]:
"""
Reference code for GPT-2 training and inference.
Will save the model weights into files, to be read from C as initialization.

References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

Example launches to only benchmark the speed of bfloat16 compiled GPU training:
1 GPU:
python train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16
you can also turn on flash-attention by appending --flash=1
4 GPU:
torchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16
"""

In [None]:
import os
import math
import glob
import struct
import inspect
from contextlib import nullcontext
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.optim import ZeroRedundancyOptimizer
import torch.distributed as dist

In [None]:
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model

In [None]:
class NewGELU(nn.Module):
    """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" # defines GeLU activation function, tanh version.
    def forward(self, input):
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

In [None]:
# using a global to toggle flash-attention

# The global variable FLASH is set to 0, which means that by default, the custom flash-attention mechanism is turned off. Flash attention is a type of efficient
# attention mechanism that can handle longer sequences more effectively by reducing memory consumption and computational requirements. This would be toggled or
# used in specific parts of the attention calculations in your model depending on the requirements or experiments you are conducting.
FLASH = 0

## TIL notes for self attention

Certainly! Here’s a brief overview of embeddings and attention heads, their roles in the model, and why they are used:

### Embeddings
**What are they?**
Embeddings are dense vector representations of data, typically words or tokens, that capture semantic meanings and relationships in a way that is usable by machine learning models. In the context of models like GPT-2, embeddings transform discrete input tokens (e.g., words) into continuous vectors.

**Where are they in the pipeline?**
In transformer models, embeddings are usually the first layer. For GPT-2, this includes:
- **Token Embeddings (`wte`)**: Convert each token into a vector.
- **Positional Embeddings (`wpe`)**: Encode the position of each token in the sequence, allowing the model to understand the order of tokens.

**What do they do and why?**
Embeddings map tokens to a high-dimensional space where similar tokens are closer together, facilitating the model's understanding of language patterns and context. The positional embeddings ensure the model can consider the order of words, crucial for generating coherent text and understanding language structure.

### Attention Heads
**What are they?**
Attention heads are components of the attention mechanism within each transformer layer. They allow the model to focus on different parts of the input sequence when predicting each output token.

**Where are they in the pipeline?**
Attention heads are part of the self-attention layers in the transformer blocks. Each head computes its own attention scores independently, allowing the model to capture various aspects of the data.

**What do they do and why?**
Each attention head:
- Computes **queries**, **keys**, and **values** from the input embeddings using linear transformations.
- Uses the **queries** to assess the importance (or relevance) of other parts of the sequence (the **keys**) and aggregates the corresponding **values** based on these computed relevancies.

This mechanism enables the model to dynamically decide which parts of the input are important for each output token, allowing for contextually aware generation of text. The use of multiple heads increases the model's capacity to focus on different subspaces of the embedding dimension, capturing a wider array of relationships within the data.

### Summary
In the pipeline, embeddings serve as the initial transformation layer, translating raw input into a format suitable for processing by deeper network layers. Attention heads come into play within the self-attention mechanism, which is fundamental to transformers' ability to handle sequences of data. They allow the model to selectively focus on different segments of the input data, enhancing its ability to make contextually informed predictions.

query: i dont get how theres more embedding splitted across a number of head etc. is this due to gpu parallelism? and i get that embedding = token embd + positional embd.

but, how is the tokens splitted across the attention heads? i cant visualise it. can you show me an example of how this works. thx

---

Certainly! Let’s dive deeper into how embeddings are used within attention heads and the concept of splitting across multiple heads, which can indeed seem a bit abstract at first.

### Embeddings Recap
As you mentioned, the embedding for each token in a transformer model like GPT-2 is the sum of:
- **Token Embedding**: Converts each token (like a word) into a vector.
- **Positional Embedding**: Adds information about the position of each token in the sequence.

This results in a combined embedding that conveys both the meaning of the token and its position in the sequence.

### Attention Heads: Splitting the Embeddings
The purpose of having multiple attention heads is to allow the model to attend to different parts of the input sequence from different 'perspectives' or 'subspaces'. Here’s how it works step-by-step:

1. **Starting Point**: Assume each token in the sequence has been converted to a combined embedding vector (from token and positional embeddings). Suppose each vector has a dimensionality \(C\).

2. **Linear Projections**: Each attention head will transform the embedding vector into three different vectors: Queries (Q), Keys (K), and Values (V). These transformations are done through learned linear projections (typically implemented as a single linear layer split into three parts).

3. **Splitting Across Heads**:
   - For a transformer with multiple heads, the dimension \(C\) of each embedding vector is split into \(N\) parts, where \(N\) is the number of heads. Each part handles a fraction of the total dimension. For instance, if \(C = 768\) and \(N = 12\) (like in GPT-2), each head deals with a sub-vector of size 64.
   - Each head processes its portion of the vector independently. This means every head will produce its own Q, K, and V vectors from different subspaces of the embedding.

### Visualization Example
Let's visualize this with an example where \(C = 6\) (for simplicity) and \(N = 3\) heads. Each head will handle 2 dimensions:

- **Original Embeddings for a Sequence** (3 tokens):
  \[
  \begin{align*}
  \text{Token 1:} & \ [e_{11}, e_{12}, e_{13}, e_{14}, e_{15}, e_{16}] \\
  \text{Token 2:} & \ [e_{21}, e_{22}, e_{23}, e_{24}, e_{25}, e_{26}] \\
  \text{Token 3:} & \ [e_{31}, e_{32}, e_{33}, e_{34}, e_{35}, e_{36}]
  \end{align*}
  \]

- **Split Across 3 Heads**:
  - **Head 1** processes dimensions 1 and 2 of each token's embedding:
    - \( Q_1, K_1, V_1 \) from \([e_{11}, e_{12}]\), \([e_{21}, e_{22}]\), \([e_{31}, e_{32}]\)
  - **Head 2** processes dimensions 3 and 4:
    - \( Q_2, K_2, V_2 \) from \([e_{13}, e_{14}]\), \([e_{23}, e_{24}]\), \([e_{33}, e_{34}]\)
  - **Head 3** processes dimensions 5 and 6:
    - \( Q_3, K_3, V_3 \) from \([e_{15}, e_{16}]\), \([e_{25}, e_{26}]\), \([e_{35}, e_{36}]\)

Each head computes attention independently, focusing on different aspects (or 'slices') of the embedding space. This allows the model to integrate diverse perspectives when synthesizing the output at each step.

### GPU Parallelism
The splitting across heads is indeed highly conducive to parallel processing on GPUs. Each head operates independently of the others, which means their computations can be parallelized effectively, leading to significant performance gains especially with high-dimensional embeddings and long sequences.

This setup not only speeds up computations but also enriches the model's capacity to learn varied dependencies across the input data, leveraging multiple representational strategies simultaneously.

query: wait you sure this is real? that the initial token is splited across all 3 instead of Q, K, V all using the same input?


---

You've raised a great point for clarification. Let's refine the explanation regarding how the embeddings are processed through the attention mechanism, particularly in the context of dividing the input across multiple heads for computing queries (Q), keys (K), and values (V).

### How Embeddings Are Handled in Attention Heads

In transformer models like GPT-2, each head does **not** process only a specific part of the input token's embedding vector to produce its own Q, K, and V. Rather, the entire embedding vector is used by each head to generate its own Q, K, and V. However, each head transforms the embeddings into different representational subspaces, so while they all start with the same input, the projections they learn and the subsequent focus of their attention can differ significantly.

### Corrected Process:

1. **Input Embedding**: Every token in the sequence is first converted into a dense vector representation, which is the sum of its token embedding and its positional embedding.

2. **Projection in Attention Heads**:
   - Each attention head applies three distinct linear transformations to the entire embedding vector of each token, generating its own set of Q, K, and V:
     - `Q = W_Q * embedding`
     - `K = W_K * embedding`
     - `V = W_V * embedding`
   - These transformations are usually implemented as a single large linear layer (part of `self.c_attn` in your code) that outputs three times the embedding dimension, which is then split into Q, K, and V for each head.

3. **Dimensionality**:
   - If the model’s embedding size \(C\) is 768 and there are 12 heads, each head will output Q, K, and V vectors that are each \(768/12 = 64\) dimensions. Thus, each head transforms the full embedding vector into smaller, 64-dimensional Q, K, and V vectors.

4. **Independent Attention Calculations**:
   - Each head computes attention scores independently based on its Q, K, and V. The scores determine how much each element of the sequence (values) should be considered for each position in the output sequence.
   - This allows the model to focus on different parts of the sequence differently in each head, capturing various types of relationships in the data (e.g., syntactic, semantic).

### Visualization Example Revisited:

- **Given**: A single token's embedding vector is `[e1, e2, ..., e768]`.
- **In Each Head**: The entire vector is used but transformed differently:
  - **Head 1** produces:
    - `Q1 = W_Q1 * [e1, ..., e768]`
    - `K1 = W_K1 * [e1, ..., e768]`
    - `V1 = W_V1 * [e1, ..., e768]`
  - **Head 2** through **Head 12** similarly use the entire vector but project it using different weights.

By using the whole embedding vector in each head, the model can maintain a comprehensive understanding of each token while still allowing different heads to "specialize" in different types of relationships or features within the data. This design leverages the parallel processing power of GPUs effectively and is fundamental to the model's ability to generate nuanced and contextually appropriate outputs.

Query: then what happens to the separate 12 heads in the end? do they just dot product each other?


---

The outputs from the 12 attention heads in a transformer model like GPT-2 don't simply undergo a dot product with each other. Instead, there's a more structured way of integrating these multiple streams of outputs to enhance the model's ability to capture diverse aspects of the information. Here's what happens step-by-step after each of the attention heads processes its input:

### Combining the Outputs of Attention Heads

1. **Independent Attention Computations**:
   - Each head computes its own queries (Q), keys (K), and values (V).
   - The attention mechanism for each head calculates a set of attention scores based on the dot product of queries and keys, which is then used to weight the values. This results in an output vector from each head that emphasizes different parts of the input sequence according to the head's learned attention pattern.

2. **Concatenation of Head Outputs**:
   - The output vectors from each head are concatenated back together. If each head produces a vector of dimension \( \frac{C}{N} \) (where \( C \) is the embedding size and \( N \) is the number of heads), the concatenation of these vectors restores the original dimensionality \( C \). For example, if \( C = 768 \) and \( N = 12 \), each head outputs a vector of size 64, and concatenating these vectors results in a vector of size 768.

3. **Final Linear Projection**:
   - The concatenated vector is then passed through another learned linear projection (often represented by `self.c_proj` in your code). This step is crucial because it integrates the information from all the heads into a single output vector per token. This final projection layer combines the diverse representations produced by different heads into a coherent output that can be used in subsequent layers or for generating predictions.
   - This projection not only serves to mix the information from different heads but also maps the concatenated output back to the desired embedding dimensionality, preparing it for the next layer or the output layer of the model.

### Why This Process?

This multi-headed approach allows the model to simultaneously focus on different types of relationships in the data—from very local (like specific syntactic ties) to more global (like thematic elements or long-range dependencies). By recombining the outputs, the model can leverage these diverse insights effectively, ensuring that no single aspect dominates unless it is contextually relevant.

This design is one of the reasons transformers are particularly effective for tasks involving complex dependencies within data, such as natural language understanding and generation, where different parts of a sequence can have various types of influence on each other.

### Visualization

Imagine you have a sequence of three words processed by a transformer model with three attention heads. Each head processes the sequence and focuses on different relationships. The outputs from these heads might look something like this for each word:

- **Head 1 Output**: [Focus on next word relationships]
- **Head 2 Output**: [Focus on whole sentence theme]
- **Head 3 Output**: [Focus on grammatical structure]

These vectors are then concatenated and passed through a final projection, synthesizing a comprehensive representation that balances all these perspectives. This processed output is what feeds into subsequent layers or contributes to the final task-specific computations, such as predicting the next word in a sequence.

## Back to coding

In [None]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0 # starts by checking that embedding size "n_embd" can be divisible by number of attention heads "n_heads". Crucial for equally
                                                  # splitting the attention computations across the heads.

                                                  # TIL, embeddings are splitted across number of heads. Each processes their own Q, K, V then all the dimensions once processed
                                                  # are then concatenated back into a final vector.

        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 # // CONFIRMED # treated as residual.
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        qkv = self.c_attn(x) # projects into qkv from input
        q, k, v = qkv.split(self.n_embd, dim=2) # splits it into 3 tensors
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        if FLASH:
            # flashattention
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            # manual implementation of attention
            # this materializes the large (T,T) matrix for all the queries and keys
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # math.sqrt() used here instead of **-0.5 for the attention(q, k, v) equation
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

## Notes

query: ah yeah here we go it is. any components with "LLMC_RESIDUAL_SCALE_FLAG = 1" will be treated like a residual. does this changes anyt?

---

Thanks for sharing the weight initialization function. The inclusion of the `LLMC_RESIDUAL_SCALE_FLAG` in the initialization logic provides more context on how these specific layers are treated differently, especially in terms of initialization. Here's how this affects the overall behavior and processing in the model:

### Understanding `LLMC_RESIDUAL_SCALE_FLAG`

In your code, the `LLMC_RESIDUAL_SCALE_FLAG` is checked during the initialization of weights for layers within your model. Here’s how it impacts the initialization process:

1. **Scaled Initialization**: If a layer (such as `self.c_proj` in your MLP module) is flagged with `LLMC_RESIDUAL_SCALE_FLAG`, the standard deviation for initializing weights is adjusted. Specifically, the standard deviation is set to `0.02 / math.sqrt(2 * self.config.n_layer)`, which scales down the variance of the initial weights. This scaled initialization helps in controlling the magnitude of outputs from these layers, which is particularly important for layers involved in residual connections. By scaling down the weights:
   - **Enhances Training Stability**: Smaller weights in layers that are part of a residual pathway can help prevent exploding gradients, especially in deeper networks.
   - **Facilitates Learning**: The scale factor rooted in the number of layers (`2 * self.config.n_layer`) suggests a deeper network requires more delicate handling of weight magnitudes to facilitate effective learning and information flow.

2. **Purpose of Scaling**:
   - In deep learning, especially with networks that use residual connections, it’s crucial to maintain a balance in the flow of gradients during backpropagation. Scaled initialization as applied here ensures that the updates to weights are not disproportionately large, which can destabilize learning, especially in the early phases.

3. **Impact on Model Behavior**:
   - This tailored initialization approach is crucial for ensuring that the contributions from the residual connections are effectively integrated without overwhelming the network’s learning process. It subtly tunes the network to leverage the benefits of deep architectures (more layers) while mitigating associated risks like vanishing or exploding gradients.

### Broader Implications

By using the `LLMC_RESIDUAL_SCALE_FLAG`, your model effectively acknowledges and addresses the unique challenges posed by residual connections in deep networks. This specialized treatment indicates a thoughtful approach to network design, where different components are initialized and scaled based on their role and impact on the overall architecture. It shows an advanced understanding of deep learning dynamics, specifically within the context of Transformers where depth and complexity can significantly influence performance.

### Conclusion

Incorporating such flags and customized initialization schemes directly ties into overall network robustness and efficacy. It aligns well with best practices for training deep neural networks, ensuring that each component is optimized not just for function but also for harmonious integration into the larger system. This modification does indeed change how the model will learn and stabilize through training iterations, particularly enhancing its ability to scale to deep configurations effectively.

query: is MLP just a different name for feedforward? if not what does it do and what is it for


---

Yes, within the context of neural network architectures, particularly in transformers, the term "MLP" (Multi-Layer Perceptron) is often used to refer to a specific component known as a feedforward neural network. However, it's important to clarify the traditional and specific uses of these terms:

### Traditional Definitions:
- **MLP (Multi-Layer Perceptron)**: Traditionally, an MLP is a type of neural network that consists of multiple layers of perceptrons (basic neural network units) with non-linear activation functions, except for the input nodes. It usually includes one or more hidden layers between the input and output layers. MLPs are capable of learning non-linear models and are used widely across various types of regression and classification problems.

- **Feedforward Neural Network**: This is a broader category of neural networks where connections between the nodes do not form a cycle. This means the information moves in only one direction—forward—from the input nodes, through the hidden nodes (if any), and to the output nodes. There are no cycles or loops in the network. MLPs are a subset of feedforward neural networks.

### In the Context of Transformers:
- **MLP as Feedforward Network**: In transformer models, the term "MLP" is used specifically to describe the feedforward neural network component that follows the self-attention mechanism in each layer of the transformer. This MLP usually has two linear layers:
  - The first linear layer expands the dimensionality of the input, often by a factor (such as 4 times the size of the model’s embedding dimension).
  - This is followed by a non-linear activation function (like ReLU or GELU).
  - The second linear layer then projects the output back to the original embedding dimension.

### Role and Function of MLP in Transformers:
- **Integration and Transformation**: The MLP component in transformers serves to further process each position of the input sequence independently of the others. After the self-attention layer aggregates information across the input sequence, taking into account the relationship between different positions, the MLP is used to integrate and transform this information at each position separately.
  
- **Complexity and Non-linearity**: The expansion and subsequent contraction of dimensions, combined with the non-linear activation, allow the MLP to capture complex patterns in the data. This is crucial for tasks that require understanding intricate relationships and dependencies, as often encountered in natural language processing.

- **Enhancement of Model Capacity**: The inclusion of MLPs in each transformer layer significantly increases the model’s capacity (its ability to learn a wide variety of functions). This is essential for the performance on complex tasks that transformers are often employed for, such as machine translation, text generation, and various other forms of sequence-to-sequence modeling.

### Summary:
In summary, while the terms "MLP" and "feedforward neural network" can be used somewhat interchangeably in the context of transformers, an MLP in this setting specifically refers to a two-layer feedforward network with a non-linear activation used to process the output of the self-attention layers within each transformer block. This component plays a crucial role in enhancing the transformer's ability to process and integrate information across the sequence effectively.

query: oh wait a question about the c_proj which is a residual connection. so. in feed forward or MLP for example, if its a residual connection. how is its output now different? like as in, is the entire output of MLP now get residual treatment? or its partially treated as residually and mostly still affects the next layer etc

---

In the context of transformers and MLPs specifically, when we talk about residual connections, it typically means that the output of the MLP is added back to its original input (before the MLP processing), and then this combined output is often normalized (e.g., using layer normalization). This method is a key component in the architecture of transformers and has several important implications for how the network processes information and learns during training.

### How Residual Connections Work in MLP within Transformers:

1. **Processing in MLP**:
   - The input \( x \) to the MLP is first transformed by an expansion layer (increases dimensionality), followed by a non-linear activation (like GELU), and finally a contraction layer (reduces dimensionality back to the original size).
   - Let's denote the output of this process as \( \text{MLP}(x) \).

2. **Adding the Residual**:
   - The output of the MLP \( \text{MLP}(x) \) is then added to the original input \( x \). This operation is defined as:
     \[
     y = x + \text{MLP}(x)
     \]
   - Here, \( y \) represents the residual output, which is the sum of the original input and the transformed output.

3. **Layer Normalization**:
   - After adding the residual, the result is usually normalized:
     \[
     \text{output} = \text{LayerNorm}(y)
     \]
   - This normalization helps stabilize the learning process by ensuring that the outputs of the layers have a mean of zero and a standard deviation of one, reducing the risk of vanishing or exploding gradients.

### Impact of Residual Connections:

- **Preservation of Information**: Residual connections help in preserving the information from the input throughout the layers of the network. By adding the input directly to the output, the network can maintain essential features that might be lost due to transformations within the MLP.

- **Easing the Training of Deep Networks**: They allow gradients to flow directly through the network via the added pathways, facilitating the training of very deep networks by combating problems associated with vanishing gradients.

- **Enhancing Learning Capability**: By combining the original information and the new features learned by the MLP, the network can refine its predictions, leading to potentially better performance and more nuanced understanding.

### How the Output is Affected:

- **Entire Output Treated Residually**: In transformers, typically the entire output of the MLP gets the "residual treatment". This means the whole transformed output is added back to the original input, not just a portion of it.

- **Influence on Subsequent Layers**: The residual output, after normalization, is then passed to the next transformer layer (if any) or used in subsequent computations. This output now contains both the original and newly learned features, affecting how the subsequent layers process the input.

### Summary:
In summary, the residual connection in the MLP of a transformer model ensures that the entire output of the MLP is adjusted by its own input, enhancing the model's ability to leverage both new and existing information across different parts of the model. This approach is fundamental to the design of modern neural architectures like transformers, enabling them to achieve remarkable performance across various tasks.

## Back to Coding

In [None]:
class MLP(nn.Module): # MLP is just another term for feed forward // its a subset of feed forward because unlike conventional feed forward
                      # MLP actually expands to higher dimension before contract REFER TO NOTES ABOVE, which allows more complex relationships to be captured.
                              # note this is just MLP for transformers.

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd) # increases dimension for more information?
        self.gelu    = NewGELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 # treated as residual.

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x) # residual component
        return x

### Purpose and Functionality

This setup is typical of transformer blocks, where each component layer (like self-attention and MLP) is wrapped with a residual connection followed by normalization. This architecture helps in several ways:

- **Enhanced Gradient Flow**: Residual connections allow gradients to flow directly through the network during backpropagation, which can significantly improve training efficiency and enable training of very deep networks.

- **Feature Preservation**: By adding the input directly to the output of processing units (attention and MLP), the network can preserve important features that might otherwise be lost or overly transformed.

- **Capability Enrichment**: Layer normalization and subsequent processing ensure that the network can adaptively scale and shift the input features, refining the model's ability to handle a wide range of input distributions and feature scales.

Overall, the `Block` class represents a fundamental building block of transformer models, encapsulating the essential mechanisms that make transformers powerful for handling sequential data across various tasks such as language modeling, text generation, and more.

In [None]:
class Block(nn.Module): # this is implementaiton of single layer in a transformer.

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x)) # input x is first normalised by layer normalisation, then passed to self.attn for self-attention processing
        x = x + self.mlp(self.ln_2(x)) # result of first residual connection is first normalised by self.ln_2. Then fed into mlp.
        return x

In [None]:
# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

In [None]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__() # initialises base class nn.Module
        self.config = config # stores configuration passed into the constructor

        self.transformer = nn.ModuleDict(dict( # storing various components of GPT model. Each component can now be accessed using a key, just like a python dict.
            wte = nn.Embedding(config.vocab_size, config.n_embd), # maps token indices to token embeddings
            wpe = nn.Embedding(config.block_size, config.n_embd), # provides positional information to tokens
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # list of Nx transformer block, refer to architecture.
            ln_f = nn.LayerNorm(config.n_embd), # final layer normalisation applied to final output from blocks before finally outputting for transformer sequence processing
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # maps n_embd=768 dimension to vocab_size=50257 for generating predictions
                                                                               # bias=False, no bias term here
        self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying <-- this is recommended based on research findings!
        """
        This line ties the weights of the language modeling head to the token embedding weights, which helps reduce the model's parameter count and can improve
        performance by reusing learned representations.
        """

        # init all weights, use a torch rng object to be very careful
        self.init_rng = torch.Generator() # initialises a RNG to ensure controlled randomness in weight initialisation
        self.init_rng.manual_seed(42)
        self.apply(self._init_weights) # applies _init_weights method to all submodules of the model, customising the weights' initialisation based on whether they are part
                                       # of linear layers, embedding layers or other configurations (handling special flags)

    def _init_weights(self, module): # this method is applied to all submodules of the model, as seen with the method above.
        if isinstance(module, nn.Linear):
            # apply special scaled init to the residual projections, per GPT-2 paper
            std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
                  # sets standard deviation for normal distribution used to initialise weights. If has RESIDUALflag, the std is reduced to "0.02/math.sqrt(2 * self.config.n_layer)"
                        # This is recommended by GPT-2 paper, and this scaling helps control the varience of weights in deeper layers, aiding in stabilising training by
                        # mitigating issues like exploding gradients.

            # we want to skip initializing lm_head, which shares parameters with wte
            # and wte was already initialized down below during the Embedding init
            if not hasattr(module, 'LLMC_SKIP_INIT'):
                torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) # skips intialisation if has LLMC_SKIP_INIT flag, else initialise using default values
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias) # if module has bias term, the bias is set to zero.
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) # normal_ just initialises the weights of tensors to Gaussian distribution based on its params

    def forward(self, idx, targets=None, return_logits=True):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        # there are performance reasons why not returning logits is prudent, if not needed
        if not return_logits:
            logits = None

        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print0(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print0(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        print0(f"using fused AdamW: {use_fused}")
        if zero_stage == 1:
            print0("using ZeroRedundancyOptimizer")
            optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW,
                                                lr=learning_rate, betas=betas, fused=use_fused)
            optimizer.add_param_group(optim_groups[1])
        else:
            print0("using regular AdamW")
            optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
        return optimizer

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [None]:
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

In [None]:
def _peek_data_shard(filename):
    # only reads the header, returns header data
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
    if header[0] != 20240520:
        print("ERROR: magic number mismatch in the data .bin file!")
        print("---> HINT: Are you passing in a correct file with --input_bin?")
        print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
        print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
        exit(1)
    assert header[1] == 1, "unsupported version"
    ntok = header[2] # number of tokens (claimed)
    return ntok # for now just return the number of tokens

In [None]:
def _load_data_shard(filename):
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        ntok = header[2] # number of tokens (claimed)
        # the rest of it are tokens, stored as uint16
        tokens = np.frombuffer(f.read(), dtype=np.uint16)
    assert len(tokens) == ntok, "number of tokens read does not match header?"
    return tokens

In [None]:
class DistributedDataLoader:
    def __init__(self, filename_pattern, B, T, process_rank, num_processes):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.B = B
        self.T = T

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= num_processes * B * T + 1
            ntok_total += shard_ntok
        self.ntok_total = ntok_total
        print0(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files")

        # kick things off
        self.current_shard = None
        self.reset()

    def reset(self):
        # we're being a bit clever here: if we already had shard 0 loaded,
        # then don't do the work to reload it, just reset the pointer
        if self.current_shard != 0:
            self.current_shard = 0
            self.tokens = _load_data_shard(self.files[self.current_shard])
        self.current_position = self.process_rank * self.B * self.T

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the start pointer in current shard
        self.current_position += B * T * self.num_processes
        # if loading the next batch would be out of bounds advance the shard
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.advance()
        return x, y

In [None]:
# -----------------------------------------------------------------------------
# Python -> C bridge utilities for saving params/grads/activations to .bin files

In [None]:
def write_fp32(tensor, file):
    t = tensor.detach().cpu().to(torch.float32)
    b = t.numpy().tobytes()
    file.write(b)

In [None]:
def write_bf16(tensor, file):
    t = tensor.detach().cpu().to(torch.bfloat16)
    # numpy doesn't have bf16 datatype so we have to trick it
    t = t.view(torch.int16) # trick: reinterpret as int16
    b = t.numpy().tobytes()
    file.write(b)

In [None]:
def write_tensors(model_tensors, L, file, dtype):
    # writes the GPT-2 model's weights to a binary file
    assert dtype in {"float32", "bfloat16"}
    write_fun = write_fp32 if dtype == "float32" else write_bf16
    write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
    write_fun(model_tensors["transformer.wpe.weight"], file) # (T, C)
    for i in range(L): # (L, C)
        write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
    for i in range(L): # (L, C)
        write_fun(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
    for i in range(L): # (L, 3C, C)
        write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
    for i in range(L): # (L, 3C)
        write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file)
    for i in range(L): # (L, C, C)
        write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
    for i in range(L): # (L, C)
        write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
    for i in range(L): # (L, C)
        write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
    for i in range(L): # (L, C)
        write_fun(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
    for i in range(L): # (L, 4C, C)
        write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
    for i in range(L): # (L, 4C)
        write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file)
    for i in range(L): # (L, C, 4C)
        write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
    for i in range(L): # (L, C)
        write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)
    write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, )
    write_fun(model_tensors["transformer.ln_f.bias"], file) # (C, )

In [None]:
@torch.no_grad()
def pad_vocab(tensor, multiple=128, value=0):
    """
    The dimension of the vocab size in GPT-2 is 50,257
    which is unfortunately a very unfriendly number for a lot of
    matrix operations on the GPU. So we pad it to the nearest
    friendlier multiple, e.g. 50,304 if multiple=128 when we
    export the weights into C land. This is a NOOP algorithmically
    and is only done to make the tensor operations more efficient.
    """
    assert tensor.ndim == 2
    V, C = tensor.shape
    assert V == 50257, "just being defensive here"
    # calculate padded vocab size by rounding up to nearest multiple
    Vp = ((V + multiple - 1) // multiple) * multiple
    # pad the tensor
    pad_rows = Vp - V
    padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value)
    assert padded.shape == (Vp, C)
    return padded

In [None]:
def write_model(model, filename, dtype):
    # everything we need to instantiate the model
    # 1) header is: version int, GPTConfig ints, padding to 1024 bytes
    assert dtype in {"float32", "bfloat16"} # float16 todo maybe later
    version = {
        "float32": 3, # 3: all tensors are fp32, padded vocab
        "bfloat16": 5, # 5: all tensors are bf16, padded vocab
    }[dtype]
    header = torch.zeros(256, dtype=torch.int32)
    header[0] = 20240326 # magic
    header[1] = version # checkpoint version
    header[2] = model.config.block_size
    header[3] = model.config.vocab_size
    header[4] = model.config.n_layer
    header[5] = model.config.n_head
    header[6] = model.config.n_embd
    # 2) the parameters follow the header
    params = {name: param.cpu() for name, param in model.named_parameters()}
    # pad the vocab to a multiple of 128 here at export, for efficiency in C
    wte = params["transformer.wte.weight"] # (V, C)
    wte_padded = pad_vocab(wte) # (Vp, C)
    params["transformer.wte.weight"] = wte_padded # (Vp, C)
    print(f"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}")
    header[7] = wte_padded.size(0) # padded vocab size store in header
    # now write to file
    with open(filename, "wb") as file:
        file.write(header.numpy().tobytes()) # header
        write_tensors(params, model.config.n_layer, file, dtype) # params
    print(f"wrote {filename}")

In [None]:
def write_state(model, x, y, logits, loss, filename):
    # the state is used for debugging.
    # it contains information about the input, logits, loss, and the parameter gradients
    # this can be used for checking the computation correctness in C
    header = torch.zeros(256, dtype=torch.int32)
    header[0] = 20240327 # magic
    header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes)
    header[2] = x.size(0) # batch size of the batch, B
    header[3] = x.size(1) # temporal extent of the batch, T
    grads = {name: param.grad.cpu() for name, param in model.named_parameters()}
    # pad the vocab grads here as well, to mirror write_model
    wte_grad = grads["transformer.wte.weight"] # (V, C)
    wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan?
    grads["transformer.wte.weight"] = wte_grad_padded # (Vp, C)
    print(f"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}")
    with open(filename, "wb") as file:
        # header
        file.write(header.numpy().tobytes())
        # input x
        file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T)
        # targets y
        file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T)
        # logits (result of the model forward pass)
        write_fp32(logits.cpu(), file)
        # loss (single float, result of the cross entropy loss)
        write_fp32(loss.cpu(), file)
        # gradients
        write_tensors(grads, model.config.n_layer, file, "float32")
    print(f"wrote {filename}")

In [None]:
def write_tokenizer(enc, filename):
    n = enc.max_token_value + 1
    header = torch.zeros(256, dtype=torch.int32)
    header[0] = 20240328 # magic
    header[1] = 2 # tokenizer version = 2 (1 -> 2: includes EOT token)
    header[2] = n # number of tokens
    header[3] = enc.eot_token # EOT token
    with open(filename, "wb") as file:
        file.write(header.numpy().tobytes())
        for i in range(n):
            b = enc.decode_bytes([i])
            length = len(b)
            assert length < 256, f"Token length exceeds 255: {length}"
            file.write(struct.pack("<B", length))  # Write the length as a 1-byte unsigned integer
            file.write(b)  # Write the actual bytes
    print(f"wrote {filename}")

In [None]:
# -----------------------------------------------------------------------------
# int main

In [None]:
def print0(*args, **kwargs):
    # modified print that only prints from the master process
    # if this is not a distributed run, it's just a print
    if int(os.environ.get("RANK", 0)) == 0:
        print(*args, **kwargs)

In [None]:
if __name__ == "__main__":
    import time
    import argparse
    import tiktoken
    print0(f"Running pytorch {torch.version.__version__}")

    # default settings will overfit a tiny batch of data
    # and save model weights and debug state to disk on the first iteration
    parser = argparse.ArgumentParser()
    # file system input / output
    parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
    parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
    parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
    parser.add_argument("--model", type=str, default="gpt2", help="gpt2|gpt2-medium|gpt2-large|gpt2-xl|d12|d24|d36|d48")
    # token layout for each step of the optimization
    parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
    parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
    parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens")
    # workload (number of steps)
    parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
    parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
    # optimization
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate warmup iterations")
    parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations")
    parser.add_argument("--learning_rate_decay_frac", type=float, default=1.0, help="learning rate warmup iterations")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay")
    parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude")
    # evaluation
    parser.add_argument("--val_loss_every", type=int, default=0, help="every how mant steps to evaluate val loss?")
    parser.add_argument("--val_max_steps", type=int, default=20, help="how many batches of val to average?")
    parser.add_argument("--sample_every", type=int, default=0, help="how often to sample from the model?")
    # debugging
    parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data")
    # numerics
    parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores")
    # memory management
    parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
    parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
    parser.add_argument("--flash", type=int, default=0, help="use flash attention")
    parser.add_argument("--dtype", type=str, default="float32", help="float32|float16|bfloat16")
    parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
    # python -> C bridge
    parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk")
    args = parser.parse_args()

    # args error checking and convenience variables
    B, T = args.batch_size, args.sequence_length
    assert 1 <= T <= 1024
    assert args.dtype in {"float32", "float16", "bfloat16"}
    assert args.model in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "d12", "d24", "d36", "d48"}

    # set up DDP (distributed data parallel). torchrun sets this env variable
    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
    if ddp:
        # use of DDP atm demands CUDA, we set the device appropriately according to rank
        assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
        init_process_group(backend='nccl')
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = int(os.environ['LOCAL_RANK'])
        ddp_world_size = int(os.environ['WORLD_SIZE'])
        device = f'cuda:{ddp_local_rank}'
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
        seed_offset = 0 # each process gets the exact same seed
        zero_stage = args.zero_stage
    else:
        ddp_rank = 0
        ddp_local_rank = 0
        zero_stage = 0
        ddp_world_size = 1
        master_process = True
        seed_offset = 0
        # select the device
        if args.device:
            # provided explicitly by the user
            device = args.device
        else:
            # attempt to autodetect the device
            device = "cpu"
            if torch.cuda.is_available():
                device = "cuda"
            elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                device = "mps"
    print(f"using device: {device}")
    device_type = 'cuda' if 'cuda' in device else 'cpu'

    # calculate gradient accumulation from the desired total batch size and the current run configuration
    tokens_per_fwdbwd = B * T * ddp_world_size
    assert args.total_batch_size % tokens_per_fwdbwd == 0
    grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd
    print0(f"total desired batch size: {args.total_batch_size}")
    print0(f"=> calculated gradient accumulation steps: {grad_accum_steps}")

    # set up a context manager following the desired dtype and device
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
    ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()

    # rng / reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    # set the torch precision mode to use TensorFloat32 (TF32) for matmuls
    # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
    if args.tensorcores:
        torch.set_float32_matmul_precision('high')

    # turn on/off flash attention
    assert args.flash in {0, 1}
    FLASH = args.flash

    # init (and write) the tokenizer
    enc = tiktoken.get_encoding("gpt2")
    if master_process and args.write_tensors: # tokenizer is technically not tensors but ok
        write_tokenizer(enc, "gpt2_tokenizer.bin")

    # init the model, either from scratch or from OpenAI pretrained checkpoint
    if args.model[0] == "d":
        # from scratch (random weights)
        model_config = {
            "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
            "d24": GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024),
            "d36": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280),
            "d48": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600),
        }[args.model]
        model = GPT(model_config)
    else:
        # load the GPT-2 model weights
        model = GPT.from_pretrained(args.model)
    model.train()
    model.to(device)
    if args.compile:
        if hasattr(config, "coordinate_descent_tuning"):
            config.coordinate_descent_tuning = True # suggested by @Chillee
        print0("compiling the model...")
        model = torch.compile(model)

    # -------------------------------------------------------------------------
    # Our own version of a simple DistributedDataLoader

    # load tokens
    train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
    val_loader = None
    if args.input_val_bin:
        val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)

    # -------------------------------------------------------------------------
    # PyTorch -> C bridge: save some weights and state for C to load later as reference

    # do one forward pass to generate ground truth for our C tests
    if master_process and args.write_tensors and (not args.inference_only):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        logits, loss = model(x, y)
        loss.backward()
        # save model params, in both float32 and bfloat16
        model_to_size = {"gpt2": "124M", "gpt2-medium": "355M", "gpt2-large": "774M", "gpt2-xl": "1558M"}
        model_to_size.update({f"d{d}": f"d{d}" for d in [12, 24, 36, 48]})
        model_size_str = model_to_size[args.model] # e.g. "124M", or "d12"
        write_model(model, f"gpt2_{model_size_str}.bin", dtype="float32")
        write_model(model, f"gpt2_{model_size_str}_bf16.bin", dtype="bfloat16")
        # save x, y, logits, loss, and parameter gradients, for debugging C
        # always store these in fp32 to have an accurate reference (?)
        write_state(model, x, y, logits, loss, f"gpt2_{model_size_str}_debug_state.bin")
        # reset the train_loader for the optimization below
        train_loader.reset()

    # -------------------------------------------------------------------------
    # main training loop

    # here we wrap model into DDP container
    if ddp:
        model = DDP(model, device_ids=[ddp_local_rank])
    raw_model = model.module if ddp else model # always contains the "raw" unwrapped model

    # init the optimizer
    optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay,
                                               learning_rate=args.learning_rate, betas=(0.9, 0.95),
                                               device_type=device, zero_stage=zero_stage)

    # learning rate decay scheduler (cosine with warmup)
    def get_lr(it):
        min_lr = args.learning_rate * args.learning_rate_decay_frac
        # 1) linear warmup for warmup_iters steps
        if it < args.warmup_iters:
            return args.learning_rate * (it+1) / args.warmup_iters
        # 2) if it > lr_decay_iters, return min learning rate
        if it > args.num_iterations:
            return min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (it - args.warmup_iters) / (args.num_iterations - args.warmup_iters)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
        return min_lr + coeff * (args.learning_rate - min_lr)

    # create the logging directory if it does not exist
    logfile = None
    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)
        logfile = os.path.join(args.output_dir, "main.log")
        # create the log file "main.log" inside it, and wipe it clean
        with open(logfile, "w") as f:
            pass

    if device == "cuda":
        torch.cuda.reset_peak_memory_stats()
    timings = []
    norm = -1.0   # dummy value to print in inference-only mode
    for step in range(args.num_iterations + 1):
        t0 = time.time()
        last_step = (step == args.num_iterations)

        # once in a while evaluate the validation dataset
        if (args.val_loss_every > 0 \
            and (step % args.val_loss_every == 0 or last_step)) \
            and (val_loader is not None):
            model.eval()
            val_loader.reset()
            with torch.no_grad():
                val_loss = 0.0
                for _ in range(args.val_max_steps):
                    x, y = val_loader.next_batch()
                    x, y = x.to(device), y.to(device)
                    _, loss = model(x, y, return_logits=False)
                    val_loss += loss.item()
                val_loss /= args.val_max_steps
            # log to console and to file
            print0(f"val loss {val_loss}")
            if master_process and logfile is not None:
                with open(logfile, "a") as f:
                    f.write("s:%d tel:%f\n" % (step, val_loss))

        # once in a while perform model inference on the master process
        if (args.sample_every > 0 \
            and (step % args.sample_every == 0 or last_step)) \
            and master_process:
            model.eval()
            # before we end, let's also do one round of inference
            # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
            start_ids = [enc.eot_token]
            xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
            max_new_tokens = 32
            temperature = 1.0
            top_k = 40
            yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k)
            print0('---------------')
            print0(enc.decode(yg[0].tolist()))
            print0('---------------')

        # bit confusing: we want to make sure to eval and sample on 0th iteration
        # but also after the very last iteration. so we loop for step <= num_iterations
        # instead of just < num_iterations (one extra due to <=), only to do
        # the validation/sampling one last time, and then we break right here as we're done.
        if last_step:
            break

        # --------------- TRAINING SECTION BEGIN -----------------
        model.train()
        optimizer.zero_grad(set_to_none=True)
        # if we are trying to overfit a single batch, we reset the loader here
        if args.overfit_single_batch:
            train_loader.reset()
        # micro-batch loop where we do gradient accumulation to reach desired total batch size
        lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps
        for micro_step in range(grad_accum_steps):
            # fetch a batch
            x, y = train_loader.next_batch()
            x, y = x.to(device), y.to(device)
            if ddp:
                # we want only the last micro-step to sync grads in a DDP model
                # the official way to do this is with model.no_sync(), but that is a
                # context manager that bloats the code, so we just toggle this variable
                model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
            # forward pass
            with ctx:
                _, loss = model(x, y, return_logits=False)
                # we have to scale the loss to account for gradient accumulation,
                # because the gradients just add on each successive backward().
                # addition of gradients corresponds to a SUM in the objective, but
                # instead of a SUM we want MEAN, so we scale the loss here
                loss = loss / grad_accum_steps
                lossf += loss.detach() # keep track of the mean loss
            # backward pass
            if not args.inference_only:
                loss.backward()
        if ddp:
            dist.all_reduce(lossf, op=dist.ReduceOp.AVG)
        lossf = lossf.item()
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        # determine and set the learning rate for this iteration
        lr = get_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        # step the optimizer
        optimizer.step()
        # --------------- TRAINING SECTION END -------------------
        # everything that follows now is just diagnostics, prints, logging, etc.

        # wait on the CPU for all device work to end so we get accurate per-iteration timings below
        if device == "mps":
            torch.mps.synchronize()
        elif device == "cuda":
            torch.cuda.synchronize()
        # time and print
        t1 = time.time()
        # the 0th iteration is often an outlier (much slower) => skip logging it
        tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0)
        print0(f"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)")
        # log to logile
        if master_process and logfile is not None:
            with open(logfile, "a") as f:
                f.write("s:%d trl:%f\n" % (step, lossf))

        # keep track of smooth timings, last 20 iterations
        if step > 0 and step > args.num_iterations - 20:
            timings.append(t1-t0)

    # print the average of the last 20 timings, to get something smooth-ish
    timings = timings[-20:]
    print0(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")
    print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

    # -------------------------------------------------------------------------
    # clean up nice
    if ddp:
        destroy_process_group()

# Below is Torch's official implementation
Module is from ```from .module import Module```

In [None]:
class LayerNorm(Module):
    r"""Applies Layer Normalization over a mini-batch of inputs.

    This layer implements the operation as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
    is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
    is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
    the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
    The standard-deviation is calculated via the biased estimator, equivalent to
    `torch.var(input, unbiased=False)`.

    .. note::
        Unlike Batch Normalization and Instance Normalization, which applies
        scalar scale and bias for each entire channel/plane with the
        :attr:`affine` option, Layer Normalization applies per-element scale and
        bias with :attr:`elementwise_affine`.

    This layer uses statistics computed from input data in both training and
    evaluation modes.

    Args:
        normalized_shape (int or list or torch.Size): input shape from an expected input
            of size

            .. math::
                [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
                    \times \ldots \times \text{normalized\_shape}[-1]]

            If a single integer is used, it is treated as a singleton list, and this module will
            normalize over the last dimension which is expected to be of that specific size.
        eps: a value added to the denominator for numerical stability. Default: 1e-5
        elementwise_affine: a boolean value that when set to ``True``, this module
            has learnable per-element affine parameters initialized to ones (for weights)
            and zeros (for biases). Default: ``True``.
        bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
            :attr:`elementwise_affine` is ``True``). Default: ``True``.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
            The values are initialized to 1.
        bias:   the learnable bias of the module of shape
                :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
                The values are initialized to 0.

    Shape:
        - Input: :math:`(N, *)`
        - Output: :math:`(N, *)` (same shape as input)

    Examples::

        >>> # NLP Example
        >>> batch, sentence_length, embedding_dim = 20, 5, 10
        >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
        >>> layer_norm = nn.LayerNorm(embedding_dim)
        >>> # Activate module
        >>> layer_norm(embedding)
        >>>
        >>> # Image Example
        >>> N, C, H, W = 20, 5, 10, 10
        >>> input = torch.randn(N, C, H, W)
        >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
        >>> # as shown in the image below
        >>> layer_norm = nn.LayerNorm([C, H, W])
        >>> output = layer_norm(input)

    .. image:: ../_static/img/nn/layer_norm.jpg
        :scale: 50 %

    """

    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
    normalized_shape: Tuple[int, ...]
    eps: float
    elementwise_affine: bool

    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 bias: bool = True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            # mypy error: incompatible types in assignment
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
            if bias:
                self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
            else:
                self.register_parameter('bias', None)
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            init.ones_(self.weight)
            if self.bias is not None:
                init.zeros_(self.bias)

    def forward(self, input: Tensor) -> Tensor:
        return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)

    def extra_repr(self) -> str:
        return '{normalized_shape}, eps={eps}, ' \
            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)

Below is the raw implementation of torch.nn.Functional's ```F.layer_norm```

In [None]:
def layer_norm(input: Tensor, normalized_shape: List[int], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5) -> Tensor:
    r"""Apply Layer Normalization for last certain number of dimensions.

    See :class:`~torch.nn.LayerNorm` for details.
    """
    if has_torch_function_variadic(input, weight, bias):
        return handle_torch_function(layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps)

    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)