# Weak-to-Strong Generalization with In-Context Learning

This notebook demonstrates weak-to-strong (W2S) generalization using in-context learning on the TruthfulQA dataset.

## Overview

We investigate whether strong models can learn from weak model predictions via few-shot learning, testing:
- **Weak models**: Llama-3.1-8B, Llama-3.1-70B, Qwen-2.5-72B
- **Strong models**: Llama-3.1-8B, Llama-3.1-70B, Qwen-2.5-72B
- **Few-shot examples**: 0, 2, 4, 8, 16, 32, 64

## Experiments

1. **Baseline Evaluation**: Zero-shot performance of each model
2. **Weak-to-Strong Learning**: Train strong models with weak labels
3. **Gold Label Baseline**: Strong models with gold labels (upper bound)
4. **Analysis**: Response matching, PGR, and correlation analysis

In [None]:
# Setup and imports
import sys
from pathlib import Path
import json
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime

# Add code directory to path
project_root = Path.cwd().parent
code_dir = project_root / "code"
sys.path.insert(0, str(code_dir))

print(f"Project root: {project_root}")
print(f"Code directory: {code_dir}")

## 1. Compute Baselines

First, we evaluate each model's zero-shot performance on TruthfulQA.

In [None]:
# Import baseline evaluation code
from baseline_evals import baselines
from dataset_utils import load_truthful_qa
from api_utils import OpenRouterAPI

# Load dataset
print("Loading TruthfulQA dataset...")
test_set, train_set = load_truthful_qa(test_size=200, random_seed=42)
print(f"Test set size: {len(test_set)}")
print(f"Train set size: {len(train_set)}")

# Define models to evaluate
models = [
    "meta-llama/llama-3.1-8b-instruct",
    "meta-llama/llama-3.1-70b-instruct",
    "qwen/qwen-2.5-72b-instruct"
]

# Initialize API (update path to your API key)
api_key_path = project_root / "SECRETS" / "api.key"
api = OpenRouterAPI(api_key_path=str(api_key_path))

# Results directory
results_dir = project_root / "results" / "weak_labels"
results_dir.mkdir(parents=True, exist_ok=True)

print(f"\nResults will be saved to: {results_dir}")

In [None]:
# Run baseline evaluations
print("="*80)
print("Running baseline evaluations...")
print("="*80)

baseline_results = {}

for model in models:
    print(f"\nEvaluating {model}...")
    
    # Run baseline
    result = baselines.run_baseline(
        model=model,
        test_set=test_set,
        api=api,
        output_dir=results_dir
    )
    
    baseline_results[model] = result
    
    print(f"  Accuracy: {result['results']['accuracy']*100:.2f}%")
    print(f"  Cost: ${result['cost']['total_cost']:.4f}")

print("\n" + "="*80)
print("Baseline evaluations complete!")
print("="*80)

## 2. Gold Label Sweep

Evaluate models with varying numbers of gold-labeled few-shot examples (upper bound).

In [None]:
# Import gold label sweep code
from sweep_gold_label_counts import sweep_number_of_gold_labels

# Define shot counts to test
num_shots_list = [0, 2, 4, 8, 16, 32, 64]

print("="*80)
print("Running gold label sweep...")
print("="*80)
print(f"Models: {models}")
print(f"Shot counts: {num_shots_list}")
print()

In [None]:
# Run gold label sweep
gold_results = sweep_number_of_gold_labels.run_gold_sweep(
    models=models,
    num_shots_list=num_shots_list,
    test_set=test_set,
    train_set=train_set,
    api=api,
    output_dir=project_root / "results"
)

print("\n" + "="*80)
print("Gold label sweep complete!")
print("="*80)

## 3. Weak-to-Strong Learning

Train strong models using labels from weak models with varying numbers of few-shot examples.

In [None]:
# Import weak-to-strong sweep code
from sweep_weak_label_counts import sweep_number_of_weak_labels

# Define model pairs (weak -> strong)
model_pairs = [
    ("meta-llama/llama-3.1-8b-instruct", "meta-llama/llama-3.1-8b-instruct"),  # Self-supervision
    ("meta-llama/llama-3.1-8b-instruct", "meta-llama/llama-3.1-70b-instruct"),
    ("meta-llama/llama-3.1-8b-instruct", "qwen/qwen-2.5-72b-instruct"),
    ("meta-llama/llama-3.1-70b-instruct", "meta-llama/llama-3.1-70b-instruct"),  # Self-supervision
    ("meta-llama/llama-3.1-70b-instruct", "qwen/qwen-2.5-72b-instruct"),
    ("qwen/qwen-2.5-72b-instruct", "qwen/qwen-2.5-72b-instruct"),  # Self-supervision
]

print("="*80)
print("Running weak-to-strong sweep...")
print("="*80)
print(f"Model pairs: {len(model_pairs)}")
print(f"Shot counts: {num_shots_list}")
print()

In [None]:
# Run weak-to-strong sweep
w2s_results = sweep_number_of_weak_labels.run_weak_sweep(
    model_pairs=model_pairs,
    num_shots_list=num_shots_list,
    test_set=test_set,
    train_set=train_set,
    api=api,
    output_dir=project_root / "results",
    weak_labels_dir=results_dir
)

print("\n" + "="*80)
print("Weak-to-strong sweep complete!")
print("="*80)

## 4. Plotting: Accuracy vs. Few-Shot Examples

Visualize W2S performance compared to gold baselines.

In [None]:
# Import plotting code
from sweep_weak_label_counts import plot_weak_label_count_sweep

# Get latest results
results_dir_main = project_root / "results"
weak_sweep_files = list(results_dir_main.glob("weak_sweep_*.json"))
gold_sweep_files = list(results_dir_main.glob("gold_sweep_*.json"))

if weak_sweep_files:
    latest_weak = max(weak_sweep_files, key=lambda p: p.stat().st_mtime)
    print(f"Loading weak sweep results: {latest_weak.name}")
    with open(latest_weak, 'r') as f:
        weak_results = json.load(f)
else:
    print("No weak sweep results found")
    weak_results = None

if gold_sweep_files:
    latest_gold = max(gold_sweep_files, key=lambda p: p.stat().st_mtime)
    print(f"Loading gold sweep results: {latest_gold.name}")
    with open(latest_gold, 'r') as f:
        gold_results_loaded = json.load(f)
else:
    print("No gold sweep results found")
    gold_results_loaded = None

In [None]:
# Plot accuracy sweep
if weak_results:
    fig, ax = plot_weak_label_count_sweep.plot_weak_sweep(
        weak_results, 
        gold_results_loaded
    )
    plt.tight_layout()
    plt.show()
    
    # Also save
    plots_dir = project_root / "results" / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    fig.savefig(plots_dir / f"weak_sweep_accuracy_{timestamp}.png", dpi=300, bbox_inches='tight')
    print(f"Plot saved to {plots_dir}")

In [None]:
# Plot self-supervision cases
if weak_results:
    fig, ax = plot_weak_label_count_sweep.plot_self_supervision(
        weak_results,
        gold_results_loaded
    )
    plt.tight_layout()
    plt.show()
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    fig.savefig(plots_dir / f"weak_sweep_self_supervision_{timestamp}.png", dpi=300, bbox_inches='tight')
    print(f"Plot saved to {plots_dir}")

## 5. Performance Gap Recovered (PGR)

Calculate and visualize PGR: how much of the performance gap between weak and strong models is recovered by W2S learning.

In [None]:
# Plot PGR with zero-shot ceiling
if weak_results:
    fig, ax = plot_weak_label_count_sweep.plot_pgr(
        weak_results,
        baseline_results_dir=results_dir
    )
    if fig:
        plt.tight_layout()
        plt.show()
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        fig.savefig(plots_dir / f"pgr_zero_shot_ceiling_{timestamp}.png", dpi=300, bbox_inches='tight')
        print(f"Plot saved to {plots_dir}")

In [None]:
# Plot PGR with gold-label ceiling (more fair comparison)
if weak_results and gold_results_loaded:
    fig, ax = plot_weak_label_count_sweep.plot_pgr_with_gold_ceiling(
        weak_results,
        baseline_results_dir=results_dir,
        gold_results=gold_results_loaded
    )
    if fig:
        plt.tight_layout()
        plt.show()
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        fig.savefig(plots_dir / f"pgr_gold_ceiling_{timestamp}.png", dpi=300, bbox_inches='tight')
        print(f"Plot saved to {plots_dir}")

## 6. Response Matching Analysis

Analyze how W2S predictions align with weak and strong model predictions.

### Grid Plots: All W2S Pairs

Generate comprehensive 3×3 grid visualizations showing all model pair combinations.

In [None]:
# Run grid plotting script
print("Generating response matching grid plots...")
print("This will create 3 comprehensive grid figures:")
print("  1. Response matching (all questions)")
print("  2. Response matching (when weak wrong)")
print("  3. W2S correctness conditioned on strong performance")
print()

# Import and run grid plotting
import sys
sys.path.insert(0, str(code_dir / "plotting_and_analysis"))
from plot_all_response_matching_grid import main as plot_grid_main

plot_grid_main()

print("\nGrid plots saved to results/plots/response_matching/")

## 7. Correlation Analysis

Analyze correlations between weak, strong, and W2S predictions for all model pairs.

In [None]:
# Import correlation analysis
from analyze_w2s_correlations import analyze_w2s_pair, load_results

# Analyze all W2S pairs
print("="*80)
print("Correlation Analysis for All W2S Pairs")
print("="*80)

for weak_model, strong_model in model_pairs:
    weak_slug = weak_model.replace('/', '_')
    strong_slug = strong_model.replace('/', '_')
    
    w2s_pattern = f"w2s_{weak_slug}_to_{strong_slug}_*.json"
    w2s_files = list(results_dir.glob(w2s_pattern))
    
    if w2s_files:
        w2s_file = max(w2s_files, key=lambda p: p.stat().st_mtime)
        
        weak_baseline_file = results_dir / f"{weak_slug}_baseline.json"
        strong_baseline_file = results_dir / f"{strong_slug}_baseline.json"
        
        if weak_baseline_file.exists() and strong_baseline_file.exists():
            print(f"\n{weak_model} → {strong_model}")
            print("-" * 80)
            
            # Run analysis
            analyze_w2s_pair(
                w2s_file=w2s_file,
                weak_baseline_file=weak_baseline_file,
                strong_baseline_file=strong_baseline_file,
                ground_truth=[q.answer for q in test_set]
            )

print("\n" + "="*80)

## 8. Summary Statistics

In [None]:
# Print summary of all results
print("="*80)
print("EXPERIMENT SUMMARY")
print("="*80)

print("\nBaseline Accuracies (Zero-Shot):")
print("-" * 80)
for model, result in baseline_results.items():
    model_name = model.split('/')[-1]
    acc = result['results']['accuracy'] * 100
    print(f"  {model_name:30s}: {acc:6.2f}%")

if weak_results:
    print("\nWeak-to-Strong Results (64 shots):")
    print("-" * 80)
    for pair_key, pair_info in weak_results["model_pairs"].items():
        weak_name = pair_info["weak_model_name"]
        strong_name = pair_info["strong_model_name"]
        
        # Get 64-shot result
        for result in pair_info["results"]:
            if result["num_few_shot"] == 64:
                acc = result["accuracy"] * 100
                print(f"  {weak_name:15s} → {strong_name:15s}: {acc:6.2f}%")
                break

print("\n" + "="*80)
print("All results saved to:", project_root / "results")
print("All plots saved to:", project_root / "results" / "plots")
print("="*80)

## Conclusion

This notebook demonstrates:

1. **Baseline Performance**: Zero-shot capabilities of each model
2. **Gold Label Upper Bound**: Performance with ground truth few-shot examples
3. **Weak-to-Strong Learning**: Strong models can learn from weak labels via ICL
4. **Performance Gap Recovered**: Quantifies how much of the weak→strong gap is closed
5. **Response Matching**: Shows when W2S aligns with weak vs. strong predictions
6. **Conditional Analysis**: W2S ability to fix strong errors and preserve strong successes

Key findings can be extracted from the plots and correlation analyses above.