# 01 - Decoding Strategies for Text Generation

## Context

The generation strategy you choose directly affects the quality of an LLM's
output. The same model can produce wildly different text depending on whether
you use greedy decoding, beam search, or nucleus sampling.

**CoCounsel context:** Legal responses need deterministic, accurate text -- not
creative sampling. A contract clause that invents a nonexistent statute is worse
than useless; it is dangerous. Understanding how each decoding strategy works
lets you pick the right configuration for each legal task: deterministic greedy
decoding for citation extraction, conservative sampling for drafting, and so on.

### Autoregressive Generation

Modern LLMs generate text **one token at a time**. At each step:

1. The model takes the full sequence so far (prompt + previously generated tokens).
2. It runs a forward pass and produces **logits** -- a raw score for every token
   in the vocabulary (e.g., 50,257 scores for GPT-2).
3. A **decoding strategy** selects the next token from those logits.
4. The selected token is appended to the sequence, and we repeat.

The decoding strategy is the decision-making layer between the model's raw
predictions and the final text. In this notebook, we implement every major
strategy from scratch -- no `model.generate()` -- so you understand exactly
what happens at each step.

## Setup

We load GPT-2 small (124M parameters) from HuggingFace. This model is small
enough to run on a CPU but large enough to produce coherent text.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# Reproducibility
torch.manual_seed(42)

# Load GPT-2 small
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()

print(f"Model: {model_name}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocabulary size: {tokenizer.vocab_size:,}")
print(f"Device: {next(model.parameters()).device}")

In [None]:
# Define our legal prompt and get initial logits
legal_prompt = "The court held that the defendant"

input_ids = tokenizer.encode(legal_prompt, return_tensors="pt")
print(f"Prompt: {legal_prompt!r}")
print(f"Token IDs: {input_ids[0].tolist()}")
print(f"Tokens: {[tokenizer.decode(tid) for tid in input_ids[0]]}")
print(f"Number of tokens: {input_ids.shape[1]}")

# Get logits from the model
with torch.no_grad():
    outputs = model(input_ids)
    # logits shape: (batch, seq_len, vocab_size)
    # We want the logits for the NEXT token, so we take the last position
    next_token_logits = outputs.logits[0, -1, :]

print(f"\nLogits shape (full): {outputs.logits.shape}")
print(f"Next-token logits shape: {next_token_logits.shape}")
print(f"Logits range: [{next_token_logits.min():.2f}, {next_token_logits.max():.2f}]")

## Understanding Logits

The model outputs **logits** -- raw, unnormalized scores for every token in the
vocabulary. These are not probabilities yet. To convert logits to probabilities,
we apply the **softmax** function:

$$P(\text{token}_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}$$

where $z_i$ is the logit for token $i$. Softmax has two key properties:

1. All outputs are positive (due to the exponential).
2. All outputs sum to 1 (due to the normalization).

This gives us a valid probability distribution over the entire vocabulary.
The decoding strategy then decides **how to select a token** from this
distribution.

In [None]:
# Convert logits to probabilities
probs = F.softmax(next_token_logits, dim=-1)

print(f"Probabilities sum: {probs.sum():.6f}")
print(f"Number of tokens with prob > 0.01: {(probs > 0.01).sum().item()}")
print(f"Number of tokens with prob > 0.001: {(probs > 0.001).sum().item()}")
print(f"Number of tokens with prob > 0.0001: {(probs > 0.0001).sum().item()}")
print(f"\nMost probability mass is concentrated in a small number of tokens.")
print(f"Top 10 tokens account for {probs.topk(10).values.sum():.1%} of total probability.")
print(f"Top 100 tokens account for {probs.topk(100).values.sum():.1%} of total probability.")

In [None]:
# Show top-10 most likely next tokens
top_k_probs, top_k_indices = probs.topk(10)

print(f"Prompt: {legal_prompt!r}")
print(f"\nTop 10 next-token predictions:")
print(f"{'Rank':<6} {'Token':<20} {'Probability':<15} {'Logit':<10}")
print("-" * 51)
for rank, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices), 1):
    token_str = tokenizer.decode(idx.item())
    logit_val = next_token_logits[idx].item()
    print(f"{rank:<6} {token_str!r:<20} {prob.item():<15.4f} {logit_val:<10.2f}")

In [None]:
# Visualize the probability distribution over the vocabulary
top_n = 30
top_probs, top_indices = probs.topk(top_n)
top_tokens = [tokenizer.decode(idx.item()).strip() for idx in top_indices]

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: bar chart of top-30 tokens
ax = axes[0]
bars = ax.barh(range(top_n - 1, -1, -1), top_probs.numpy(), color="steelblue")
ax.set_yticks(range(top_n - 1, -1, -1))
ax.set_yticklabels(top_tokens, fontsize=8)
ax.set_xlabel("Probability")
ax.set_title(f"Top {top_n} Next-Token Probabilities\nPrompt: {legal_prompt!r}")
ax.grid(axis="x", alpha=0.3)

# Right: full distribution (sorted, log scale)
ax = axes[1]
sorted_probs, _ = probs.sort(descending=True)
ax.plot(sorted_probs.numpy(), color="steelblue", linewidth=1.5)
ax.set_xlabel("Token rank (sorted by probability)")
ax.set_ylabel("Probability (log scale)")
ax.set_yscale("log")
ax.set_title("Full Vocabulary Probability Distribution")
ax.axvline(x=10, color="red", linestyle="--", alpha=0.5, label="Top 10")
ax.axvline(x=100, color="orange", linestyle="--", alpha=0.5, label="Top 100")
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("The distribution is extremely skewed -- a handful of tokens hold most")
print("of the probability mass, while thousands of tokens are near zero.")
print("This is why raw logits need processing before selecting a token:")
print("different strategies handle this long tail differently.")

## Greedy Decoding

The simplest decoding strategy: at each step, pick the token with the
**highest probability** (i.e., take the argmax of the logits).

**Advantages:**
- Deterministic -- same input always produces same output.
- Fast -- no sampling, no maintaining multiple candidates.
- Often produces grammatically correct text.

**Disadvantages:**
- Tends to produce **repetitive text** -- the model gets stuck in loops.
- Misses higher-probability sequences that require a locally suboptimal
  choice early on (greedy is not globally optimal).
- Output lacks diversity -- no variation between runs.

In [None]:
def greedy_decode(model, tokenizer, prompt, max_tokens=50):
    """Generate text using greedy decoding (always pick the argmax token).

    Args:
        model: A HuggingFace causal LM.
        tokenizer: The corresponding tokenizer.
        prompt: The input text string.
        max_tokens: Maximum number of tokens to generate.

    Returns:
        The full generated text (prompt + generated tokens).
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    for _ in range(max_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            # Get logits for the last token position
            logits = outputs.logits[0, -1, :]

        # Greedy: pick the token with the highest logit
        next_token_id = torch.argmax(logits).unsqueeze(0).unsqueeze(0)

        # Append to the sequence
        input_ids = torch.cat([input_ids, next_token_id], dim=1)

        # Stop if we generate the EOS token
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# Run greedy decoding on the legal prompt
greedy_output = greedy_decode(model, tokenizer, legal_prompt, max_tokens=80)

print("=" * 70)
print("GREEDY DECODING")
print("=" * 70)
print(f"Prompt: {legal_prompt!r}")
print(f"\nGenerated text:\n{greedy_output}")
print("\n" + "=" * 70)
print("Notice: greedy decoding often produces repetitive loops.")
print("The model gets stuck repeating the same phrases because the")
print("highest-probability next token keeps leading back to the same state.")

## Beam Search

Beam search addresses greedy decoding's key weakness: greedy decoding makes
locally optimal choices that may not be globally optimal. A sequence might
require picking a lower-probability token early on to reach a much
higher-probability sequence overall.

**How it works:**

1. Start with the prompt as the single initial beam.
2. At each step, expand each beam by considering every possible next token.
3. Score each expanded sequence by summing the log-probabilities.
4. Keep only the top-k sequences (where k = `num_beams`).
5. Repeat until `max_tokens` or all beams hit EOS.
6. Return the highest-scoring beam.

Beam search explores more of the search space than greedy decoding while
remaining tractable (we never explore the full exponential tree).

**Note:** Beam search is still deterministic -- it always produces the same
output for the same input. It generally finds higher-probability sequences
than greedy decoding but is slower (k forward passes per step).

In [None]:
def beam_search(model, tokenizer, prompt, num_beams=5, max_tokens=50):
    """Generate text using beam search.

    Maintains `num_beams` candidate sequences at each step, expanding all
    candidates and keeping the top-scoring ones.

    Args:
        model: A HuggingFace causal LM.
        tokenizer: The corresponding tokenizer.
        prompt: The input text string.
        num_beams: Number of beams (candidate sequences) to maintain.
        max_tokens: Maximum number of tokens to generate.

    Returns:
        The highest-scoring generated text.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Each beam is a tuple of (token_ids_tensor, cumulative_log_prob)
    beams = [(input_ids, 0.0)]

    for step in range(max_tokens):
        all_candidates = []

        for seq, score in beams:
            # If this beam already ended with EOS, keep it as-is
            if seq[0, -1].item() == tokenizer.eos_token_id:
                all_candidates.append((seq, score))
                continue

            with torch.no_grad():
                outputs = model(seq)
                logits = outputs.logits[0, -1, :]

            # Convert to log-probabilities
            log_probs = F.log_softmax(logits, dim=-1)

            # Get top-k candidates for this beam
            top_log_probs, top_indices = log_probs.topk(num_beams)

            for i in range(num_beams):
                next_token_id = top_indices[i].unsqueeze(0).unsqueeze(0)
                new_seq = torch.cat([seq, next_token_id], dim=1)
                new_score = score + top_log_probs[i].item()
                all_candidates.append((new_seq, new_score))

        # Keep only the top num_beams candidates
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:num_beams]

        # Early stop if all beams ended with EOS
        if all(
            seq[0, -1].item() == tokenizer.eos_token_id
            for seq, _ in beams
        ):
            break

    # Return the highest-scoring beam
    best_seq, best_score = beams[0]
    return tokenizer.decode(best_seq[0], skip_special_tokens=True)


# Run beam search on the legal prompt
beam_output = beam_search(model, tokenizer, legal_prompt, num_beams=5, max_tokens=50)

print("=" * 70)
print("BEAM SEARCH (num_beams=5)")
print("=" * 70)
print(f"Prompt: {legal_prompt!r}")
print(f"\nGenerated text:\n{beam_output}")
print("\n" + "=" * 70)
print("\nComparison with greedy:")
print(f"  Greedy: {greedy_output[:120]}...")
print(f"  Beam:   {beam_output[:120]}...")
print("\nBeam search finds higher-probability sequences by exploring multiple")
print("paths simultaneously, but it is still deterministic and can still repeat.")

## Temperature Scaling

Temperature scaling adjusts the **sharpness** of the probability distribution
before sampling. The idea is simple: divide the logits by a temperature value
$T$ before applying softmax.

$$P(\text{token}_i) = \frac{e^{z_i / T}}{\sum_j e^{z_j / T}}$$

The effect:

- **T < 1** (low temperature): Makes the distribution **sharper**. The
  highest-probability token gets even more probability mass. As T approaches 0,
  this approaches greedy decoding.
- **T = 1**: No change (the default).
- **T > 1** (high temperature): Makes the distribution **flatter**. Probability
  mass spreads more evenly across tokens. As T approaches infinity, this
  approaches uniform random sampling.

**For legal text:** Lower temperatures (0.1 - 0.5) are generally preferred
because they keep the model focused on high-probability, factually grounded
continuations.

In [None]:
def apply_temperature(logits, temperature):
    """Scale logits by temperature before softmax.

    Args:
        logits: Raw model logits, shape (vocab_size,).
        temperature: Scaling factor. Lower = sharper, higher = flatter.

    Returns:
        Scaled logits.
    """
    return logits / temperature


# Demonstrate temperature scaling on the next-token logits from our legal prompt
temperatures = [0.1, 0.5, 1.0, 2.0]

print(f"Prompt: {legal_prompt!r}")
print(f"\nEffect of temperature on top-5 token probabilities:\n")

for temp in temperatures:
    scaled_logits = apply_temperature(next_token_logits, temp)
    scaled_probs = F.softmax(scaled_logits, dim=-1)
    top5_probs, top5_indices = scaled_probs.topk(5)

    tokens_str = ", ".join(
        f"{tokenizer.decode(idx.item()).strip()!r}: {p:.4f}"
        for p, idx in zip(top5_probs, top5_indices)
    )
    print(f"  T={temp:<4} -> {tokens_str}")

In [None]:
# Visualize probability distributions at different temperatures
top_n_vis = 20

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for ax, temp in zip(axes, temperatures):
    scaled_logits = apply_temperature(next_token_logits, temp)
    scaled_probs = F.softmax(scaled_logits, dim=-1)
    top_probs_t, top_indices_t = scaled_probs.topk(top_n_vis)
    top_tokens_t = [tokenizer.decode(idx.item()).strip() for idx in top_indices_t]

    colors = ["#d62728" if i == 0 else "steelblue" for i in range(top_n_vis)]
    ax.barh(
        range(top_n_vis - 1, -1, -1), top_probs_t.numpy(), color=colors
    )
    ax.set_yticks(range(top_n_vis - 1, -1, -1))
    ax.set_yticklabels(top_tokens_t, fontsize=8)
    ax.set_xlabel("Probability")

    # Show entropy as a measure of distribution spread
    entropy = -(scaled_probs * torch.log(scaled_probs + 1e-10)).sum().item()
    ax.set_title(f"Temperature = {temp}\nEntropy = {entropy:.2f} nats")
    ax.grid(axis="x", alpha=0.3)

fig.suptitle(
    f"Next-Token Probability Distribution at Different Temperatures\n"
    f"Prompt: {legal_prompt!r}",
    fontsize=13,
    y=1.02,
)
plt.tight_layout()
plt.show()

print("Low temperature (0.1): Almost all mass on the top token -- nearly greedy.")
print("High temperature (2.0): Mass is spread across many tokens -- high randomness.")
print("For legal text, low temperatures keep the model focused and factual.")

## Top-k Sampling

Top-k sampling restricts the candidate pool to the **k most likely tokens**,
then renormalizes the probabilities and samples from this reduced distribution.

**How it works:**

1. Compute probabilities from logits (with optional temperature scaling).
2. Keep only the top-k tokens; set all other probabilities to zero.
3. Renormalize the remaining probabilities so they sum to 1.
4. Sample from this distribution.

**Advantages:**
- Prevents sampling very unlikely tokens (which cause incoherent output).
- Introduces controlled randomness -- output varies between runs.

**Disadvantages:**
- Fixed k does not adapt to the distribution. When the model is confident
  (one token has 95% probability), k=50 still considers 49 unlikely tokens.
  When the model is uncertain, k=50 might cut off reasonable options.

In [None]:
def top_k_sample(logits, k=50, temperature=1.0):
    """Sample a token using top-k sampling.

    Keeps only the top-k highest-probability tokens, zeroes out the rest,
    renormalizes, and samples.

    Args:
        logits: Raw model logits, shape (vocab_size,).
        k: Number of top tokens to keep.
        temperature: Temperature scaling factor.

    Returns:
        Sampled token index (int).
    """
    # Apply temperature
    scaled_logits = apply_temperature(logits, temperature)

    # Find the top-k logits and their indices
    top_k_logits, top_k_indices = scaled_logits.topk(k)

    # Create a new logits tensor filled with -inf (zeroed after softmax)
    filtered_logits = torch.full_like(scaled_logits, float("-inf"))
    filtered_logits.scatter_(0, top_k_indices, top_k_logits)

    # Softmax to get renormalized probabilities
    token_probs = F.softmax(filtered_logits, dim=-1)

    # Sample from the distribution
    return torch.multinomial(token_probs, num_samples=1).item()


def generate_top_k(model, tokenizer, prompt, k=50, temperature=1.0, max_tokens=50):
    """Generate text using top-k sampling.

    Args:
        model: A HuggingFace causal LM.
        tokenizer: The corresponding tokenizer.
        prompt: The input text string.
        k: Number of top tokens to consider at each step.
        temperature: Temperature scaling factor.
        max_tokens: Maximum number of tokens to generate.

    Returns:
        The full generated text.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    for _ in range(max_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[0, -1, :]

        next_token_id = top_k_sample(logits, k=k, temperature=temperature)
        next_token_tensor = torch.tensor([[next_token_id]])
        input_ids = torch.cat([input_ids, next_token_tensor], dim=1)

        if next_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# Run top-k sampling with different k values
print(f"Prompt: {legal_prompt!r}")
print()

k_values = [5, 20, 50, 100]
for k in k_values:
    torch.manual_seed(42)  # Reproducibility
    output = generate_top_k(
        model, tokenizer, legal_prompt, k=k, temperature=0.8, max_tokens=50
    )
    print(f"k={k:<4} -> {output}")
    print()

## Top-p (Nucleus) Sampling

Top-p sampling (also called **nucleus sampling**, introduced by Holtzman et al.,
2019) fixes the key limitation of top-k: instead of using a fixed number of
tokens, it dynamically selects tokens until their **cumulative probability**
reaches a threshold $p$.

**How it works:**

1. Compute probabilities from logits (with optional temperature scaling).
2. Sort tokens by probability (highest first).
3. Compute the cumulative sum of probabilities.
4. Keep tokens until the cumulative probability reaches $p$.
5. Zero out all other tokens, renormalize, and sample.

**Why this is better than top-k:**

- When the model is **confident** (one token has 90% probability), top-p=0.9
  might only keep 1-2 tokens -- staying focused.
- When the model is **uncertain** (probability is spread across many tokens),
  top-p=0.9 might keep 50+ tokens -- allowing diversity.

Top-p **adapts to the shape of the distribution**, which top-k cannot.

In [None]:
def top_p_sample(logits, p=0.9, temperature=1.0):
    """Sample a token using top-p (nucleus) sampling.

    Includes the smallest set of tokens whose cumulative probability exceeds p,
    then renormalizes and samples.

    Args:
        logits: Raw model logits, shape (vocab_size,).
        p: Cumulative probability threshold.
        temperature: Temperature scaling factor.

    Returns:
        Sampled token index (int).
    """
    # Apply temperature
    scaled_logits = apply_temperature(logits, temperature)

    # Convert to probabilities
    token_probs = F.softmax(scaled_logits, dim=-1)

    # Sort by probability (descending)
    sorted_probs, sorted_indices = torch.sort(token_probs, descending=True)

    # Compute cumulative sum
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # Find tokens to remove: those where cumulative prob exceeds p
    # We shift the cumulative sum right by 1 so the first token that
    # pushes us over p is still included.
    sorted_mask = cumulative_probs - sorted_probs >= p

    # Zero out tokens beyond the threshold
    sorted_probs[sorted_mask] = 0.0

    # Renormalize
    sorted_probs = sorted_probs / sorted_probs.sum()

    # Sample from the filtered distribution
    sampled_index = torch.multinomial(sorted_probs, num_samples=1).item()

    # Map back to the original vocabulary index
    return sorted_indices[sampled_index].item()


def generate_top_p(model, tokenizer, prompt, p=0.9, temperature=1.0, max_tokens=50):
    """Generate text using top-p (nucleus) sampling.

    Args:
        model: A HuggingFace causal LM.
        tokenizer: The corresponding tokenizer.
        prompt: The input text string.
        p: Cumulative probability threshold.
        temperature: Temperature scaling factor.
        max_tokens: Maximum number of tokens to generate.

    Returns:
        The full generated text.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    for _ in range(max_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[0, -1, :]

        next_token_id = top_p_sample(logits, p=p, temperature=temperature)
        next_token_tensor = torch.tensor([[next_token_id]])
        input_ids = torch.cat([input_ids, next_token_tensor], dim=1)

        if next_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# Run top-p sampling with different p values
print(f"Prompt: {legal_prompt!r}")
print()

p_values = [0.5, 0.8, 0.9, 0.95]
for p_val in p_values:
    torch.manual_seed(42)
    output = generate_top_p(
        model, tokenizer, legal_prompt, p=p_val, temperature=0.8, max_tokens=50
    )
    print(f"p={p_val:<5} -> {output}")
    print()

In [None]:
# Compare top-p vs top-k: how many tokens does each include?
# This demonstrates why top-p adapts better to different distributions.

# Get logits at two different positions to show distribution variation
comparison_prompts = [
    "The court held that the defendant",
    "The contract shall be governed by the laws of",
]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, prompt_text in zip(axes, comparison_prompts):
    ids = tokenizer.encode(prompt_text, return_tensors="pt")
    with torch.no_grad():
        out = model(ids)
        logits_i = out.logits[0, -1, :]

    p_dist = F.softmax(logits_i, dim=-1)
    sorted_p, _ = p_dist.sort(descending=True)
    cumsum = torch.cumsum(sorted_p, dim=-1)

    # How many tokens needed to reach p=0.9?
    n_tokens_p90 = (cumsum < 0.9).sum().item() + 1

    ax.plot(sorted_p[:100].numpy(), label="Token probability", color="steelblue")
    ax.axvline(
        x=n_tokens_p90, color="red", linestyle="--",
        label=f"top-p=0.9 boundary ({n_tokens_p90} tokens)",
    )
    ax.axvline(
        x=50, color="orange", linestyle="--",
        label="top-k=50 boundary",
    )
    ax.set_xlabel("Token rank")
    ax.set_ylabel("Probability")
    ax.set_title(f"Prompt: {prompt_text!r}\n(top-p=0.9 uses {n_tokens_p90} tokens)")
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)

plt.suptitle(
    "Top-p adapts to distribution shape; top-k uses a fixed cutoff",
    fontsize=12, y=1.02,
)
plt.tight_layout()
plt.show()

print("When the distribution is peaked (model is confident), top-p includes fewer tokens.")
print("When the distribution is flat (model is uncertain), top-p includes more tokens.")
print("Top-k always includes exactly k tokens regardless of the distribution shape.")

## Side-by-Side Comparison

Let's run all strategies on the same legal prompt and compare the outputs
directly. For sampling methods, we use `torch.manual_seed()` for
reproducibility.

In [None]:
# Generate with all strategies
max_gen_tokens = 60

results = {}

# 1. Greedy
results["Greedy"] = greedy_decode(
    model, tokenizer, legal_prompt, max_tokens=max_gen_tokens
)

# 2. Beam search
results["Beam Search (k=5)"] = beam_search(
    model, tokenizer, legal_prompt, num_beams=5, max_tokens=max_gen_tokens
)

# 3. Top-k sampling
torch.manual_seed(42)
results["Top-k (k=50, T=0.7)"] = generate_top_k(
    model, tokenizer, legal_prompt, k=50, temperature=0.7, max_tokens=max_gen_tokens
)

# 4. Top-p sampling
torch.manual_seed(42)
results["Top-p (p=0.9, T=0.7)"] = generate_top_p(
    model, tokenizer, legal_prompt, p=0.9, temperature=0.7, max_tokens=max_gen_tokens
)

# 5. Low temperature top-p (conservative, good for legal)
torch.manual_seed(42)
results["Top-p (p=0.9, T=0.3)"] = generate_top_p(
    model, tokenizer, legal_prompt, p=0.9, temperature=0.3, max_tokens=max_gen_tokens
)

# Display results
print("=" * 80)
print("SIDE-BY-SIDE COMPARISON")
print(f"Prompt: {legal_prompt!r}")
print("=" * 80)

for strategy, output in results.items():
    print(f"\n--- {strategy} ---")
    # Show only the generated part (after the prompt)
    generated = output[len(legal_prompt):]
    print(f"{legal_prompt}[{generated}]")

print("\n" + "=" * 80)
print("\nKey observations:")
print("- Greedy and beam search are deterministic but may be repetitive.")
print("- Sampling methods (top-k, top-p) produce more diverse output.")
print("- Lower temperature + top-p gives conservative, legal-appropriate text.")
print("- Higher temperature adds creativity but risks factual errors.")

## Probability Distribution Visualization

For a single generation step, let's visualize the full probability distribution
and highlight which tokens each strategy would select. This makes the
differences concrete and visual.

In [None]:
# Get fresh logits for the visualization
with torch.no_grad():
    vis_outputs = model(tokenizer.encode(legal_prompt, return_tensors="pt"))
    vis_logits = vis_outputs.logits[0, -1, :]

vis_probs = F.softmax(vis_logits, dim=-1)

# Get top-30 tokens for display
top30_probs, top30_indices = vis_probs.topk(30)
top30_tokens = [tokenizer.decode(idx.item()).strip() for idx in top30_indices]

# Determine which tokens each strategy would consider
# Greedy: just the top-1 token
greedy_selection = {0}  # index into top-30

# Top-k (k=10): top 10 tokens
topk_selection = set(range(min(10, 30)))

# Top-p (p=0.9): tokens until cumulative prob >= 0.9
cumulative = torch.cumsum(top30_probs, dim=-1)
topp_selection = set()
for i in range(30):
    topp_selection.add(i)
    if cumulative[i].item() >= 0.9:
        break

# Create the visualization
fig, ax = plt.subplots(figsize=(14, 8))

x = np.arange(30)
bar_colors = []
for i in range(30):
    if i in greedy_selection:
        bar_colors.append("#d62728")   # red for greedy
    elif i in topp_selection:
        bar_colors.append("#2ca02c")   # green for top-p
    elif i in topk_selection:
        bar_colors.append("#1f77b4")   # blue for top-k only
    else:
        bar_colors.append("#cccccc")   # gray for excluded

bars = ax.bar(x, top30_probs.numpy(), color=bar_colors, edgecolor="white", linewidth=0.5)

# Add token labels
ax.set_xticks(x)
ax.set_xticklabels(top30_tokens, rotation=60, ha="right", fontsize=8)
ax.set_ylabel("Probability", fontsize=11)
ax.set_title(
    f"Next-Token Probabilities with Strategy Selections Highlighted\n"
    f"Prompt: {legal_prompt!r}",
    fontsize=12,
)
ax.grid(axis="y", alpha=0.3)

# Legend
from matplotlib.patches import Patch

legend_elements = [
    Patch(facecolor="#d62728", label="Greedy (top-1)"),
    Patch(facecolor="#2ca02c", label=f"Top-p=0.9 ({len(topp_selection)} tokens)"),
    Patch(facecolor="#1f77b4", label="Top-k=10 (10 tokens)"),
    Patch(facecolor="#cccccc", label="Excluded by all"),
]
ax.legend(handles=legend_elements, loc="upper right", fontsize=10)

# Add cumulative probability line on secondary y-axis
ax2 = ax.twinx()
ax2.plot(
    x, cumulative.numpy(), color="black", linewidth=1.5, linestyle="--",
    alpha=0.6, marker=".", markersize=3,
)
ax2.axhline(y=0.9, color="black", linewidth=0.8, linestyle=":", alpha=0.4)
ax2.set_ylabel("Cumulative probability", fontsize=11)
ax2.set_ylim(0, 1.05)

plt.tight_layout()
plt.show()

print(f"Greedy selects: {top30_tokens[0]!r} (probability: {top30_probs[0]:.4f})")
print(f"Top-k=10 considers: {top30_tokens[:10]}")
print(f"Top-p=0.9 considers: {top30_tokens[:len(topp_selection)]}")
print(f"")
print(f"Top-p adapts: it includes {len(topp_selection)} tokens to reach 90% cumulative probability.")
print(f"Top-k always includes exactly 10 tokens regardless of the distribution.")

## Exercises

### Exercise (a): Combine Top-k + Top-p + Temperature

In practice, production systems often combine multiple strategies. Implement
a generation function that applies all three: temperature scaling, then top-k
filtering, then top-p filtering, then sampling.

Experiment with different combinations and find settings that produce good
legal text -- coherent, non-repetitive, and factually conservative.

```python
def combined_sample(logits, k=50, p=0.9, temperature=0.7):
    """Sample using temperature + top-k + top-p combined."""
    # Step 1: Apply temperature
    scaled_logits = logits / temperature

    # Step 2: Top-k filtering -- keep only top-k logits
    top_k_logits, top_k_indices = scaled_logits.topk(k)
    filtered = torch.full_like(scaled_logits, float("-inf"))
    filtered.scatter_(0, top_k_indices, top_k_logits)

    # Step 3: Top-p filtering on the remaining tokens
    probs = F.softmax(filtered, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative = torch.cumsum(sorted_probs, dim=-1)
    mask = (cumulative - sorted_probs) >= p
    sorted_probs[mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum()

    # Step 4: Sample
    sampled_idx = torch.multinomial(sorted_probs, num_samples=1).item()
    return sorted_indices[sampled_idx].item()


# Try these combinations on the legal prompt:
configs = [
    {"k": 50, "p": 0.9, "temperature": 0.3},   # Conservative
    {"k": 50, "p": 0.9, "temperature": 0.7},   # Balanced
    {"k": 100, "p": 0.95, "temperature": 1.0},  # Creative
]
```

### Exercise (b): Implement Repetition Penalty

Repetition is a common problem in text generation, especially with greedy
decoding. Implement a **repetition penalty** that divides the logits of
previously generated tokens by a penalty factor, making them less likely
to be selected again.

```python
def apply_repetition_penalty(logits, generated_ids, penalty=1.2):
    """Penalize tokens that have already been generated.

    For each token that appears in generated_ids:
    - If its logit is positive, divide by penalty (makes it smaller).
    - If its logit is negative, multiply by penalty (makes it more negative).

    Args:
        logits: Raw model logits, shape (vocab_size,).
        generated_ids: List of previously generated token IDs.
        penalty: Penalty factor (> 1.0 to penalize repetition).

    Returns:
        Modified logits with repetition penalty applied.
    """
    penalized_logits = logits.clone()
    for token_id in set(generated_ids):
        if penalized_logits[token_id] > 0:
            penalized_logits[token_id] /= penalty
        else:
            penalized_logits[token_id] *= penalty
    return penalized_logits


# Modify greedy_decode to use repetition penalty:
def greedy_decode_with_penalty(model, tokenizer, prompt, max_tokens=50, penalty=1.2):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    generated_ids = input_ids[0].tolist()

    for _ in range(max_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            next_logits = outputs.logits[0, -1, :]

        # Apply repetition penalty
        next_logits = apply_repetition_penalty(next_logits, generated_ids, penalty)

        next_token_id = torch.argmax(next_logits).item()
        generated_ids.append(next_token_id)
        next_token_tensor = torch.tensor([[next_token_id]])
        input_ids = torch.cat([input_ids, next_token_tensor], dim=1)

        if next_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# Compare greedy with and without repetition penalty:
# greedy_output = greedy_decode(model, tokenizer, legal_prompt, max_tokens=80)
# penalized_output = greedy_decode_with_penalty(
#     model, tokenizer, legal_prompt, max_tokens=80, penalty=1.2
# )
#
# Try different penalty values: 1.0 (no penalty), 1.2, 1.5, 2.0
# Observe how repetition decreases but coherence may also change.
```