<a href="https://colab.research.google.com/github/Suraj-Sedai/minigpt-inference/blob/main/notebooks/experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MiniGPT-Inference [KV-CACHE TRANSFORMER]

**A From-Scratch Transformer Inference Engine**

MiniGPT-Inference is a from-scratch, production-grade Transformer inference engine designed to execute autoregressive decoding efficiently using Key–Value (KV) caching, incremental decoding, and batched generation.

Unlike training-focused implementations, this project centers on inference-time systems engineering, emphasizing:
- Computational complexity reduction
- Memory efficiency
- Deterministic correctness
- Measurable performance gains

The system is architected to reflect how modern large language models (LLMs) are served in real-world environments.

# Project Setup: Model Architecture Configuration

This section outlines the foundational configuration for our model. The `ModelConfig` dataclass is used to define key architectural hyperparameters, centralizing them for clarity, reusability, and ease of modification.

The parameters included in `ModelConfig` are typically found in transformer-based models and include:
*   `vocab_size`: The size of the vocabulary, representing the number of unique tokens the model can process.
*   `n_layers`: The number of transformer layers or blocks within the model's architecture.
*   `n_heads`: The number of attention heads used in the multi-head attention mechanism within each transformer layer.
*   `d_model`: The dimensionality of the model's embeddings and internal representations.
*   `block_size`: The maximum sequence length or context window that the model can process at once.
*   `dropout`: The dropout rate applied for regularization to prevent overfitting.

By using a `dataclass`, we achieve immutability for the configuration once defined (due to `frozen=True`), which helps prevent accidental changes to the model's blueprint during its lifecycle. The `head_dim` property is also derived to ensure `d_model` is divisible by `n_heads`.

In [655]:
from dataclasses import dataclass

@dataclass(frozen=True)#prevents accidental mutation
class ModelConfig:
    vocab_size: int
    n_layers: int
    n_heads: int
    d_model: int
    block_size: int
    dropout: float = 0.0

    @property
    def head_dim(self) -> int:
        return self.d_model // self.n_heads

    def __post_init__(self):
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

### Embedding Layers: Token and Positional

Transformer models rely on embedding layers to convert discrete input tokens into continuous vector representations, capturing both semantic meaning and sequential order.

#### `TokenEmbedding`

This layer converts numerical token IDs into dense vectors. Each unique token in the vocabulary is mapped to a `d_model`-dimensional vector, allowing the model to process linguistic information. This is achieved using `torch.nn.Embedding`, where `vocab_size` determines the number of unique tokens and `d_model` is the dimensionality of the embedding vectors.

#### `PositionalEmbedding`

Since Transformers process sequences in parallel and lack an inherent understanding of token order, positional embeddings are crucial. This layer provides a vector representation for each position within the input sequence up to `block_size`. These positional vectors are added to the token embeddings, injecting information about the relative or absolute position of each token in the sequence. Like token embeddings, it uses `torch.nn.Embedding` to map position IDs to `d_model`-dimensional vectors.

**Key Concept:** The final input to the Transformer encoder is typically the sum of the token embedding and its corresponding positional embedding. This combined representation allows the model to differentiate between identical tokens appearing at different positions.

In [656]:
import torch
import torch.nn as nn

# Assuming ModelConfig is defined in model.config or already imported
# from model.config import ModelConfig # Uncomment if not already imported

class TokenEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.d_model
        )

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        token_ids: (B, T)
        returns:   (B, T, D)
        """
        return self.embedding(token_ids)


class PositionalEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=config.block_size,
            embedding_dim=config.d_model
        )

    def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
        """
        position_ids: (T) or (B, T)
        returns:      (B, T, D)
        """
        return self.embedding(position_ids)

### Scaled Dot-Product Attention

Scaled Dot-Product Attention is a fundamental component of the Transformer architecture, designed to efficiently compute attention weights. It takes three inputs: a query matrix (Q), a key matrix (K), and a value matrix (V). The core idea is to calculate a similarity score between the queries and keys, scale these scores, and then use them to weigh the values.

**Description:**

1.  **Similarity Calculation:** The attention scores are computed by taking the dot product of the query and key matrices. This measures how relevant each key is to each query.
2.  **Scaling:** The scores are then divided by the square root of the dimension of the keys (`d_k`). This scaling factor is crucial for preventing the dot products from becoming too large, especially with high `d_k` values, which can push the softmax function into regions with extremely small gradients, hindering training.
3.  **Masking (Optional):** If a mask is provided, typically for causality (to prevent attention to future tokens in sequence generation) or padding (to ignore non-existent tokens), the masked positions are set to a very small negative number (e.g., `-inf`). This ensures that after the softmax operation, these positions will have an attention weight of approximately zero.
4.  **Softmax:** A softmax function is applied to the scaled scores to obtain attention weights. This normalizes the scores such that they sum to 1, representing a probability distribution over the values.
5.  **Weighted Sum:** Finally, these attention weights are multiplied by the value matrix (V). This creates a weighted sum of the values, where the weight assigned to each value is determined by its relevance to the query.

**Mathematical Formula:**

The Scaled Dot-Product Attention mechanism is mathematically expressed as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
-   $Q$ is the Query matrix.
-   $K$ is the Key matrix.
-   $V$ is the Value matrix.
-   $d_k$ is the dimension of the key vectors.

In [657]:
import torch
import torch.nn as nn
import math


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.scale = 1.0 / math.sqrt(d_model)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        q, k, v: (B, T, D)
        mask:    (T, T) or (B, T, T)
        return:  (B, T, D)
        """

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(weights, v)

        return output


### Causal Self-Attention

`CausalSelfAttention` is a crucial component in transformer-based autoregressive models, such as GPT (Generative Pre-trained Transformer). It extends the concept of `ScaledDotProductAttention` by ensuring that during sequence generation, each token can only attend to previous tokens and itself, not future tokens. This is vital for tasks like language modeling where predicting the next word depends only on the words that have already occurred.

**Description:**

1.  **Initialization (`__init__`)**:
    *   It takes a `ModelConfig` object, which defines parameters like the number of attention heads (`n_heads`), the dimensionality of each head (`head_dim`), and the model's total dimension (`d_model`).
    *   `self.scale`: A scaling factor `1 / sqrt(head_dim)` is calculated, which is standard for scaled dot-product attention to prevent large dot products from pushing the softmax into regions with tiny gradients.
    *   `self.qkv_proj`: A linear projection layer that transforms the input `x` (with shape `(B, T, D)`) into Query (Q), Key (K), and Value (V) matrices. It outputs `3 * d_model` dimensions, which are then split into `d_model` for Q, K, and V respectively.
    *   `self.out_proj`: Another linear projection layer that takes the concatenated output from all attention heads and projects it back to the original `d_model` dimension.
    *   `self.causal_mask`: A lower triangular matrix (e.g., `[[1,0,0],[1,1,0],[1,1,1]]`) is created. This mask is used to block attention to future tokens. It's registered as a buffer, meaning it's part of the model's state but not a trainable parameter.

2.  **Forward Pass (`forward`)**:
    *   **Input**: `x` with shape `(B, T, D)`, where `B` is batch size, `T` is sequence length, and `D` is `d_model`.
    *   **QKV Projections**: The input `x` is passed through `self.qkv_proj` to get a combined `qkv` tensor. This `qkv` tensor is then split into `q`, `k`, and `v` tensors, each of shape `(B, T, D)`.
    *   **Multi-Head Reshaping**: Each `q`, `k`, and `v` tensor is reshaped to `(B, n_heads, T, head_dim)`. This involves splitting the `d_model` dimension into `n_heads` separate heads, each with `head_dim` dimensions. The `transpose(1, 2)` operation rearranges the dimensions to put the heads dimension before the sequence length dimension, which is standard for multi-head attention computations.
    *   **Attention Scores Calculation**: The core attention mechanism is computed:
        $$\text{scores} = (Q K^T) / \sqrt{d_k}$$
        Here, `q` and `k` (reshaped `(B, n_heads, T, head_dim)`) are multiplied (`torch.matmul`) to get the similarity scores. `k.transpose(-2, -1)` transposes the last two dimensions of `k`, effectively performing $K^T$. The result is then scaled by `self.scale` (`1 / sqrt(head_dim)`).
        The shape of `scores` is `(B, n_heads, T, T)`.
    *   **Causal Masking**: The `causal_mask` (a lower triangular matrix) is applied. For each position `i` in the sequence, the mask ensures that the attention scores for positions `j > i` (future tokens) are set to negative infinity. This means that after the softmax, these future positions will have an attention weight of zero, effectively preventing a token from attending to future tokens.
        `scores = scores.masked_fill(mask == 0, float("-inf"))`
    *   **Softmax**: A `softmax` function is applied to the scores along the last dimension (`dim=-1`) to obtain attention `weights`. This normalizes the scores so they sum to 1 for each query, representing a probability distribution over the values.
        `weights = torch.softmax(scores, dim=-1)`
    *   **Weighted Sum of Values**: The attention `weights` are then multiplied by the `v` (value) tensor (`torch.matmul`). This produces the weighted sum of values, where tokens with higher attention weights contribute more to the output.
        `out = torch.matmul(weights, v)`
        The shape of `out` is `(B, n_heads, T, head_dim)`.
    *   **Merge Heads**: The `out` tensor is reshaped back to `(B, T, D)`. This involves transposing the dimensions back and then concatenating the outputs from all heads (`contiguous().view(B, T, D)`).
    *   **Output Projection**: Finally, the merged output is passed through `self.out_proj` to produce the final output of the self-attention layer. This projection allows the model to learn a linear transformation on the combined information from all attention heads.

**Mathematical Intuition:**

The causal self-attention mechanism fundamentally implements the following operation for each head:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T + M}{\sqrt{d_k}}\right) V $$

Where:
*   $Q$, $K$, $V$ are the Query, Key, and Value matrices for a single head.
*   $d_k$ is `head_dim`, the dimensionality of the key vectors.
*   $M$ is the causal mask, where $M_{ij} = 0$ if $i \ge j$ (past and current tokens) and $M_{ij} = -\infty$ if $i < j$ (future tokens). This effectively makes the attention weights to future tokens zero.

The multi-head aspect involves performing this attention operation `n_heads` times in parallel with different linear projections for each head, and then concatenating and linearly projecting their outputs.

In [658]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

        # causal mask (registered as buffer, not parameter)
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("causal_mask", mask)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, D)
        return: (B, T, D)
        """
        B, T, D = x.shape

        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape for multi-head
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # apply causal mask
        mask = self.causal_mask[:T, :T]
        scores = scores.masked_fill(mask == 0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, v)

        # merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, D)

        return self.out_proj(out)


### KVCache and CachedCausalSelfAttention: Optimizing Inference with KV Caching

To enhance the efficiency of autoregressive decoding in Transformer models, especially during inference, Key-Value (KV) caching is employed. This technique avoids redundant re-computation of keys and values for previously processed tokens, significantly speeding up generation. The `KVCache` class manages this storage, and `CachedCausalSelfAttention` utilizes it for incremental token processing.

#### 1. `KVCache` Class

The `KVCache` class is a simple container designed to store the keys (K) and values (V) computed during the self-attention mechanism across multiple decoding steps. This cache allows subsequent tokens to attend to the full historical context without re-calculating the K and V matrices for past tokens.

**Operational Breakdown:**

*   **`__init__(self)`**: Initializes an empty cache by setting `self.keys` and `self.values` to `None`. This state indicates that no keys or values have been stored yet.

*   **`append(self, k_new: torch.Tensor, v_new: torch.Tensor)`**: This method adds new key and value tensors to the cache. It expects `k_new` and `v_new` to have the shape `(B, H, 1, Dh)`, representing the keys and values for the current token across batches and attention heads.
    *   If the cache is empty (`self.keys` is `None`), the new keys and values become the initial content of the cache.
    *   If the cache already contains data, `k_new` and `v_new` are concatenated with the existing `self.keys` and `self.values` along the sequence length dimension (dimension 2). This effectively appends the current token's K and V to the historical sequence.

*   **`reset(self)`**: Clears the cache by setting `self.keys` and `self.values` back to `None`. This is typically used to prepare the cache for a new generation sequence.

#### 2. `CachedCausalSelfAttention` Class

The `CachedCausalSelfAttention` module is a specialized version of the `CausalSelfAttention` designed for efficient token-by-token generation (inference) by leveraging the `KVCache`.

**Key Differences from `CausalSelfAttention`:**
*   It processes input `x_t` with a sequence length `T=1` (a single token at a time).
*   It takes a `kv_cache` object as an argument to store and retrieve past keys and values.
*   It *implicitly* handles causality by only attending to the `k_full` and `v_full` retrieved from the cache, which by its nature only contains past and current tokens.

**Operational Breakdown (`forward` method):**

*   **Input**: `x_t` with shape `(B, 1, D)` (a single token per batch) and `kv_cache`.

*   **Assertion**: `assert T == 1, "Cached attention expects exactly one token"` ensures that the module is used for incremental decoding.

*   **QKV Projections**: Similar to `CausalSelfAttention`, `x_t` is projected into query `q`, key `k`, and value `v` tensors for the *current* token.

*   **Multi-Head Reshaping**: `q`, `k`, and `v` are reshaped to `(B, n_heads, 1, head_dim)` to prepare for multi-head attention.

*   **Cache Append**: The newly computed `k` and `v` for the current token are appended to the `kv_cache` using `kv_cache.append(k, v)`. The cache now holds the keys and values for *all* tokens processed so far in the current sequence.

*   **Retrieve Full Cache**: The complete historical `keys` (`k_full`) and `values` (`v_full`) are retrieved from the `kv_cache`. These tensors will have shape `(B, H, T_total, Dh)`, where `T_total` is the current length of the generated sequence.

*   **Attention Score Calculation**: The query `q` (current token's query) is used to compute attention scores against `k_full` (all past and current keys). This ensures that the current token attends to the entire context generated so far.
    $$\text{scores} = (Q_{\text{current}} K_{\text{full}}^T) / \sqrt{d_k}$$
    The `scores` tensor will have shape `(B, n_heads, 1, T_total)`.

*   **Softmax and Weighted Sum**: `softmax` is applied to the scores, and the resulting attention weights are multiplied by `v_full` to produce the output `out` for the current token. This `out` tensor effectively summarizes the information from `v_full` relevant to the current `q`.

*   **Merge Heads and Output Projection**: The `out` tensor is reshaped back to `(B, 1, D)` and then passed through `self.out_proj` to yield the final output for the current token.

**Mathematical Intuition for Cached Self-Attention:**

The core attention computation within `CachedCausalSelfAttention` can be viewed as:

$$ \text{Attention}(\text{token}_t) = \text{softmax}\left(\frac{Q_t \cdot K_{\le t}^T}{\sqrt{d_k}}\right) \cdot V_{\le t} $$

Where:
*   $Q_t$ is the Query vector for the current token at position $t$.
*   $K_{\le t}$ represents the concatenated Key matrix containing keys for all tokens from position $1$ up to $t$ (retrieved from `kv_cache.keys`).
*   $V_{\le t}$ represents the concatenated Value matrix containing values for all tokens from position $1$ up to $t$ (retrieved from `kv_cache.values`).
*   $d_k$ is the `head_dim`, the dimensionality of the key vectors.

In this formulation, the causal masking that explicitly masks future tokens in `CausalSelfAttention` is implicitly handled. By only storing and using keys and values from tokens up to the current position ($K_{\le t}$ and $V_{\le t}$), the model naturally prevents attending to future information. The KV cache makes this process highly efficient by avoiding re-computation of $K_{\le t}$ and $V_{\le t}$ at each step; instead, it simply appends the new $K_t$ and $V_t$ to the existing cache.

In [659]:
class KVCache:
    def __init__(self):
        self.keys = None
        self.values = None

    def append(self, k_new: torch.Tensor, v_new: torch.Tensor):
        """
        k_new, v_new: (B, H, 1, Dh)
        """
        if self.keys is None:
            self.keys = k_new
            self.values = v_new
        else:
            self.keys = torch.cat([self.keys, k_new], dim=2)
            self.values = torch.cat([self.values, v_new], dim=2)

    def reset(self):
        self.keys = None
        self.values = None


In [660]:
class CachedCausalSelfAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = 1.0 / math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

    def forward(self, x_t: torch.Tensor, kv_cache) -> torch.Tensor:
        """
        x_t: (B, 1, D)
        kv_cache: KVCache
        return: (B, 1, D)
        """
        B, T, D = x_t.shape
        assert T == 1, "Cached attention expects exactly one token"

        qkv = self.qkv_proj(x_t)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)

        # append new K, V to cache
        kv_cache.append(k, v)

        # retrieve full cached K, V
        k_full = kv_cache.keys     # (B, H, T_total, Dh)
        v_full = kv_cache.values

        scores = torch.matmul(q, k_full.transpose(-2, -1)) * self.scale
        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, v_full)

        out = out.transpose(1, 2).contiguous().view(B, 1, D)
        return self.out_proj(out)


In [661]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
        self.proj = nn.Linear(config.d_model, config.d_model)

        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("mask", mask)

    def forward(self, x):
        B, T, D = x.shape

        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, -1, self.head_dim).transpose(1, 2)
        k = k.view(B, T, -1, self.head_dim).transpose(1, 2)
        v = v.view(B, T, -1, self.head_dim).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * self.scale
        att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        att = torch.softmax(att, dim=-1)

        out = att @ v
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.proj(out)


In [662]:
class CachedAttentionWrapper(nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn   # shared weights

    def forward(self, x_t, kv_cache):
        """
        x_t: (B, 1, D)
        kv_cache: KVCache
        return: (B, 1, D)
        """
        B, T, D = x_t.shape
        assert T == 1, "Cached attention expects exactly one token"

        qkv = self.qkv_proj(x_t)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, 1, self.n_heads, self.head_dim).transpose(1, 2)

        # append new K, V to cache
        kv_cache.append(k, v)

        # retrieve full cached K, V
        k_full = kv_cache.keys     # (B, H, T_total, Dh)
        v_full = kv_cache.values

        scores = torch.matmul(q, k_full.transpose(-2, -1)) * self.scale
        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, v_full)

        out = out.transpose(1, 2).contiguous().view(B, 1, D)
        return self.out_proj(out)


In [663]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.head_dim
        self.scale = self.head_dim ** -0.5

        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
        self.out_proj = nn.Linear(config.d_model, config.d_model)

        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("causal_mask", mask)

    def forward_full(self, x):
        B, T, D = x.shape
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, -1, self.head_dim).transpose(1, 2)
        k = k.view(B, T, -1, self.head_dim).transpose(1, 2)
        v = v.view(B, T, -1, self.head_dim).transpose(1, 2)

        scores = (q @ k.transpose(-2, -1)) * self.scale
        mask = self.causal_mask[:T, :T]
        scores = scores.masked_fill(mask == 0, float("-inf"))

        out = (scores.softmax(-1) @ v)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out)

    def forward_cached(self, x, kv_cache):
        B, _, D = x.shape
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, 1, -1, self.head_dim).transpose(1, 2)
        k = k.view(B, 1, -1, self.head_dim).transpose(1, 2)
        v = v.view(B, 1, -1, self.head_dim).transpose(1, 2)

        kv_cache.append(k, v)

        k_full = kv_cache.keys
        v_full = kv_cache.values

        scores = (q @ k_full.transpose(-2, -1)) * self.scale
        out = (scores.softmax(-1) @ v_full)

        out = out.transpose(1, 2).contiguous().view(B, 1, D)
        return self.out_proj(out)


### FeedForward Layer

In the Transformer architecture, the FeedForward layer (also known as the Position-wise Feed-Forward Network or FFN) is applied independently to each position in the sequence. It consists of two linear transformations with a non-linear activation function (GELU) in between.

**Description:**

1.  **First Linear Layer (`self.fc1`)**: This layer projects the input `x` from `d_model` dimensions to `4 * d_model` dimensions. This expansion allows the model to learn more complex relationships within each token's representation.
2.  **Activation Function (`self.act`)**: A GELU (Gaussian Error Linear Unit) activation function is applied to the output of the first linear layer. GELU is a smooth approximation of the ReLU activation function, often performing better in Transformer-based models.
3.  **Second Linear Layer (`self.fc2`)**: This layer projects the expanded representation back from `4 * d_model` dimensions to the original `d_model` dimensions. This ensures that the output of the FFN has the same dimensionality as its input, allowing for residual connections.

This layer processes each position identically but independently, meaning the same weights are used for all positions, but each position gets its own distinct computation.

**Mathematical Formula:**

The FeedForward layer can be mathematically expressed as:

$$\text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2$$

Where:
*   $x$ is the input to the FeedForward network, typically the output of the self-attention sub-layer.
*   $W_1$ and $b_1$ are the weights and biases of the first linear transformation (from `d_model` to `4 * d_model`).
*   $W_2$ and $b_2$ are the weights and biases of the second linear transformation (from `4 * d_model` to `d_model`).
*   $\text{GELU}$ is the Gaussian Error Linear Unit activation function.

In [664]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, 4 * config.d_model)
        self.fc2 = nn.Linear(4 * config.d_model, config.d_model)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(x)))

### TransformerBlock

The `TransformerBlock` is the core building block of the Transformer's encoder and decoder. This particular implementation appears to be a decoder block designed for incremental inference, as indicated by the use of `CachedCausalSelfAttention` and an input shape `(B, 1, D)`.

Each `TransformerBlock` consists of two main sub-layers, each followed by a residual connection and layer normalization:

1.  **Cached Causal Self-Attention**: Processes the input `x` to allow it to attend to previous tokens, incorporating the KV-cache for efficient inference.
2.  **Feed-Forward Network (FFN)**: Further processes the output of the attention layer through two linear transformations with an activation function.

**Description:**

*   **`self.ln1` (Layer Normalization)**: Applied before the attention sub-layer. Layer Normalization helps stabilize training by normalizing the inputs to the next layer across the feature dimension. It ensures that the mean and variance of the inputs are consistent.
*   **`self.attn` (CachedCausalSelfAttention)**: This is the self-attention mechanism, adapted for efficient autoregressive decoding. It takes the layer-normalized input `self.ln1(x)` and a `kv_cache` (which stores keys and values of previously processed tokens for this specific layer). The output of the attention mechanism is then added to the original input `x` via a residual connection.
*   **`self.ln2` (Layer Normalization)**: Applied before the Feed-Forward Network. Similar to `self.ln1`, it normalizes the input to the FFN.
*   **`self.mlp` (FeedForward Network)**: This is the position-wise feed-forward network. It takes the layer-normalized output of the attention sub-layer `self.ln2(x)` and processes it. The output of the FFN is then added to the result of the attention sub-layer via another residual connection.

**Forward Pass Logic (`forward` method):**

The `forward` method implements the following sequence of operations:
1.  **Attention Sub-layer**: The input `x` is first normalized by `self.ln1`. This normalized input is then passed to the `self.attn` module along with the `kv_cache` for the current layer. The output of the attention module is added back to the original input `x` (residual connection).
    *   `x = x + self.attn(self.ln1(x), kv_cache)`
2.  **Feed-Forward Sub-layer**: The result from the attention sub-layer is then normalized by `self.ln2`. This normalized result is passed to the `self.mlp` module. The output of the MLP is added back to the result of the attention sub-layer (another residual connection).
    *   `x = x + self.mlp(self.ln2(x))`
3.  **Output**: The final result `x` is the output of the Transformer block.

**Mathematical Formula:**

Let $x$ be the input to the Transformer block, and $x_{\text{cache}}$ represent the `kv_cache` for the current layer.

1.  **Layer Normalization 1**: $x' = \text{LayerNorm}_1(x)$
2.  **Cached Causal Self-Attention**: $x'' = \text{CachedCausalSelfAttention}(x', x_{\text{cache}})$
3.  **Residual Connection 1**: $x_{\text{attn}} = x + x''$
4.  **Layer Normalization 2**: $x''' = \text{LayerNorm}_2(x_{\text{attn}})$
5.  **Feed-Forward Network**: $x'''' = \text{FeedForward}(x''')$
6.  **Residual Connection 2**: $x_{\text{output}} = x_{\text{attn}} + x''''$

Thus, the entire block's operation can be summarized as:

$$ \text{TransformerBlock}(x, x_{\text{cache}}) = \text{LayerNorm}_2(x + \text{CachedCausalSelfAttention}(\text{LayerNorm}_1(x), x_{\text{cache}})) + \text{FeedForward}(\text{LayerNorm}_2(x + \text{CachedCausalSelfAttention}(\text{LayerNorm}_1(x), x_{\text{cache}}))) $$

This structure, often referred to as "Pre-Normalization" or "Pre-LN" Transformer, applies layer normalization before the self-attention and FFN sub-layers.

In [665]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)

        self.attn = Attention(config)
        self.mlp = FeedForward(config)

    def forward_full(self, x):
        x = x + self.attn.forward_full(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

    def forward_cached(self, x, kv_cache):
        x = x + self.attn.forward_cached(self.ln1(x), kv_cache)
        x = x + self.mlp(self.ln2(x))
        return x


In [666]:
class ReferenceTransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


### MiniGPTInferenceModel: The Autoregressive Inference Engine

The `MiniGPTInferenceModel` class encapsulates the entire Transformer architecture, specifically designed for efficient autoregressive inference (token-by-token generation). It orchestrates the flow of a single input token through the model, leveraging Key-Value (KV) caching to maintain context from previously generated tokens.

**Description:**

1.  **Initialization (`__init__`)**:
    *   **`self.token_emb`**: An `nn.Embedding` layer that converts input token IDs into dense vector representations. Its size is `(vocab_size, d_model)`.
    *   **`self.pos_emb`**: An `nn.Embedding` layer that provides positional information, mapping the current `position` in the sequence to a `d_model`-dimensional vector. Its size is `(block_size, d_model)`.
    *   **`self.blocks`**: A `nn.ModuleList` containing a stack of `n_layers` `TransformerBlock` instances. Each `TransformerBlock` is configured to use `CachedCausalSelfAttention`, enabling efficient incremental decoding.
    *   **`self.ln_f`**: A final `nn.LayerNorm` applied after the stack of Transformer blocks, normalizing the output features before the final prediction head.
    *   **`self.lm_head`**: A linear layer (`nn.Linear`) that projects the `d_model`-dimensional output of the Transformer stack to `vocab_size` dimensions, representing the logits for the next token.
    *   **`self.kv_caches`**: A list of `KVCache` objects, one for each `TransformerBlock`, initialized in `reset_cache()`. These caches store the keys and values computed by each layer during the generation process.
    *   **`self.position`**: An integer tracking the current position in the sequence being generated, used for positional embeddings.

2.  **`reset_cache(self)`**:
    *   This method initializes or clears the KV caches for all Transformer layers and resets the `self.position` counter to 0. It should be called at the beginning of a new generation sequence.

3.  **`forward_step(self, token_ids: torch.Tensor)`**:
    *   This method performs a single forward pass for one token (`T=1`), generating the logits for the next token. It's decorated with `@torch.no_grad()` as it's intended for inference, preventing gradient calculation.
    *   **Input**: `token_ids` is a tensor of shape `(B, 1)`, where `B` is the batch size and `1` indicates a single token.
    *   **Positional Embedding**: The current `self.position` is used to retrieve the appropriate positional embedding. This position is incremented after each `forward_step`.
    *   **Input Embedding**: The input `token_ids` are combined with their corresponding positional embeddings: `x = token_emb(token_ids) + pos_emb(position_ids)`.
    *   **Transformer Blocks**: The embedded input `x` is sequentially passed through each `TransformerBlock`. Crucially, each block receives its own `kv_cache`, allowing it to append the current token's keys and values and attend to all prior tokens' cached keys and values.
    *   **Final Layer Norm and Head**: After passing through all blocks, the output `x` is normalized by `self.ln_f` and then projected through `self.lm_head` to produce `logits` of shape `(B, vocab_size)`.
    *   **Output**: The method returns `logits.squeeze(1)`, which are the probabilities (before softmax) for the next token across the vocabulary.

**Mathematical Intuition for `forward_step`:**

Let $t$ be the current token position in the sequence. The `forward_step` processes a single token at a time:

1.  **Input Embedding**: The input token $w_t$ is first converted into its vector representation, combining token and positional information:
    $$ 	ext{embedding}_t = 	ext{TokenEmbed}(w_t) + 	ext{PosEmbed}(t) $$

2.  **Transformer Blocks**: The embedded token passes through $L$ Transformer blocks. For each block $l$ from $1$ to $L$:
    $$ h_t^{(l)} = 	ext{TransformerBlock}^{(l)}(h_t^{(l-1)}, 	ext{KVCache}^{(l)}) $$
    where $h_t^{(0)} = 	ext{embedding}_t$. Each `TransformerBlock` computes its self-attention using $Q_t^{(l)}$ (query for current token) and $K_{	ext{cached}}^{(l)}$, $V_{	ext{cached}}^{(l)}$ (cached keys and values for all tokens up to $t$).

3.  **Final Layer Normalization**: After all blocks, the output is normalized:
    $$ h_t^{	ext{final}} = 	ext{LayerNorm}_{	ext{final}}(h_t^{(L)}) $$

4.  **Language Model Head**: The final representation is projected to vocabulary size to obtain logits for the next token:
    $$ 	ext{logits}_{t+1} = 	ext{LMHead}(h_t^{	ext{final}}) $$
    These logits represent the unnormalized probabilities for each token in the vocabulary to be the next token in the sequence.

In [667]:
class MiniGPTInferenceModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.block_size, config.d_model)

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        self.ln_f = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.reset_cache()
        self.reference_blocks = nn.ModuleList(
    [ReferenceTransformerBlock(config) for _ in range(config.n_layers)]
)


    def reset_cache(self):
        self.kv_caches = [KVCache() for _ in range(self.config.n_layers)]
        self.position = 0

    @torch.no_grad()
    def forward_step(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        token_ids: (B, 1)
        returns logits: (B, vocab_size)
        """
        B, T = token_ids.shape
        assert T == 1, "Inference model expects exactly one token at a time"

        pos = torch.full((B, 1), self.position, device=token_ids.device)

        x = self.token_emb(token_ids) + self.pos_emb(pos)

        for block, cache in zip(self.blocks, self.kv_caches):
          x = block.forward_cached(x, cache)


        x = self.ln_f(x)
        logits = self.lm_head(x)

        self.position += 1
        return logits.squeeze(1)

    #helper function just for testing ::Cached vs Non-Cached Attention Equivalence Test
    @torch.no_grad()
    def forward_full(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        input_ids: (B, T)
        returns logits: (B, T, vocab_size)
        """
        B, T = input_ids.shape
        positions = torch.arange(T, device=input_ids.device).unsqueeze(0)

        x = self.token_emb(input_ids) + self.pos_emb(positions)

        for block in self.blocks:
          x = block.forward_full(x)


        x = self.ln_f(x)
        return self.lm_head(x)

### `generate` Function: Autoregressive Text Generation

The `generate` function is responsible for orchestrating the autoregressive generation of new tokens given an initial prompt (`input_ids`). It utilizes the `MiniGPTInferenceModel` to predict subsequent tokens one by one, leveraging the KV cache for efficiency and sampling new tokens based on predicted probabilities.

**Description:**

1.  **Initialization (`model.reset_cache()`)**:
    *   Before starting a new generation sequence, the `model.reset_cache()` method is called. This clears all Key-Value (KV) caches in each Transformer layer and resets the positional counter, ensuring a clean state for generating new text.

2.  **Feeding Prompt Tokens (`for t in range(input_ids.size(1)):`)**:
    *   The `input_ids` (the initial prompt) are fed into the model one token at a time. For each token in the prompt, `model.forward_step()` is called. This populates the KV caches within each Transformer layer with the keys and values corresponding to the prompt tokens. Crucially, during this phase, no new tokens are generated; the model is simply "reading" the prompt to build its internal context.

3.  **Starting Generation (`token = input_ids[:, -1:]`)**:
    *   After processing the prompt, the last token of the prompt is selected as the starting point for actual generation. This `token` will be the first input to `model.forward_step()` during the iterative generation loop.
    *   An `outputs` list is initialized with this starting token to collect all generated tokens.

4.  **Iterative Token Generation (`for _ in range(max_new_tokens):`)**:
    *   The core generation loop runs for `max_new_tokens` iterations, where `max_new_tokens` specifies how many new tokens to generate.
    *   **Forward Step**: `logits = model.forward_step(token)`: The current `token` (initially the last prompt token, then each newly generated token) is passed through the `MiniGPTInferenceModel`'s `forward_step`. This returns the logits (unnormalized log-probabilities) for the next possible token across the entire vocabulary.
    *   **Temperature Scaling**: `logits = logits / temperature`: The `temperature` parameter controls the randomness of the generation. A higher temperature (e.g., > 1.0) makes the softmax distribution flatter, leading to more diverse but potentially less coherent outputs. A lower temperature (e.g., < 1.0) makes the distribution sharper, leading to more deterministic and focused outputs. `temperature = 1.0` means no change.
    *   **Probability Distribution**: `probs = torch.softmax(logits, dim=-1)`: The scaled logits are converted into a probability distribution over the vocabulary using the softmax function. This gives the likelihood of each vocabulary token being the next token.
    *   **Token Sampling**: `token = torch.multinomial(probs, num_samples=1)`: A new token is sampled from this probability distribution using `torch.multinomial`. This allows for probabilistic generation, where tokens with higher probabilities are more likely to be chosen, but less likely tokens can also be selected.
    *   **Append Output**: `outputs.append(token)`: The newly sampled token is appended to the `outputs` list.

5.  **Concatenate Outputs (`return torch.cat(outputs, dim=1)`)**:
    *   After `max_new_tokens` have been generated, all collected `outputs` (including the initial prompt's last token and all subsequent generated tokens) are concatenated along dimension 1 to form the complete generated sequence.

**Mathematical Intuition for `generate`:**

Given an initial prompt $P = (p_1, p_2, 	ext{...}, p_m)$, the function first initializes the model's state:

1.  **Prompt Processing**: The model processes each token $p_i$ of the prompt sequentially:
    $$ 	ext{model.forward_step}(p_i) 	ext{ for } i=1 	ext{ to } m $$
    This populates the KV caches within the model based on the prompt's context.

2.  **Iterative Generation**: Starting with $w_0 = p_m$ (the last token of the prompt), the function iteratively generates new tokens $w_1, w_2, 	ext{...}, w_N$ for $N = 	ext{max_new_tokens}$.
    For each step $k$ from $0$ to $N-1$:
    *   **Predict Logits**: The model computes logits for the next token given the current token $w_k$ and the accumulated KV cache from previous tokens:
        $$ L_{k+1} = 	ext{model.forward_step}(w_k) $$
    *   **Apply Temperature**: Scale the logits by temperature:
        $$ L'_{k+1} =
rac{L_{k+1}}{	ext{temperature}} $$
    *   **Compute Probabilities**: Convert scaled logits to probabilities:
        $$ P_{k+1} = 	ext{softmax}(L'_{k+1}) $$
    *   **Sample Next Token**: Sample the next token $w_{k+1}$ from the probability distribution $P_{k+1}$:
        $$ w_{k+1} 	hicksim 	ext{Multinomial}(P_{k+1}) $$

3.  **Final Sequence**: The generated sequence is formed by concatenating the initial prompt tokens and the generated tokens:
    $$ 	ext{Output} = (p_1, 	ext{...}, p_m, w_1, 	ext{...}, w_N) $$

In [668]:
@torch.no_grad()
def generate(
    model,
    input_ids: torch.Tensor,
    max_new_tokens: int,
    temperature: float = 1.0
):
    model.reset_cache()

    # feed prompt tokens first
    for t in range(input_ids.size(1)):
        model.forward_step(input_ids[:, t:t+1])

    token = input_ids[:, -1:]

    outputs = [token]

    for _ in range(max_new_tokens):
        logits = model.forward_step(token)
        logits = logits / temperature

        probs = torch.softmax(logits, dim=-1)
        token = torch.multinomial(probs, num_samples=1)

        outputs.append(token)

    return torch.cat(outputs, dim=1)


#Step-by-step validation strategy for a KV-cached Transformer

#### KVCache Validation

This code snippet demonstrates the functionality of the `KVCache` class. It initializes a cache and then iteratively appends new key and value tensors, showing how the `keys` tensor grows in the sequence length dimension with each append operation. This confirms that the cache is correctly storing and concatenating past keys and values.

In [669]:
B, H, Dh = 1, 2, 4
cache = KVCache()

for t in range(3):
    k = torch.randn(B, H, 1, Dh)
    v = torch.randn(B, H, 1, Dh)
    cache.append(k, v)

    print(f"Step {t}: keys shape =", cache.keys.shape)


Step 0: keys shape = torch.Size([1, 2, 1, 4])
Step 1: keys shape = torch.Size([1, 2, 2, 4])
Step 2: keys shape = torch.Size([1, 2, 3, 4])


#### Cached Attention (Shape and Determinism) Validation

This section validates the `CachedCausalSelfAttention` module, ensuring that its output shape remains consistent `(B, 1, D)` during incremental decoding (processing one token at a time). It also implicitly checks that the internal caching mechanism functions without altering the expected output dimensions.

In [670]:
config = ModelConfig(
    vocab_size=100,
    d_model=8,
    n_heads=2,
    n_layers=1,
    block_size=16
)

attn = CachedCausalSelfAttention(config)
cache = KVCache()

x = torch.randn(1, 1, 8)

for t in range(3):
    y = attn(x, cache)
    print(f"Step {t}: output shape =", y.shape)


Step 0: output shape = torch.Size([1, 1, 8])
Step 1: output shape = torch.Size([1, 1, 8])
Step 2: output shape = torch.Size([1, 1, 8])


#### TransformerBlock (Residual Safety) Validation

Here, a single `TransformerBlock` is tested in an incremental fashion. The code verifies that the output `x` maintains its shape `(B, 1, D)` after passing through the block, confirming that residual connections and layer normalizations do not introduce dimensional changes. It also shows that the `kv_cache` associated with this block correctly accumulates keys and values for each step.

In [671]:
block = TransformerBlock(config)
cache = KVCache()

x = torch.randn(1, 1, 8)

for t in range(3):
    x = block.forward_cached(x, cache)
    print(f"Step {t}: x shape =", x.shape,
          "cache length =", cache.keys.size(2))


Step 0: x shape = torch.Size([1, 1, 8]) cache length = 1
Step 1: x shape = torch.Size([1, 1, 8]) cache length = 2
Step 2: x shape = torch.Size([1, 1, 8]) cache length = 3


#### Full Model `forward_step` Validation

This demonstrates the `forward_step` method of the `MiniGPTInferenceModel`. It feeds a single token at a time and verifies that the model produces logits of the expected shape `(B, vocab_size)` and that the internal `position` counter correctly increments with each step, reflecting the current length of the generated sequence.

In [672]:
model = MiniGPTInferenceModel(config)

token = torch.tensor([[5]])

for t in range(3):
    logits = model.forward_step(token)
    print(
        f"Step {t}: logits shape =",
        logits.shape,
        "position =",
        model.position
    )


Step 0: logits shape = torch.Size([1, 100]) position = 1
Step 1: logits shape = torch.Size([1, 100]) position = 2
Step 2: logits shape = torch.Size([1, 100]) position = 3


#### Minimal Generation Smoke Test

This is a basic test of the `generate` function. It provides a short `input_ids` prompt and requests a few new tokens. The output shape is checked to ensure that the function returns a tensor with the expected batch size and a sequence length that combines the prompt length and the number of generated tokens.

In [673]:
input_ids = torch.tensor([[1, 2, 3]])
out = generate(model, input_ids, max_new_tokens=5)
print(out.shape)


torch.Size([1, 6])


#Cached vs Non-Cached Attention Equivalence Test

###Equivalence Test

The test calculates the maximum absolute difference between the logits produced by the `forward_cached` method (which uses the KV cache for incremental processing) and the `forward_full` method (which processes the entire sequence at once). A very small `max_diff` (typically close to floating-point precision limits, like `e-07` or `e-08`) indicates that the cached and non-cached implementations are functionally equivalent. This confirms that the KV caching mechanism correctly preserves the mathematical output of the self-attention mechanism during autoregressive inference.

In [674]:
# Re-initialize model here to ensure it uses the latest class definition from cell jz00E0jK4qTP
# 'config' is expected to be available from an earlier cell (d79f4cb9).
model = MiniGPTInferenceModel(config)

torch.manual_seed(0)

input_ids = torch.tensor([[1, 2, 3, 4]])

# cached decoding
model.reset_cache()
cached_logits = []

for t in range(input_ids.size(1)):
    logits = model.forward_step(input_ids[:, t:t+1])
    cached_logits.append(logits)

cached_logits = torch.stack(cached_logits, dim=1)

# full forward
full_logits = model.forward_full(input_ids)

# compare
max_diff = (cached_logits - full_logits).abs().max()
print("Max difference:", max_diff.item())

Max difference: 2.384185791015625e-07


### Max Difference Output

The output `Max difference: 3.5762786865234375e-07` is extremely small. This indicates that the logits produced by the `forward_cached` method (which uses the KV cache for incremental processing) and the `forward_full` method (which processes the entire sequence at once) are practically identical. The tiny difference observed is well within typical floating-point precision errors in numerical computations, confirming that the KV caching mechanism is implemented correctly and maintains functional equivalence with the full forward pass.

#Top-k and Top-p (Nucleus) Sampling — Quality Control in Generation

### `top_k_filter` Function

This function implements Top-K sampling, a method for controlling the diversity of generated text. It works by retaining only the `k` most probable tokens for consideration at each generation step, and setting the probabilities of all other tokens to zero.

**How it works:**
1.  `torch.topk(probs, k)`: Identifies the `k` tokens with the highest probabilities and their corresponding indices.
2.  `filtered = torch.zeros_like(probs)`: Creates a new tensor of zeros with the same shape as the input probabilities.
3.  `filtered.scatter_(1, indices, values)`: Scatters the `values` (the top `k` probabilities) into the `filtered` tensor at the `indices` found by `topk`. All other positions remain zero.
4.  `return filtered / filtered.sum(dim=-1, keepdim=True)`: Renormalizes the probabilities of the selected `k` tokens so that they sum to 1. This ensures that the filtered distribution remains a valid probability distribution.

**Effect:** Top-K sampling reduces the vocabulary size from which the next token is chosen, making the generation process less prone to selecting very unlikely tokens, thus improving coherence.

In [675]:
def top_k_filter(probs, k):
    values, indices = torch.topk(probs, k)
    filtered = torch.zeros_like(probs)
    filtered.scatter_(1, indices, values)
    return filtered / filtered.sum(dim=-1, keepdim=True)


### `top_p_filter` (Nucleus Sampling) Function

This function implements Top-P sampling, also known as Nucleus Sampling. This method selects the smallest set of most probable tokens whose cumulative probability exceeds a threshold `p`, and then redistributes the probability mass among only these selected tokens.

**How it works:**
1.  `sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)`: Sorts the probabilities in descending order and keeps track of their original indices.
2.  `cumulative_probs = torch.cumsum(sorted_probs, dim=-1)`: Calculates the cumulative sum of the sorted probabilities.
3.  `mask = cumulative_probs <= p` and `mask[..., 0] = True`: Creates a boolean mask. All tokens whose cumulative probability is less than or equal to `p` are kept (`True`). The first token is always kept to ensure at least one token is always in the nucleus.
4.  `filtered_sorted_probs = torch.zeros_like(sorted_probs)`: Initializes a zero tensor for the filtered probabilities.
5.  `filtered_sorted_probs[mask] = sorted_probs[mask]`: Assigns the probabilities of the tokens within the nucleus (where `mask` is `True`) to the `filtered_sorted_probs` tensor. All others remain zero.
6.  `result_probs.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted_probs)`: Reconstructs the probability distribution in the original token order, placing the filtered probabilities back into their correct positions.
7.  `return result_probs / result_probs.sum(dim=-1, keepdim=True)`: Renormalizes the probabilities within the nucleus so they sum to 1.

**Effect:** Top-P sampling dynamically adjusts the effective vocabulary size based on the shape of the probability distribution. It helps avoid repetitive text generation while maintaining diversity and relevance.

In [676]:
def top_p_filter(probs, p):
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # Create a mask for elements whose cumulative probability is <= p
    mask = cumulative_probs <= p
    mask[..., 0] = True  # Always keep at least one token

    # Set probabilities of tokens outside the mass p to zero
    filtered_sorted_probs = torch.zeros_like(sorted_probs)
    # This operation correctly assigns values where mask is True, maintaining 2D shape
    filtered_sorted_probs[mask] = sorted_probs[mask]

    # Re-order the filtered probabilities back to the original index positions
    # using scatter_ with the original sorted_indices (which is 2D)
    result_probs = torch.zeros_like(probs)
    result_probs.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted_probs)

    # Renormalize the filtered probabilities
    return result_probs / result_probs.sum(dim=-1, keepdim=True)

### `generate1` Function

This is an enhanced version of the `generate` function that incorporates temperature scaling, Top-K, and Top-P (Nucleus) sampling for more flexible and controlled text generation. It processes tokens one by one, leveraging the KV cache for efficiency.

**Key Features:**
1.  **Cache Reset:** `model.reset_cache()` ensures a clean state for each new generation sequence.
2.  **Prompt Processing:** The input `input_ids` (prompt) are fed token by token into `model.forward_step()` to build up the KV cache without generating new tokens.
3.  **Iterative Generation:** The loop runs for `max_new_tokens` to generate subsequent tokens.
4.  **Greedy Sampling (`temperature = 0.0`):** If `temperature` is 0, the function bypasses probability calculation and directly picks the token with the highest logit using `torch.argmax`. This results in deterministic, greedy generation.
5.  **Temperature Scaling:** For `temperature > 0.0`, `logits` are divided by `temperature` before `softmax`. A higher temperature flattens the probability distribution, making less likely tokens more probable and increasing randomness. A lower temperature sharpens the distribution, making the most likely tokens even more probable.
6.  **Top-K Filtering:** If `top_k` is provided, `top_k_filter` is applied to prune the probability distribution to only the `k` most probable tokens.
7.  **Top-P Filtering:** If `top_p` is provided, `top_p_filter` is applied to select the smallest set of tokens whose cumulative probability exceeds `p`.
8.  **Token Sampling:** `torch.multinomial(probs, num_samples=1)` samples a new token based on the (optionally filtered and temperature-scaled) probability distribution.
9.  **Output Accumulation:** Each newly generated token is appended to the `outputs` list, which is concatenated at the end to form the complete generated sequence.

In [677]:
@torch.no_grad()
def generate1(
    model,
    input_ids,
    max_new_tokens,
    temperature=1.0,
    top_k=None,
    top_p=None
):
    model.reset_cache()

    for t in range(input_ids.size(1)):
        model.forward_step(input_ids[:, t:t+1])

    token = input_ids[:, -1:]
    outputs = [token]

    for _ in range(max_new_tokens):
        logits = model.forward_step(token)

        if temperature == 0.0:
            # Greedy sampling (argmax)
            token = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)

            if top_k is not None:
                probs = top_k_filter(probs, top_k)

            if top_p is not None:
                probs = top_p_filter(probs, top_p)

            token = torch.multinomial(probs, num_samples=1)
        outputs.append(token)

    return torch.cat(outputs, dim=1)

##Testing Top-k / Top-p Sampling (Correctness + Behavior) Test Case

This section contains several test cases designed to validate the correctness and behavior of the `generate1` function, particularly focusing on the `top_k` and `top_p` sampling strategies and the effect of `temperature`.

#### 1. Deterministic sanity check (no sampling)
**Purpose:** To verify that when `temperature` is set to `0.0`, the `generate1` function behaves deterministically and performs greedy sampling (i.e., always picking the token with the highest probability). By setting a `torch.manual_seed(0)`, we ensure that if there were any randomness, it would be consistent, but with `temperature=0.0`, the output should be identical every time the cell is run with the same initial `input_ids`.

**Expected Outcome:** The output sequence of tokens should be the same each time, confirming greedy selection.

In [678]:
torch.manual_seed(0)

input_ids = torch.tensor([[1, 2, 3]])

out = generate1(
    model,
    input_ids,
    max_new_tokens=5,
    temperature=0.0,   # argmax behavior
    top_k=None,
    top_p=None
)

print(out)


tensor([[ 3, 42, 51, 95, 82, 51]])


#### 2. Top-k constraint test (hard cutoff)
**Purpose:** To confirm that the `top_k_filter` function correctly identifies and keeps only the `k` most probable tokens, setting all other probabilities to zero, and then renormalizing the remaining probabilities.

**Expected Outcome:** The `print` statement should show `Non-zero entries: 5` (or whatever `k` is set to), indicating that only the top `k` probabilities are non-zero after filtering.

In [679]:
torch.manual_seed(0)

logits = torch.randn(1, model.config.vocab_size)
probs = torch.softmax(logits, dim=-1)

filtered = top_k_filter(probs, k=5)

print("Non-zero entries:", (filtered > 0).sum().item())


Non-zero entries: 5


#### 3. Top-p constraint test (probability mass)
**Purpose:** To verify that the `top_p_filter` function correctly identifies the smallest set of most probable tokens whose cumulative probability exceeds `p`, and then retains only those tokens, setting others to zero, followed by renormalization.

**Expected Outcome:** The `print` statement should show a `Cumulative prob` very close to `1.0` (due to renormalization) for the filtered probabilities. This indicates that the filtering and renormalization were performed correctly, maintaining the total probability mass among the selected tokens.

In [680]:
torch.manual_seed(0)

logits = torch.randn(1, model.config.vocab_size)
probs = torch.softmax(logits, dim=-1)

filtered = top_p_filter(probs, p=0.9)

sorted_probs, _ = torch.sort(filtered, descending=True)
print("Cumulative prob:", sorted_probs.sum().item())


Cumulative prob: 0.9999999403953552


#### 4. Stochasticity test (sampling actually random)
**Purpose:** To demonstrate that when `temperature > 0.0` and `top_p` (or `top_k`) filtering is applied, the generation process is indeed stochastic. By setting different random seeds (`torch.manual_seed(42)` and `torch.manual_seed(43)`), we expect to get different generated sequences, even with the same initial prompt and sampling parameters.

**Expected Outcome:** `out1` and `out2` should contain different sequences of generated tokens, proving that the sampling mechanism introduces randomness.

In [681]:
torch.manual_seed(42)
out1 = generate1(model, input_ids, 10, temperature=1.0, top_p=0.9)

torch.manual_seed(43)
out2 = generate1(model, input_ids, 10, temperature=1.0, top_p=0.9)

print(out1)
print(out2)


tensor([[ 3, 60, 29, 11, 80, 14, 64, 17, 42, 22, 25]])
tensor([[ 3, 18, 71, 73, 73, 50, 39, 45, 80, 65, 92]])


#### 5. Entropy control (temperature effect)
**Purpose:** To illustrate how the `temperature` parameter influences the 'creativity' or 'randomness' of the generated text. A lower temperature makes the probability distribution sharper (more likely to pick the most probable token), leading to more deterministic and focused output. A higher temperature flattens the distribution, increasing the chances of picking less probable tokens, resulting in more diverse and potentially 'creative' (but sometimes less coherent) output.

**Expected Outcome:** As `temp` increases, the generated sequences are expected to diverge more significantly, demonstrating the increased entropy in sampling.

In [682]:
for temp in [0.3, 0.7, 1.2]:
    torch.manual_seed(0)
    out = generate1(
        model,
        input_ids,
        10,
        temperature=temp,
        top_p=0.9
    )
    print(f"Temp={temp} →", out)


Temp=0.3 → tensor([[ 3, 47, 44, 51,  6, 83, 99, 47, 28, 55, 34]])
Temp=0.7 → tensor([[ 3, 58, 28, 51, 29, 82,  4, 40, 28, 10,  7]])
Temp=1.2 → tensor([[ 3, 58, 28, 51,  8, 36,  4, 40, 28, 10,  7]])


#Benchmark: Cached vs Non-Cached Decoding Speed

### `benchmark` Function: Cached vs Non-Cached Decoding Speed

This `benchmark` function is designed to measure and compare the inference speed of the `MiniGPTInferenceModel` using two different decoding strategies:
1.  **Full Decoding (`use_cache=False`):** Simulates a standard Transformer decoding process where the entire sequence is re-fed to the model at each step, causing redundant computations.
2.  **Cached Decoding (`use_cache=True`):** Leverages the KV cache to store past keys and values, only processing the most recent token incrementally.

**Description:**

*   **`@torch.no_grad()`**: Disables gradient calculation, as this function is for inference/benchmarking only, not training.
*   **`model.eval()`**: Sets the model to evaluation mode, disabling dropout and batch normalization updates.
*   **`model.reset_cache()` (for `use_cache=True`)**: Ensures a clean KV cache state before starting a new generation sequence.
*   **Loop for `max_new_tokens`**: The core of the benchmark iterates `max_new_tokens` times to simulate the generation of new tokens.

    *   **If `use_cache=True` (Cached Decoding):**
        *   The model receives only the last generated token (`x[:, -1:]`) as input to `model.forward_step()`. This is efficient because the KV cache handles the historical context.
        *   The `argmax` of the logits is taken to deterministically select the next token.
        *   The `next_token` (which is a single token) is concatenated to `x` for the next iteration.

    *   **If `use_cache=False` (Full Decoding):**
        *   The *entire sequence* generated so far (`x`) is fed to `model.forward_full()` at each step. This method recomputes attention over the growing sequence every time, demonstrating the computational cost without caching.
        *   The `logits[:, -1, :]` selects the logits for the last token in the sequence (which represents the prediction for the next token).
        *   The `argmax` is used to select the next token.
        *   The `next_token` is concatenated to `x`, growing the sequence for the next iteration.

*   **`torch.cuda.synchronize()`**: If a GPU is available, this ensures all GPU computations are complete before measuring the end time, providing accurate timing.
*   **Return Value**: Returns the total time taken for generation.

### Benchmark Execution and Results

This code block sets up and executes the `benchmark` function to measure the speed of full vs. cached decoding.

**Description:**

*   **`torch.manual_seed(0)`**: Ensures reproducibility of random number generation, particularly for `torch.randint` used to create initial `input_ids`.
*   **`initial_prompt_length` and `max_new_tokens_actual`**: These variables are carefully set to ensure that the total sequence length during generation (`initial_prompt_length + max_new_tokens_actual`) does not exceed `model.config.block_size` (which is 16). This is crucial because `model.forward_full` (used for non-cached benchmarking) has a hard limit on the sequence length it can process due to positional embeddings.
    *   `input_ids` is initialized with a single token (`initial_prompt_length = 1`).
    *   `max_new_tokens` is set to `block_size - initial_prompt_length` (15), meaning a total of 15 new tokens will be generated.
*   **Benchmarking Calls**: The `benchmark` function is called twice:
    *   `t_full = benchmark(model, input_ids, max_new_tokens, use_cache=False)`: Measures time for full, non-cached decoding.
    *   `t_cached = benchmark(model, input_ids, max_new_tokens, use_cache=True)`: Measures time for cached decoding.
*   **Output**: The elapsed times for both methods and the speedup factor (`t_full / t_cached`) are printed.

**Expected Outcome:** The cached decoding method (`t_cached`) is expected to be significantly faster than the full decoding method (`t_full`), resulting in a speedup factor greater than 1. This demonstrates the practical benefits of KV caching for efficient autoregressive inference.

In [683]:
import time

@torch.no_grad()
def benchmark(model, input_ids, max_new_tokens, use_cache=True):
    model.eval()

    if use_cache:
        model.reset_cache()

    x = input_ids.clone()

    start = time.time()

    for _ in range(max_new_tokens):
        if use_cache:
            logits = model.forward_step(x[:, -1:])
        else:
            logits = model.forward_full(x)

        if logits.dim() == 3:
            logits = logits[:, -1, :]

        next_token = logits.argmax(dim=-1, keepdim=True)
        x = torch.cat([x, next_token], dim=1)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    return time.time() - start



In [684]:
torch.manual_seed(0)

# Adjust initial prompt length and max_new_tokens to fit within config.block_size (which is 16)
initial_prompt_length = 1
max_new_tokens_actual = model.config.block_size - initial_prompt_length # 16 - 1 = 15

input_ids = torch.randint(0, model.config.vocab_size, (1, initial_prompt_length))
max_new_tokens = max_new_tokens_actual

t_full = benchmark(model, input_ids, max_new_tokens, use_cache=False)
t_cached = benchmark(model, input_ids, max_new_tokens, use_cache=True)

print(f"Full decoding time:   {t_full:.4f}s")
print(f"Cached decoding time: {t_cached:.4f}s")
print(f"Speedup: {t_full / t_cached:.2f}x")

Full decoding time:   0.0101s
Cached decoding time: 0.0086s
Speedup: 1.18x


#Logits Processor (Quality Control Layer in Decoding) & Repetition Penalty (Preventing Token Loops in Decoding)

### LogitsProcessor Class

The `LogitsProcessor` class acts as a flexible and composable layer for post-processing raw model logits before token sampling. It allows for the application of various decoding strategies like temperature scaling, Top-K filtering, Top-P (Nucleus) filtering, and repetition penalty in a modular way.

**Description:**

1.  **Initialization (`__init__`)**:
    *   `self.temperature`: Controls the randomness of the sampling. A value of `1.0` means no change, `0.0` leads to greedy sampling (argmax), values `> 1.0` increase randomness, and `0.0 < value < 1.0` makes sampling more deterministic.
    *   `self.top_k`: If set to an integer, it limits the sampling pool to the `k` tokens with the highest logits. Probabilities of other tokens are implicitly set to negative infinity.
    *   `self.top_p`: If set to a float between 0 and 1, it implements Nucleus Sampling. It finds the smallest set of tokens whose cumulative probability exceeds `p`, and only samples from these tokens. Probabilities of tokens outside this set are implicitly set to negative infinity.
    *   `self.repetition_penalty`: If set to a float `> 1.0`, it penalizes tokens that have already appeared in the `generated_tokens` sequence. This helps prevent the model from generating repetitive text.

2.  **Call Method (`__call__`)**:
    *   Takes raw `logits` (shape `B, vocab_size`) and optionally `generated_tokens` (shape `B, T`) as input.
    *   **0. Repetition Penalty**: If `repetition_penalty` is specified and `generated_tokens` are provided, the logits of tokens that are present in `generated_tokens` are modified. If a logit is positive, it's divided by `penalty`; if negative, it's multiplied by `penalty`. This makes previously generated tokens less likely to be chosen again.
    *   **1. Temperature Scaling**: Divides `logits` by `self.temperature`. This step is crucial for controlling the shape of the probability distribution. Higher temperature flattens the distribution, increasing randomness.
    *   **2. Top-K Filtering**: If `self.top_k` is specified, it identifies the `top_k` logits and sets all other logits to `float("-inf")`. This effectively makes their probabilities zero after softmax.
    *   **3. Top-P Filtering (Nucleus Sampling)**: If `self.top_p` is specified, it sorts the logits, calculates cumulative probabilities, and identifies the smallest set of tokens that make up `self.top_p` of the probability mass. Logits outside this set are then set to `float("-inf")`.
    *   **Returns**: The processed `logits` with modifications from repetition penalty, temperature, Top-K, and Top-P applied. These modified logits can then be passed to a softmax function and subsequently to a sampling mechanism (like `torch.multinomial`).

**How it works (Detailed):**

*   **Repetition Penalty**: Directly manipulates the raw logits. A higher penalty value more strongly discourages repetition.
*   **Temperature**: Directly scales the logits. Higher temperature makes the `softmax` distribution flatter, increasing the probability of less likely tokens. Lower temperature makes it sharper.
*   **Top-K**: A hard cutoff. Only the `k` highest logits are considered. This helps prevent very unlikely tokens from being sampled.
*   **Top-P**: A dynamic cutoff. It ensures that the model samples from a diverse yet probable set of tokens. It adapts to the shape of the probability distribution for each prediction step.

In [685]:
class LogitsProcessor:
    def __init__(
        self,
        temperature: float = 1.0,
        top_k: int | None = None,
        top_p: float | None = None,
        repetition_penalty: float | None = None,
    ):
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.repetition_penalty = repetition_penalty

    def __call__(
        self,
        logits: torch.Tensor,
        generated_tokens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        logits: (B, vocab_size)
        generated_tokens: (B, T)
        """

        # 0. repetition penalty
        if (
            self.repetition_penalty is not None
            and generated_tokens is not None
        ):
            penalty = self.repetition_penalty

            for b in range(logits.size(0)):
                unique_tokens = torch.unique(generated_tokens[b])
                for token in unique_tokens:
                    token = token.item()
                    if logits[b, token] > 0:
                        logits[b, token] /= penalty
                    else:
                        logits[b, token] *= penalty

        # 1. temperature
        if self.temperature != 1.0:
            logits = logits / self.temperature

        # 2. top-k
        if self.top_k is not None:
            values, _ = torch.topk(logits, self.top_k, dim=-1)
            min_values = values[:, -1].unsqueeze(-1)
            logits = torch.where(
                logits < min_values,
                torch.full_like(logits, float("-inf")),
                logits,
            )

        # 3. top-p
        if self.top_p is not None:
            sorted_logits, sorted_indices = torch.sort(
                logits, descending=True, dim=-1
            )
            probs = torch.softmax(sorted_logits, dim=-1)
            cumulative_probs = probs.cumsum(dim=-1)

            cutoff = cumulative_probs > self.top_p
            cutoff[:, 1:] = cutoff[:, :-1].clone()
            cutoff[:, 0] = False

            sorted_logits = sorted_logits.masked_fill(
                cutoff, float("-inf")
            )

            logits = torch.gather(
                sorted_logits,
                dim=-1,
                index=torch.argsort(sorted_indices, dim=-1),
            )

        return logits


### LogitsProcessor Test Case Explanation

This test case (`PZYgMzuJwZV4`) validates the functionality of the `LogitsProcessor` class, particularly focusing on how it applies temperature scaling, Top-K, Top-P, and the repetition penalty.

**First Part: General Logits Processing and Sampling**

1.  **Initialization**: A `LogitsProcessor` instance is created with `temperature=0.7`, `top_k=50`, `top_p=0.9`, and `repetition_penalty=1.2`.
2.  **Forward Pass**: `model.forward_step(x[:, -1:])` generates raw logits for the next token based on the model's current state.
3.  **Processing Logits**: `logits = processor(logits, generated_tokens=x)` applies all the configured processing steps (temperature, top-k, top-p, and repetition penalty based on `x`) to the raw logits.
4.  **Sampling**: `torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)` samples a `next_token` from the processed probability distribution.

**Purpose**: This part demonstrates the integrated usage of the `LogitsProcessor` in a typical generation step and confirms that a token can be sampled from the processed distribution.

**Second Part: Repetition Penalty Specific Test Case**

This section explicitly verifies the `repetition_penalty` mechanism.

1.  **Simulated Generated Tokens**: `simulated_generated_tokens = torch.tensor([[10, 20, 10, 30]])` creates a mock sequence of previously generated tokens, where token `10` and `20` appear.
2.  **Raw Logits Creation**: `raw_logits = torch.randn(1, model.config.vocab_size)` generates random raw logits. Specific high logits are assigned to token `10` (repeated), `20` (repeated), and `40` (new/non-repeated) to clearly observe the penalty effect.
3.  **Repetition Penalty Application**: A new `LogitsProcessor` instance `processor_rp` is created with `repetition_penalty=2.0`. It then processes the `raw_logits` using `simulated_generated_tokens`.
4.  **Verification**: Print statements show the original and processed logits for tokens `10`, `20`, and `40`.
5.  **Assertions**: Assertions are used to programmatically confirm:
    *   `processed_logits[0, 10].item() < raw_logits[0, 10].item()`: Logit for token 10 (repeated and positive) should be reduced.
    *   `processed_logits[0, 20].item() < raw_logits[0, 20].item()`: Logit for token 20 (repeated and positive) should be reduced.
    *   `processed_logits[0, 40].item() == raw_logits[0, 40].item()`: Logit for token 40 (non-repeated) should remain unchanged.

**Purpose**: This detailed setup isolates and confirms that the `repetition_penalty` correctly modifies the logits of previously generated tokens, making them less likely to be sampled again, thus preventing repetitive text generation.

In [688]:
torch.manual_seed(0) # Ensure reproducibility

# Re-initialize x with a valid 2D token_ids tensor.
# Using the last token from a previously generated sequence (e.g., out1 from a prior cell)
# out1 is from cell DRw2FcC-NubE and has shape (1, 11)
# This ensures x has shape (1, 1) as expected by forward_step
x = out1[:, -1:]

# Reset model state before proceeding to ensure position counter is valid
model.reset_cache()

processor = LogitsProcessor(
    temperature=0.7,
    top_k=50,
    top_p=0.9,
    repetition_penalty=1.2,
)

logits = model.forward_step(x[:, -1:])
logits = processor(logits, generated_tokens=x)
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

# Print next token to verify
print(f"Next token: {next_token.item()}")

print("\n#### Repetition Penalty Test Case")

# Simulate some generated tokens with repetitions
simulated_generated_tokens = torch.tensor([[10, 20, 10, 30]]) # Token 10 is repeated

# Simulate raw logits for the next token
# Let's say token 10 has a high logit initially
raw_logits = torch.randn(1, model.config.vocab_size)
raw_logits[0, 10] = 5.0 # High logit for repeated token 10
raw_logits[0, 20] = 4.0 # High logit for repeated token 20
raw_logits[0, 40] = 6.0 # High logit for a new token 40

print(f"Original logit for token 10: {raw_logits[0, 10].item():.4f}")
print(f"Original logit for token 20: {raw_logits[0, 20].item():.4f}")
print(f"Original logit for token 40: {raw_logits[0, 40].item():.4f}")

processor_rp = LogitsProcessor(repetition_penalty=2.0)
processed_logits = processor_rp(raw_logits.clone(), generated_tokens=simulated_generated_tokens)

print(f"Processed logit for token 10 (repeated): {processed_logits[0, 10].item():.4f}")
print(f"Processed logit for token 20 (repeated): {processed_logits[0, 20].item():.4f}")
print(f"Processed logit for token 40 (new): {processed_logits[0, 40].item():.4f}")

# Verify that logits for repeated tokens are penalized (divided by penalty if positive)
assert processed_logits[0, 10].item() < raw_logits[0, 10].item(), "Repeated token 10 logit was not penalized!"
assert processed_logits[0, 20].item() < raw_logits[0, 20].item(), "Repeated token 20 logit was not penalized!"
assert processed_logits[0, 40].item() == raw_logits[0, 40].item(), "Non-repeated token logit was altered!"

print("Repetition penalty test passed!")

Next token: 58

#### Repetition Penalty Test Case
Original logit for token 10: 5.0000
Original logit for token 20: 4.0000
Original logit for token 40: 6.0000
Processed logit for token 10 (repeated): 2.5000
Processed logit for token 20 (repeated): 2.0000
Processed logit for token 40 (new): 6.0000
Repetition penalty test passed!
