# Notebook 28: EAGLE Speculation Concept

## Inference Engineering Course

---

## Overview

**EAGLE** (Extrapolation Algorithm for Greater Language-model Efficiency) is an advanced speculative decoding method that achieves significant speedups by using **hidden states** to predict future tokens, rather than relying on a separate smaller draft model.

### Standard Speculative Decoding vs EAGLE

```
Standard Speculative Decoding:        EAGLE:

Draft Model (small) → draft tokens    Hidden States → feature regression
Target Model → verify                      → draft tokens
Accept/Reject                         Target Model → verify
                                      Accept/Reject

Problem: Draft model may be           Advantage: Uses target model's own
         poorly aligned               knowledge for better drafts
```

### What You'll Learn

| Topic | Description |
|-------|-------------|
| EAGLE Architecture | How hidden-state-based drafting works |
| Feature Regression | Predicting next-token representations |
| Token Tree | Tree-structured draft verification |
| Simplified Implementation | Build the core concept from scratch |
| Comparison | EAGLE vs standard speculative decoding |
| Speedup Analysis | Theoretical and practical speedups |

### Prerequisites
- Understanding of transformer architecture (attention, hidden states)
- Familiarity with speculative decoding concepts (Notebook 10)
- No GPU required (conceptual + simulation notebook)

In [None]:
# ============================================================
# Install dependencies
# ============================================================
!pip install matplotlib numpy torch -q

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-whitegrid')
print("Dependencies loaded!")

---

## Section 1: The Key Insight Behind EAGLE

### Why Hidden States?

In a transformer, the hidden state at position $t$ contains rich information about the context. EAGLE's key insight is:

> **The hidden state at position $t$ can be used to predict the hidden state at position $t+1$, which can then predict the token at $t+1$.**

This is more accurate than using a separate draft model because:
1. Hidden states capture the target model's **internal representation**
2. The prediction is conditioned on the **full context** the target model has seen
3. A lightweight regression head is much faster than a full forward pass

### EAGLE Architecture

```
                    Target Model (frozen)
                    ┌──────────────────┐
Input tokens ──────►│  Transformer     │──── hidden states (h_t)
                    │  Layers          │          │
                    └──────────────────┘          │
                                                  ▼
                    EAGLE Head (trainable)    ┌─────────┐
                    ┌──────────────────┐     │Feature   │
          h_t ─────►│ Lightweight      │────►│Regression│──► h_{t+1} (predicted)
     token_t ─────►│ Transformer      │     └─────────┘        │
                    └──────────────────┘                        ▼
                                                         LM Head ──► token_{t+1}
```

In [None]:
# ============================================================
# Visualize EAGLE vs Standard Speculative Decoding
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(18, 8))

# ---- Standard Speculative Decoding ----
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Standard Speculative Decoding', fontsize=14, fontweight='bold')

# Draft model
draft_box = FancyBboxPatch((0.5, 6), 4, 2, boxstyle="round,pad=0.2",
                            facecolor='#FFCDD2', edgecolor='#C62828', linewidth=2)
ax.add_patch(draft_box)
ax.text(2.5, 7, 'Draft Model\n(Small, separate)', ha='center', va='center',
        fontsize=11, fontweight='bold')

# Target model
target_box = FancyBboxPatch((0.5, 2), 4, 2, boxstyle="round,pad=0.2",
                             facecolor='#BBDEFB', edgecolor='#1565C0', linewidth=2)
ax.add_patch(target_box)
ax.text(2.5, 3, 'Target Model\n(Large, expensive)', ha='center', va='center',
        fontsize=11, fontweight='bold')

# Arrows
ax.annotate('Draft tokens', xy=(2.5, 6), xytext=(2.5, 4.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='gray'),
            fontsize=10, ha='center')

# Output
ax.text(6.5, 7, 'Problems:\n- Misaligned\n  representations\n- Separate training\n- Lower acceptance rate',
        fontsize=10, va='top', color='#C62828',
        bbox=dict(boxstyle='round', facecolor='#FFF3E0', alpha=0.8))

# ---- EAGLE ----
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('EAGLE Speculation', fontsize=14, fontweight='bold')

# Target model
target_box2 = FancyBboxPatch((0.5, 5.5), 4, 2.5, boxstyle="round,pad=0.2",
                              facecolor='#BBDEFB', edgecolor='#1565C0', linewidth=2)
ax.add_patch(target_box2)
ax.text(2.5, 6.75, 'Target Model\n(frozen)', ha='center', va='center',
        fontsize=11, fontweight='bold')

# EAGLE head
eagle_box = FancyBboxPatch((0.5, 2), 4, 2, boxstyle="round,pad=0.2",
                            facecolor='#C8E6C9', edgecolor='#2E7D32', linewidth=2)
ax.add_patch(eagle_box)
ax.text(2.5, 3, 'EAGLE Head\n(Lightweight, trainable)', ha='center', va='center',
        fontsize=11, fontweight='bold')

# Arrow: Hidden states
ax.annotate('Hidden states', xy=(2.5, 5.5), xytext=(2.5, 4.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='#2E7D32'),
            fontsize=10, ha='center', color='#2E7D32')

# Output
ax.text(6.5, 7, 'Advantages:\n- Uses target model\'s\n  own representations\n- Better alignment\n- Higher acceptance rate\n- 2-4x speedup',
        fontsize=10, va='top', color='#2E7D32',
        bbox=dict(boxstyle='round', facecolor='#E8F5E9', alpha=0.8))

plt.tight_layout()
plt.savefig('eagle_vs_standard.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 2: Feature Regression - Predicting Next Hidden States

The core of EAGLE is a **feature regression** network that predicts the hidden state of the next token given:
1. The current hidden state $h_t$ from the target model
2. The current token embedding $e_t$

$$\hat{h}_{t+1} = \text{EAGLEHead}(h_t, e_t)$$

The predicted hidden state $\hat{h}_{t+1}$ is then passed through the model's LM head to get token predictions.

Let's implement a simplified version.

In [None]:
# ============================================================
# Simplified EAGLE Head Implementation
# ============================================================

class SimplifiedEAGLEHead(nn.Module):
    """
    Simplified EAGLE head that predicts the next hidden state
    from the current hidden state and token embedding.
    
    Real EAGLE uses a lightweight transformer; we use a feedforward
    network for clarity.
    """
    
    def __init__(self, hidden_dim: int, embed_dim: int, num_layers: int = 2):
        super().__init__()
        
        # Combine hidden state and embedding
        self.input_projection = nn.Linear(hidden_dim + embed_dim, hidden_dim)
        
        # Feature regression layers
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),  # SiLU (Swish) activation, common in modern LLMs
                nn.LayerNorm(hidden_dim),
            ])
        self.regression = nn.Sequential(*layers)
        
        # Output projection (predicts next hidden state)
        self.output_projection = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, hidden_state, token_embedding):
        """
        Predict the next hidden state.
        
        Args:
            hidden_state: [batch, hidden_dim] - current hidden state from target model
            token_embedding: [batch, embed_dim] - current token embedding
        
        Returns:
            predicted_hidden: [batch, hidden_dim] - predicted next hidden state
        """
        # Concatenate hidden state and embedding
        combined = torch.cat([hidden_state, token_embedding], dim=-1)
        
        # Project to hidden dimension
        x = self.input_projection(combined)
        
        # Regression (with residual connection)
        x = x + self.regression(x)
        
        # Output projection
        predicted_hidden = self.output_projection(x)
        
        return predicted_hidden


# ---- Setup dimensions ----
HIDDEN_DIM = 256   # Target model hidden dimension
EMBED_DIM = 128    # Token embedding dimension
VOCAB_SIZE = 1000  # Vocabulary size

# Create the EAGLE head
eagle_head = SimplifiedEAGLEHead(HIDDEN_DIM, EMBED_DIM)

# Create a simulated LM head (shared with target model)
lm_head = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)

eagle_params = sum(p.numel() for p in eagle_head.parameters())
lm_params = sum(p.numel() for p in lm_head.parameters())

print("Simplified EAGLE Head:")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Embed dim: {EMBED_DIM}")
print(f"  EAGLE parameters: {eagle_params:,}")
print(f"  LM head parameters: {lm_params:,}")
print(f"  Total draft overhead: {eagle_params + lm_params:,}")
print(f"\n  Compare to a small draft model: ~50-100M parameters")
print(f"  EAGLE head is ~{50_000_000 // (eagle_params + lm_params)}x smaller!")

In [None]:
# ============================================================
# Demonstrate the EAGLE prediction process
# ============================================================

torch.manual_seed(42)

# Simulate target model hidden states for a sequence
seq_len = 10
batch_size = 1

# These would come from the target model's last layer
hidden_states = torch.randn(seq_len, HIDDEN_DIM)
token_embeddings = torch.randn(seq_len, EMBED_DIM)

# Use EAGLE to predict future hidden states
eagle_head.eval()
with torch.no_grad():
    # Predict hidden state for position t+1 from position t
    predicted_hidden = eagle_head(
        hidden_states[-1:],      # Last hidden state
        token_embeddings[-1:]    # Last token embedding
    )
    
    # Convert predicted hidden state to token probabilities
    logits = lm_head(predicted_hidden)
    probs = F.softmax(logits, dim=-1)
    
    # Get top-k predictions
    top_k = 5
    top_probs, top_indices = torch.topk(probs[0], top_k)

print("EAGLE Prediction Process:")
print("=" * 50)
print(f"1. Hidden state shape: {hidden_states[-1:].shape}")
print(f"2. Token embedding shape: {token_embeddings[-1:].shape}")
print(f"3. Predicted hidden shape: {predicted_hidden.shape}")
print(f"4. Logits shape: {logits.shape}")
print(f"\nTop-{top_k} predicted tokens:")
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
    print(f"  Token {idx.item():4d}: probability = {prob.item():.4f}")

---

## Section 3: Multi-Step Drafting (Auto-Regressive EAGLE)

EAGLE generates multiple draft tokens by chaining predictions:

```
Step 0: h_t (from target model) → EAGLE → ĥ_{t+1} → token_{t+1}
Step 1: ĥ_{t+1} (from EAGLE)   → EAGLE → ĥ_{t+2} → token_{t+2}
Step 2: ĥ_{t+2} (from EAGLE)   → EAGLE → ĥ_{t+3} → token_{t+3}
...and so on for K draft steps
```

Each step uses the **predicted** hidden state as input for the next prediction.

In [None]:
# ============================================================
# Multi-step EAGLE drafting
# ============================================================

class EAGLEDrafter:
    """
    Generates multiple draft tokens using EAGLE's
    auto-regressive hidden state prediction.
    """
    
    def __init__(self, eagle_head, lm_head, embedding_table):
        self.eagle_head = eagle_head
        self.lm_head = lm_head
        self.embedding_table = embedding_table  # Token → embedding lookup
    
    def draft(self, hidden_state, token_embedding, num_draft_tokens=5, temperature=1.0):
        """
        Generate multiple draft tokens from a starting hidden state.
        
        Returns:
            draft_tokens: list of predicted token IDs
            draft_probs: list of token probabilities
            draft_hiddens: list of predicted hidden states
        """
        draft_tokens = []
        draft_probs = []
        draft_hiddens = []
        
        current_hidden = hidden_state
        current_embedding = token_embedding
        
        self.eagle_head.eval()
        with torch.no_grad():
            for step in range(num_draft_tokens):
                # Predict next hidden state
                predicted_hidden = self.eagle_head(current_hidden, current_embedding)
                draft_hiddens.append(predicted_hidden)
                
                # Get token probabilities
                logits = self.lm_head(predicted_hidden) / temperature
                probs = F.softmax(logits, dim=-1)
                
                # Sample token
                token_id = torch.multinomial(probs[0], 1).item()
                token_prob = probs[0, token_id].item()
                
                draft_tokens.append(token_id)
                draft_probs.append(token_prob)
                
                # Update for next step
                current_hidden = predicted_hidden
                current_embedding = self.embedding_table(torch.tensor([[token_id]]))
                current_embedding = current_embedding.squeeze(0)
        
        return draft_tokens, draft_probs, draft_hiddens


# Create components
embedding_table = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
drafter = EAGLEDrafter(eagle_head, lm_head, embedding_table)

# Generate draft tokens
initial_hidden = torch.randn(1, HIDDEN_DIM)
initial_embedding = torch.randn(1, EMBED_DIM)

draft_tokens, draft_probs, draft_hiddens = drafter.draft(
    initial_hidden, initial_embedding, num_draft_tokens=8
)

print("Multi-Step EAGLE Drafting:")
print("=" * 50)
for i, (token, prob) in enumerate(zip(draft_tokens, draft_probs)):
    confidence = "HIGH" if prob > 0.1 else "MED" if prob > 0.01 else "LOW"
    bar = '#' * int(prob * 100)
    print(f"  Step {i}: Token {token:4d} | P={prob:.4f} | [{confidence:4s}] {bar}")

print(f"\nGenerated {len(draft_tokens)} draft tokens")
print(f"Average confidence: {np.mean(draft_probs):.4f}")
print(f"Note: Confidence typically decreases for later draft positions")

---

## Section 4: Token Tree Structure

EAGLE (especially EAGLE-2) uses a **tree structure** for draft tokens instead of a single linear chain. This increases the chance that at least one path through the tree is correct.

```
Linear Draft:           Tree Draft (EAGLE):

t0 → t1 → t2 → t3     t0 → t1a → t2a → t3a
                             ├── t2b → t3b
                             └── t2c
                        └── t1b → t2d
                             └── t2e
                        
If t1 is wrong,         If t1a is wrong,
everything after         t1b path might
is wasted.              still be correct!
```

In [None]:
# ============================================================
# Implement Token Tree Structure
# ============================================================

class TokenTreeNode:
    """A node in the EAGLE token tree."""
    
    def __init__(self, token_id, probability, hidden_state, parent=None, depth=0):
        self.token_id = token_id
        self.probability = probability
        self.hidden_state = hidden_state
        self.parent = parent
        self.children = []
        self.depth = depth
        self.is_verified = False
    
    def add_child(self, token_id, probability, hidden_state):
        child = TokenTreeNode(token_id, probability, hidden_state, 
                             parent=self, depth=self.depth + 1)
        self.children.append(child)
        return child
    
    def path_probability(self):
        """Compute the cumulative probability from root to this node."""
        prob = self.probability
        node = self.parent
        while node is not None:
            prob *= node.probability
            node = node.parent
        return prob


class EAGLETokenTree:
    """
    Builds a tree of draft tokens for EAGLE speculation.
    """
    
    def __init__(self, eagle_head, lm_head, embedding_table, 
                 max_depth=4, branch_factor=3, top_k=5):
        self.eagle_head = eagle_head
        self.lm_head = lm_head
        self.embedding_table = embedding_table
        self.max_depth = max_depth
        self.branch_factor = branch_factor
        self.top_k = top_k
    
    def build_tree(self, root_hidden, root_embedding, root_token_id=0):
        """Build the draft token tree."""
        root = TokenTreeNode(root_token_id, 1.0, root_hidden)
        
        # BFS to build tree level by level
        current_level = [root]
        total_nodes = 1
        
        self.eagle_head.eval()
        with torch.no_grad():
            for depth in range(self.max_depth):
                next_level = []
                
                for node in current_level:
                    # Get embedding for this node's token
                    if node == root:
                        embedding = root_embedding
                    else:
                        embedding = self.embedding_table(
                            torch.tensor([[node.token_id]])
                        ).squeeze(0)
                    
                    # Predict next hidden state
                    predicted_hidden = self.eagle_head(
                        node.hidden_state, embedding
                    )
                    
                    # Get top-k tokens
                    logits = self.lm_head(predicted_hidden)
                    probs = F.softmax(logits, dim=-1)
                    top_probs, top_indices = torch.topk(probs[0], self.top_k)
                    
                    # Add children (branch_factor determines width)
                    n_branches = min(self.branch_factor, self.top_k)
                    # Reduce branching at deeper levels
                    n_branches = max(1, n_branches - depth)
                    
                    for i in range(n_branches):
                        child = node.add_child(
                            top_indices[i].item(),
                            top_probs[i].item(),
                            predicted_hidden
                        )
                        next_level.append(child)
                        total_nodes += 1
                
                current_level = next_level
        
        return root, total_nodes
    
    def get_all_paths(self, root):
        """Get all root-to-leaf paths in the tree."""
        paths = []
        
        def dfs(node, current_path):
            current_path.append((node.token_id, node.probability))
            if not node.children:
                paths.append(list(current_path))
            else:
                for child in node.children:
                    dfs(child, current_path)
            current_path.pop()
        
        dfs(root, [])
        return paths


# Build a token tree
tree_builder = EAGLETokenTree(
    eagle_head, lm_head, embedding_table,
    max_depth=4, branch_factor=3, top_k=5
)

root, total_nodes = tree_builder.build_tree(
    torch.randn(1, HIDDEN_DIM),
    torch.randn(1, EMBED_DIM)
)

paths = tree_builder.get_all_paths(root)

print(f"Token Tree Statistics:")
print(f"  Total nodes: {total_nodes}")
print(f"  Total paths: {len(paths)}")
print(f"  Max depth: {max(len(p) for p in paths)}")
print(f"\nTop 5 paths by cumulative probability:")

# Sort paths by cumulative probability
path_probs = [(path, np.prod([p[1] for p in path])) for path in paths]
path_probs.sort(key=lambda x: x[1], reverse=True)

for i, (path, prob) in enumerate(path_probs[:5]):
    tokens = [str(p[0]) for p in path]
    print(f"  Path {i+1}: [{' → '.join(tokens)}] P={prob:.6f}")

In [None]:
# ============================================================
# Visualize the Token Tree
# ============================================================

def visualize_token_tree(root, max_display_depth=4):
    """Create a visual representation of the token tree."""
    fig, ax = plt.subplots(figsize=(16, 10))
    ax.set_xlim(-1, 11)
    ax.set_ylim(-1, 8)
    ax.axis('off')
    ax.set_title('EAGLE Token Tree Structure', fontsize=16, fontweight='bold')
    
    # Collect nodes by level
    levels = {}
    queue = [(root, 0)]
    while queue:
        node, depth = queue.pop(0)
        if depth > max_display_depth:
            continue
        if depth not in levels:
            levels[depth] = []
        levels[depth].append(node)
        for child in node.children:
            queue.append((child, depth + 1))
    
    # Position nodes
    positions = {}
    for depth, nodes in levels.items():
        n = len(nodes)
        for i, node in enumerate(nodes):
            x = depth * 2.5 + 0.5
            y = 7 - (i + 0.5) * 7 / n if n > 0 else 3.5
            positions[id(node)] = (x, y)
    
    # Draw edges
    for depth, nodes in levels.items():
        for node in nodes:
            if id(node) in positions:
                for child in node.children:
                    if id(child) in positions:
                        x1, y1 = positions[id(node)]
                        x2, y2 = positions[id(child)]
                        alpha = min(1.0, child.probability * 3)
                        ax.plot([x1 + 0.3, x2 - 0.3], [y1, y2], 
                               color='gray', alpha=alpha, linewidth=1.5)
    
    # Draw nodes
    for depth, nodes in levels.items():
        for node in nodes:
            if id(node) in positions:
                x, y = positions[id(node)]
                prob = node.probability
                
                # Color based on probability
                if prob > 0.3:
                    color = '#4CAF50'
                elif prob > 0.1:
                    color = '#FF9800'
                else:
                    color = '#F44336'
                
                circle = plt.Circle((x, y), 0.3, color=color, alpha=0.7)
                ax.add_patch(circle)
                ax.text(x, y, f't{node.token_id}', ha='center', va='center',
                       fontsize=7, fontweight='bold', color='white')
                ax.text(x, y - 0.45, f'{prob:.2f}', ha='center', 
                       fontsize=6, color='gray')
    
    # Add depth labels
    for depth in levels:
        ax.text(depth * 2.5 + 0.5, 7.5, f'Depth {depth}', 
               ha='center', fontsize=10, fontweight='bold')
    
    # Legend
    legend_elements = [
        mpatches.Patch(color='#4CAF50', alpha=0.7, label='High prob (>0.3)'),
        mpatches.Patch(color='#FF9800', alpha=0.7, label='Medium prob (0.1-0.3)'),
        mpatches.Patch(color='#F44336', alpha=0.7, label='Low prob (<0.1)'),
    ]
    ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('token_tree.png', dpi=150, bbox_inches='tight')
    plt.show()

visualize_token_tree(root)

---

## Section 5: Verification and Acceptance

After building the draft tree, EAGLE verifies all paths simultaneously using the target model. This is done in a **single forward pass** using clever attention masking.

### Verification Process

1. Flatten all tree nodes into a sequence
2. Create a tree attention mask (each node can only attend to its ancestors)
3. Run the target model once on the full sequence
4. Compare target model probabilities with draft probabilities
5. Accept tokens from the longest valid path

In [None]:
# ============================================================
# Simulate the verification process
# ============================================================

np.random.seed(42)

def simulate_eagle_verification(draft_tokens, draft_probs, acceptance_threshold=0.5):
    """
    Simulate the EAGLE verification process.
    
    In reality, the target model computes exact probabilities.
    Here we simulate based on draft probabilities (higher draft
    probability = more likely to be accepted).
    """
    accepted = []
    
    for i, (token, draft_prob) in enumerate(zip(draft_tokens, draft_probs)):
        # Simulate target model probability
        # Higher draft prob → higher chance target agrees
        target_prob = draft_prob * np.random.uniform(0.5, 1.5)
        target_prob = np.clip(target_prob, 0, 1)
        
        # Acceptance criterion (simplified)
        # Real EAGLE uses: accept if rand() < min(1, target_prob / draft_prob)
        acceptance_ratio = min(1.0, target_prob / max(draft_prob, 1e-10))
        is_accepted = np.random.random() < acceptance_ratio
        
        if is_accepted:
            accepted.append({
                'position': i,
                'token': token,
                'draft_prob': draft_prob,
                'target_prob': target_prob,
                'accepted': True
            })
        else:
            accepted.append({
                'position': i,
                'token': token,
                'draft_prob': draft_prob,
                'target_prob': target_prob,
                'accepted': False
            })
            break  # Stop at first rejection (for linear draft)
    
    return accepted


# Run multiple verification simulations
n_simulations = 1000
draft_length = 8
acceptance_lengths = []

for _ in range(n_simulations):
    # Generate draft
    tokens, probs, _ = drafter.draft(
        torch.randn(1, HIDDEN_DIM),
        torch.randn(1, EMBED_DIM),
        num_draft_tokens=draft_length
    )
    
    # Verify
    results = simulate_eagle_verification(tokens, probs)
    n_accepted = sum(1 for r in results if r['accepted'])
    acceptance_lengths.append(n_accepted)

# Analyze results
mean_accepted = np.mean(acceptance_lengths)
acceptance_rate = mean_accepted / draft_length

print(f"Verification Simulation ({n_simulations} trials):")
print(f"  Draft length: {draft_length} tokens")
print(f"  Average accepted: {mean_accepted:.2f} tokens")
print(f"  Acceptance rate: {acceptance_rate:.1%}")
print(f"  Speedup estimate: {mean_accepted + 1:.1f}x (accepted + 1 verified)")

# Visualize acceptance distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram of accepted lengths
axes[0].hist(acceptance_lengths, bins=range(0, draft_length + 2), 
             color='steelblue', alpha=0.7, edgecolor='white', linewidth=1)
axes[0].axvline(mean_accepted, color='red', linestyle='--', linewidth=2,
               label=f'Mean: {mean_accepted:.2f}')
axes[0].set_xlabel('Number of Accepted Tokens', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Distribution of Accepted Draft Tokens', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Acceptance rate by position
position_acceptance = np.zeros(draft_length)
position_counts = np.zeros(draft_length)

for _ in range(n_simulations):
    tokens, probs, _ = drafter.draft(
        torch.randn(1, HIDDEN_DIM),
        torch.randn(1, EMBED_DIM),
        num_draft_tokens=draft_length
    )
    results = simulate_eagle_verification(tokens, probs)
    for r in results:
        position_counts[r['position']] += 1
        if r['accepted']:
            position_acceptance[r['position']] += 1

position_rates = position_acceptance / np.maximum(position_counts, 1)
axes[1].bar(range(draft_length), position_rates, color='steelblue', alpha=0.7)
axes[1].set_xlabel('Draft Position', fontsize=12)
axes[1].set_ylabel('Acceptance Rate', fontsize=12)
axes[1].set_title('Acceptance Rate by Position', fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('verification_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 6: Comparing EAGLE with Standard Speculative Decoding

In [None]:
# ============================================================
# Side-by-side comparison simulation
# ============================================================

np.random.seed(42)

# Simulation parameters
target_model_time = 50  # ms per forward pass
draft_model_time = 10   # ms per forward pass (separate draft model)
eagle_head_time = 2     # ms per forward pass (lightweight head)
tokens_to_generate = 100

def simulate_autoregressive(tokens_to_gen, model_time):
    """Simulate standard autoregressive generation."""
    total_time = tokens_to_gen * model_time
    return total_time, tokens_to_gen  # forward passes = tokens

def simulate_speculative(tokens_to_gen, target_time, draft_time, 
                         draft_length, acceptance_rate):
    """Simulate standard speculative decoding."""
    generated = 0
    total_time = 0
    forward_passes = 0
    
    while generated < tokens_to_gen:
        # Draft phase
        total_time += draft_length * draft_time
        
        # Verify phase
        total_time += target_time
        forward_passes += 1
        
        # Accept tokens
        accepted = 0
        for _ in range(draft_length):
            if np.random.random() < acceptance_rate:
                accepted += 1
            else:
                break
        
        generated += accepted + 1  # +1 for the correction token
    
    return total_time, forward_passes

def simulate_eagle(tokens_to_gen, target_time, eagle_time, 
                   draft_length, acceptance_rate):
    """Simulate EAGLE speculative decoding."""
    generated = 0
    total_time = 0
    forward_passes = 0
    
    while generated < tokens_to_gen:
        # Eagle draft phase (much faster than separate draft model)
        total_time += draft_length * eagle_time
        
        # Verify phase (single target model pass for tree)
        total_time += target_time
        forward_passes += 1
        
        # Higher acceptance rate for EAGLE
        accepted = 0
        for _ in range(draft_length):
            if np.random.random() < acceptance_rate:
                accepted += 1
            else:
                break
        
        generated += accepted + 1
    
    return total_time, forward_passes


# Run comparisons
n_runs = 500

methods = {
    'Autoregressive': [],
    'Speculative (draft model)': [],
    'EAGLE': [],
}

for _ in range(n_runs):
    ar_time, ar_passes = simulate_autoregressive(tokens_to_generate, target_model_time)
    methods['Autoregressive'].append(ar_time)
    
    spec_time, spec_passes = simulate_speculative(
        tokens_to_generate, target_model_time, draft_model_time,
        draft_length=5, acceptance_rate=0.65
    )
    methods['Speculative (draft model)'].append(spec_time)
    
    eagle_time_total, eagle_passes = simulate_eagle(
        tokens_to_generate, target_model_time, eagle_head_time,
        draft_length=6, acceptance_rate=0.75  # EAGLE has higher acceptance
    )
    methods['EAGLE'].append(eagle_time_total)

# Compute statistics
print(f"Generation of {tokens_to_generate} tokens:")
print("=" * 60)
for method, times in methods.items():
    mean_time = np.mean(times)
    speedup = np.mean(methods['Autoregressive']) / mean_time
    print(f"  {method:30s}: {mean_time:>7.0f}ms | Speedup: {speedup:.2f}x")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Box plot of generation times
bp = axes[0].boxplot([methods[m] for m in methods.keys()], 
                     labels=['Auto-\nregressive', 'Speculative\n(draft model)', 'EAGLE'],
                     patch_artist=True,
                     boxprops=dict(alpha=0.7))
colors = ['#E57373', '#FF9800', '#4CAF50']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

axes[0].set_ylabel('Total Generation Time (ms)', fontsize=12)
axes[0].set_title(f'Time to Generate {tokens_to_generate} Tokens', 
                  fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')

# Speedup bar chart
method_names = list(methods.keys())
speedups = [np.mean(methods['Autoregressive']) / np.mean(methods[m]) for m in method_names]
bars = axes[1].bar(method_names, speedups, color=colors, alpha=0.8,
                   edgecolor='white', linewidth=2)

for bar, s in zip(bars, speedups):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.05,
                f'{s:.2f}x', ha='center', fontweight='bold', fontsize=13)

axes[1].set_ylabel('Speedup', fontsize=12)
axes[1].set_title('Speedup Comparison', fontsize=13, fontweight='bold')
axes[1].axhline(y=1, color='gray', linestyle='--', alpha=0.5)
axes[1].grid(True, alpha=0.3, axis='y')
axes[1].tick_params(axis='x', rotation=15)

plt.tight_layout()
plt.savefig('eagle_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# Speedup sensitivity analysis
# ============================================================

# How does EAGLE speedup depend on acceptance rate and draft length?

acceptance_rates = np.arange(0.3, 0.95, 0.05)
draft_lengths = [2, 4, 6, 8, 10]

fig, ax = plt.subplots(figsize=(12, 7))

for dl in draft_lengths:
    speedups = []
    for ar in acceptance_rates:
        # Expected accepted tokens per round
        expected_accepted = sum(ar**i for i in range(1, dl + 1))
        tokens_per_round = expected_accepted + 1  # +1 correction
        
        # Time per round: draft time + verify time
        draft_time_total = dl * eagle_head_time
        verify_time = target_model_time
        round_time = draft_time_total + verify_time
        
        # Autoregressive time for same tokens
        ar_time_total = tokens_per_round * target_model_time
        
        speedup = ar_time_total / round_time
        speedups.append(speedup)
    
    ax.plot(acceptance_rates * 100, speedups, '-o', linewidth=2, 
            markersize=6, label=f'Draft length = {dl}')

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='No speedup')
ax.set_xlabel('Acceptance Rate (%)', fontsize=12)
ax.set_ylabel('Theoretical Speedup', fontsize=12)
ax.set_title('EAGLE Speedup vs Acceptance Rate and Draft Length',
            fontsize=14, fontweight='bold')
ax.legend(fontsize=10, title='Configuration')
ax.grid(True, alpha=0.3)
ax.set_ylim(0.5, 5)

# Add annotation for typical EAGLE operating point
ax.annotate('Typical EAGLE\noperating point',
            xy=(75, 3.0), fontsize=11, fontweight='bold',
            xytext=(55, 4.2),
            arrowprops=dict(arrowstyle='->', color='green', lw=2),
            color='green')

plt.tight_layout()
plt.savefig('speedup_sensitivity.png', dpi=150, bbox_inches='tight')
plt.show()

print("Key Observations:")
print("1. Higher acceptance rate → more speedup (obviously)")
print("2. Longer drafts help only when acceptance rate is high")
print("3. EAGLE typically achieves 70-80% acceptance → 2.5-4x speedup")
print("4. Diminishing returns beyond draft length ~8")

---

## Section 7: Training the EAGLE Head

The EAGLE head is trained to minimize the **feature regression loss** -- the difference between predicted and actual hidden states.

$$\mathcal{L} = ||\hat{h}_{t+1} - h_{t+1}||_2^2$$

Training data comes from running the target model on a text corpus and collecting (hidden_state, next_hidden_state) pairs.

In [None]:
# ============================================================
# Training simulation for EAGLE head
# ============================================================

torch.manual_seed(42)

# Create a "target model" that generates hidden state sequences
# (In practice, you'd run the real target model)
class SimulatedTargetModel(nn.Module):
    """Simulates a target model generating correlated hidden states."""
    def __init__(self, hidden_dim):
        super().__init__()
        self.transition = nn.Linear(hidden_dim, hidden_dim)
        nn.init.orthogonal_(self.transition.weight)
    
    def generate_sequence(self, batch_size, seq_len, hidden_dim):
        h = torch.randn(batch_size, hidden_dim)
        hidden_states = [h]
        for _ in range(seq_len - 1):
            h = self.transition(h) + 0.1 * torch.randn_like(h)
            hidden_states.append(h)
        return torch.stack(hidden_states, dim=1)


# Setup
target_sim = SimulatedTargetModel(HIDDEN_DIM)
eagle_trainable = SimplifiedEAGLEHead(HIDDEN_DIM, EMBED_DIM)
optimizer = torch.optim.Adam(eagle_trainable.parameters(), lr=1e-3)

# Training
num_epochs = 100
batch_size = 32
seq_len = 20
losses = []

print("Training EAGLE head...")
for epoch in range(num_epochs):
    # Generate training data from "target model"
    with torch.no_grad():
        hidden_seq = target_sim.generate_sequence(batch_size, seq_len, HIDDEN_DIM)
    
    # Random token embeddings (in practice, these come from actual tokens)
    embeddings = torch.randn(batch_size, seq_len, EMBED_DIM)
    
    # Train: predict h_{t+1} from (h_t, e_t)
    total_loss = 0
    for t in range(seq_len - 1):
        predicted_h = eagle_trainable(hidden_seq[:, t], embeddings[:, t])
        target_h = hidden_seq[:, t + 1]
        
        loss = F.mse_loss(predicted_h, target_h)
        total_loss += loss
    
    avg_loss = total_loss / (seq_len - 1)
    optimizer.zero_grad()
    avg_loss.backward()
    optimizer.step()
    
    losses.append(avg_loss.item())
    if (epoch + 1) % 20 == 0:
        print(f"  Epoch {epoch+1}/{num_epochs}: Loss = {avg_loss.item():.4f}")

# Plot training loss
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(losses, linewidth=2, color='steelblue')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('MSE Loss', fontsize=12)
ax.set_title('EAGLE Head Training Loss', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('eagle_training.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTraining complete!")
print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Reduction: {(1 - losses[-1]/losses[0])*100:.1f}%")

---

## Summary & Key Takeaways

| Concept | Key Insight |
|---------|-------------|
| **EAGLE Core Idea** | Use hidden states from target model to draft tokens, not a separate model |
| **Feature Regression** | Lightweight head predicts next hidden state from current state + embedding |
| **Token Tree** | Tree structure increases chance of finding correct paths |
| **Verification** | Single forward pass verifies entire tree using attention masking |
| **Acceptance Rate** | EAGLE achieves ~70-80% (vs ~60-70% for standard spec decoding) |
| **Speedup** | 2-4x typical; depends on acceptance rate and draft length |
| **Training Cost** | Much cheaper than training a draft model (feature regression only) |

### EAGLE vs Alternatives

| Method | Draft Source | Acceptance Rate | Overhead | Setup Cost |
|--------|-------------|----------------|----------|------------|
| Standard Speculative | Small model | 60-70% | High (full model) | Train/find draft model |
| EAGLE | Hidden states | 70-80% | Very low (tiny head) | Train head on features |
| EAGLE-2 | Hidden states + tree | 75-85% | Low (tree overhead) | Train head + tree config |
| Self-speculative | Same model (early exit) | 50-65% | None | None |

---

## Exercises

### Exercise 1: Tree Width Optimization
Experiment with different tree configurations (branch factor, depth, total budget). Find the optimal tree shape for different acceptance rates.

### Exercise 2: Real Model Hidden States
Using GPT-2, extract actual hidden states and train an EAGLE head to predict them. Compare the prediction quality with random baselines.

### Exercise 3: Adaptive Draft Length
Implement an adaptive system that adjusts draft length based on recent acceptance rates (draft fewer tokens when acceptance is low).

### Exercise 4: EAGLE vs Medusa
Compare EAGLE's feature regression approach with Medusa's multiple independent heads. What are the trade-offs?

In [None]:
# ============================================================
# Exercise 3 Starter: Adaptive Draft Length
# ============================================================

class AdaptiveEAGLE:
    """Adapts draft length based on recent acceptance rates."""
    
    def __init__(self, min_draft=2, max_draft=10, window_size=20):
        self.min_draft = min_draft
        self.max_draft = max_draft
        self.window_size = window_size
        self.recent_acceptance = []
        self.current_draft_length = 5  # Start in the middle
    
    def update(self, num_accepted, num_drafted):
        """Update based on last round's results."""
        rate = num_accepted / max(num_drafted, 1)
        self.recent_acceptance.append(rate)
        
        if len(self.recent_acceptance) > self.window_size:
            self.recent_acceptance.pop(0)
        
        # Adjust draft length
        avg_rate = np.mean(self.recent_acceptance)
        if avg_rate > 0.8:
            self.current_draft_length = min(self.max_draft, self.current_draft_length + 1)
        elif avg_rate < 0.5:
            self.current_draft_length = max(self.min_draft, self.current_draft_length - 1)
    
    def get_draft_length(self):
        return self.current_draft_length

# Demo
adaptive = AdaptiveEAGLE()
print("Adaptive Draft Length Demo:")
print(f"Initial draft length: {adaptive.get_draft_length()}")

# Simulate varying acceptance rates
for i in range(30):
    if i < 10:
        acc = np.random.binomial(adaptive.get_draft_length(), 0.85)  # High acceptance
    elif i < 20:
        acc = np.random.binomial(adaptive.get_draft_length(), 0.4)   # Low acceptance
    else:
        acc = np.random.binomial(adaptive.get_draft_length(), 0.7)   # Medium
    
    dl = adaptive.get_draft_length()
    adaptive.update(acc, dl)
    if (i + 1) % 5 == 0:
        print(f"  Round {i+1}: Accepted {acc}/{dl}, New draft length: {adaptive.get_draft_length()}")

print("\nThe adaptive system increases draft length when acceptance is high")
print("and decreases it when acceptance is low, optimizing throughput.")