# LLM Decoding from Scratch

This notebook demonstrates how to implement text generation decoding strategies from scratch using a real language model (GPT-2 small).

## What We'll Build

| Strategy | Description | Use Case |
|----------|-------------|----------|
| **Greedy** | Always pick highest probability token | Deterministic, fast |
| **Temperature** | Scale logits before softmax | Control randomness |
| **Top-p (Nucleus)** | Sample from smallest set with cumulative p | Dynamic vocabulary |
| **Beam Search** | Track multiple candidates | Higher quality |

---

## Key Insight: LLMs Only Predict Next Token

```
Input: "The cat sat on the"
                          ↓
              [LLM computes logits]
                          ↓
         logits = [2.1, -0.5, 3.2, ...] (vocab_size)
                          ↓
              [Apply decoding strategy]
                          ↓
                  Next token: "mat"
```

The **decoding strategy** is how we convert logits → next token.

In [None]:
!pip install torch==2.1.0 transformers==4.36.0 matplotlib==3.8.2 numpy==1.26.2

---
## Setup: Load GPT-2 Small

In [None]:
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import matplotlib.pyplot as plt

# Load GPT-2 small (124M parameters - smallest real LLM)
print("Loading GPT-2 small...")
model_name = "gpt2"  # gpt2 = gpt2-small (124M params)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()  # Set to evaluation mode

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model loaded on: {device}")
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

---
## Understanding Logits

The model outputs **logits** - raw scores for each token in the vocabulary.

```
logits[i] = how much the model "likes" token i as the next token
```

In [None]:
def get_next_token_logits(prompt: str) -> torch.Tensor:
    """
    Get logits for the next token given a prompt.
    
    Returns: tensor of shape (vocab_size,)
    """
    # Tokenize input
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Forward pass (no gradient needed for inference)
    with torch.no_grad():
        outputs = model(input_ids)
        # outputs.logits shape: (batch=1, seq_len, vocab_size)
        # We want the logits for the LAST position (next token prediction)
        logits = outputs.logits[0, -1, :]  # shape: (vocab_size,)
    
    return logits

# Demo: Get logits for a prompt
prompt = "The capital of France is"
logits = get_next_token_logits(prompt)

print(f"Prompt: '{prompt}'")
print(f"Logits shape: {logits.shape}")
print(f"Logits range: [{logits.min():.2f}, {logits.max():.2f}]")

# Show top 10 tokens by logit value
top_k_logits, top_k_indices = torch.topk(logits, 10)
print("\nTop 10 tokens by logit:")
for i, (logit, idx) in enumerate(zip(top_k_logits, top_k_indices)):
    token = tokenizer.decode([idx])
    print(f"  {i+1}. '{token}' (logit: {logit:.2f})")

---
## Strategy 1: Greedy Decoding

**Simplest strategy**: Always pick the token with the highest logit.

```python
next_token = argmax(logits)
```

**Pros**: Deterministic, fast  
**Cons**: Repetitive, boring text

In [None]:
def greedy_decode(prompt: str, max_new_tokens: int = 50) -> str:
    """
    Greedy decoding: always pick the highest probability token.
    
    Algorithm:
    1. Get logits for next token
    2. Pick token with highest logit (argmax)
    3. Append to sequence
    4. Repeat
    """
    # Start with the prompt tokens
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_ids = input_ids[0].tolist()  # Convert to list for easier manipulation
    
    for _ in range(max_new_tokens):
        # Create input tensor from current sequence
        input_tensor = torch.tensor([generated_ids]).to(device)
        
        # Get logits for next token
        with torch.no_grad():
            outputs = model(input_tensor)
            logits = outputs.logits[0, -1, :]  # (vocab_size,)
        
        # GREEDY: Pick the token with highest logit
        next_token_id = torch.argmax(logits).item()
        
        # Append to sequence
        generated_ids.append(next_token_id)
        
        # Stop if we hit end of text token
        if next_token_id == tokenizer.eos_token_id:
            break
    
    # Decode back to text
    return tokenizer.decode(generated_ids)

# Test greedy decoding
prompt = "Once upon a time"
print(f"Prompt: '{prompt}'")
print("\n" + "="*50)
print("GREEDY OUTPUT:")
print("="*50)
output = greedy_decode(prompt, max_new_tokens=50)
print(output)

print("\n" + "-"*50)
print("Note: Greedy is deterministic - same input always gives same output")
print("Run it again - you'll get the exact same text!")

---
## Strategy 2: Temperature Sampling

**Idea**: Scale logits before converting to probabilities.

```python
probs = softmax(logits / temperature)
next_token = sample from probs
```

| Temperature | Effect |
|-------------|--------|
| T < 1 | Sharper distribution (more deterministic) |
| T = 1 | Original distribution |
| T > 1 | Flatter distribution (more random) |
| T → 0 | Approaches greedy |
| T → ∞ | Uniform random |

In [None]:
def temperature_decode(prompt: str, temperature: float = 1.0, max_new_tokens: int = 50) -> str:
    """
    Temperature sampling: scale logits before softmax.
    
    Algorithm:
    1. Get logits for next token
    2. Divide logits by temperature
    3. Apply softmax to get probabilities
    4. Sample from the distribution
    5. Repeat
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_ids = input_ids[0].tolist()
    
    for _ in range(max_new_tokens):
        input_tensor = torch.tensor([generated_ids]).to(device)
        
        with torch.no_grad():
            outputs = model(input_tensor)
            logits = outputs.logits[0, -1, :]
        
        # TEMPERATURE: Scale logits
        scaled_logits = logits / temperature
        
        # Convert to probabilities
        probs = F.softmax(scaled_logits, dim=-1)
        
        # Sample from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1).item()
        
        generated_ids.append(next_token_id)
        
        if next_token_id == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(generated_ids)

# Compare different temperatures
prompt = "The meaning of life is"
print(f"Prompt: '{prompt}'\n")

for temp in [0.3, 0.7, 1.0, 1.5]:
    print(f"\n{'='*50}")
    print(f"TEMPERATURE = {temp}")
    print(f"{'='*50}")
    output = temperature_decode(prompt, temperature=temp, max_new_tokens=40)
    print(output)

---
## Strategy 4: Top-p (Nucleus) Sampling

**Idea**: Sample from the smallest set of tokens whose cumulative probability ≥ p.

```python
sorted_probs = sort probabilities descending
cumulative = cumsum(sorted_probs)
cutoff = first index where cumulative >= p
keep tokens up to cutoff
```

**Why better than top-k?** The number of tokens varies based on the distribution:
- Confident prediction → few tokens
- Uncertain prediction → more tokens

In [None]:
def top_p_decode(prompt: str, p: float = 0.9, temperature: float = 1.0, max_new_tokens: int = 50) -> str:
    """
    Top-p (nucleus) sampling: sample from smallest set with cumulative prob >= p.
    
    Algorithm:
    1. Get logits and convert to probabilities
    2. Sort probabilities in descending order
    3. Compute cumulative sum
    4. Find cutoff where cumsum >= p
    5. Keep only tokens up to cutoff
    6. Renormalize and sample
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_ids = input_ids[0].tolist()
    
    for _ in range(max_new_tokens):
        input_tensor = torch.tensor([generated_ids]).to(device)
        
        with torch.no_grad():
            outputs = model(input_tensor)
            logits = outputs.logits[0, -1, :]
        
        # Apply temperature first
        scaled_logits = logits / temperature
        
        # Convert to probabilities
        probs = F.softmax(scaled_logits, dim=-1)
        
        # TOP-P: Sort probabilities descending
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        
        # Compute cumulative sum
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
        
        # Find cutoff: where cumsum first exceeds p
        # We want to KEEP tokens where cumsum <= p (plus the first one that exceeds)
        sorted_indices_to_remove = cumsum_probs > p
        # Shift right to keep the first token that exceeds p
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = False
        
        # Set removed tokens to 0 probability
        sorted_probs[sorted_indices_to_remove] = 0
        
        # Renormalize
        sorted_probs = sorted_probs / sorted_probs.sum()
        
        # Create full probability tensor
        filtered_probs = torch.zeros_like(probs)
        filtered_probs[sorted_indices] = sorted_probs
        
        # Sample
        next_token_id = torch.multinomial(filtered_probs, num_samples=1).item()
        
        generated_ids.append(next_token_id)
        
        if next_token_id == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(generated_ids)

# Compare different p values
prompt = "The secret to happiness is"
print(f"Prompt: '{prompt}'\n")

for p_val in [0.5, 0.8, 0.95, 0.99]:
    print(f"\n{'='*50}")
    print(f"TOP-P = {p_val}")
    print(f"{'='*50}")
    output = top_p_decode(prompt, p=p_val, temperature=0.8, max_new_tokens=40)
    print(output)

---
## Strategy 5: Beam Search

**Idea**: Track multiple candidate sequences ("beams") and keep the best ones.

```
Step 1: Start with n beams, each with the prompt
Step 2: For each beam, get top-k next tokens
Step 3: Score all (beam × token) combinations
Step 4: Keep top n combinations as new beams
Step 5: Repeat until done
```

**Scoring**: Usually sum of log-probabilities (higher = better)

**Pros**: Finds higher probability sequences  
**Cons**: Slower, can be repetitive

In [None]:
def beam_search_decode(prompt: str, num_beams: int = 3, max_new_tokens: int = 50) -> str:
    """
    Beam search: track multiple candidate sequences.
    
    Each beam is a tuple: (token_ids, cumulative_log_prob)
    
    Algorithm:
    1. Start with one beam containing the prompt
    2. For each beam, get log-probs for all next tokens
    3. Consider all (beam, next_token) combinations
    4. Keep top num_beams combinations by cumulative log-prob
    5. Repeat until all beams hit EOS or max length
    6. Return best beam
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    initial_ids = input_ids[0].tolist()
    
    # Initialize beams: (token_ids, cumulative_log_prob)
    beams = [(initial_ids, 0.0)]
    completed_beams = []
    
    for step in range(max_new_tokens):
        all_candidates = []
        
        for beam_ids, beam_score in beams:
            # Skip if this beam is already complete
            if beam_ids[-1] == tokenizer.eos_token_id:
                completed_beams.append((beam_ids, beam_score))
                continue
            
            # Get next token probabilities
            input_tensor = torch.tensor([beam_ids]).to(device)
            with torch.no_grad():
                outputs = model(input_tensor)
                logits = outputs.logits[0, -1, :]
            
            # Convert to log-probabilities
            log_probs = F.log_softmax(logits, dim=-1)
            
            # Get top candidates for this beam
            top_log_probs, top_indices = torch.topk(log_probs, num_beams * 2)
            
            for log_prob, idx in zip(top_log_probs, top_indices):
                new_ids = beam_ids + [idx.item()]
                new_score = beam_score + log_prob.item()
                all_candidates.append((new_ids, new_score))
        
        # If no active candidates, we're done
        if not all_candidates:
            break
        
        # Keep top num_beams candidates
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:num_beams]
    
    # Add remaining beams to completed
    completed_beams.extend(beams)
    
    # Return best beam (highest score)
    best_beam = max(completed_beams, key=lambda x: x[1])
    return tokenizer.decode(best_beam[0])

# Test beam search
prompt = "The quick brown fox"
print(f"Prompt: '{prompt}'\n")

print("="*50)
print("GREEDY:")
print("="*50)
print(greedy_decode(prompt, max_new_tokens=30))

print(f"\n{'='*50}")
print(f"BEAM SEARCH (num_beams=3):")
print(f"{'='*50}")
print(beam_search_decode(prompt, num_beams=3, max_new_tokens=30))