In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1_yOuaRupWcvvBB5tNnjVrtDllXqg6x4Q", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/03_00_intro.mp3"))


In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ On-Policy Distillation: Learning *What* to Improve with Hindsight Hints

*Part 3 of the Vizuara series on OpenClaw-RL*
*Estimated time: 55 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/openclaw-rl/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
#@title üéß Listen: Why This Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_01_why_this_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 1. Why Does This Matter?

In the previous notebook, we built Binary RL ‚Äî a system that tells the model "that was good" or "that was bad." But it never tells the model **what** to do differently.

Imagine a music teacher. One teacher just says "Wrong note!" after every mistake. Another teacher says "You played F# instead of F natural ‚Äî flatten your finger on the fourth fret." Which teacher helps you improve faster?

**On-Policy Distillation (OPD)** is the second teacher. Instead of reducing feedback to a scalar (+1 or -1), it extracts a **textual hint** from the user's correction and uses it to create rich, **token-level** training signals.

By the end of this notebook, you will have implemented:
- **Hindsight hint extraction** from user feedback
- **Enhanced prompt construction** (original prompt + hint)
- **Teacher-student log-probability comparison** at the token level
- **Token-level advantage computation** ($A_t = \log \pi_{\text{teacher}} - \log \pi_{\text{student}}$)
- A full OPD training loop with a side-by-side comparison against Binary RL

In [None]:
# üéØ Teaser: Token-level advantages show EXACTLY which tokens need to change
#
# Token:     "Here"  "is"  "a"  "JavaScript"  "sorting"  "function"
# Advantage:  +0.1   +0.1  +0.0    +3.8         +0.2       +0.1
#                                    ‚Üë
#                         This token should DEFINITELY change!
#                      (Teacher with hint would say "Python" here)

In [None]:
#@title üéß Listen: Building Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_02_building_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 2. Building Intuition

Let us think carefully about what information is lost when we reduce feedback to a scalar.

Consider this conversation:

**User:** "Write me a function to sort a list in Python."

**Assistant:** "Here is a sorting function in JavaScript:
```
function sortArray(arr) { return arr.sort((a, b) => a - b); }
```"

**User:** "No, I said Python not JavaScript."

With **Binary RL**, the system assigns reward = -1 to the entire response. Every token gets the same negative signal. The model learns: "This whole response was bad." But *why* was it bad? Was it the function name? The logic? The formatting? Binary RL cannot say.

With **OPD**, the system:
1. Reads the user's correction: "No, I said Python not JavaScript"
2. Extracts a **hindsight hint**: "The user wants Python code, not JavaScript"
3. Feeds the original prompt + hint to the same model
4. The model (now acting as a "teacher" because it has the hint) would generate Python code
5. Compares the teacher's token probabilities to the student's token probabilities

The result? A **token-level advantage map** that says: "The tokens 'JavaScript', 'function', 'arr.sort', etc. should all change. But 'sorting' and 'function' are fine concepts."

This is dramatically more informative than a single -1.

### ü§î Think About This

Here is a subtle point: the "teacher" in OPD is **the same model** as the student. The only difference is that the teacher sees the enhanced prompt (with the hint). Why is this important? Why not use a different, stronger model as the teacher?

(Answer: using the same model ensures the token-level differences are meaningful ‚Äî they reflect what *this specific model* would do differently with more information, not what a completely different model would do.)

In [None]:
#@title üéß Listen: Mathematics
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_03_mathematics.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 3. The Mathematics

### 3.1 Hindsight Hint Extraction

Given a user feedback message $f$, a judge model extracts a short textual hint $h$:

$$h = \text{Judge}(f)$$

For example:
- $f$ = "No, I said Python not JavaScript" ‚Üí $h$ = "Use Python instead of JavaScript"
- $f$ = "That is too verbose, be more concise" ‚Üí $h$ = "Keep the response short and direct"

### 3.2 Enhanced Prompt

The enhanced prompt concatenates the original prompt with the hint:

$$s_{\text{enhanced}} = s \oplus h$$

where $\oplus$ denotes concatenation.

### 3.3 Token-Level Advantage

The advantage at token position $t$ is the log-probability gap between teacher and student:

$$A_t = \log \pi_{\text{teacher}}(a_t \mid s_{\text{enhanced}}) - \log \pi_{\theta}(a_t \mid s)$$

Computationally: at each token, we compare how confident the teacher (with the hint) is versus the student (without the hint). A large positive $A_t$ means the teacher is much more confident ‚Äî this token needs a big correction. A near-zero $A_t$ means both agree ‚Äî no correction needed.

Let us work through a concrete example. At token position $t$:
- Teacher (with hint) assigns log-prob = $-0.5$ to "Python"
- Student (without hint) assigns log-prob = $-2.3$ to "Python"
- $A_t = -0.5 - (-2.3) = 1.8$ ‚Üí Strong signal: "increase the probability of this token"

At another position:
- Teacher assigns $-1.2$, Student assigns $-1.0$
- $A_t = -1.2 - (-1.0) = -0.2$ ‚Üí Tiny signal: "student was already fine here"

### 3.4 OPD Loss

The OPD loss uses the same PPO-style clipped surrogate as Binary RL, but with **token-level advantages** instead of a single broadcasted scalar:

$$J_{\text{OPD}}(\theta) = \mathbb{E}\left[\min\left(\rho_t A_t,\; \text{clip}(\rho_t, 1-\epsilon, 1+\epsilon) A_t\right)\right]$$

The key difference from Binary RL: each token has its own advantage $A_t$, so the model receives directional guidance at every position.

In [None]:
#@title üéß Code Walkthrough: Setup And Imports
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_04_setup_and_imports.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 4. Let's Build It ‚Äî Component by Component

### 4.1 Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from typing import List, Tuple, Optional
import re

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Using device: {device}")

In [None]:
#@title üéß Code Walkthrough: Hint Extractor
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_05_hint_extractor.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 4.2 The Hindsight Hint Extractor

In the real system, hint extraction is done by a judge LLM. Here we build a rule-based version that captures the core idea:

In [None]:
class HindsightHintExtractor:
    """
    Extracts actionable hints from user feedback.
    In production, this is done by a judge LLM with majority voting.
    Here we use pattern matching to demonstrate the concept.
    """

    # Patterns that indicate specific corrections
    CORRECTION_PATTERNS = [
        (r"(?:no|not)\s*,?\s*(?:I\s+)?(?:said|asked|wanted|meant)\s+(.+)",
         "The user wants: {}"),
        (r"(?:use|try|switch to)\s+(.+?)(?:\s+instead)?",
         "Use {} instead"),
        (r"(?:too|very)\s+(verbose|long|short|brief|formal|casual)",
         "Adjust tone: be less {}"),
        (r"(?:should|must|need to)\s+(.+)",
         "Important requirement: {}"),
    ]

    # Hints that are too trivial to be useful
    TRIVIAL_HINTS = ["ok", "fine", "sure", "yes", "no", "thanks"]

    def extract_hint(self, feedback: str) -> Optional[str]:
        """
        Extract a hindsight hint from user feedback.

        Args:
            feedback: The user's response after the assistant's message

        Returns:
            A short textual hint, or None if feedback is not corrective
        """
        feedback_lower = feedback.lower().strip()

        # Filter out trivial feedback
        if feedback_lower in self.TRIVIAL_HINTS:
            return None

        # Try each pattern
        for pattern, template in self.CORRECTION_PATTERNS:
            match = re.search(pattern, feedback_lower)
            if match:
                extracted = match.group(1).strip()
                hint = template.format(extracted)
                return hint

        # If no pattern matches but feedback contains negative sentiment,
        # use the feedback itself as the hint
        negative_words = ["wrong", "incorrect", "bad", "no", "not", "don't"]
        if any(word in feedback_lower for word in negative_words):
            return f"User correction: {feedback[:100]}"

        return None

    def extract_with_voting(self, feedback: str, num_votes: int = 3) -> Optional[str]:
        """
        Extract hints with majority voting (simulated).
        In production, the judge LLM generates m hints and the
        longest, most informative one is kept.

        Args:
            feedback: User feedback message
            num_votes: Number of extraction attempts

        Returns:
            The best hint (longest non-trivial one), or None
        """
        hints = []
        for _ in range(num_votes):
            hint = self.extract_hint(feedback)
            if hint is not None:
                hints.append(hint)

        if not hints:
            return None

        # Keep the longest, most informative hint
        return max(hints, key=len)

# Test the hint extractor
extractor = HindsightHintExtractor()

test_feedbacks = [
    "No, I said Python not JavaScript.",
    "Use Flask instead of Django.",
    "That's too verbose, be more concise.",
    "You should add error handling.",
    "Perfect, thanks!",
    "Wrong, the answer is 42.",
    "Great job!",
]

print("Hindsight Hint Extraction:")
for fb in test_feedbacks:
    hint = extractor.extract_hint(fb)
    emoji = "üí°" if hint else "‚ö™"
    print(f"  {emoji} \"{fb}\"")
    print(f"     ‚Üí Hint: {hint or '(no correction detected)'}\n")

In [None]:
#@title üéß Code Walkthrough: Enhanced Prompt And Log Probs
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_06_enhanced_prompt_and_log_probs.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 4.3 Enhanced Prompt Construction

In [None]:
class EnhancedPromptBuilder:
    """
    Constructs enhanced prompts by appending hindsight hints.
    The enhanced prompt is what the model SHOULD have seen to get it right.
    """

    def __init__(self, hint_prefix: str = "\n[HINT: ", hint_suffix: str = "]\n"):
        self.hint_prefix = hint_prefix
        self.hint_suffix = hint_suffix

    def build(self, original_prompt: str, hint: str) -> str:
        """
        Build an enhanced prompt by appending the hint.

        Args:
            original_prompt: The original user prompt
            hint: The extracted hindsight hint

        Returns:
            Enhanced prompt with hint appended
        """
        return original_prompt + self.hint_prefix + hint + self.hint_suffix

# Demonstrate
builder = EnhancedPromptBuilder()

original = "Write me a function to sort a list."
hint = "The user wants: Python not JavaScript"
enhanced = builder.build(original, hint)

print("Original prompt:")
print(f"  \"{original}\"\n")
print("Enhanced prompt (with hint):")
print(f"  \"{enhanced}\"")
print("\nThe teacher model sees this enhanced prompt and generates a better response!")

### 4.4 Token-Level Log-Probability Computation

Now the core mechanic: computing log-probabilities for both the teacher and student models.

In [None]:
class TokenLogProbComputer:
    """
    Computes token-level log-probabilities for teacher and student.
    """

    def __init__(self, vocab_size=50, hidden_size=32, max_seq_len=30):
        """Create a simple model for demonstration."""
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        # Simplified teacher/student as embedding + linear
        # In reality, these are the SAME model with different inputs
        self.model = nn.Sequential(
            nn.Embedding(vocab_size, hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, vocab_size),
        ).to(device)

    def get_log_probs(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Get log-probabilities for each token position.

        Args:
            input_ids: (batch, seq_len) ‚Äî token IDs

        Returns:
            log_probs: (batch, seq_len, vocab_size)
        """
        logits = self.model(input_ids)
        return F.log_softmax(logits, dim=-1)

    def get_token_log_probs(self, input_ids: torch.Tensor, target_ids: torch.Tensor) -> torch.Tensor:
        """
        Get log-probability of specific target tokens at each position.

        Args:
            input_ids: (batch, seq_len) ‚Äî input context
            target_ids: (batch, seq_len) ‚Äî target tokens to score

        Returns:
            selected_log_probs: (batch, seq_len)
        """
        all_log_probs = self.get_log_probs(input_ids)
        # Gather the log-prob for each target token
        selected = all_log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
        return selected

# Create the model
computer = TokenLogProbComputer()

# Demo: compute log-probs for teacher vs student
batch_size = 2
seq_len = 10

# Same response tokens for both teacher and student
response_ids = torch.randint(0, 50, (batch_size, seq_len)).to(device)

# Student sees the original prompt
student_input = torch.randint(0, 50, (batch_size, seq_len)).to(device)
# Teacher sees the enhanced prompt (different tokens due to hint)
teacher_input = torch.randint(0, 50, (batch_size, seq_len)).to(device)

student_log_probs = computer.get_token_log_probs(student_input, response_ids)
teacher_log_probs = computer.get_token_log_probs(teacher_input, response_ids)

print(f"Student log-probs shape: {student_log_probs.shape}")
print(f"Teacher log-probs shape: {teacher_log_probs.shape}")
print(f"\nSample student log-probs: {student_log_probs[0, :5].detach().cpu().tolist()}")
print(f"Sample teacher log-probs: {teacher_log_probs[0, :5].detach().cpu().tolist()}")

In [None]:
#@title üéß What to Look For: Token Advantages And Heatmap
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_07_token_advantages_and_heatmap.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### 4.5 Token-Level Advantage Computation

This is the heart of OPD. At each token position, we compute how much the teacher (with hint) disagrees with the student (without hint):

In [None]:
def compute_token_advantages(
    teacher_log_probs: torch.Tensor,
    student_log_probs: torch.Tensor
) -> torch.Tensor:
    """
    Compute token-level OPD advantages.

    A_t = log œÄ_teacher(a_t | s + hint) - log œÄ_student(a_t | s)

    Args:
        teacher_log_probs: (batch, seq_len) ‚Äî teacher's per-token log-probs
        student_log_probs: (batch, seq_len) ‚Äî student's per-token log-probs

    Returns:
        advantages: (batch, seq_len) ‚Äî token-level advantages
    """
    return teacher_log_probs - student_log_probs

# Compute advantages
advantages = compute_token_advantages(teacher_log_probs, student_log_probs)

print(f"Token-level advantages shape: {advantages.shape}")
print(f"Sample advantages: {advantages[0].detach().cpu().numpy().round(3)}")
print(f"\nMean advantage: {advantages.mean().item():.4f}")
print(f"Max advantage:  {advantages.max().item():.4f}")
print(f"Min advantage:  {advantages.min().item():.4f}")

### üìä Visualization: Token-Level Advantage Heatmap

In [None]:
def visualize_token_advantages(advantages, tokens=None, title="Token-Level OPD Advantages"):
    """Visualize token-level advantages as a heatmap."""
    adv_np = advantages.detach().cpu().numpy()

    fig, ax = plt.subplots(figsize=(14, 3))

    # Normalize colormap around zero
    vmax = max(abs(adv_np.min()), abs(adv_np.max()))
    norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)

    im = ax.imshow(adv_np, cmap='RdYlGn', norm=norm, aspect='auto')
    plt.colorbar(im, ax=ax, label='Advantage (green=increase, red=decrease)')

    if tokens:
        ax.set_xticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=9)
    else:
        ax.set_xlabel('Token Position')

    ax.set_ylabel('Batch')
    ax.set_title(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize a simulated meaningful example
# Simulate a scenario where the teacher strongly disagrees at position 3 ("JavaScript")
simulated_teacher_lp = torch.tensor([[-1.0, -0.8, -0.5, -0.5, -0.9, -0.7, -1.1, -0.6]])
simulated_student_lp = torch.tensor([[-1.1, -0.9, -0.6, -2.3, -1.0, -0.8, -1.2, -0.7]])
simulated_adv = compute_token_advantages(simulated_teacher_lp, simulated_student_lp)

tokens = ["Here", "is", "a", "Python", "sorting", "function", "that", "works"]
visualize_token_advantages(simulated_adv, tokens=tokens,
                           title="OPD Advantage: Teacher Strongly Prefers 'Python' at Position 3")

print("Token-by-token breakdown:")
for i, (tok, adv) in enumerate(zip(tokens, simulated_adv[0].tolist())):
    direction = "‚Üë increase" if adv > 0.1 else "‚Üì decrease" if adv < -0.1 else "‚Üí keep"
    print(f"  Token '{tok}': A_t = {adv:+.2f}  ({direction})")

In [None]:
#@title üéß Listen: Stop And Think
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_08_stop_and_think.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### ‚úã Stop and Think

Look at the advantage values above. Notice that:
- Position 3 ("Python") has $A_t = 1.8$ ‚Äî a very large positive advantage
- Most other positions have small values near zero

This is the power of OPD: it pinpoints **exactly which tokens** need to change. Binary RL would assign the same scalar to all 8 tokens. OPD gives each token its own gradient direction.

*Take a moment to appreciate this before continuing.*

In [None]:
#@title üéß Before You Start: Todo1 Opd Loss
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_09_todo1_opd_loss.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 5. üîß Your Turn

### TODO 1: Implement the OPD Loss Function

Combine token-level advantages with the clipped surrogate loss:

In [None]:
def opd_loss(
    log_probs_new: torch.Tensor,     # (batch, seq_len) ‚Äî current policy
    log_probs_ref: torch.Tensor,     # (batch, seq_len) ‚Äî reference policy
    token_advantages: torch.Tensor,   # (batch, seq_len) ‚Äî per-token OPD advantages
    eps: float = 0.2,
    mask: torch.Tensor = None,        # (batch, seq_len) ‚Äî optional padding mask
) -> torch.Tensor:
    """
    Compute the OPD clipped surrogate loss.

    Unlike Binary RL where advantages are per-response (scalar),
    OPD advantages are per-token (vector).

    Args:
        log_probs_new: Log probs under current policy
        log_probs_ref: Log probs under reference policy
        token_advantages: Token-level advantages from teacher-student gap
        eps: Clipping bound
        mask: Binary mask (1 for real tokens, 0 for padding)

    Returns:
        Scalar loss value

    Steps:
        1. Compute ratio: œÅ_t = exp(log_new - log_ref)
        2. Unclipped term: œÅ_t * A_t  (element-wise, both are per-token)
        3. Clipped term: clip(œÅ_t) * A_t
        4. Min of unclipped and clipped
        5. Apply mask if provided
        6. Return mean over all tokens
    """
    # ============ TODO ============
    # Step 1: Compute ratio
    # Step 2: Unclipped objective
    # Step 3: Clipped objective
    # Step 4: Pessimistic bound (min)
    # Step 5: Apply mask and average
    # ==============================

    loss = ???  # YOUR CODE HERE

    return loss

# ‚úÖ Verification
batch, seq = 4, 10
log_new = torch.randn(batch, seq) * 0.1 - 2.0
log_ref = torch.randn(batch, seq) * 0.1 - 2.0
tok_advs = torch.randn(batch, seq) * 0.5  # Token-level advantages

loss = opd_loss(log_new, log_ref, tok_advs)
assert loss.dim() == 0, f"‚ùå Loss should be scalar, got shape {loss.shape}"
assert not torch.isnan(loss), "‚ùå Loss is NaN!"
print(f"‚úÖ OPD loss computed successfully: {loss.item():.4f}")

In [None]:
#@title üéß Before You Start: Todo2 Hint Filtering
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_10_todo2_hint_filtering.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### TODO 2: Implement Hint Quality Filtering

Not all hints are useful. Trivial hints like "the response was wrong" add no information beyond what Binary RL already provides. Implement a quality filter:

In [None]:
def filter_hints(hints: List[Optional[str]], min_length: int = 10) -> List[Optional[str]]:
    """
    Filter hints by quality. Keep only informative hints.

    Args:
        hints: List of extracted hints (some may be None)
        min_length: Minimum character length for a hint to be considered informative

    Returns:
        Filtered list where low-quality hints are replaced with None

    Rules:
        1. None hints stay None
        2. Hints shorter than min_length characters ‚Üí None (too trivial)
        3. Hints that are purely negative without direction ‚Üí None
           (e.g., "That was wrong" has no actionable information)
        4. All other hints are kept
    """
    # ============ TODO ============
    # Filter each hint based on the rules above
    # ==============================

    filtered = ???  # YOUR CODE HERE

    return filtered

# ‚úÖ Verification
test_hints = [
    "The user wants: Python not JavaScript",   # Good ‚Äî specific correction
    "wrong",                                     # Bad ‚Äî too short, no direction
    None,                                        # None ‚Äî no hint extracted
    "Adjust tone: be less verbose and more concise when explaining code", # Good
    "bad",                                       # Bad ‚Äî too short
    "Use Flask instead of Django for the web framework",  # Good
]

filtered = filter_hints(test_hints)
assert filtered[0] is not None, "‚ùå First hint should be kept"
assert filtered[1] is None, "‚ùå 'wrong' should be filtered out"
assert filtered[2] is None, "‚ùå None should stay None"
assert filtered[3] is not None, "‚ùå Long specific hint should be kept"
assert filtered[4] is None, "‚ùå 'bad' should be filtered out"
assert filtered[5] is not None, "‚ùå Specific framework hint should be kept"
print("‚úÖ Hint quality filtering works correctly!")
print(f"   Kept {sum(1 for h in filtered if h is not None)}/{len(test_hints)} hints")

In [None]:
#@title üéß Code Walkthrough: Full Pipeline
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_11_full_pipeline.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 6. Putting It All Together ‚Äî The Full OPD Pipeline

In [None]:
class OnPolicyDistillationPipeline:
    """
    The complete OPD pipeline: from user feedback to token-level training.
    """

    def __init__(self, model, vocab_size=50, hidden_size=32):
        self.hint_extractor = HindsightHintExtractor()
        self.prompt_builder = EnhancedPromptBuilder()
        self.model = model
        self.vocab_size = vocab_size

    def process_sample(self, prompt: str, response: str, feedback: str):
        """
        Process a single (prompt, response, feedback) triple.

        Returns:
            dict with keys: hint, enhanced_prompt, token_advantages, or None if no hint
        """
        # Step 1: Extract hint
        hint = self.hint_extractor.extract_with_voting(feedback)
        if hint is None:
            return None

        # Step 2: Build enhanced prompt
        enhanced = self.prompt_builder.build(prompt, hint)

        return {
            "hint": hint,
            "original_prompt": prompt,
            "enhanced_prompt": enhanced,
            "response": response,
            "feedback": feedback,
        }

# Create pipeline
pipeline = OnPolicyDistillationPipeline(computer)

# Process several examples
examples = [
    {
        "prompt": "Write a sorting function",
        "response": "function sortArray(arr) { return arr.sort(); }",
        "feedback": "No, I said Python not JavaScript."
    },
    {
        "prompt": "Explain machine learning",
        "response": "Machine learning is a comprehensive field that encompasses...(500 words)",
        "feedback": "That's too verbose, be more concise."
    },
    {
        "prompt": "What is 2+2?",
        "response": "2+2 = 4",
        "feedback": "Thanks!"
    },
    {
        "prompt": "Help me with my Flask app",
        "response": "Here's how to do it with Django...",
        "feedback": "Use Flask instead of Django."
    },
]

print("OPD Pipeline ‚Äî Processing Examples:\n")
for ex in examples:
    result = pipeline.process_sample(**ex)
    if result:
        print(f"  ‚úÖ Prompt: \"{ex['prompt']}\"")
        print(f"     Hint: \"{result['hint']}\"")
        print(f"     Enhanced: \"{result['enhanced_prompt']}\"\n")
    else:
        print(f"  ‚ö™ Prompt: \"{ex['prompt']}\"")
        print(f"     Feedback: \"{ex['feedback']}\" ‚Üí No actionable hint\n")

In [None]:
#@title üéß Code Walkthrough: Training Comparison
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_12_training_comparison.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 7. Training: OPD vs Binary RL Comparison

Let us compare both approaches on a synthetic task where OPD should shine:

In [None]:
# Create simple teacher and student models
class SimpleSeqModel(nn.Module):
    """A tiny sequence model for OPD demonstration."""
    def __init__(self, vocab_size=50, hidden=32):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden)
        self.fc1 = nn.Linear(hidden, hidden)
        self.fc2 = nn.Linear(hidden, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x):
        h = F.relu(self.fc1(self.embed(x)))
        return F.log_softmax(self.fc2(h), dim=-1)

    def get_token_log_probs(self, input_ids, target_ids):
        log_probs = self.forward(input_ids)
        return log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)

# Training comparison
vocab_size = 50
seq_len = 15

student_model = SimpleSeqModel(vocab_size).to(device)
ref_model = SimpleSeqModel(vocab_size).to(device)
ref_model.load_state_dict(student_model.state_dict())

# Binary RL baseline (separate copy)
binary_model = SimpleSeqModel(vocab_size).to(device)
binary_model.load_state_dict(student_model.state_dict())

opt_opd = torch.optim.Adam(student_model.parameters(), lr=3e-4)
opt_binary = torch.optim.Adam(binary_model.parameters(), lr=3e-4)

# Generate synthetic data where OPD has an advantage
# The "correct" tokens are known, so we can compute teacher log-probs
target_pattern = torch.arange(seq_len).to(device) % vocab_size  # Repeating pattern

opd_losses = []
binary_losses = []
opd_accuracies = []
binary_accuracies = []

num_steps = 150
batch_size = 16

print("Training OPD vs Binary RL...\n")

for step in range(num_steps):
    # Generate batch
    inputs = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
    targets = target_pattern.unsqueeze(0).expand(batch_size, -1)

    # === OPD Training ===
    student_lp = student_model.get_token_log_probs(inputs, targets)
    with torch.no_grad():
        ref_lp = ref_model.get_token_log_probs(inputs, targets)
        # Teacher log-probs (simulated: teacher is more confident about correct tokens)
        teacher_lp = ref_lp + torch.randn_like(ref_lp) * 0.1 + 0.5  # Teacher is better

    token_advs = teacher_lp - student_lp.detach()
    ratio = torch.exp(student_lp - ref_lp.detach())
    unclipped = ratio * token_advs
    clipped = torch.clamp(ratio, 0.8, 1.2) * token_advs
    opd_obj = torch.min(unclipped, clipped).mean()

    opt_opd.zero_grad()
    (-opd_obj).backward()
    torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
    opt_opd.step()
    opd_losses.append(opd_obj.item())

    # === Binary RL Training ===
    binary_lp = binary_model.get_token_log_probs(inputs, targets)
    with torch.no_grad():
        ref_lp_b = ref_model.get_token_log_probs(inputs, targets)

    # Binary rewards: +1 if most tokens match target pattern
    preds = binary_model.forward(inputs).argmax(dim=-1)
    match_rate = (preds == targets).float().mean(dim=1)
    rewards = (match_rate > 0.3).float() * 2 - 1  # +1 if >30% match, else -1

    # Broadcast scalar reward to all tokens
    mean_r, std_r = rewards.mean(), rewards.std() + 1e-8
    advs = ((rewards - mean_r) / std_r).unsqueeze(1).expand_as(binary_lp)

    ratio_b = torch.exp(binary_lp - ref_lp_b.detach())
    unclipped_b = ratio_b * advs
    clipped_b = torch.clamp(ratio_b, 0.8, 1.2) * advs
    binary_obj = torch.min(unclipped_b, clipped_b).mean()

    opt_binary.zero_grad()
    (-binary_obj).backward()
    torch.nn.utils.clip_grad_norm_(binary_model.parameters(), 1.0)
    opt_binary.step()
    binary_losses.append(binary_obj.item())

    # Compute accuracies
    with torch.no_grad():
        opd_preds = student_model.forward(inputs).argmax(dim=-1)
        binary_preds = binary_model.forward(inputs).argmax(dim=-1)
        opd_accuracies.append((opd_preds == targets).float().mean().item())
        binary_accuracies.append((binary_preds == targets).float().mean().item())

    if (step + 1) % 50 == 0:
        print(f"Step {step+1}: OPD acc={opd_accuracies[-1]:.3f}, "
              f"Binary acc={binary_accuracies[-1]:.3f}")

print("\n‚úÖ Training complete!")

In [None]:
#@title üéß What to Look For: Visualization Comparison
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_13_visualization_comparison.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


### üìä Visualization: OPD vs Binary RL

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy curves
window = 5
opd_smooth = np.convolve(opd_accuracies, np.ones(window)/window, mode='valid')
binary_smooth = np.convolve(binary_accuracies, np.ones(window)/window, mode='valid')

axes[0].plot(opd_smooth, linewidth=2.5, color='#2ecc71', label='OPD (token-level)')
axes[0].plot(binary_smooth, linewidth=2.5, color='#3498db', label='Binary RL (scalar)')
axes[0].set_xlabel('Training Step', fontsize=12)
axes[0].set_ylabel('Token Accuracy', fontsize=12)
axes[0].set_title('OPD vs Binary RL: Learning Speed', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Token-level advantage distribution at the end of training
with torch.no_grad():
    test_inputs = torch.randint(0, vocab_size, (32, seq_len)).to(device)
    test_targets = target_pattern.unsqueeze(0).expand(32, -1)
    final_student_lp = student_model.get_token_log_probs(test_inputs, test_targets)
    final_teacher_lp = ref_model.get_token_log_probs(test_inputs, test_targets) + 0.5
    final_advs = (final_teacher_lp - final_student_lp).cpu().numpy().flatten()

axes[1].hist(final_advs, bins=40, color='#9b59b6', alpha=0.7, edgecolor='white')
axes[1].axvline(x=0, color='#e74c3c', linestyle='--', linewidth=2)
axes[1].set_xlabel('Token-Level Advantage', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].set_title('Final Token Advantage Distribution', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final OPD accuracy:       {opd_accuracies[-1]:.3f}")
print(f"Final Binary RL accuracy: {binary_accuracies[-1]:.3f}")
improvement = opd_accuracies[-1] - binary_accuracies[-1]
print(f"OPD advantage:            {improvement:+.3f}")

In [None]:
#@title üéß What to Look For: Final Output
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_14_final_output.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 8. üéØ Final Output: The Complete OPD Pipeline Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Hint extraction success rate
hint_results = {"Extracted": 0, "No hint": 0}
test_fbs = [
    "No, use Python", "Great!", "Too verbose", "Thanks",
    "Wrong framework", "Perfect", "Should add tests", "OK",
    "Not what I asked for", "Looks good",
]
for fb in test_fbs:
    hint = extractor.extract_hint(fb)
    if hint:
        hint_results["Extracted"] += 1
    else:
        hint_results["No hint"] += 1

axes[0, 0].bar(hint_results.keys(), hint_results.values(),
               color=['#2ecc71', '#95a5a6'], edgecolor='white', linewidth=2)
axes[0, 0].set_title('Hint Extraction Rate', fontsize=12, fontweight='bold')
axes[0, 0].set_ylabel('Count')

# 2. Token advantage heatmap
sample_advs = torch.tensor([[0.1, 0.0, -0.1, 1.8, 0.2, 0.1, -0.2, 0.0]])
im = axes[0, 1].imshow(sample_advs.numpy(), cmap='RdYlGn', aspect='auto',
                        vmin=-2, vmax=2)
axes[0, 1].set_xticks(range(8))
axes[0, 1].set_xticklabels(["Here", "is", "a", "Python", "sort", "fn", "that", "works"],
                            fontsize=9)
plt.colorbar(im, ax=axes[0, 1])
axes[0, 1].set_title('Token-Level Advantages', fontsize=12, fontweight='bold')

# 3. OPD vs Binary learning curves
axes[1, 0].plot(opd_smooth, linewidth=2, color='#2ecc71', label='OPD')
axes[1, 0].plot(binary_smooth, linewidth=2, color='#3498db', label='Binary RL')
axes[1, 0].legend()
axes[1, 0].set_title('Learning Speed Comparison', fontsize=12, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)

# 4. Signal density comparison
categories = ['Signal\nGranularity', 'Information\nDensity', 'Compute\nCost']
opd_vals = [5, 5, 4]      # Token-level, rich, higher
binary_vals = [2, 2, 2]    # Scalar, coarse, lower

x = np.arange(len(categories))
width = 0.3
axes[1, 1].bar(x - width/2, binary_vals, width, label='Binary RL',
               color='#3498db', alpha=0.8)
axes[1, 1].bar(x + width/2, opd_vals, width, label='OPD',
               color='#2ecc71', alpha=0.8)
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(categories)
axes[1, 1].set_ylabel('Score (1-5)')
axes[1, 1].legend()
axes[1, 1].set_title('Binary RL vs OPD Comparison', fontsize=12, fontweight='bold')

plt.suptitle('On-Policy Distillation: Complete Pipeline', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("üéâ Congratulations! You've built On-Policy Distillation from scratch!")
print("   ‚úÖ Hindsight hint extraction from user feedback")
print("   ‚úÖ Enhanced prompt construction")
print("   ‚úÖ Teacher-student log-probability comparison")
print("   ‚úÖ Token-level advantage computation")
print("   ‚úÖ OPD loss function with clipped surrogate")
print("   ‚úÖ Demonstrated OPD's advantage over Binary RL")

In [None]:
#@title üéß Wrap-Up: Reflection And Next Steps
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_15_reflection_and_next_steps.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")


## 9. Reflection and Next Steps

### ü§î Reflection Questions
1. In OPD, the teacher and student are the **same model**. What would happen if we used a much larger model as the teacher? Would that be better or worse?
2. The hint quality filter discards short hints. Can you think of a case where a very short hint is still highly informative? (e.g., "Python" as a hint)
3. OPD requires computing log-probabilities twice (once for teacher, once for student). How could we reduce this compute cost?

### üèÜ Optional Challenges
1. **Weighted token advantages**: Instead of raw $A_t = \log \pi_{\text{teacher}} - \log \pi_{\text{student}}$, implement a version where advantages are weighted by the teacher's confidence.
2. **Adaptive hint selection**: Instead of always keeping the longest hint, implement a scoring function that balances hint length with specificity.
3. **Mixed training**: Implement a training loop that alternates between Binary RL steps (for samples without hints) and OPD steps (for samples with hints).