# 🧠 Building Decoding Strategies from Scratch

### 🎯 Objective
In this session, we’ll **recreate common decoding strategies step by step**, using only the model’s raw output probabilities (logits).  
You’ve already seen how Hugging Face’s `generate()` method can perform greedy search, beam search, top-k, and nucleus (top-p) sampling for you.  
Now it’s time to **open the black box** and see *how those algorithms actually work under the hood.*


### 🧩 What You’ll Learn
By the end of this notebook, you’ll be able to:
- Access token probabilities from a language model (e.g., GPT-2).  
- Implement your own:
  - **Greedy Search**
  - **Beam Search**
  - **Top-k Sampling**
  - **Nucleus (Top-p) Sampling**
- Compare their behavior and outputs across different prompts.  
- Understand the trade-offs between determinism, diversity, and coherence.

In [None]:
from typing import List, Dict, Optional

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

## ⚙️ Setup and First Generation

Before we start building decoding strategies manually, let’s load a pretrained language model and generate some text using the **built-in Hugging Face API**.  
This will serve as our **baseline** — we’ll soon replace this simple `.generate()` call with our own decoding logic.


`**TODO:**`

1. **Import dependencies** — we’ll use `transformers` for the model and tokenizer, and `torch` for tensor operations.  
2. **Load GPT-2** — a small, autoregressive model trained to predict the next token in a sequence.  
3. **Prepare the input** — we’ll encode a short prompt into token IDs.  
4. **Generate text** — using the model’s default decoding (greedy search).  
5. **Decode the output** — back to human-readable text.


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
text = "I have a dream"


model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

outputs = model.generate(input_ids, max_length=len(input_ids.squeeze())+5)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")


## 🚀 Implementing Greedy Search

### 🧠 Concept

Greedy Search is the most straightforward decoding method:
- At each step, the model predicts a probability distribution over the vocabulary.
- We pick **only the most probable token** (the one with the highest logit or probability).
- That token is appended to the sequence and fed back into the model.
- The process repeats until we reach the desired length or an end-of-sentence token.

This strategy is **deterministic** and **fast**, but often leads to repetitive or locally optimal results — the model never explores alternative continuations.

### ⚙️ Implementation Hints

In this method:
1. We run the model on the current sequence to obtain the logits.
2. We take the **argmax** over the last-token logits to choose the next token.
3. We append that token to the input sequence.
4. We recursively continue until we’ve generated the target number of tokens.

**`TODO:`** Implement the `greedy_search` function below.

In [None]:
def greedy_search(
    input_ids: torch.Tensor,   # Current sequence of token IDs (shape: [1, seq_len])
    length: int = 5            # Number of tokens left to generate
) -> torch.Tensor:
    """
    Performs recursive Greedy Search decoding.

    At each step, the model selects the most probable next token
    (the one with the highest logit value) and appends it to the sequence.

    Args:
        input_ids: Current sequence of token IDs (1 x seq_len tensor).
        length: Number of tokens left to generate.

    Returns:
        A tensor containing the full generated token sequence.
    """
    
    # TODO: Implement Greedy Search decoding

# Start generating text
output_ids = greedy_search(input_ids, length=5)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")

## 🧮 Helper Function: `get_log_prob`

This small utility computes the **log-probability** of a chosen token given the model’s output logits.  
It’s useful for tracking and comparing sequence scores across decoding strategies.

**How it works:**
1. Applies a softmax to convert logits into probabilities.  
2. Takes the logarithm of those probabilities.  
3. Returns the log-probability corresponding to the selected `token_id`.


In [None]:
def get_log_prob(logits, token_id):
    # Compute the softmax of the logits
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    log_probabilities = torch.log(probabilities)
    
    # Get the log probability of the token
    token_log_probability = log_probabilities[token_id].item()

    return token_log_probability

## 🚀 Implementing Beam Search

### 🧠 Concept

Beam Search improves upon Greedy Search by keeping track of **multiple candidate sequences** (called *beams*) instead of just one.  
At each generation step, instead of picking only the single most likely token, we explore the **top-k next tokens** for each current sequence (where *k* is the beam width).  
We then keep only the *k* best overall sequences based on their **cumulative log-probability scores**.

This strategy balances **exploration** and **exploitation** — it often produces more coherent text than Greedy Search, though it’s more computationally expensive.

### ⚙️ Implementation Hints

In this implementation:
1. For each call, we compute the model’s output logits for the current sequence.  
2. We extract the **top-k tokens** (beam width) with the highest logit values.  
3. For each of these tokens:
   - Compute its log-probability.  
   - Append it to the input sequence.  
   - Recursively continue the search for the remaining steps.  
4. Once all beams reach the target length, we return all completed sequences with their total scores and pick the **best one**.

**`TODO:`** Implement the `beam_search` function below.  
Focus on:
- Expanding each beam with the top-k tokens.  
- Keeping track of cumulative scores.  
- Returning all completed sequences so you can later select the best one.


In [None]:
def beam_search(
    input_ids: torch.Tensor,       # Current input token IDs (shape: [1, seq_len])
    length: int,                   # Number of tokens left to generate
    beams: int,                    # Beam width (number of candidate sequences to keep)
    score: Optional[float] = None  # Cumulative log-probability of the sequence so far
) -> List[Dict[str, torch.Tensor]]:
    
    """
    Performs recursive Beam Search decoding.

    Args:
        input_ids: Current sequence of token IDs (1 x seq_len tensor).
        length: Number of tokens left to generate.
        beams: Number of top candidate tokens to expand at each step.
        score: Optional cumulative log-probability for the sequence so far.

    Returns:
        A list of dictionaries, each containing:
            - "new_input_ids": the generated token sequence (tensor)
            - "score": the cumulative log-probability score (float)
    """
    
    # TODO: Implement Beam Search decoding

# Start generating text
output_ids = beam_search(input_ids, length=5, beams=2)
best_entry = max(output_ids, key=lambda x: x["score"])
best_output = tokenizer.decode(best_entry["new_input_ids"].squeeze().tolist(), skip_special_tokens=True)
print(f"Best Generated text: {best_output}")

## 🚀 Implementing Top-k Sampling from Scratch

### 🧠 Concept

So far, **Greedy Search** and **Beam Search** always choose the most likely tokens — making them deterministic but sometimes repetitive or predictable.  
To introduce more **creativity and diversity**, we can use **sampling-based decoding** methods.

**Top-k Sampling** limits randomness to a controlled subset of possible next tokens:
1. At each step, we look at the model’s logits for all tokens.
2. We keep only the **top-k most probable tokens**.
3. We **mask out** all other tokens by setting their logits to `-inf`.
4. We apply a softmax over the remaining tokens to obtain a proper probability distribution.
5. We **randomly sample** one token from this smaller set.
6. Append the chosen token and continue.

This way, we don’t always pick the top token (like in greedy search), but we avoid sampling from the long tail of improbable words — a balance between **coherence** and **diversity**.

### ⚙️ Implementation Hints

In this method:
- Use `torch.topk(logits, top_k)` to find the cutoff threshold.  
- Set all logits below that threshold to `-inf` (so their probability becomes zero after softmax).  
- Use `torch.multinomial()` to sample one token ID according to the resulting probability distribution.  
- Recursively repeat the process for the desired number of tokens.  
- Remember to use a **manual seed** (e.g., `torch.manual_seed(0)`) for reproducibility in experiments.


**`TODO:`** Implement the `top_k_sampling` function below.  
Try different `top_k` values (e.g., 5, 20, 100) and observe how the output’s **creativity** changes.


In [None]:
def top_k_sampling(
    input_ids: torch.Tensor,   # Current sequence of token IDs (shape: [1, seq_len])
    length: int,               # Number of tokens left to generate
    top_k: int                 # Number of top tokens to sample from at each step
) -> torch.Tensor:
    """
    Performs recursive Top-k Sampling decoding.

    At each step:
    1. We take the model’s logits for the next token.
    2. We keep only the top-k most probable tokens.
    3. We apply softmax to get a probability distribution.
    4. We sample one token from that reduced set (introducing controlled randomness).
    5. Append it to the sequence and continue recursively.

    Args:
        input_ids: Current sequence of token IDs (1 x seq_len tensor).
        length: Number of tokens left to generate.
        top_k: Number of highest-probability tokens to consider at each step.

    Returns:
        A tensor containing the full generated token sequence.
    """
    
    # TODO: Implement Top-k Sampling decoding

output_ids = top_k_sampling(input_ids, length=5, top_k=20)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")

## 🌡️ Adding Temperature to Top-k Sampling

### 🧠 Concept
**Temperature** controls how “confident” or “creative” the model’s sampling behavior is.  
Before applying softmax, we divide the logits by a temperature value:

\begin{align}
p_i = \text{softmax}\left(\frac{\text{logits}_i}{T}\right)
\end{align}

- **Low temperature (< 1)** → sharper distribution → more deterministic, focused outputs  
- **High temperature (> 1)** → flatter distribution → more random, diverse outputs  

**`TODO:`**  
Modify the Top-k Sampling implementation so that logits are divided by `temperature` **before** applying softmax.  
Then, experiment with different temperature values (e.g., `0.7`, `1.0`, `1.5`) and observe how the output’s creativity changes.


In [None]:
def top_k_sampling(
    input_ids: torch.Tensor,   # Current sequence of token IDs (shape: [1, seq_len])
    length: int,               # Number of tokens left to generate
    top_k: int,                # Number of top tokens to sample from at each step
    temperature: float = 1.0   # Temperature parameter controlling randomness
) -> torch.Tensor:
    """
    Performs recursive Top-k Sampling with Temperature.

    This decoding method introduces *controlled randomness*:
    - Restricts sampling to the top-k most probable tokens.
    - Scales the logits by a temperature before applying softmax.

    Args:
        input_ids: Current sequence of token IDs (1 x seq_len tensor).
        length: Number of tokens left to generate.
        top_k: Number of highest-probability tokens to consider at each step.
        temperature: Value > 0 that controls randomness.
            • Lower (<1) → sharper distribution, more deterministic.
            • Higher (>1) → flatter distribution, more random.

    Returns:
        A tensor containing the full generated token sequence.
    """
    
    # TODO: Implement Top-k Sampling with Temperature decoding

output_ids = top_k_sampling(input_ids, length=5, top_k=20, temperature=0.2)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")