<a href="https://colab.research.google.com/github/Suraj-Sedai/kv-cache-transformer/blob/main/experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MiniGPT-Inference [KV-CACHE TRANSFORMER]

**A From-Scratch Transformer Inference Engine**

MiniGPT-Inference is a from-scratch, production-grade Transformer inference engine designed to execute autoregressive decoding efficiently using Keyâ€“Value (KV) caching, incremental decoding, and batched generation.

Unlike training-focused implementations, this project centers on inference-time systems engineering, emphasizing:
- Computational complexity reduction
- Memory efficiency
- Deterministic correctness
- Measurable performance gains

The system is architected to reflect how modern large language models (LLMs) are served in real-world environments.

# Project Setup: Model Architecture Configuration

This section outlines the foundational configuration for our model. The `ModelConfig` dataclass is used to define key architectural hyperparameters, centralizing them for clarity, reusability, and ease of modification.

The parameters included in `ModelConfig` are typically found in transformer-based models and include:
*   `vocab_size`: The size of the vocabulary, representing the number of unique tokens the model can process.
*   `n_layers`: The number of transformer layers or blocks within the model's architecture.
*   `n_heads`: The number of attention heads used in the multi-head attention mechanism within each transformer layer.
*   `d_model`: The dimensionality of the model's embeddings and internal representations.
*   `block_size`: The maximum sequence length or context window that the model can process at once.
*   `dropout`: The dropout rate applied for regularization to prevent overfitting.

By using a `dataclass`, we achieve immutability for the configuration once defined (due to `frozen=True`), which helps prevent accidental changes to the model's blueprint during its lifecycle. The `head_dim` property is also derived to ensure `d_model` is divisible by `n_heads`.

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)#prevents accidental mutation
class ModelConfig:
    vocab_size: int
    n_layers: int
    n_heads: int
    d_model: int
    block_size: int
    dropout: float = 0.0

    @property
    def head_dim(self) -> int:
        return self.d_model // self.n_heads

    def __post_init__(self):
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

### Embedding Layers: Token and Positional

Transformer models rely on embedding layers to convert discrete input tokens into continuous vector representations, capturing both semantic meaning and sequential order.

#### `TokenEmbedding`

This layer converts numerical token IDs into dense vectors. Each unique token in the vocabulary is mapped to a `d_model`-dimensional vector, allowing the model to process linguistic information. This is achieved using `torch.nn.Embedding`, where `vocab_size` determines the number of unique tokens and `d_model` is the dimensionality of the embedding vectors.

#### `PositionalEmbedding`

Since Transformers process sequences in parallel and lack an inherent understanding of token order, positional embeddings are crucial. This layer provides a vector representation for each position within the input sequence up to `block_size`. These positional vectors are added to the token embeddings, injecting information about the relative or absolute position of each token in the sequence. Like token embeddings, it uses `torch.nn.Embedding` to map position IDs to `d_model`-dimensional vectors.

**Key Concept:** The final input to the Transformer encoder is typically the sum of the token embedding and its corresponding positional embedding. This combined representation allows the model to differentiate between identical tokens appearing at different positions.

In [None]:
import torch
import torch.nn as nn

# Assuming ModelConfig is defined in model.config or already imported
# from model.config import ModelConfig # Uncomment if not already imported

class TokenEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.d_model
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        token_ids: (B, T)
        returns:   (B, T, D)
        """
        return self.embedding(token_ids)


class PositionalEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.block_size,
            embedding_dim=config.d_model
        )

    def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
        """
        position_ids: (T) or (B, T)
        returns:      (B, T, D)
        """
        return self.embedding(position_ids)

### Scaled Dot-Product Attention

Scaled Dot-Product Attention is a fundamental component of the Transformer architecture, designed to efficiently compute attention weights. It takes three inputs: a query matrix (Q), a key matrix (K), and a value matrix (V). The core idea is to calculate a similarity score between the queries and keys, scale these scores, and then use them to weigh the values.

**Description:**

1.  **Similarity Calculation:** The attention scores are computed by taking the dot product of the query and key matrices. This measures how relevant each key is to each query.
2.  **Scaling:** The scores are then divided by the square root of the dimension of the keys (`d_k`). This scaling factor is crucial for preventing the dot products from becoming too large, especially with high `d_k` values, which can push the softmax function into regions with extremely small gradients, hindering training.
3.  **Masking (Optional):** If a mask is provided, typically for causality (to prevent attention to future tokens in sequence generation) or padding (to ignore non-existent tokens), the masked positions are set to a very small negative number (e.g., `-inf`). This ensures that after the softmax operation, these positions will have an attention weight of approximately zero.
4.  **Softmax:** A softmax function is applied to the scaled scores to obtain attention weights. This normalizes the scores such that they sum to 1, representing a probability distribution over the values.
5.  **Weighted Sum:** Finally, these attention weights are multiplied by the value matrix (V). This creates a weighted sum of the values, where the weight assigned to each value is determined by its relevance to the query.

**Mathematical Formula:**

The Scaled Dot-Product Attention mechanism is mathematically expressed as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
-   $Q$ is the Query matrix.
-   $K$ is the Key matrix.
-   $V$ is the Value matrix.
-   $d_k$ is the dimension of the key vectors.

In [None]:
import torch
import torch.nn as nn
import math


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.scale = 1.0 / math.sqrt(d_model)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        q, k, v: (B, T, D)
        mask:    (T, T) or (B, T, T)
        return:  (B, T, D)
        """

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(weights, v)

        return output


### Causal Self-Attention

`CausalSelfAttention` is a crucial component in transformer-based autoregressive models, such as GPT (Generative Pre-trained Transformer). It extends the concept of `ScaledDotProductAttention` by ensuring that during sequence generation, each token can only attend to previous tokens and itself, not future tokens. This is vital for tasks like language modeling where predicting the next word depends only on the words that have already occurred.

**Description:**

1.  **Initialization (`__init__`)**:
    *   It takes a `ModelConfig` object, which defines parameters like the number of attention heads (`n_heads`), the dimensionality of each head (`head_dim`), and the model's total dimension (`d_model`).
    *   `self.scale`: A scaling factor `1 / sqrt(head_dim)` is calculated, which is standard for scaled dot-product attention to prevent large dot products from pushing the softmax into regions with tiny gradients.
    *   `self.qkv_proj`: A linear projection layer that transforms the input `x` (with shape `(B, T, D)`) into Query (Q), Key (K), and Value (V) matrices. It outputs `3 * d_model` dimensions, which are then split into `d_model` for Q, K, and V respectively.
    *   `self.out_proj`: Another linear projection layer that takes the concatenated output from all attention heads and projects it back to the original `d_model` dimension.
    *   `self.causal_mask`: A lower triangular matrix (e.g., `[[1,0,0],[1,1,0],[1,1,1]]`) is created. This mask is used to block attention to future tokens. It's registered as a buffer, meaning it's part of the model's state but not a trainable parameter.

2.  **Forward Pass (`forward`)**:
    *   **Input**: `x` with shape `(B, T, D)`, where `B` is batch size, `T` is sequence length, and `D` is `d_model`.
    *   **QKV Projections**: The input `x` is passed through `self.qkv_proj` to get a combined `qkv` tensor. This `qkv` tensor is then split into `q`, `k`, and `v` tensors, each of shape `(B, T, D)`.
    *   **Multi-Head Reshaping**: Each `q`, `k`, and `v` tensor is reshaped to `(B, n_heads, T, head_dim)`. This involves splitting the `d_model` dimension into `n_heads` separate heads, each with `head_dim` dimensions. The `transpose(1, 2)` operation rearranges the dimensions to put the heads dimension before the sequence length dimension, which is standard for multi-head attention computations.
    *   **Attention Scores Calculation**: The core attention mechanism is computed:
        $$\text{scores} = (Q K^T) / \sqrt{d_k}$$
        Here, `q` and `k` (reshaped `(B, n_heads, T, head_dim)`) are multiplied (`torch.matmul`) to get the similarity scores. `k.transpose(-2, -1)` transposes the last two dimensions of `k`, effectively performing $K^T$. The result is then scaled by `self.scale` (`1 / sqrt(head_dim)`).
        The shape of `scores` is `(B, n_heads, T, T)`.
    *   **Causal Masking**: The `causal_mask` (a lower triangular matrix) is applied. For each position `i` in the sequence, the mask ensures that the attention scores for positions `j > i` (future tokens) are set to negative infinity. This means that after the softmax, these future positions will have an attention weight of zero, effectively preventing a token from attending to future tokens.
        `scores = scores.masked_fill(mask == 0, float("-inf"))`
    *   **Softmax**: A `softmax` function is applied to the scores along the last dimension (`dim=-1`) to obtain attention `weights`. This normalizes the scores so they sum to 1 for each query, representing a probability distribution over the values.
        `weights = torch.softmax(scores, dim=-1)`
    *   **Weighted Sum of Values**: The attention `weights` are then multiplied by the `v` (value) tensor (`torch.matmul`). This produces the weighted sum of values, where tokens with higher attention weights contribute more to the output.
        `out = torch.matmul(weights, v)`
        The shape of `out` is `(B, n_heads, T, head_dim)`.
    *   **Merge Heads**: The `out` tensor is reshaped back to `(B, T, D)`. This involves transposing the dimensions back and then concatenating the outputs from all heads (`contiguous().view(B, T, D)`).
    *   **Output Projection**: Finally, the merged output is passed through `self.out_proj` to produce the final output of the self-attention layer. This projection allows the model to learn a linear transformation on the combined information from all attention heads.

**Mathematical Intuition:**

The causal self-attention mechanism fundamentally implements the following operation for each head:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T + M}{\sqrt{d_k}}\right) V $$

Where:
*   $Q$, $K$, $V$ are the Query, Key, and Value matrices for a single head.
*   $d_k$ is `head_dim`, the dimensionality of the key vectors.
*   $M$ is the causal mask, where $M_{ij} = 0$ if $i \ge j$ (past and current tokens) and $M_{ij} = -\infty$ if $i < j$ (future tokens). This effectively makes the attention weights to future tokens zero.

The multi-head aspect involves performing this attention operation `n_heads` times in parallel with different linear projections for each head, and then concatenating and linearly projecting their outputs.

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

        # causal mask (registered as buffer, not parameter)
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("causal_mask", mask)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, D)
        return: (B, T, D)
        """
        B, T, D = x.shape

        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape for multi-head
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # apply causal mask
        mask = self.causal_mask[:T, :T]
        scores = scores.masked_fill(mask == 0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, v)

        # merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, D)

        return self.out_proj(out)


### KVCache and CachedCausalSelfAttention: Optimizing Inference with KV Caching

To enhance the efficiency of autoregressive decoding in Transformer models, especially during inference, Key-Value (KV) caching is employed. This technique avoids redundant re-computation of keys and values for previously processed tokens, significantly speeding up generation. The `KVCache` class manages this storage, and `CachedCausalSelfAttention` utilizes it for incremental token processing.

#### 1. `KVCache` Class

The `KVCache` class is a simple container designed to store the keys (K) and values (V) computed during the self-attention mechanism across multiple decoding steps. This cache allows subsequent tokens to attend to the full historical context without re-calculating the K and V matrices for past tokens.

**Operational Breakdown:**

*   **`__init__(self)`**: Initializes an empty cache by setting `self.keys` and `self.values` to `None`. This state indicates that no keys or values have been stored yet.

*   **`append(self, k_new: torch.Tensor, v_new: torch.Tensor)`**: This method adds new key and value tensors to the cache. It expects `k_new` and `v_new` to have the shape `(B, H, 1, Dh)`, representing the keys and values for the current token across batches and attention heads.
    *   If the cache is empty (`self.keys` is `None`), the new keys and values become the initial content of the cache.
    *   If the cache already contains data, `k_new` and `v_new` are concatenated with the existing `self.keys` and `self.values` along the sequence length dimension (dimension 2). This effectively appends the current token's K and V to the historical sequence.

*   **`reset(self)`**: Clears the cache by setting `self.keys` and `self.values` back to `None`. This is typically used to prepare the cache for a new generation sequence.

#### 2. `CachedCausalSelfAttention` Class

The `CachedCausalSelfAttention` module is a specialized version of the `CausalSelfAttention` designed for efficient token-by-token generation (inference) by leveraging the `KVCache`.

**Key Differences from `CausalSelfAttention`:**
*   It processes input `x_t` with a sequence length `T=1` (a single token at a time).
*   It takes a `kv_cache` object as an argument to store and retrieve past keys and values.
*   It *implicitly* handles causality by only attending to the `k_full` and `v_full` retrieved from the cache, which by its nature only contains past and current tokens.

**Operational Breakdown (`forward` method):**

*   **Input**: `x_t` with shape `(B, 1, D)` (a single token per batch) and `kv_cache`.

*   **Assertion**: `assert T == 1, "Cached attention expects exactly one token"` ensures that the module is used for incremental decoding.

*   **QKV Projections**: Similar to `CausalSelfAttention`, `x_t` is projected into query `q`, key `k`, and value `v` tensors for the *current* token.

*   **Multi-Head Reshaping**: `q`, `k`, and `v` are reshaped to `(B, n_heads, 1, head_dim)` to prepare for multi-head attention.

*   **Cache Append**: The newly computed `k` and `v` for the current token are appended to the `kv_cache` using `kv_cache.append(k, v)`. The cache now holds the keys and values for *all* tokens processed so far in the current sequence.

*   **Retrieve Full Cache**: The complete historical `keys` (`k_full`) and `values` (`v_full`) are retrieved from the `kv_cache`. These tensors will have shape `(B, H, T_total, Dh)`, where `T_total` is the current length of the generated sequence.

*   **Attention Score Calculation**: The query `q` (current token's query) is used to compute attention scores against `k_full` (all past and current keys). This ensures that the current token attends to the entire context generated so far.
    $$\text{scores} = (Q_{\text{current}} K_{\text{full}}^T) / \sqrt{d_k}$$
    The `scores` tensor will have shape `(B, n_heads, 1, T_total)`.

*   **Softmax and Weighted Sum**: `softmax` is applied to the scores, and the resulting attention weights are multiplied by `v_full` to produce the output `out` for the current token. This `out` tensor effectively summarizes the information from `v_full` relevant to the current `q`.

*   **Merge Heads and Output Projection**: The `out` tensor is reshaped back to `(B, 1, D)` and then passed through `self.out_proj` to yield the final output for the current token.

**Mathematical Intuition for Cached Self-Attention:**

The core attention computation within `CachedCausalSelfAttention` can be viewed as:

$$ \text{Attention}(\text{token}_t) = \text{softmax}\left(\frac{Q_t \cdot K_{\le t}^T}{\sqrt{d_k}}\right) \cdot V_{\le t} $$

Where:
*   $Q_t$ is the Query vector for the current token at position $t$.
*   $K_{\le t}$ represents the concatenated Key matrix containing keys for all tokens from position $1$ up to $t$ (retrieved from `kv_cache.keys`).
*   $V_{\le t}$ represents the concatenated Value matrix containing values for all tokens from position $1$ up to $t$ (retrieved from `kv_cache.values`).
*   $d_k$ is the `head_dim`, the dimensionality of the key vectors.

In this formulation, the causal masking that explicitly masks future tokens in `CausalSelfAttention` is implicitly handled. By only storing and using keys and values from tokens up to the current position ($K_{\le t}$ and $V_{\le t}$), the model naturally prevents attending to future information. The KV cache makes this process highly efficient by avoiding re-computation of $K_{\le t}$ and $V_{\le t}$ at each step; instead, it simply appends the new $K_t$ and $V_t$ to the existing cache.

In [None]:
class KVCache:
    def __init__(self):
        self.keys = None
        self.values = None

    def append(self, k_new: torch.Tensor, v_new: torch.Tensor):
        """
        k_new, v_new: (B, H, 1, Dh)
        """
        if self.keys is None:
            self.keys = k_new
            self.values = v_new
        else:
            self.keys = torch.cat([self.keys, k_new], dim=2)
            self.values = torch.cat([self.values, v_new], dim=2)

    def reset(self):
        self.keys = None
        self.values = None


In [None]:
class CachedCausalSelfAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

    def forward(self, x_t: torch.Tensor, kv_cache) -> torch.Tensor:
        """
        x_t: (B, 1, D)
        kv_cache: KVCache
        return: (B, 1, D)
        """
        B, T, D = x_t.shape
        assert T == 1, "Cached attention expects exactly one token"

        qkv = self.qkv_proj(x_t)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)

        # append new K, V to cache
        kv_cache.append(k, v)

        # retrieve full cached K, V
        k_full = kv_cache.keys     # (B, H, T_total, Dh)
        v_full = kv_cache.values

        scores = torch.matmul(q, k_full.transpose(-2, -1)) * self.scale
        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, v_full)

        out = out.transpose(1, 2).contiguous().view(B, 1, D)
        return self.out_proj(out)


### FeedForward Layer

In the Transformer architecture, the FeedForward layer (also known as the Position-wise Feed-Forward Network or FFN) is applied independently to each position in the sequence. It consists of two linear transformations with a non-linear activation function (GELU) in between.

**Description:**

1.  **First Linear Layer (`self.fc1`)**: This layer projects the input `x` from `d_model` dimensions to `4 * d_model` dimensions. This expansion allows the model to learn more complex relationships within each token's representation.
2.  **Activation Function (`self.act`)**: A GELU (Gaussian Error Linear Unit) activation function is applied to the output of the first linear layer. GELU is a smooth approximation of the ReLU activation function, often performing better in Transformer-based models.
3.  **Second Linear Layer (`self.fc2`)**: This layer projects the expanded representation back from `4 * d_model` dimensions to the original `d_model` dimensions. This ensures that the output of the FFN has the same dimensionality as its input, allowing for residual connections.

This layer processes each position identically but independently, meaning the same weights are used for all positions, but each position gets its own distinct computation.

**Mathematical Formula:**

The FeedForward layer can be mathematically expressed as:

$$\text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2$$

Where:
*   $x$ is the input to the FeedForward network, typically the output of the self-attention sub-layer.
*   $W_1$ and $b_1$ are the weights and biases of the first linear transformation (from `d_model` to `4 * d_model`).
*   $W_2$ and $b_2$ are the weights and biases of the second linear transformation (from `4 * d_model` to `d_model`).
*   $\text{GELU}$ is the Gaussian Error Linear Unit activation function.

In [None]:
import torch
import torch.nn as nn

from model.config import ModelConfig


class FeedForward(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, 4 * config.d_model)
        self.fc2 = nn.Linear(4 * config.d_model, config.d_model)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(x)))

### TransformerBlock

The `TransformerBlock` is the core building block of the Transformer's encoder and decoder. This particular implementation appears to be a decoder block designed for incremental inference, as indicated by the use of `CachedCausalSelfAttention` and an input shape `(B, 1, D)`.

Each `TransformerBlock` consists of two main sub-layers, each followed by a residual connection and layer normalization:

1.  **Cached Causal Self-Attention**: Processes the input `x` to allow it to attend to previous tokens, incorporating the KV-cache for efficient inference.
2.  **Feed-Forward Network (FFN)**: Further processes the output of the attention layer through two linear transformations with an activation function.

**Description:**

*   **`self.ln1` (Layer Normalization)**: Applied before the attention sub-layer. Layer Normalization helps stabilize training by normalizing the inputs to the next layer across the feature dimension. It ensures that the mean and variance of the inputs are consistent.
*   **`self.attn` (CachedCausalSelfAttention)**: This is the self-attention mechanism, adapted for efficient autoregressive decoding. It takes the layer-normalized input `self.ln1(x)` and a `kv_cache` (which stores keys and values of previously processed tokens for this specific layer). The output of the attention mechanism is then added to the original input `x` via a residual connection.
*   **`self.ln2` (Layer Normalization)**: Applied before the Feed-Forward Network. Similar to `self.ln1`, it normalizes the input to the FFN.
*   **`self.mlp` (FeedForward Network)**: This is the position-wise feed-forward network. It takes the layer-normalized output of the attention sub-layer `self.ln2(x)` and processes it. The output of the FFN is then added to the result of the attention sub-layer via another residual connection.

**Forward Pass Logic (`forward` method):**

The `forward` method implements the following sequence of operations:
1.  **Attention Sub-layer**: The input `x` is first normalized by `self.ln1`. This normalized input is then passed to the `self.attn` module along with the `kv_cache` for the current layer. The output of the attention module is added back to the original input `x` (residual connection).
    *   `x = x + self.attn(self.ln1(x), kv_cache)`
2.  **Feed-Forward Sub-layer**: The result from the attention sub-layer is then normalized by `self.ln2`. This normalized result is passed to the `self.mlp` module. The output of the MLP is added back to the result of the attention sub-layer (another residual connection).
    *   `x = x + self.mlp(self.ln2(x))`
3.  **Output**: The final result `x` is the output of the Transformer block.

**Mathematical Formula:**

Let $x$ be the input to the Transformer block, and $x_{\text{cache}}$ represent the `kv_cache` for the current layer.

1.  **Layer Normalization 1**: $x' = \text{LayerNorm}_1(x)$
2.  **Cached Causal Self-Attention**: $x'' = \text{CachedCausalSelfAttention}(x', x_{\text{cache}})$
3.  **Residual Connection 1**: $x_{\text{attn}} = x + x''$
4.  **Layer Normalization 2**: $x''' = \text{LayerNorm}_2(x_{\text{attn}})$
5.  **Feed-Forward Network**: $x'''' = \text{FeedForward}(x''')$
6.  **Residual Connection 2**: $x_{\text{output}} = x_{\text{attn}} + x''''$

Thus, the entire block's operation can be summarized as:

$$ \text{TransformerBlock}(x, x_{\text{cache}}) = \text{LayerNorm}_2(x + \text{CachedCausalSelfAttention}(\text{LayerNorm}_1(x), x_{\text{cache}})) + \text{FeedForward}(\text{LayerNorm}_2(x + \text{CachedCausalSelfAttention}(\text{LayerNorm}_1(x), x_{\text{cache}}))) $$

This structure, often referred to as "Pre-Normalization" or "Pre-LN" Transformer, applies layer normalization before the self-attention and FFN sub-layers.

In [None]:
import torch
import torch.nn as nn

from model.config import ModelConfig
from model.attention import CachedCausalSelfAttention
from model.layers import FeedForward


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CachedCausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = FeedForward(config)

    def forward(self, x: torch.Tensor, kv_cache) -> torch.Tensor:
        """
        x: (B, 1, D)
        kv_cache: KVCache for this layer
        """
        x = x + self.attn(self.ln1(x), kv_cache)
        x = x + self.mlp(self.ln2(x))
        return x
