# GRPO Evaluation - Modular Version

**Purpose**: Evaluate trained GRPO checkpoints with proper optimization direction handling.

**Key Features**:
- ✅ **Checkpoint-first approach** - load any checkpoint to evaluate
- ✅ **Auto-detect optimization** - reads direction from checkpoint metadata
- ✅ **Independent cells** - no need to run training first
- ✅ **Correct metrics** - handles both minimization and maximization
- ✅ **Multiple modes** - single checkpoint, compare checkpoints, compare objectives

**Workflow**:
1. Select evaluation mode and checkpoints
2. Load checkpoint(s) and validate metadata
3. Generate or load test SCMs
4. Run evaluation with baselines
5. Generate visualizations with correct labels
6. Export results for analysis

## 1. Setup and Configuration

In [1]:
#!/usr/bin/env python3
"""
Cell 1: Import base components and configure environment

This cell sets up the evaluation environment.
"""

import sys
import os
from pathlib import Path
import logging
import json
import time
import subprocess
import shutil
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass

# Add project root to path
project_root = Path.cwd().parent if Path.cwd().name == "experiments" else Path.cwd()
sys.path.insert(0, str(project_root))

# Import base components
from scripts.notebooks.base_components import (
    NotebookError, CheckpointManager, SCMGenerator,
    OptimizationConfig, CheckpointMetadata, validate_environment,
    format_results_summary
)
from scripts.notebooks.config_templates import create_evaluation_config

# Core imports
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as random
import pyrsistent as pyr

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Image, display
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
)
logger = logging.getLogger(__name__)

# Validate environment
try:
    env_info = validate_environment()
    print("✅ Environment Setup Complete")
    print(f"📁 Project root: {project_root}")
    print(f"🔧 JAX devices: {env_info['jax_devices']}")
    print(f"📅 Date: {env_info['timestamp']}")
except Exception as e:
    raise NotebookError(f"Environment validation failed: {e}")

# Initialize checkpoint manager
checkpoint_dir = project_root / "checkpoints" / "grpo_training"
checkpoint_manager = CheckpointManager(checkpoint_dir)
print(f"\n📁 Checkpoint directory: {checkpoint_dir}")

INFO:2025-07-23 20:36:17,380:jax._src.xla_bridge:749: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
[2025-07-23 20:36:17,380][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


✅ Environment Setup Complete
📁 Project root: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt
🔧 JAX devices: [CpuDevice(id=0)]
📅 Date: 2025-07-23 20:36:17

📁 Checkpoint directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training


In [2]:
"""
Cell 2: Configure evaluation parameters

Set evaluation mode and parameters for the run.
"""

# EVALUATION MODE SELECTION
EVALUATION_MODE = "SINGLE_CHECKPOINT"  # Options: "SINGLE_CHECKPOINT", "COMPARE_CHECKPOINTS", "COMPARE_OBJECTIVES"

# Configuration for different modes
if EVALUATION_MODE == "SINGLE_CHECKPOINT":
    # Evaluate one checkpoint against baselines
    NUM_TEST_SCMS = 10
    RUNS_PER_METHOD = 3
    INTERVENTION_BUDGET = 10
    
elif EVALUATION_MODE == "COMPARE_CHECKPOINTS":
    # Compare multiple checkpoints
    COMPARISON_COUNT = 3  # Number of checkpoints to compare
    NUM_TEST_SCMS = 5  # Fewer SCMs for faster comparison
    RUNS_PER_METHOD = 2
    INTERVENTION_BUDGET = 8
    
elif EVALUATION_MODE == "COMPARE_OBJECTIVES":
    # Compare minimization vs maximization
    NUM_TEST_SCMS = 8
    RUNS_PER_METHOD = 3
    INTERVENTION_BUDGET = 10

else:
    raise NotebookError(f"Unknown evaluation mode: {EVALUATION_MODE}")

# General settings
RANDOM_SEED = 42

print("🎯 Evaluation Configuration")
print("=" * 50)
print(f"Mode: {EVALUATION_MODE}")
print(f"Test SCMs: {NUM_TEST_SCMS}")
print(f"Runs per method: {RUNS_PER_METHOD}")
print(f"Intervention budget: {INTERVENTION_BUDGET}")
print(f"Random seed: {RANDOM_SEED}")

if EVALUATION_MODE == "COMPARE_CHECKPOINTS":
    print(f"Checkpoints to compare: {COMPARISON_COUNT}")


🎯 Evaluation Configuration
Mode: SINGLE_CHECKPOINT
Test SCMs: 10
Runs per method: 3
Intervention budget: 10
Random seed: 42


## 2. Evaluation Mode and Checkpoint Selection

In [3]:
"""
Cell 4: Select checkpoints for evaluation

This cell handles checkpoint selection based on the evaluation mode.
Uses intelligent discovery instead of hardcoded names.
"""

print("📋 Intelligent Checkpoint Selection")
print("=" * 50)

# Get available checkpoints
try:
    available_checkpoints = checkpoint_manager.list_checkpoints()
    if not available_checkpoints:
        raise NotebookError("No checkpoints found")
        
    print(f"Found {len(available_checkpoints)} total checkpoints")
    
    # Show usable checkpoints by direction
    usable_minimize = checkpoint_manager.find_usable_checkpoints('MINIMIZE')
    usable_maximize = checkpoint_manager.find_usable_checkpoints('MAXIMIZE')
    
    print(f"Usable checkpoints:")
    print(f"  MINIMIZE: {len(usable_minimize)}")
    print(f"  MAXIMIZE: {len(usable_maximize)}")
    print(f"  Total usable: {len(usable_minimize) + len(usable_maximize)}")
        
except Exception as e:
    raise NotebookError(f"Failed to analyze checkpoints: {e}")

# SELECT CHECKPOINTS BASED ON MODE
selected_checkpoints = []

if EVALUATION_MODE == "SINGLE_CHECKPOINT":
    # Find best MINIMIZE checkpoint (preferred for comparison with PARENT_SCALE)
    best_checkpoint = checkpoint_manager.find_best_checkpoint({
        'optimization_direction': 'MINIMIZE',
        'training_mode': 'QUICK'
    })
    
    if best_checkpoint:
        selected_checkpoints = [best_checkpoint]
        print(f"🎯 Selected best MINIMIZE checkpoint: {best_checkpoint.name}")
        
        # Validate the selected checkpoint
        validation = checkpoint_manager.validate_checkpoint(best_checkpoint)
        if validation['is_valid']:
            print(f"  ✅ Checkpoint is valid and ready for evaluation")
        else:
            print(f"  ⚠️ Checkpoint issues: {validation['issues']}")
    else:
        # Fallback: try any usable checkpoint
        all_usable = checkpoint_manager.find_usable_checkpoints()
        if all_usable:
            selected_checkpoints = [all_usable[0]]
            print(f"🎯 No MINIMIZE checkpoint found, using: {all_usable[0].name}")
            print(f"  Optimization: {all_usable[0].optimization_config.direction}")
        else:
            raise NotebookError("No usable checkpoints found. Please ensure checkpoints have both metadata.json and checkpoint.pkl files.")

elif EVALUATION_MODE == "COMPARE_CHECKPOINTS":
    # Get multiple usable checkpoints
    all_usable = checkpoint_manager.find_usable_checkpoints()
    comparison_count = min(COMPARISON_COUNT, len(all_usable))
    selected_checkpoints = all_usable[:comparison_count]
    print(f"📊 Selected {comparison_count} checkpoints for comparison")

elif EVALUATION_MODE == "COMPARE_OBJECTIVES":
    # Get best from each optimization direction
    best_minimize = checkpoint_manager.find_best_checkpoint({'optimization_direction': 'MINIMIZE'})
    best_maximize = checkpoint_manager.find_best_checkpoint({'optimization_direction': 'MAXIMIZE'})
    
    selected_checkpoints = []
    if best_minimize:
        selected_checkpoints.append(best_minimize)
    if best_maximize:
        selected_checkpoints.append(best_maximize)
    
    if not selected_checkpoints:
        raise NotebookError("Need checkpoints from both MINIMIZE and MAXIMIZE directions for objective comparison")
    
    print(f"🔄 Selected checkpoints for objective comparison:")
    for ckpt in selected_checkpoints:
        print(f"  - {ckpt.name} ({ckpt.optimization_config.direction})")

else:
    raise NotebookError(f"Unknown evaluation mode: {EVALUATION_MODE}")

# Final validation
if not selected_checkpoints:
    raise NotebookError("No checkpoints selected for evaluation")

print(f"✅ Final Selection ({len(selected_checkpoints)} checkpoint(s)):")
for i, checkpoint in enumerate(selected_checkpoints, 1):
    print(f"{i}. {checkpoint.name}")
    print(f"     Optimization: {checkpoint.optimization_config.direction}")
    print(f"     Training mode: {checkpoint.training_config.get('mode', 'unknown')}")
    print(f"     Path: {checkpoint.path}")
    
    # Final validation
    validation = checkpoint_manager.validate_checkpoint(checkpoint)
    if validation['is_valid']:
        print(f"     Status: ✅ Ready for evaluation")
    else:
        print(f"     Status: ❌ Issues found: {validation['issues']}")
        raise NotebookError(f"Selected checkpoint {checkpoint.name} has validation issues: {validation['issues']}")

print(f"🚀 All selected checkpoints validated and ready for evaluation\!")

  print(f"🚀 All selected checkpoints validated and ready for evaluation\!")


📋 Intelligent Checkpoint Selection
Found 4 total checkpoints
Usable checkpoints:
  MINIMIZE: 1
  MAXIMIZE: 2
  Total usable: 3
🎯 Selected best MINIMIZE checkpoint: grpo_quick_minimize_20250723_101252_fixed
  ✅ Checkpoint is valid and ready for evaluation
✅ Final Selection (1 checkpoint(s)):
1. grpo_quick_minimize_20250723_101252_fixed
     Optimization: MINIMIZE
     Training mode: QUICK
     Path: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_quick_minimize_20250723_101252_fixed
     Status: ✅ Ready for evaluation
🚀 All selected checkpoints validated and ready for evaluation\!


## 3. Load and Validate Checkpoints

In [4]:
"""
Cell 3: Load checkpoint metadata and validate

This cell loads the selected checkpoints and validates their metadata.
"""

print("📥 Loading Checkpoint Metadata")
print("=" * 60)

# Store loaded checkpoint info
loaded_checkpoints = {}

for ckpt in selected_checkpoints:
    print(f"\nLoading: {ckpt.name}")
    try:
        # For now, we're using the metadata we already have
        # In production, this would load the actual model parameters
        loaded_checkpoints[ckpt.name] = {
            'metadata': ckpt,
            'optimization_config': ckpt.optimization_config,
            'training_config': ckpt.training_config,
            'model_params': None  # TODO: Load actual model parameters
        }
        
        print(f"  ✓ Optimization: {ckpt.optimization_config.direction}")
        print(f"  ✓ Training mode: {ckpt.training_config.get('mode', 'unknown')}")
        print(f"  ✓ Episodes completed: {ckpt.training_results.get('episodes_completed', 'unknown')}")
        print(f"  ✓ Duration: {ckpt.training_results.get('duration_minutes', 0):.1f} minutes")
        
        # Show reward weights if available
        if 'reward_weights' in ckpt.training_config:
            weights = ckpt.training_config['reward_weights']
            print(f"  ✓ Reward weights: opt={weights.get('optimization', 0):.1f}, "
                  f"struct={weights.get('discovery', 0):.1f}, "
                  f"eff={weights.get('efficiency', 0):.1f}")
                  
    except Exception as e:
        print(f"  ✗ Failed to load: {e}")
        raise NotebookError(f"Failed to load checkpoint {ckpt.name}: {e}")

print(f"\n✅ Loaded {len(loaded_checkpoints)} checkpoint(s) successfully")

# Check optimization compatibility for comparison modes
if EVALUATION_MODE == "COMPARE_OBJECTIVES":
    directions = [ckpt.optimization_config.direction for ckpt in selected_checkpoints]
    if len(set(directions)) == 1:
        print(f"\n⚠️ Warning: All checkpoints have same optimization direction: {directions[0]}")
        print("   Objective comparison may not be meaningful.")
    else:
        print(f"\n✅ Comparing optimization directions: {set(directions)}")

📥 Loading Checkpoint Metadata

Loading: grpo_quick_minimize_20250723_101252_fixed
  ✓ Optimization: MINIMIZE
  ✓ Training mode: QUICK
  ✓ Episodes completed: unknown
  ✓ Duration: 5.0 minutes
  ✓ Reward weights: opt=0.8, struct=0.1, eff=0.1

✅ Loaded 1 checkpoint(s) successfully


## 4. Generate Test SCMs

In [5]:
"""
Cell 4: Generate test SCMs for evaluation

Create a balanced set of test SCMs different from training.
"""

print("🔬 Generating Test SCMs")
print("=" * 60)

# Initialize SCM generator
scm_generator = SCMGenerator()

# Generate test SCMs with different seed than training
test_seed = RANDOM_SEED + 1000  # Ensure different from training

try:
    test_scms, test_metadata = scm_generator.generate_balanced_scms(
        num_scms=NUM_TEST_SCMS,
        variable_range=(3, 6),
        structure_types=['fork', 'chain', 'collider', 'mixed'],
        seed=test_seed
    )
    
    print(f"\n✅ Generated {len(test_scms)} test SCMs")
    
    # Analyze distribution
    distribution = scm_generator._summarize_distribution(test_metadata)
    print(f"\n📊 Test Set Distribution:")
    print(f"  Structure types: {distribution['structure_types']}")
    print(f"  Variable counts: {distribution['variable_counts']}")
    
    # Save test SCM metadata
    test_scm_path = project_root / "results" / "test_scms" / f"test_scms_{test_seed}.json"
    test_scm_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(test_scm_path, 'w') as f:
        json.dump({
            'metadata': test_metadata,
            'config': {
                'num_scms': len(test_scms),
                'seed': test_seed,
                'variable_range': [3, 6],
                'structure_types': ['fork', 'chain', 'collider', 'mixed']
            },
            'generated_at': datetime.now().isoformat()
        }, f, indent=2)
    
    print(f"\n💾 Saved test SCM metadata to: {test_scm_path}")
    
except Exception as e:
    raise NotebookError(f"Failed to generate test SCMs: {e}")

🔬 Generating Test SCMs


[2025-07-23 20:36:18,135][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-23 20:36:18,135][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 3 vars, 2 edges, target=X1
[2025-07-23 20:36:18,150][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X2'
[2025-07-23 20:36:18,151][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 4 vars, 3 edges, target=X2
[2025-07-23 20:36:18,169][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 5 variables, 4 edges, target='X2'
[2025-07-23 20:36:18,170][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 5 vars, 4 edges, target=X2
[2025-07-23 20:36:18,175][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 6 variables, 5 edges, target='X3'
[2025-07-23 20:36:18,175][causal_bayes_opt.experiments.variable_scm_factory]


✅ Generated 10 test SCMs

📊 Test Set Distribution:
  Structure types: {'fork': 4, 'chain': 4, 'collider': 2}
  Variable counts: {3: 3, 4: 3, 5: 2, 6: 2}

💾 Saved test SCM metadata to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/test_scms/test_scms_1042.json


## 5. Run Evaluation

In [6]:
"""
Cell 5: Run evaluation with proper optimization handling

Evaluate checkpoints against baselines with correct metrics.
"""

print("🏁 Running Evaluation")
print("=" * 60)

# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = project_root / "results" / f"evaluation_{EVALUATION_MODE.lower()}_{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"📁 Output directory: {output_dir}")
print(f"\nEvaluation parameters:")
print(f"  Mode: {EVALUATION_MODE}")
print(f"  Test SCMs: {NUM_TEST_SCMS}")
print(f"  Runs per method: {RUNS_PER_METHOD}")
print(f"  Intervention budget: {INTERVENTION_BUDGET}")

# Store results
evaluation_results = {}
evaluation_start_time = time.time()

# Run evaluation for each checkpoint
for ckpt in selected_checkpoints:
    print(f"\n{'='*60}")
    print(f"Evaluating: {ckpt.name}")
    print(f"Optimization: {ckpt.optimization_config.direction}")
    
    # Create checkpoint-specific output directory
    ckpt_output_dir = output_dir / ckpt.name
    ckpt_output_dir.mkdir(exist_ok=True)
    
    # Run evaluation using unified pipeline
    cmd = [
        "poetry", "run", "python",
        str(project_root / "scripts" / "unified_pipeline.py"),
        f"--checkpoint={ckpt.path}",
        f"--num-scms={min(3, NUM_TEST_SCMS)}",  # Use subset for speed
        f"--runs-per-method={RUNS_PER_METHOD}",
        f"--intervention-budget={INTERVENTION_BUDGET}",
        f"--output-dir={ckpt_output_dir}",
        f"--optimization-direction={ckpt.optimization_config.direction}"
    ]
    
    print(f"\nRunning evaluation command...")
    print(f"  (This may take a few minutes)")
    
    start_time = time.time()
    
    try:
        # Run evaluation
        result = subprocess.run(cmd, capture_output=True, text=True)
        
        if result.returncode == 0:
            duration = (time.time() - start_time) / 60
            print(f"\n✅ Evaluation completed in {duration:.1f} minutes")
            
            # Load results
            results_file = ckpt_output_dir / "comparison_results.json"
            if not results_file.exists():
                # Try to find results in alternative locations
                alt_results = list(ckpt_output_dir.glob("*results*.json"))
                if alt_results:
                    results_file = alt_results[0]
            
            if results_file.exists():
                with open(results_file, 'r') as f:
                    results = json.load(f)
                
                # Store results with optimization info
                evaluation_results[ckpt.name] = {
                    'results': results,
                    'optimization_direction': ckpt.optimization_config.direction,
                    'checkpoint_metadata': ckpt.to_dict(),
                    'duration_minutes': duration
                }
                
                # Quick summary
                if 'statistical_analysis' in results:
                    print("\n📊 Quick Summary:")
                    summary = results['statistical_analysis'].get('summary_statistics', {})
                    for method, stats in list(summary.items())[:3]:  # Show top 3
                        mean_val = stats.get('target_improvement_mean', 0)
                        # Convert if needed
                        if ckpt.optimization_config.is_minimizing:
                            display_val = -mean_val  # Show actual minimized value
                        else:
                            display_val = mean_val
                        print(f"  {method}: {ckpt.optimization_config.format_improvement(display_val)}")
            else:
                print(f"⚠️ No results file found at {results_file}")
                evaluation_results[ckpt.name] = {'error': 'No results file found'}
                
        else:
            print(f"\n❌ Evaluation failed with return code {result.returncode}")
            print(f"Error: {result.stderr[:500]}...")
            evaluation_results[ckpt.name] = {'error': result.stderr}
            
    except Exception as e:
        print(f"\n❌ Evaluation failed with exception: {e}")
        evaluation_results[ckpt.name] = {'error': str(e)}

total_duration = (time.time() - evaluation_start_time) / 60

print(f"\n{'='*60}")
print(f"✅ All evaluations complete!")
print(f"  Total duration: {total_duration:.1f} minutes")
print(f"  Successful: {sum(1 for r in evaluation_results.values() if 'error' not in r)}/{len(evaluation_results)}")

🏁 Running Evaluation
📁 Output directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_single_checkpoint_20250723_203618

Evaluation parameters:
  Mode: SINGLE_CHECKPOINT
  Test SCMs: 10
  Runs per method: 3
  Intervention budget: 10

Evaluating: grpo_quick_minimize_20250723_101252_fixed
Optimization: MINIMIZE

Running evaluation command...
  (This may take a few minutes)

✅ Evaluation completed in 0.0 minutes
⚠️ No results file found at /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_single_checkpoint_20250723_203618/grpo_quick_minimize_20250723_101252_fixed/comparison_results.json

✅ All evaluations complete!
  Total duration: 0.0 minutes
  Successful: 0/1


## 6. Generate Visualizations

In [7]:
"""
Cell 6: Generate visualizations with proper optimization labels

Create plots that correctly show minimization vs maximization.
"""

print("📊 Generating Visualizations")
print("=" * 60)

# Check if we have results to visualize
valid_results = {k: v for k, v in evaluation_results.items() if 'error' not in v}

if not valid_results:
    print("❌ No valid results to visualize")
else:
    # Extract method performance
    method_performance = {}
    
    for ckpt_name, eval_data in valid_results.items():
        results = eval_data['results']
        opt_direction = eval_data['optimization_direction']
        opt_config = OptimizationConfig(direction=opt_direction)
        
        if 'statistical_analysis' in results and 'summary_statistics' in results['statistical_analysis']:
            for method, stats in results['statistical_analysis']['summary_statistics'].items():
                key = f"{ckpt_name}_{method}"
                
                # Get raw value
                raw_value = stats.get('target_improvement_mean', 0.0)
                
                # Convert to actual value for minimization
                if opt_config.is_minimizing:
                    actual_value = -raw_value
                else:
                    actual_value = raw_value
                
                method_performance[key] = {
                    'checkpoint': ckpt_name,
                    'method': method,
                    'raw_value': raw_value,
                    'actual_value': actual_value,
                    'optimization_direction': opt_direction,
                    'display_value': opt_config.format_improvement(actual_value)
                }
    
    # Create visualizations based on mode
    if EVALUATION_MODE == "SINGLE_CHECKPOINT":
        # Single checkpoint visualization
        ckpt_name = selected_checkpoints[0].name
        opt_config = selected_checkpoints[0].optimization_config
        
        # Get methods for this checkpoint
        checkpoint_methods = {k: v for k, v in method_performance.items() if v['checkpoint'] == ckpt_name}
        
        if checkpoint_methods:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
            
            # Method performance bars
            methods = [v['method'] for v in checkpoint_methods.values()]
            values = [v['actual_value'] for v in checkpoint_methods.values()]
            
            # Color based on whether it's the trained policy
            colors = ['red' if 'Policy' in m or 'Trained' in m else 'lightblue' for m in methods]
            
            bars = ax1.bar(range(len(methods)), values, color=colors, alpha=0.7, edgecolor='black')
            ax1.set_xticks(range(len(methods)))
            ax1.set_xticklabels([m.replace(' + ', '\n') for m in methods], rotation=45, ha='right')
            
            # Set appropriate y-label based on optimization direction
            if opt_config.is_minimizing:
                ax1.set_ylabel('Target Value (Lower is Better)')
                ax1.invert_yaxis()  # Invert so lower values appear higher
            else:
                ax1.set_ylabel('Target Value (Higher is Better)')
            
            ax1.set_title(f'Method Performance - {ckpt_name}')
            ax1.grid(True, alpha=0.3, axis='y')
            
            # Add value labels
            for bar, val, method_data in zip(bars, values, checkpoint_methods.values()):
                ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        method_data['display_value'], ha='center', va='bottom', fontsize=8)
            
            # Rankings (best to worst)
            if opt_config.is_minimizing:
                sorted_methods = sorted(zip(methods, values), key=lambda x: x[1])  # Lower is better
            else:
                sorted_methods = sorted(zip(methods, values), key=lambda x: x[1], reverse=True)  # Higher is better
            
            y_pos = range(len(sorted_methods))
            
            for i, (method, score) in enumerate(sorted_methods):
                color = 'red' if 'Policy' in method or 'Trained' in method else 'lightblue'
                ax2.barh(i, abs(score), color=color, alpha=0.7, edgecolor='black')
                display_text = opt_config.format_improvement(score)
                ax2.text(abs(score) + 0.02, i, display_text, va='center', fontsize=8)
            
            ax2.set_yticks(y_pos)
            ax2.set_yticklabels([m[0] for m in sorted_methods])
            ax2.set_xlabel('Target Value (Absolute)')
            ax2.set_title(f'Method Rankings ({"Best to Worst" if opt_config.is_minimizing else "Best to Worst"})')
            ax2.grid(True, alpha=0.3, axis='x')
            
            plt.suptitle(f'Evaluation Results - {ckpt_name} ({opt_config.direction})', fontsize=14)
            plt.tight_layout()
            
            plot_path = output_dir / "single_checkpoint_results.png"
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            plt.show()
            
            print(f"\n💾 Saved plot: {plot_path}")
            
    elif EVALUATION_MODE in ["COMPARE_CHECKPOINTS", "COMPARE_OBJECTIVES"]:
        # Multi-checkpoint comparison
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        axes = axes.flatten()
        
        # Group by checkpoint
        checkpoints_data = {}
        for key, perf in method_performance.items():
            ckpt = perf['checkpoint']
            method = perf['method']
            if ckpt not in checkpoints_data:
                checkpoints_data[ckpt] = {}
            checkpoints_data[ckpt][method] = perf
        
        # 1. Side-by-side comparison
        ax1 = axes[0]
        all_methods = set()
        for methods in checkpoints_data.values():
            all_methods.update(methods.keys())
        all_methods = sorted(list(all_methods))
        
        x = np.arange(len(all_methods))
        width = 0.8 / len(checkpoints_data)
        
        for i, (ckpt_name, methods_data) in enumerate(checkpoints_data.items()):
            values = [methods_data.get(m, {}).get('actual_value', 0) for m in all_methods]
            offset = (i - len(checkpoints_data)/2 + 0.5) * width
            
            # Get checkpoint info
            ckpt_info = next((c for c in selected_checkpoints if c.name == ckpt_name), None)
            if ckpt_info:
                label = f"{ckpt_info.optimization_config.direction}"
            else:
                label = ckpt_name
            
            bars = ax1.bar(x + offset, values, width, label=label, alpha=0.7)
        
        ax1.set_xlabel('Method')
        ax1.set_ylabel('Target Value')
        ax1.set_title('Method Performance Across Checkpoints')
        ax1.set_xticks(x)
        ax1.set_xticklabels([m.replace(' + ', '\n') for m in all_methods], rotation=45, ha='right')
        ax1.legend()
        ax1.grid(True, alpha=0.3, axis='y')
        
        # 2. Optimization direction comparison
        if EVALUATION_MODE == "COMPARE_OBJECTIVES":
            ax2 = axes[1]
            
            # Group by optimization direction
            direction_performance = {'MINIMIZE': [], 'MAXIMIZE': []}
            
            for ckpt_name, methods_data in checkpoints_data.items():
                # Get best trained policy performance
                trained_methods = [m for m in methods_data if 'Policy' in m or 'Trained' in m]
                if trained_methods:
                    best_method = trained_methods[0]
                    perf = methods_data[best_method]
                    direction = perf['optimization_direction']
                    if direction in direction_performance:
                        direction_performance[direction].append({
                            'checkpoint': ckpt_name,
                            'value': perf['actual_value'],
                            'display': perf['display_value']
                        })
            
            # Plot comparison
            if direction_performance['MINIMIZE'] and direction_performance['MAXIMIZE']:
                labels = ['Minimization', 'Maximization']
                min_val = np.mean([p['value'] for p in direction_performance['MINIMIZE']])
                max_val = np.mean([p['value'] for p in direction_performance['MAXIMIZE']])
                
                bars = ax2.bar(labels, [min_val, max_val], color=['blue', 'red'], alpha=0.7)
                
                # Add value labels
                ax2.text(0, min_val + 0.01, f"{min_val:.3f}\n(Lower better)", ha='center')
                ax2.text(1, max_val + 0.01, f"{max_val:.3f}\n(Higher better)", ha='center')
                
                ax2.set_ylabel('Average Target Value')
                ax2.set_title('Optimization Direction Comparison')
                ax2.grid(True, alpha=0.3, axis='y')
        
        # 3. Summary statistics
        ax_summary = axes[3]
        ax_summary.axis('off')
        
        summary_text = f"Evaluation Summary\n{'='*30}\n\n"
        summary_text += f"Mode: {EVALUATION_MODE}\n"
        summary_text += f"Checkpoints evaluated: {len(selected_checkpoints)}\n"
        summary_text += f"Test SCMs: {NUM_TEST_SCMS}\n"
        summary_text += f"Runs per method: {RUNS_PER_METHOD}\n\n"
        
        for ckpt in selected_checkpoints:
            if ckpt.name in checkpoints_data:
                summary_text += f"\n{ckpt.name}:\n"
                summary_text += f"  Optimization: {ckpt.optimization_config.direction}\n"
                
                # Get best trained method
                methods = checkpoints_data[ckpt.name]
                trained = [m for m in methods.values() if 'Policy' in m['method'] or 'Trained' in m['method']]
                if trained:
                    best = trained[0]
                    summary_text += f"  Best policy: {best['display_value']}\n"
        
        ax_summary.text(0.05, 0.95, summary_text, transform=ax_summary.transAxes, fontsize=10,
                       verticalalignment='top', family='monospace',
                       bbox=dict(boxstyle='round,pad=1', facecolor='lightgray', alpha=0.8))
        
        plt.suptitle('Multi-Checkpoint Comparison Results', fontsize=16)
        plt.tight_layout()
        
        plot_path = output_dir / "checkpoint_comparison_results.png"
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"\n💾 Saved plot: {plot_path}")

print("\n✅ Visualization complete!")

📊 Generating Visualizations
❌ No valid results to visualize

✅ Visualization complete!


## 7. Export Results and Summary

In [8]:
"""
Cell 7: Export results and generate summary report

Save all results for further analysis.
"""

print("💾 Exporting Results")
print("=" * 60)

# Prepare export data
export_data = {
    'evaluation_config': {
        'mode': EVALUATION_MODE,
        'num_test_scms': NUM_TEST_SCMS,
        'runs_per_method': RUNS_PER_METHOD,
        'intervention_budget': INTERVENTION_BUDGET,
        'random_seed': RANDOM_SEED,
        'timestamp': timestamp
    },
    'checkpoints_evaluated': [
        {
            'name': ckpt.name,
            'optimization_direction': ckpt.optimization_config.direction,
            'path': str(ckpt.path),
            'training_mode': ckpt.training_config.get('mode', 'unknown')
        }
        for ckpt in selected_checkpoints
    ],
    'results': evaluation_results,
    'method_performance': method_performance if 'method_performance' in locals() else {},
    'duration_minutes': total_duration if 'total_duration' in locals() else 0
}

# Save JSON results
json_path = output_dir / "evaluation_results.json"

with open(json_path, 'w') as f:
    json.dump(export_data, f, indent=2)

print(f"✅ Results saved to: {json_path}")

# Generate text summary
summary_path = output_dir / "evaluation_summary.txt"
with open(summary_path, 'w') as f:
    f.write(f"GRPO Evaluation Summary\n")
    f.write(f"="*60 + "\n\n")
    
    f.write(f"Evaluation Mode: {EVALUATION_MODE}\n")
    f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M')}\n")
    f.write(f"Duration: {total_duration:.1f} minutes\n\n" if 'total_duration' in locals() else "\n")
    
    f.write(f"Configuration:\n")
    f.write(f"  Test SCMs: {NUM_TEST_SCMS}\n")
    f.write(f"  Runs per method: {RUNS_PER_METHOD}\n")
    f.write(f"  Intervention budget: {INTERVENTION_BUDGET}\n")
    f.write(f"  Random seed: {RANDOM_SEED}\n\n")
    
    f.write(f"Checkpoints Evaluated:\n")
    for ckpt in selected_checkpoints:
        f.write(f"  - {ckpt.name} ({ckpt.optimization_config.direction})\n")
    
    f.write(f"\nKey Findings:\n")
    f.write("-" * 60 + "\n")
    
    # Summarize results
    if 'method_performance' in locals() and method_performance:
        # Group by checkpoint
        for ckpt in selected_checkpoints:
            f.write(f"\n{ckpt.name} ({ckpt.optimization_config.direction}):\n")
            
            # Get methods for this checkpoint
            ckpt_methods = {k: v for k, v in method_performance.items() if v['checkpoint'] == ckpt.name}
            
            if ckpt_methods:
                # Sort by performance
                if ckpt.optimization_config.is_minimizing:
                    sorted_methods = sorted(ckpt_methods.items(), key=lambda x: x[1]['actual_value'])
                else:
                    sorted_methods = sorted(ckpt_methods.items(), key=lambda x: x[1]['actual_value'], reverse=True)
                
                for rank, (key, perf) in enumerate(sorted_methods[:5], 1):  # Top 5
                    marker = "*" if 'Policy' in perf['method'] or 'Trained' in perf['method'] else " "
                    f.write(f"  {rank}. {marker} {perf['method']}: {perf['display_value']}\n")
    
    if EVALUATION_MODE == "COMPARE_OBJECTIVES":
        f.write(f"\n\nOptimization Direction Insights:\n")
        f.write("-" * 60 + "\n")
        f.write("This evaluation compared minimization vs maximization objectives.\n")
        f.write("Key observation: Different optimization directions require different\n")
        f.write("reward structures and may lead to different exploration strategies.\n")
    
    f.write(f"\n\nOutput Files:\n")
    f.write(f"  Results JSON: {json_path.name}\n")
    f.write(f"  Summary: {summary_path.name}\n")
    
    # List any plots
    plots = list(output_dir.glob("*.png"))
    for plot in plots:
        f.write(f"  Plot: {plot.name}\n")

print(f"✅ Summary saved to: {summary_path}")

# Display final summary
print(f"\n🎉 Evaluation Complete!")
print(f"Mode: {EVALUATION_MODE}")
print(f"Checkpoints: {len(selected_checkpoints)}")
print(f"Duration: {total_duration:.1f} minutes" if 'total_duration' in locals() else "")
print(f"\nOutput directory: {output_dir}")

print(f"\n📊 Key Insights:")
if EVALUATION_MODE == "SINGLE_CHECKPOINT":
    print(f"- Evaluated {selected_checkpoints[0].name} with {selected_checkpoints[0].optimization_config.direction} objective")
    print(f"- Compare against baselines to see if training improved performance")
elif EVALUATION_MODE == "COMPARE_OBJECTIVES":
    print(f"- Compared {len(selected_checkpoints)} checkpoints with different optimization directions")
    print(f"- Minimization policies optimize for lower target values (like PARENT_SCALE)")
    print(f"- Maximization policies optimize for higher target values")
    print(f"- Check plots to see which direction performs better for your use case")

print(f"\n💡 Next steps:")
print(f"1. Review the plots in {output_dir}")
print(f"2. Check the detailed results in {json_path.name}")
print(f"3. Read the summary report for key findings")

💾 Exporting Results
✅ Results saved to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_single_checkpoint_20250723_203618/evaluation_results.json
✅ Summary saved to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_single_checkpoint_20250723_203618/evaluation_summary.txt

🎉 Evaluation Complete!
Mode: SINGLE_CHECKPOINT
Checkpoints: 1
Duration: 0.0 minutes

Output directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_single_checkpoint_20250723_203618

📊 Key Insights:
- Evaluated grpo_quick_minimize_20250723_101252_fixed with MINIMIZE objective
- Compare against baselines to see if training improved performance

💡 Next steps:
1. Review the plots in /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_single_checkpoint_20250723_203618
2. Check the detailed results in evaluation_results.json
3. Read the summary report 

## Summary

**What we've accomplished:**
1. ✅ Loaded checkpoints with optimization direction metadata
2. ✅ Ran evaluation with proper metric handling
3. ✅ Generated visualizations that show min/max correctly
4. ✅ Exported comprehensive results

**Key improvements over original notebook:**
- Checkpoint-first approach - no need to run training
- Auto-detects optimization direction from metadata
- Correctly displays minimization vs maximization results
- All cells are independent and can be re-run
- No silent failures - explicit errors throughout

**Understanding the results:**
- For MINIMIZE checkpoints: Lower target values are better
- For MAXIMIZE checkpoints: Higher target values are better
- The plots show "(↓ better)" or "(↑ better)" to clarify
- PARENT_SCALE baseline uses minimization by default