In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Emergent Reasoning Behaviors and Distillation -- Vizuara

In this notebook, we explore the fascinating emergent behaviors that arise when models are trained with RL, and implement a simple rejection sampling and distillation pipeline.

**What you will build:** Analysis tools for detecting emergent reasoning patterns, and a distillation pipeline that transfers reasoning from a large model to a small one.

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import re
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
torch.manual_seed(42)

In [None]:
!pip install -q transformers

## 1. Why Does This Matter?

When DeepSeek trained their R1 model with RL, something unexpected happened. The model spontaneously developed reasoning strategies that were never taught:
- **Self-verification:** checking its own calculations
- **Backtracking:** catching and correcting mistakes
- **Adaptive depth:** thinking longer on harder problems

These behaviors emerged purely from the reward signal. Understanding how and why they emerge is crucial for building better reasoning models. And distillation lets us transfer these capabilities to smaller, more deployable models.

## 2. Building Intuition

Imagine you are training for a math competition. Nobody tells you to double-check your work. Nobody tells you to re-read the problem when confused. But over time, you discover that these strategies help you score higher.

RL does the same thing for language models. The model discovers reasoning strategies because they lead to higher rewards, not because they are explicitly taught.

### Think About This
- Why would self-verification emerge from a binary (correct/incorrect) reward?
- Why would the model learn to produce longer reasoning for harder problems?

## 3. The Mathematics

### Rejection Sampling for Distillation

Given a large model $\pi_L$ and a dataset of prompts $\{x_1, \ldots, x_N\}$:

1. For each prompt $x_i$, sample $K$ completions: $y_{i,1}, \ldots, y_{i,K} \sim \pi_L(\cdot | x_i)$
2. Keep only correct completions: $\mathcal{D}_{\text{filtered}} = \{(x_i, y_{i,k}) : R(y_{i,k}) = 1\}$
3. Train small model $\pi_S$ via SFT on $\mathcal{D}_{\text{filtered}}$

$$\mathcal{L}_{\text{distill}} = -\sum_{(x,y) \in \mathcal{D}_{\text{filtered}}} \sum_t \log p_{\pi_S}(y_t | y_{<t}, x)$$

**Computational meaning:** We use the large model as a "solution generator," filter for quality, then train the small model on the best solutions. The small model learns from the large model's successful reasoning patterns.

## 4. Let's Build It -- Component by Component

### 4.1 Analyzing Emergent Behaviors

Let us build tools to detect emergent reasoning patterns in model outputs.

In [None]:
def detect_self_verification(text):
    """Detect if the model checks its own work."""
    patterns = [
        r'[Ll]et me (verify|check|double.?check)',
        r'[Vv]erif(y|ying|ication)',
        r'[Cc]heck(ing)?\s*(this|that|my|the)',
        r'[Yy]es,?\s*(that|this) is correct',
        r'[Cc]orrect[.!]',
    ]
    for pattern in patterns:
        if re.search(pattern, text):
            return True
    return False

def detect_backtracking(text):
    """Detect if the model corrects a mistake."""
    patterns = [
        r'[Ww]ait',
        r'[Aa]ctually',
        r'[Ll]et me reconsider',
        r'[Tt]hat\'s (wrong|incorrect|not right)',
        r'[Ii] made (a|an) (error|mistake)',
        r'[Ll]et me (redo|recalculate)',
    ]
    for pattern in patterns:
        if re.search(pattern, text):
            return True
    return False

def measure_reasoning_depth(text):
    """Count the number of reasoning steps."""
    steps = re.findall(r'[Ss]tep \d+', text)
    if steps:
        return len(steps)
    # Count sentences in the think block
    think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
    if think_match:
        sentences = [s.strip() for s in think_match.group(1).split('.') if s.strip()]
        return len(sentences)
    return 0

# Test on example reasoning traces
examples = [
    "<think>\nStep 1: 3 * 7 = 21. Step 2: Let me verify: 7+7+7 = 21. Correct.\n</think>\nThe answer is 21.",
    "<think>\nStep 1: 5 * 3 = 18. Wait, that's wrong. 5 * 3 = 15. Let me recalculate.\n</think>\nThe answer is 15.",
    "<think>\nStep 1: 2 + 2 = 4.\n</think>\nThe answer is 4.",
    "<think>\nStep 1: First, find 20% of 150.\nStep 2: 20% = 0.20.\nStep 3: 0.20 * 150 = 30.\nStep 4: Let me check: 10% of 150 = 15, so 20% = 30. Yes, correct.\n</think>\nThe answer is 30.",
]

print("=== Analyzing Reasoning Traces ===\n")
for i, ex in enumerate(examples):
    print(f"Example {i+1}:")
    print(f"  Self-verification: {detect_self_verification(ex)}")
    print(f"  Backtracking: {detect_backtracking(ex)}")
    print(f"  Reasoning depth: {measure_reasoning_depth(ex)} steps")
    print()

### Visualization Checkpoint: Emergent Behavior Frequency

In [None]:
# Simulate analyzing many outputs from a trained model
# In practice, these would come from actual model generations
np.random.seed(42)

# Simulated data: behavior frequency over training
training_steps = np.arange(0, 1000, 50)

# Self-verification increases with training
verification_rate = 1 / (1 + np.exp(-(training_steps - 300) / 100)) * 0.6 + np.random.randn(len(training_steps)) * 0.03

# Backtracking peaks mid-training then stabilizes
backtracking_rate = 0.3 * np.exp(-((training_steps - 500) ** 2) / (2 * 200**2)) + 0.1 + np.random.randn(len(training_steps)) * 0.02

# Average reasoning depth increases
avg_depth = 2 + 3 / (1 + np.exp(-(training_steps - 400) / 150)) + np.random.randn(len(training_steps)) * 0.3

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

axes[0].plot(training_steps, np.clip(verification_rate, 0, 1), color='#4CAF50', linewidth=2)
axes[0].set_xlabel('Training Steps'); axes[0].set_ylabel('Frequency')
axes[0].set_title('Self-Verification Rate'); axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(0, 1)

axes[1].plot(training_steps, np.clip(backtracking_rate, 0, 1), color='#FF9800', linewidth=2)
axes[1].set_xlabel('Training Steps'); axes[1].set_ylabel('Frequency')
axes[1].set_title('Backtracking Rate'); axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(0, 0.5)

axes[2].plot(training_steps, np.clip(avg_depth, 1, 8), color='#2196F3', linewidth=2)
axes[2].set_xlabel('Training Steps'); axes[2].set_ylabel('Avg Steps')
axes[2].set_title('Average Reasoning Depth'); axes[2].grid(True, alpha=0.3)

plt.suptitle('Emergent Behaviors During RL Training', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### 4.2 Rejection Sampling

The core of distillation: generate many solutions, keep only the correct ones.

In [None]:
def rejection_sampling(generator_fn, problems, K=8):
    """
    Generate K solutions per problem, keep only correct ones.

    Args:
        generator_fn: Function that generates a completion given a prompt
        problems: List of (prompt, ground_truth) tuples
        K: Number of solutions to generate per problem

    Returns:
        filtered_data: List of (prompt, correct_completion) tuples
    """
    filtered_data = []
    stats = {"total_generated": 0, "total_kept": 0}

    for prompt, truth in problems:
        correct_solutions = []
        for _ in range(K):
            completion = generator_fn(prompt)
            stats["total_generated"] += 1

            # Check if correct
            predicted = re.search(r'[Tt]he answer is[:\s]*(\-?[\d,\.]+)', completion)
            if predicted and predicted.group(1).strip('.') == truth:
                correct_solutions.append(completion)

        # Keep all correct solutions
        for sol in correct_solutions:
            filtered_data.append((prompt, sol))
            stats["total_kept"] += 1

    return filtered_data, stats

# Simulate a "large model" that gets 60% of problems right
def simulated_large_model(prompt):
    """Simulate a large model generating chain-of-thought solutions."""
    # In reality, this would be model.generate(...)
    # Here we simulate with template-based responses
    if np.random.random() < 0.6:  # 60% correct
        # Generate a correct chain-of-thought
        numbers = re.findall(r'\d+', prompt)
        if len(numbers) >= 2:
            a, b = int(numbers[0]), int(numbers[1])
            if '*' in prompt or 'times' in prompt:
                answer = a * b
                return f"<think>\nStep 1: {a} * {b} = {answer}.\nStep 2: Let me verify: {answer}.\n</think>\nThe answer is {answer}."
            elif '+' in prompt:
                answer = a + b
                return f"<think>\nStep 1: {a} + {b} = {answer}.\n</think>\nThe answer is {answer}."
            elif '-' in prompt:
                answer = a - b
                return f"<think>\nStep 1: {a} - {b} = {answer}.\n</think>\nThe answer is {answer}."
    # Wrong answer
    return f"<think>\nI think the answer is 999.\n</think>\nThe answer is 999."

# Test rejection sampling
test_problems = [
    ("What is 4 * 8?", "32"),
    ("What is 17 + 25?", "42"),
    ("What is 90 - 34?", "56"),
    ("What is 6 * 7?", "42"),
]

filtered, stats = rejection_sampling(simulated_large_model, test_problems, K=8)
print(f"Generated: {stats['total_generated']}")
print(f"Kept (correct): {stats['total_kept']}")
print(f"Acceptance rate: {stats['total_kept']/stats['total_generated']*100:.1f}%")
print(f"\nFiltered dataset size: {len(filtered)}")
for prompt, sol in filtered[:3]:
    print(f"\n  Q: {prompt}")
    print(f"  A: {sol[:100]}")

### 4.3 Distillation: Train a Smaller Model

In [None]:
def distill(small_model, tokenizer, filtered_data, epochs=20, lr=5e-5):
    """Fine-tune a small model on filtered (correct) solutions from the large model."""
    optimizer = torch.optim.AdamW(small_model.parameters(), lr=lr)
    losses = []

    for epoch in range(epochs):
        small_model.train()
        epoch_loss = 0

        for prompt, completion in filtered_data:
            text = f"Question: {prompt}\n{completion}"
            tokens = tokenizer(text, return_tensors="pt", truncation=True,
                             max_length=256, padding="max_length")
            tokens = {k: v.to(device) for k, v in tokens.items()}
            labels = tokens["input_ids"].clone()
            labels[tokens["attention_mask"] == 0] = -100

            loss = small_model(**tokens, labels=labels).loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / max(len(filtered_data), 1)
        losses.append(avg_loss)
        if (epoch + 1) % 5 == 0:
            print(f"Distillation Epoch {epoch+1}, Loss: {avg_loss:.4f}")

    return losses

# Create a fresh small model for distillation
from transformers import AutoModelForCausalLM
small_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
small_model.resize_token_embeddings(len(tokenizer))

print("Distilling large model reasoning into small model...")
distill_losses = distill(small_model, tokenizer, filtered)
print("Distillation complete!")

### Visualization Checkpoint: Distillation Training

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(distill_losses, color='#9C27B0', linewidth=2)
plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.title('Distillation Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Your Turn -- TODO Exercises

### TODO 1: Vary Acceptance Rate

Modify the rejection sampling to use different acceptance thresholds and observe how dataset size changes.

In [None]:
# ============ TODO ============
# Run rejection_sampling with K = 2, 4, 8, 16, 32
# Plot: K vs acceptance rate and K vs filtered dataset size
#
# k_values = [2, 4, 8, 16, 32]
# acceptance_rates = []
# dataset_sizes = []
#
# for k in k_values:
#     filtered, stats = rejection_sampling(simulated_large_model, test_problems, K=k)
#     rate = stats['total_kept'] / stats['total_generated']
#     acceptance_rates.append(rate)
#     dataset_sizes.append(len(filtered))
#
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# ax1.plot(k_values, acceptance_rates, 'o-')
# ax1.set_xlabel('K (samples per problem)')
# ax1.set_ylabel('Acceptance Rate')
# ax2.plot(k_values, dataset_sizes, 'o-')
# ax2.set_xlabel('K'); ax2.set_ylabel('Filtered Dataset Size')
# plt.tight_layout(); plt.show()
# ==============================

### TODO 2: Analyze Reasoning Depth vs Problem Difficulty

Create problems of varying difficulty and measure how reasoning depth changes.

In [None]:
# ============ TODO ============
# Create easy, medium, and hard problems
# Generate solutions and measure average reasoning depth for each
#
# easy = [("What is 2 + 3?", "5"), ("What is 4 * 2?", "8")]
# medium = [("What is 15% of 200?", "30"), ("What is 125 - 67?", "58")]
# hard = [("If 3 boxes of 12 items each, remove 7. How many left?", "29")]
#
# Measure reasoning depth for each difficulty level
# ==============================

## 6. Putting It All Together

Let us run the complete distillation pipeline and compare the distilled model with the original.

In [None]:
# Compare the distilled model on test problems
print("=== Distilled Model Evaluation ===\n")
from transformers import AutoModelForCausalLM as AMLM

def generate_from_model(model, tokenizer, prompt, max_new=100):
    text = f"Question: {prompt}\n"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    model.eval()
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new,
                           do_sample=True, temperature=0.7,
                           pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(out[0], skip_special_tokens=False)[len(text):]

for prompt, truth in test_problems:
    comp = generate_from_model(small_model, tokenizer, prompt)
    r = compute_reward(comp, truth) if 'compute_reward' in dir() else 0
    print(f"Q: {prompt} (Truth: {truth})")
    print(f"A: {comp[:200]}")
    print(f"Correct: {'Yes' if r else 'Check manually'}")
    print("-" * 50)

## 7. Training and Results

The distillation pipeline demonstrates a key finding from DeepSeek-R1:
- A 7B model trained on distilled data from the 671B model can outperform the 70B model trained with RL directly
- The quality of the training data matters more than the training procedure
- Rejection sampling is a simple but powerful way to curate high-quality data

## 8. Final Output

We have built a complete understanding of reasoning model training:

| Component | Purpose |
|-----------|---------|
| Emergent behavior analysis | Detect self-verification, backtracking, depth |
| Rejection sampling | Generate many, keep only correct solutions |
| Distillation via SFT | Transfer reasoning from large to small model |

The complete pipeline: **Large model + RL -> Generate solutions -> Filter correct -> Train small model**

## 9. Reflection and Next Steps

### Think About This
1. Is distillation "cheating"? The small model never did RL itself.
2. What reasoning behaviors might a 671B model develop that a 7B model cannot?
3. Could we iterate: distill, then RL-train the distilled model, then distill again?

### Key Takeaway
Emergent reasoning behaviors (self-verification, backtracking, adaptive depth) arise naturally from RL with verifiable rewards. Distillation is a practical way to make these capabilities accessible in smaller models. The combination of RL for discovery and SFT for transfer is one of the most powerful paradigms in modern AI.