<a href="https://colab.research.google.com/github/Suraj-Sedai/kv-cache-transformer/blob/main/experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MiniGPT-Inference [KV-CACHE TRANSFORMER]

**A From-Scratch Transformer Inference Engine**

MiniGPT-Inference is a from-scratch, production-grade Transformer inference engine designed to execute autoregressive decoding efficiently using Keyâ€“Value (KV) caching, incremental decoding, and batched generation.

Unlike training-focused implementations, this project centers on inference-time systems engineering, emphasizing:
- Computational complexity reduction
- Memory efficiency
- Deterministic correctness
- Measurable performance gains

The system is architected to reflect how modern large language models (LLMs) are served in real-world environments.

# Project Setup: Model Architecture Configuration

This section outlines the foundational configuration for our model. The `ModelConfig` dataclass is used to define key architectural hyperparameters, centralizing them for clarity, reusability, and ease of modification.

The parameters included in `ModelConfig` are typically found in transformer-based models and include:
*   `vocab_size`: The size of the vocabulary, representing the number of unique tokens the model can process.
*   `n_layers`: The number of transformer layers or blocks within the model's architecture.
*   `n_heads`: The number of attention heads used in the multi-head attention mechanism within each transformer layer.
*   `d_model`: The dimensionality of the model's embeddings and internal representations.
*   `block_size`: The maximum sequence length or context window that the model can process at once.
*   `dropout`: The dropout rate applied for regularization to prevent overfitting.

By using a `dataclass`, we achieve immutability for the configuration once defined (due to `frozen=True`), which helps prevent accidental changes to the model's blueprint during its lifecycle. The `head_dim` property is also derived to ensure `d_model` is divisible by `n_heads`.

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)#prevents accidental mutation
class ModelConfig:
    vocab_size: int
    n_layers: int
    n_heads: int
    d_model: int
    block_size: int
    dropout: float = 0.0

    @property
    def head_dim(self) -> int:
        return self.d_model // self.n_heads

    def __post_init__(self):
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

### Embedding Layers: Token and Positional

Transformer models rely on embedding layers to convert discrete input tokens into continuous vector representations, capturing both semantic meaning and sequential order.

#### `TokenEmbedding`

This layer converts numerical token IDs into dense vectors. Each unique token in the vocabulary is mapped to a `d_model`-dimensional vector, allowing the model to process linguistic information. This is achieved using `torch.nn.Embedding`, where `vocab_size` determines the number of unique tokens and `d_model` is the dimensionality of the embedding vectors.

#### `PositionalEmbedding`

Since Transformers process sequences in parallel and lack an inherent understanding of token order, positional embeddings are crucial. This layer provides a vector representation for each position within the input sequence up to `block_size`. These positional vectors are added to the token embeddings, injecting information about the relative or absolute position of each token in the sequence. Like token embeddings, it uses `torch.nn.Embedding` to map position IDs to `d_model`-dimensional vectors.

**Key Concept:** The final input to the Transformer encoder is typically the sum of the token embedding and its corresponding positional embedding. This combined representation allows the model to differentiate between identical tokens appearing at different positions.

In [None]:
import torch
import torch.nn as nn

# Assuming ModelConfig is defined in model.config or already imported
# from model.config import ModelConfig # Uncomment if not already imported

class TokenEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.d_model
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        token_ids: (B, T)
        returns:   (B, T, D)
        """
        return self.embedding(token_ids)


class PositionalEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.block_size,
            embedding_dim=config.d_model
        )

    def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
        """
        position_ids: (T) or (B, T)
        returns:      (B, T, D)
        """
        return self.embedding(position_ids)

### Scaled Dot-Product Attention

Scaled Dot-Product Attention is a fundamental component of the Transformer architecture, designed to efficiently compute attention weights. It takes three inputs: a query matrix (Q), a key matrix (K), and a value matrix (V). The core idea is to calculate a similarity score between the queries and keys, scale these scores, and then use them to weigh the values.

**Description:**

1.  **Similarity Calculation:** The attention scores are computed by taking the dot product of the query and key matrices. This measures how relevant each key is to each query.
2.  **Scaling:** The scores are then divided by the square root of the dimension of the keys (`d_k`). This scaling factor is crucial for preventing the dot products from becoming too large, especially with high `d_k` values, which can push the softmax function into regions with extremely small gradients, hindering training.
3.  **Masking (Optional):** If a mask is provided, typically for causality (to prevent attention to future tokens in sequence generation) or padding (to ignore non-existent tokens), the masked positions are set to a very small negative number (e.g., `-inf`). This ensures that after the softmax operation, these positions will have an attention weight of approximately zero.
4.  **Softmax:** A softmax function is applied to the scaled scores to obtain attention weights. This normalizes the scores such that they sum to 1, representing a probability distribution over the values.
5.  **Weighted Sum:** Finally, these attention weights are multiplied by the value matrix (V). This creates a weighted sum of the values, where the weight assigned to each value is determined by its relevance to the query.

**Mathematical Formula:**

The Scaled Dot-Product Attention mechanism is mathematically expressed as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
-   $Q$ is the Query matrix.
-   $K$ is the Key matrix.
-   $V$ is the Value matrix.
-   $d_k$ is the dimension of the key vectors.

In [None]:
import torch
import torch.nn as nn
import math


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.scale = 1.0 / math.sqrt(d_model)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        q, k, v: (B, T, D)
        mask:    (T, T) or (B, T, T)
        return:  (B, T, D)
        """

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(weights, v)

        return output


### Causal Self-Attention

`CausalSelfAttention` is a crucial component in transformer-based autoregressive models, such as GPT (Generative Pre-trained Transformer). It extends the concept of `ScaledDotProductAttention` by ensuring that during sequence generation, each token can only attend to previous tokens and itself, not future tokens. This is vital for tasks like language modeling where predicting the next word depends only on the words that have already occurred.

**Description:**

1.  **Initialization (`__init__`)**:
    *   It takes a `ModelConfig` object, which defines parameters like the number of attention heads (`n_heads`), the dimensionality of each head (`head_dim`), and the model's total dimension (`d_model`).
    *   `self.scale`: A scaling factor `1 / sqrt(head_dim)` is calculated, which is standard for scaled dot-product attention to prevent large dot products from pushing the softmax into regions with tiny gradients.
    *   `self.qkv_proj`: A linear projection layer that transforms the input `x` (with shape `(B, T, D)`) into Query (Q), Key (K), and Value (V) matrices. It outputs `3 * d_model` dimensions, which are then split into `d_model` for Q, K, and V respectively.
    *   `self.out_proj`: Another linear projection layer that takes the concatenated output from all attention heads and projects it back to the original `d_model` dimension.
    *   `self.causal_mask`: A lower triangular matrix (e.g., `[[1,0,0],[1,1,0],[1,1,1]]`) is created. This mask is used to block attention to future tokens. It's registered as a buffer, meaning it's part of the model's state but not a trainable parameter.

2.  **Forward Pass (`forward`)**:
    *   **Input**: `x` with shape `(B, T, D)`, where `B` is batch size, `T` is sequence length, and `D` is `d_model`.
    *   **QKV Projections**: The input `x` is passed through `self.qkv_proj` to get a combined `qkv` tensor. This `qkv` tensor is then split into `q`, `k`, and `v` tensors, each of shape `(B, T, D)`.
    *   **Multi-Head Reshaping**: Each `q`, `k`, and `v` tensor is reshaped to `(B, n_heads, T, head_dim)`. This involves splitting the `d_model` dimension into `n_heads` separate heads, each with `head_dim` dimensions. The `transpose(1, 2)` operation rearranges the dimensions to put the heads dimension before the sequence length dimension, which is standard for multi-head attention computations.
    *   **Attention Scores Calculation**: The core attention mechanism is computed:
        $$\text{scores} = (Q K^T) / \sqrt{d_k}$$
        Here, `q` and `k` (reshaped `(B, n_heads, T, head_dim)`) are multiplied (`torch.matmul`) to get the similarity scores. `k.transpose(-2, -1)` transposes the last two dimensions of `k`, effectively performing $K^T$. The result is then scaled by `self.scale` (`1 / sqrt(head_dim)`).
        The shape of `scores` is `(B, n_heads, T, T)`.
    *   **Causal Masking**: The `causal_mask` (a lower triangular matrix) is applied. For each position `i` in the sequence, the mask ensures that the attention scores for positions `j > i` (future tokens) are set to negative infinity. This means that after the softmax, these future positions will have an attention weight of zero, effectively preventing a token from attending to future tokens.
        `scores = scores.masked_fill(mask == 0, float("-inf"))`
    *   **Softmax**: A `softmax` function is applied to the scores along the last dimension (`dim=-1`) to obtain attention `weights`. This normalizes the scores so they sum to 1 for each query, representing a probability distribution over the values.
        `weights = torch.softmax(scores, dim=-1)`
    *   **Weighted Sum of Values**: The attention `weights` are then multiplied by the `v` (value) tensor (`torch.matmul`). This produces the weighted sum of values, where tokens with higher attention weights contribute more to the output.
        `out = torch.matmul(weights, v)`
        The shape of `out` is `(B, n_heads, T, head_dim)`.
    *   **Merge Heads**: The `out` tensor is reshaped back to `(B, T, D)`. This involves transposing the dimensions back and then concatenating the outputs from all heads (`contiguous().view(B, T, D)`).
    *   **Output Projection**: Finally, the merged output is passed through `self.out_proj` to produce the final output of the self-attention layer. This projection allows the model to learn a linear transformation on the combined information from all attention heads.

**Mathematical Intuition:**

The causal self-attention mechanism fundamentally implements the following operation for each head:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T + M}{\sqrt{d_k}}\right) V $$

Where:
*   $Q$, $K$, $V$ are the Query, Key, and Value matrices for a single head.
*   $d_k$ is `head_dim`, the dimensionality of the key vectors.
*   $M$ is the causal mask, where $M_{ij} = 0$ if $i \ge j$ (past and current tokens) and $M_{ij} = -\infty$ if $i < j$ (future tokens). This effectively makes the attention weights to future tokens zero.

The multi-head aspect involves performing this attention operation `n_heads` times in parallel with different linear projections for each head, and then concatenating and linearly projecting their outputs.

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

        # causal mask (registered as buffer, not parameter)
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("causal_mask", mask)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, D)
        return: (B, T, D)
        """
        B, T, D = x.shape

        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape for multi-head
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # apply causal mask
        mask = self.causal_mask[:T, :T]
        scores = scores.masked_fill(mask == 0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, v)

        # merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, D)

        return self.out_proj(out)
