<a href="https://colab.research.google.com/github/Suraj-Sedai/kv-cache-transformer/blob/main/experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MiniGPT-Inference [KV-CACHE TRANSFORMER]

**A From-Scratch Transformer Inference Engine**

MiniGPT-Inference is a from-scratch, production-grade Transformer inference engine designed to execute autoregressive decoding efficiently using Keyâ€“Value (KV) caching, incremental decoding, and batched generation.

Unlike training-focused implementations, this project centers on inference-time systems engineering, emphasizing:
- Computational complexity reduction
- Memory efficiency
- Deterministic correctness
- Measurable performance gains

The system is architected to reflect how modern large language models (LLMs) are served in real-world environments.

# Project Setup: Model Architecture Configuration

This section outlines the foundational configuration for our model. The `ModelConfig` dataclass is used to define key architectural hyperparameters, centralizing them for clarity, reusability, and ease of modification.

The parameters included in `ModelConfig` are typically found in transformer-based models and include:
*   `vocab_size`: The size of the vocabulary, representing the number of unique tokens the model can process.
*   `n_layers`: The number of transformer layers or blocks within the model's architecture.
*   `n_heads`: The number of attention heads used in the multi-head attention mechanism within each transformer layer.
*   `d_model`: The dimensionality of the model's embeddings and internal representations.
*   `block_size`: The maximum sequence length or context window that the model can process at once.
*   `dropout`: The dropout rate applied for regularization to prevent overfitting.

By using a `dataclass`, we achieve immutability for the configuration once defined (due to `frozen=True`), which helps prevent accidental changes to the model's blueprint during its lifecycle. The `head_dim` property is also derived to ensure `d_model` is divisible by `n_heads`.

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)#prevents accidental mutation
class ModelConfig:
    vocab_size: int
    n_layers: int
    n_heads: int
    d_model: int
    block_size: int
    dropout: float = 0.0

    @property
    def head_dim(self) -> int:
        return self.d_model // self.n_heads

    def __post_init__(self):
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

### Embedding Layers: Token and Positional

Transformer models rely on embedding layers to convert discrete input tokens into continuous vector representations, capturing both semantic meaning and sequential order.

#### `TokenEmbedding`

This layer converts numerical token IDs into dense vectors. Each unique token in the vocabulary is mapped to a `d_model`-dimensional vector, allowing the model to process linguistic information. This is achieved using `torch.nn.Embedding`, where `vocab_size` determines the number of unique tokens and `d_model` is the dimensionality of the embedding vectors.

#### `PositionalEmbedding`

Since Transformers process sequences in parallel and lack an inherent understanding of token order, positional embeddings are crucial. This layer provides a vector representation for each position within the input sequence up to `block_size`. These positional vectors are added to the token embeddings, injecting information about the relative or absolute position of each token in the sequence. Like token embeddings, it uses `torch.nn.Embedding` to map position IDs to `d_model`-dimensional vectors.

**Key Concept:** The final input to the Transformer encoder is typically the sum of the token embedding and its corresponding positional embedding. This combined representation allows the model to differentiate between identical tokens appearing at different positions.

In [None]:
import torch
import torch.nn as nn

# Assuming ModelConfig is defined in model.config or already imported
# from model.config import ModelConfig # Uncomment if not already imported

class TokenEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.d_model
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        token_ids: (B, T)
        returns:   (B, T, D)
        """
        return self.embedding(token_ids)


class PositionalEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.block_size,
            embedding_dim=config.d_model
        )

    def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
        """
        position_ids: (T) or (B, T)
        returns:      (B, T, D)
        """
        return self.embedding(position_ids)

### Scaled Dot-Product Attention

Scaled Dot-Product Attention is a fundamental component of the Transformer architecture, crucial for allowing the model to weigh the importance of different parts of the input sequence when processing a specific element.

**Description:**

At its core, this mechanism computes an output as a weighted sum of **value** vectors, where the weight assigned to each value is determined by the dot product of the **query** with the corresponding **key** vectors. This process is scaled by the square root of the dimension of the key vectors (`d_model` in our case, or `d_k` in the original paper) to prevent the dot products from growing too large, which can push the softmax function into regions with extremely small gradients, hindering training.

Masking is applied to prevent attention to certain positions, typically future tokens in decoder architectures (causal masking) or padded tokens.

**Mathematical Formulation:**

Given a Query matrix \(Q\), Key matrix \(K\), and Value matrix \(V\):

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V $$

Where:
*   \(Q\) (Query), \(K\) (Key), and \(V\) (Value) are matrices derived from the input embeddings.
*   \(d_k\) is the dimension of the key vectors (here, `d_model`).
*   \(M\) is an optional mask matrix, where entries corresponding to positions that should be ignored are set to \(-\infty\) (or a very large negative number), causing them to become zero after the softmax operation.

In [None]:
import torch
import torch.nn as nn
import math


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.scale = 1.0 / math.sqrt(d_model)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        q, k, v: (B, T, D)
        mask:    (T, T) or (B, T, T)
        return:  (B, T, D)
        """

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(weights, v)

        return output
