# GRPO Evaluation - Modular Version

**Purpose**: Evaluate trained GRPO checkpoints with proper optimization direction handling.

**Key Features**:
- ‚úÖ **Checkpoint-first approach** - load any checkpoint to evaluate
- ‚úÖ **Auto-detect optimization** - reads direction from checkpoint metadata
- ‚úÖ **Independent cells** - no need to run training first
- ‚úÖ **Correct metrics** - handles both minimization and maximization
- ‚úÖ **Multiple modes** - single checkpoint, compare checkpoints, compare objectives

**Workflow**:
1. Select evaluation mode and checkpoints
2. Load checkpoint(s) and validate metadata
3. Generate or load test SCMs
4. Run evaluation with baselines
5. Generate visualizations with correct labels
6. Export results for analysis

## 1. Setup and Configuration

In [1]:
#!/usr/bin/env python3
"""
Cell 1: Import base components and configure environment

This cell sets up the evaluation environment.
"""

import sys
import os
from pathlib import Path
import logging
import json
import time
import subprocess
import shutil
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass

# Add project root to path
project_root = Path.cwd().parent if Path.cwd().name == "experiments" else Path.cwd()
sys.path.insert(0, str(project_root))

# Import base components
from scripts.notebooks.base_components import (
    NotebookError, CheckpointManager, SCMGenerator,
    OptimizationConfig, CheckpointMetadata, validate_environment,
    format_results_summary
)
from scripts.notebooks.config_templates import create_evaluation_config

# Core imports
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as random
import pyrsistent as pyr

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Image, display
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
)
logger = logging.getLogger(__name__)

# Validate environment
try:
    env_info = validate_environment()
    print("‚úÖ Environment Setup Complete")
    print(f"üìÅ Project root: {project_root}")
    print(f"üîß JAX devices: {env_info['jax_devices']}")
    print(f"üìÖ Date: {env_info['timestamp']}")
except Exception as e:
    raise NotebookError(f"Environment validation failed: {e}")

# Initialize checkpoint manager
checkpoint_dir = project_root / "checkpoints" / "grpo_training"
checkpoint_manager = CheckpointManager(checkpoint_dir)
print(f"\nüìÅ Checkpoint directory: {checkpoint_dir}")

INFO:2025-07-28 16:33:05,379:jax._src.xla_bridge:749: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
[2025-07-28 16:33:05,379][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


‚úÖ Environment Setup Complete
üìÅ Project root: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt
üîß JAX devices: [CpuDevice(id=0)]
üìÖ Date: 2025-07-28 16:33:05

üìÅ Checkpoint directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training


In [2]:
"""
Cell 2: Configure evaluation parameters

Set evaluation mode and parameters for the run.
"""

# EVALUATION MODE SELECTION
EVALUATION_MODE = "PHASE2_ACTIVE_LEARNING"  # Options: "SINGLE_CHECKPOINT", "COMPARE_CHECKPOINTS", "COMPARE_OBJECTIVES", "PHASE2_ACTIVE_LEARNING"

# Configuration for different modes
if EVALUATION_MODE == "SINGLE_CHECKPOINT":
    # Evaluate one checkpoint against baselines
    NUM_TEST_SCMS = 10
    RUNS_PER_METHOD = 3
    INTERVENTION_BUDGET = 10
    
elif EVALUATION_MODE == "COMPARE_CHECKPOINTS":
    # Compare multiple checkpoints
    COMPARISON_COUNT = 3  # Number of checkpoints to compare
    NUM_TEST_SCMS = 5  # Fewer SCMs for faster comparison
    RUNS_PER_METHOD = 2
    INTERVENTION_BUDGET = 8
    
elif EVALUATION_MODE == "COMPARE_OBJECTIVES":
    # Compare minimization vs maximization
    NUM_TEST_SCMS = 8
    RUNS_PER_METHOD = 3
    INTERVENTION_BUDGET = 10

elif EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
    # Phase 2: Use GRPO policy with active learning surrogate
    NUM_TEST_SCMS = 10
    RUNS_PER_METHOD = 3
    INTERVENTION_BUDGET = 50  # More interventions for structure learning
    LEARNING_RATE = 1e-3  # For active surrogate
    OBSERVATION_SAMPLES = 30  # Initial observational data

else:
    raise NotebookError(f"Unknown evaluation mode: {EVALUATION_MODE}")

# General settings
RANDOM_SEED = 42

print("üéØ Evaluation Configuration")
print("=" * 50)
print(f"Mode: {EVALUATION_MODE}")
print(f"Test SCMs: {NUM_TEST_SCMS}")
print(f"Runs per method: {RUNS_PER_METHOD}")
print(f"Intervention budget: {INTERVENTION_BUDGET}")
print(f"Random seed: {RANDOM_SEED}")

if EVALUATION_MODE == "COMPARE_CHECKPOINTS":
    print(f"Checkpoints to compare: {COMPARISON_COUNT}")
elif EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
    print(f"Learning rate: {LEARNING_RATE}")
    print(f"Observation samples: {OBSERVATION_SAMPLES}")
    print(f"\nüìö Phase 2 combines:")
    print(f"  ‚Ä¢ GRPO policy (guides interventions)")
    print(f"  ‚Ä¢ Active learning (discovers structure)")

üéØ Evaluation Configuration
Mode: PHASE2_ACTIVE_LEARNING
Test SCMs: 10
Runs per method: 3
Intervention budget: 50
Random seed: 42
Learning rate: 0.001
Observation samples: 30

üìö Phase 2 combines:
  ‚Ä¢ GRPO policy (guides interventions)
  ‚Ä¢ Active learning (discovers structure)


## 2. Evaluation Mode and Checkpoint Selection

In [3]:
"""
Cell 4: Select checkpoints for evaluation

This cell handles checkpoint selection based on the evaluation mode.
Uses intelligent discovery instead of hardcoded names.
"""

print("üìã Intelligent Checkpoint Selection")
print("=" * 50)

# Get available checkpoints
try:
    available_checkpoints = checkpoint_manager.list_checkpoints()
    if not available_checkpoints:
        raise NotebookError("No checkpoints found")
        
    print(f"Found {len(available_checkpoints)} total checkpoints")
    
    # Show usable checkpoints by direction
    usable_minimize = checkpoint_manager.find_usable_checkpoints('MINIMIZE')
    usable_maximize = checkpoint_manager.find_usable_checkpoints('MAXIMIZE')
    
    print(f"Usable checkpoints:")
    print(f"  MINIMIZE: {len(usable_minimize)}")
    print(f"  MAXIMIZE: {len(usable_maximize)}")
    print(f"  Total usable: {len(usable_minimize) + len(usable_maximize)}")
        
except Exception as e:
    raise NotebookError(f"Failed to analyze checkpoints: {e}")

# SELECT CHECKPOINTS BASED ON MODE
selected_checkpoints = []

if EVALUATION_MODE == "SINGLE_CHECKPOINT":
    # Find best MINIMIZE checkpoint (preferred for comparison with PARENT_SCALE)
    best_checkpoint = checkpoint_manager.find_best_checkpoint({
        'optimization_direction': 'MINIMIZE',
        'training_mode': 'FULL'
    })
    
    if best_checkpoint:
        selected_checkpoints = [best_checkpoint]
        print(f"üéØ Selected best MINIMIZE checkpoint: {best_checkpoint.name}")
        
        # Validate the selected checkpoint
        validation = checkpoint_manager.validate_checkpoint(best_checkpoint)
        if validation['is_valid']:
            print(f"  ‚úÖ Checkpoint is valid and ready for evaluation")
        else:
            print(f"  ‚ö†Ô∏è Checkpoint issues: {validation['issues']}")
    else:
        # Fallback: try any usable checkpoint
        all_usable = checkpoint_manager.find_usable_checkpoints()
        if all_usable:
            selected_checkpoints = [all_usable[0]]
            print(f"üéØ No MINIMIZE checkpoint found, using: {all_usable[0].name}")
            print(f"  Optimization: {all_usable[0].optimization_config.direction}")
        else:
            raise NotebookError("No usable checkpoints found. Please ensure checkpoints have both metadata.json and checkpoint.pkl files.")

elif EVALUATION_MODE == "COMPARE_CHECKPOINTS":
    # Get multiple usable checkpoints
    all_usable = checkpoint_manager.find_usable_checkpoints()
    comparison_count = min(COMPARISON_COUNT, len(all_usable))
    selected_checkpoints = all_usable[:comparison_count]
    print(f"üìä Selected {comparison_count} checkpoints for comparison")

elif EVALUATION_MODE == "COMPARE_OBJECTIVES":
    # Get best from each optimization direction
    best_minimize = checkpoint_manager.find_best_checkpoint({'optimization_direction': 'MINIMIZE'})
    best_maximize = checkpoint_manager.find_best_checkpoint({'optimization_direction': 'MAXIMIZE'})
    
    selected_checkpoints = []
    if best_minimize:
        selected_checkpoints.append(best_minimize)
    if best_maximize:
        selected_checkpoints.append(best_maximize)
    
    if not selected_checkpoints:
        raise NotebookError("Need checkpoints from both MINIMIZE and MAXIMIZE directions for objective comparison")
    
    print(f"üîÑ Selected checkpoints for objective comparison:")
    for ckpt in selected_checkpoints:
        print(f"  - {ckpt.name} ({ckpt.optimization_config.direction})")

elif EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
    # Phase 2: Select best GRPO checkpoint to use with active learning
    best_checkpoint = checkpoint_manager.find_best_checkpoint({
        'optimization_direction': 'MINIMIZE',
        'training_mode': 'FULL'
    })
    
    if best_checkpoint:
        selected_checkpoints = [best_checkpoint]
        print(f"üéØ Selected checkpoint for Phase 2: {best_checkpoint.name}")
        print(f"  This GRPO policy will guide active learning")
        
        # Validate the selected checkpoint
        validation = checkpoint_manager.validate_checkpoint(best_checkpoint)
        if validation['is_valid']:
            print(f"  ‚úÖ Checkpoint is valid and ready for Phase 2")
        else:
            print(f"  ‚ö†Ô∏è Checkpoint issues: {validation['issues']}")
    else:
        # Fallback: try any usable checkpoint
        all_usable = checkpoint_manager.find_usable_checkpoints()
        if all_usable:
            selected_checkpoints = [all_usable[0]]
            print(f"üéØ Using checkpoint for Phase 2: {all_usable[0].name}")
            print(f"  Optimization: {all_usable[0].optimization_config.direction}")
        else:
            raise NotebookError("No usable checkpoints found for Phase 2. Please train a GRPO policy first.")

else:
    raise NotebookError(f"Unknown evaluation mode: {EVALUATION_MODE}")

# Final validation
if not selected_checkpoints:
    raise NotebookError("No checkpoints selected for evaluation")

print(f"\n‚úÖ Final Selection ({len(selected_checkpoints)} checkpoint(s)):")
for i, checkpoint in enumerate(selected_checkpoints, 1):
    print(f"{i}. {checkpoint.name}")
    print(f"     Optimization: {checkpoint.optimization_config.direction}")
    print(f"     Training mode: {checkpoint.training_config.get('mode', 'unknown')}")
    print(f"     Path: {checkpoint.path}")
    
    # Final validation
    validation = checkpoint_manager.validate_checkpoint(checkpoint)
    if validation['is_valid']:
        print(f"     Status: ‚úÖ Ready for evaluation")
    else:
        print(f"     Status: ‚ùå Issues found: {validation['issues']}")
        raise NotebookError(f"Selected checkpoint {checkpoint.name} has validation issues: {validation['issues']}")

print(f"\nüöÄ All selected checkpoints validated and ready for evaluation!")

üìã Intelligent Checkpoint Selection
Found 4 total checkpoints
Usable checkpoints:
  MINIMIZE: 4
  MAXIMIZE: 0
  Total usable: 4
üéØ Selected checkpoint for Phase 2: grpo_full_minimize_20250728_163157
  This GRPO policy will guide active learning
  ‚úÖ Checkpoint is valid and ready for Phase 2

‚úÖ Final Selection (1 checkpoint(s)):
1. grpo_full_minimize_20250728_163157
     Optimization: MINIMIZE
     Training mode: FULL
     Path: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_full_minimize_20250728_163157
     Status: ‚úÖ Ready for evaluation

üöÄ All selected checkpoints validated and ready for evaluation!


## 3. Load and Validate Checkpoints

In [4]:
"""
Cell 3: Load checkpoint metadata and validate

This cell loads the selected checkpoints and validates their metadata.
"""

print("üì• Loading Checkpoint Metadata")
print("=" * 60)

# Store loaded checkpoint info
loaded_checkpoints = {}

for ckpt in selected_checkpoints:
    print(f"\nLoading: {ckpt.name}")
    try:
        # For now, we're using the metadata we already have
        # In production, this would load the actual model parameters
        loaded_checkpoints[ckpt.name] = {
            'metadata': ckpt,
            'optimization_config': ckpt.optimization_config,
            'training_config': ckpt.training_config,
            'model_params': None  # TODO: Load actual model parameters
        }
        
        print(f"  ‚úì Optimization: {ckpt.optimization_config.direction}")
        print(f"  ‚úì Training mode: {ckpt.training_config.get('mode', 'unknown')}")
        print(f"  ‚úì Episodes completed: {ckpt.training_results.get('episodes_completed', 'unknown')}")
        print(f"  ‚úì Duration: {ckpt.training_results.get('duration_minutes', 0):.1f} minutes")
        
        # Show reward weights if available
        if 'reward_weights' in ckpt.training_config:
            weights = ckpt.training_config['reward_weights']
            print(f"  ‚úì Reward weights: opt={weights.get('optimization', 0):.1f}, "
                  f"struct={weights.get('discovery', 0):.1f}, "
                  f"eff={weights.get('efficiency', 0):.1f}")
                  
    except Exception as e:
        print(f"  ‚úó Failed to load: {e}")
        raise NotebookError(f"Failed to load checkpoint {ckpt.name}: {e}")

print(f"\n‚úÖ Loaded {len(loaded_checkpoints)} checkpoint(s) successfully")

# Check optimization compatibility for comparison modes
if EVALUATION_MODE == "COMPARE_OBJECTIVES":
    directions = [ckpt.optimization_config.direction for ckpt in selected_checkpoints]
    if len(set(directions)) == 1:
        print(f"\n‚ö†Ô∏è Warning: All checkpoints have same optimization direction: {directions[0]}")
        print("   Objective comparison may not be meaningful.")
    else:
        print(f"\n‚úÖ Comparing optimization directions: {set(directions)}")

üì• Loading Checkpoint Metadata

Loading: grpo_full_minimize_20250728_163157
  ‚úì Optimization: MINIMIZE
  ‚úì Training mode: FULL
  ‚úì Episodes completed: 512
  ‚úì Duration: 44.5 minutes
  ‚úì Reward weights: opt=0.8, struct=0.1, eff=0.1

‚úÖ Loaded 1 checkpoint(s) successfully


## 4. Generate Test SCMs

In [5]:
"""
Cell 4: Generate test SCMs for evaluation

Create a balanced set of test SCMs different from training.
"""

print("üî¨ Generating Test SCMs")
print("=" * 60)

# Initialize SCM generator
scm_generator = SCMGenerator()

# Generate test SCMs with different seed than training
test_seed = RANDOM_SEED + 1000  # Ensure different from training

try:
    test_scms, test_metadata = scm_generator.generate_balanced_scms(
        num_scms=NUM_TEST_SCMS,
        variable_range=(3, 6),
        structure_types=['fork', 'chain', 'collider', 'mixed'],
        seed=test_seed
    )
    
    print(f"\n‚úÖ Generated {len(test_scms)} test SCMs")
    
    # Analyze distribution
    distribution = scm_generator._summarize_distribution(test_metadata)
    print(f"\nüìä Test Set Distribution:")
    print(f"  Structure types: {distribution['structure_types']}")
    print(f"  Variable counts: {distribution['variable_counts']}")
    
    # Save test SCM metadata
    test_scm_path = project_root / "results" / "test_scms" / f"test_scms_{test_seed}.json"
    test_scm_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(test_scm_path, 'w') as f:
        json.dump({
            'metadata': test_metadata,
            'config': {
                'num_scms': len(test_scms),
                'seed': test_seed,
                'variable_range': [3, 6],
                'structure_types': ['fork', 'chain', 'collider', 'mixed']
            },
            'generated_at': datetime.now().isoformat()
        }, f, indent=2)
    
    print(f"\nüíæ Saved test SCM metadata to: {test_scm_path}")
    
except Exception as e:
    raise NotebookError(f"Failed to generate test SCMs: {e}")

üî¨ Generating Test SCMs


[2025-07-28 16:33:05,831][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-28 16:33:05,832][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 3 vars, 2 edges, target=X1
[2025-07-28 16:33:05,847][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X2'
[2025-07-28 16:33:05,848][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 4 vars, 3 edges, target=X2
[2025-07-28 16:33:05,864][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 5 variables, 4 edges, target='X2'
[2025-07-28 16:33:05,864][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 5 vars, 4 edges, target=X2
[2025-07-28 16:33:05,867][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 6 variables, 5 edges, target='X3'
[2025-07-28 16:33:05,867][causal_bayes_opt.experiments.variable_scm_factory]


‚úÖ Generated 10 test SCMs

üìä Test Set Distribution:
  Structure types: {'fork': 4, 'chain': 4, 'collider': 2}
  Variable counts: {3: 3, 4: 3, 5: 2, 6: 2}

üíæ Saved test SCM metadata to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/test_scms/test_scms_1042.json


## 5. Run Evaluation

In [6]:
"""
Cell 5: Run evaluation with proper optimization handling

Evaluate checkpoints against baselines with correct metrics.
"""

print("üèÅ Running Evaluation")
print("=" * 60)

# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = project_root / "results" / f"evaluation_{EVALUATION_MODE.lower()}_{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"üìÅ Output directory: {output_dir}")
print(f"\nEvaluation parameters:")
print(f"  Mode: {EVALUATION_MODE}")
print(f"  Test SCMs: {NUM_TEST_SCMS}")
print(f"  Runs per method: {RUNS_PER_METHOD}")
print(f"  Intervention budget: {INTERVENTION_BUDGET}")

# Store results
evaluation_results = {}
evaluation_start_time = time.time()

# Phase 2 Active Learning Implementation
if EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
    print("\nüöÄ Running Phase 2: Active Learning with GRPO Policy")
    print("=" * 60)
    
    # Import Phase 2 modules
    from src.causal_bayes_opt.training.grpo_policy_loader import (
        load_grpo_policy, create_grpo_intervention_fn
    )
    from examples.demo_learning import (
        DemoConfig, create_learnable_surrogate_model,
        create_oracle_intervention_policy
    )
    from examples.complete_workflow_demo import run_progressive_learning_demo_with_scm
    
    # Ensure we have a checkpoint selected
    if not selected_checkpoints:
        raise NotebookError("No checkpoint selected for Phase 2 evaluation")
    
    ckpt = selected_checkpoints[0]  # Use first checkpoint
    print(f"\nLoading GRPO policy from: {ckpt.name}")
    
    try:
        # Load GRPO policy
        grpo_policy = load_grpo_policy(str(ckpt.path))
        print(f"‚úÖ Loaded GRPO policy successfully")
        print(f"  Variables: {grpo_policy.variables}")
        print(f"  Target: {grpo_policy.target_variable}")
        
        # Run Phase 2 for each test SCM
        phase2_results = {}
        
        for scm_idx, scm in enumerate(test_scms):
            print(f"\n{'='*40}")
            print(f"SCM {scm_idx + 1}/{len(test_scms)}")
            
            # Create GRPO intervention function for this SCM
            grpo_intervention_fn = create_grpo_intervention_fn(
                loaded_policy=grpo_policy,
                scm=scm,
                intervention_range=(-2.0, 2.0)
            )
            
            # Run multiple seeds
            scm_results = []
            for seed_idx in range(RUNS_PER_METHOD):
                print(f"  Run {seed_idx + 1}/{RUNS_PER_METHOD}", end="... ")
                
                # Create new config for each run with updated seed
                phase2_config = DemoConfig(
                    n_observational_samples=OBSERVATION_SAMPLES,
                    n_intervention_steps=INTERVENTION_BUDGET,
                    learning_rate=LEARNING_RATE,
                    intervention_value_range=(-2.0, 2.0),
                    random_seed=RANDOM_SEED + scm_idx * 100 + seed_idx,
                    scoring_method="bic"
                )
                
                # Run Phase 2: GRPO policy + active learning
                result = run_progressive_learning_demo_with_scm(
                    scm=scm,
                    config=phase2_config,
                    pretrained_surrogate=None,  # Active learning from scratch
                    pretrained_acquisition=grpo_intervention_fn  # GRPO policy
                )
                
                scm_results.append(result)
                
                # Extract final value from target_progress if available
                if 'target_progress' in result and result['target_progress']:
                    final_value = result['target_progress'][-1]
                    print(f"Final value: {final_value:.4f}")
                else:
                    # Handle intervention_mean which might be a JAX array
                    final_value = result.get('intervention_mean', 'N/A')
                    if hasattr(final_value, 'item'):
                        # Convert JAX array to scalar
                        final_value = float(final_value.item())
                        print(f"Final value: {final_value:.4f}")
                    elif isinstance(final_value, (int, float)):
                        print(f"Final value: {final_value:.4f}")
                    else:
                        print(f"Final value: {final_value}")
            
            phase2_results[f"scm_{scm_idx}"] = scm_results
        
        # Also run baseline: Random + Active Learning
        print(f"\n{'='*60}")
        print("Running baseline: Random + Active Learning")
        
        random_active_results = {}
        for scm_idx, scm in enumerate(test_scms):
            print(f"\nSCM {scm_idx + 1}/{len(test_scms)}")
            
            scm_results = []
            for seed_idx in range(RUNS_PER_METHOD):
                print(f"  Run {seed_idx + 1}/{RUNS_PER_METHOD}", end="... ")
                
                # Create new config for baseline with different seed offset
                phase2_config = DemoConfig(
                    n_observational_samples=OBSERVATION_SAMPLES,
                    n_intervention_steps=INTERVENTION_BUDGET,
                    learning_rate=LEARNING_RATE,
                    intervention_value_range=(-2.0, 2.0),
                    random_seed=RANDOM_SEED + scm_idx * 100 + seed_idx + 1000,
                    scoring_method="bic"
                )
                
                result = run_progressive_learning_demo_with_scm(
                    scm=scm,
                    config=phase2_config,
                    pretrained_surrogate=None,  # Active learning
                    pretrained_acquisition=None  # Random interventions
                )
                
                scm_results.append(result)
                
                # Extract final value from target_progress if available
                if 'target_progress' in result and result['target_progress']:
                    final_value = result['target_progress'][-1]
                    print(f"Final value: {final_value:.4f}")
                else:
                    # Handle intervention_mean which might be a JAX array
                    final_value = result.get('intervention_mean', 'N/A')
                    if hasattr(final_value, 'item'):
                        # Convert JAX array to scalar
                        final_value = float(final_value.item())
                        print(f"Final value: {final_value:.4f}")
                    elif isinstance(final_value, (int, float)):
                        print(f"Final value: {final_value:.4f}")
                    else:
                        print(f"Final value: {final_value}")
            
            random_active_results[f"scm_{scm_idx}"] = scm_results
        
        # NEW: Run Oracle + Active Learning baseline
        print(f"\n{'='*60}")
        print("Running oracle baseline: Oracle + Active Learning")
        print("(Oracle knows true parents and tests interventions)")
        
        oracle_active_results = {}
        for scm_idx, scm in enumerate(test_scms):
            print(f"\nSCM {scm_idx + 1}/{len(test_scms)}")
            
            # Get variables and target from SCM
            from causal_bayes_opt.data_structures import get_variables, get_target
            variables = sorted(get_variables(scm))
            target = get_target(scm)
            
            # Create oracle intervention function for this SCM
            oracle_intervention_fn = create_oracle_intervention_policy(
                variables=variables,
                target_variable=target,
                scm=scm,
                intervention_value_range=(-2.0, 2.0)
            )
            
            scm_results = []
            for seed_idx in range(RUNS_PER_METHOD):
                print(f"  Run {seed_idx + 1}/{RUNS_PER_METHOD}", end="... ")
                
                # Create new config for oracle with different seed offset
                phase2_config = DemoConfig(
                    n_observational_samples=OBSERVATION_SAMPLES,
                    n_intervention_steps=INTERVENTION_BUDGET,
                    learning_rate=LEARNING_RATE,
                    intervention_value_range=(-2.0, 2.0),
                    random_seed=RANDOM_SEED + scm_idx * 100 + seed_idx + 2000,
                    scoring_method="bic"
                )
                
                result = run_progressive_learning_demo_with_scm(
                    scm=scm,
                    config=phase2_config,
                    pretrained_surrogate=None,  # Active learning
                    pretrained_acquisition=oracle_intervention_fn  # Oracle interventions
                )
                
                scm_results.append(result)
                
                # Extract final value from target_progress if available
                if 'target_progress' in result and result['target_progress']:
                    final_value = result['target_progress'][-1]
                    print(f"Final value: {final_value:.4f}")
                else:
                    # Handle intervention_mean which might be a JAX array
                    final_value = result.get('intervention_mean', 'N/A')
                    if hasattr(final_value, 'item'):
                        # Convert JAX array to scalar
                        final_value = float(final_value.item())
                        print(f"Final value: {final_value:.4f}")
                    elif isinstance(final_value, (int, float)):
                        print(f"Final value: {final_value:.4f}")
                    else:
                        print(f"Final value: {final_value}")
            
            oracle_active_results[f"scm_{scm_idx}"] = scm_results
        
        # Store results
        evaluation_results['phase2_grpo_active'] = {
            'results': phase2_results,
            'method': 'GRPO + Active Learning',
            'checkpoint': ckpt.name
        }
        
        evaluation_results['phase2_random_active'] = {
            'results': random_active_results,
            'method': 'Random + Active Learning',
            'checkpoint': 'baseline'
        }
        
        evaluation_results['phase2_oracle_active'] = {
            'results': oracle_active_results,
            'method': 'Oracle + Active Learning',
            'checkpoint': 'oracle'
        }
        
        # Also run standard GRPO evaluation for comparison
        print(f"\n{'='*60}")
        print("Running standard GRPO (Phase 1 only) for comparison...")
        
        # Import the unified evaluation function
        from src.causal_bayes_opt.evaluation import run_evaluation
        
        eval_config = {
            'n_scms': NUM_TEST_SCMS,
            'n_seeds': RUNS_PER_METHOD,
            'parallel': False,
            'experiment': {
                'runs_per_method': RUNS_PER_METHOD,
                'target': {
                    'max_interventions': INTERVENTION_BUDGET,
                    'n_observational_samples': OBSERVATION_SAMPLES,
                    'optimization_direction': ckpt.optimization_config.direction
                }
            }
        }
        
        standard_results = run_evaluation(
            checkpoint_path=ckpt.path,
            output_dir=output_dir / "grpo_standard",
            config=eval_config,
            test_scms=test_scms,
            methods=['grpo']
        )
        
        evaluation_results['phase1_grpo_bootstrap'] = {
            'results': standard_results,
            'method': 'GRPO + Bootstrap (Phase 1)',
            'checkpoint': ckpt.name
        }
        
    except Exception as e:
        print(f"\n‚ùå Phase 2 evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        raise NotebookError(f"Phase 2 evaluation failed: {e}")

else:
    # Original evaluation code for other modes
    # Import the unified evaluation function
    from src.causal_bayes_opt.evaluation import run_evaluation

    # Run evaluation for each checkpoint
    for ckpt in selected_checkpoints:
        print(f"\n{'='*60}")
        print(f"Evaluating: {ckpt.name}")
        print(f"Optimization: {ckpt.optimization_config.direction}")
        
        # Create checkpoint-specific output directory
        ckpt_output_dir = output_dir / ckpt.name
        ckpt_output_dir.mkdir(exist_ok=True)
        
        # Prepare evaluation config
        eval_config = {
            'n_scms': NUM_TEST_SCMS,
            'n_seeds': RUNS_PER_METHOD,  # This is the correct parameter name
            'parallel': False,  # Disable parallel to avoid serialization issues
            'experiment': {
                'runs_per_method': RUNS_PER_METHOD,
                'target': {
                    'max_interventions': INTERVENTION_BUDGET,
                    'n_observational_samples': 100,
                    'optimization_direction': ckpt.optimization_config.direction
                },
                'scm_generation': {
                    'use_variable_factory': True,
                    'variable_range': [3, 6],
                    'structure_types': ['fork', 'chain', 'collider', 'mixed']
                }
            },
            'visualization': {
                'enabled': True,
                'plot_types': ['target_trajectory', 'f1_trajectory', 'shd_trajectory', 'method_comparison']
            }
        }
        
        # Run evaluation using the unified system
        print(f"\nRunning evaluation with unified system...")
        start_time = time.time()
        
        try:
            # Use our test SCMs if already generated
            comparison_results = run_evaluation(
                checkpoint_path=ckpt.path,
                output_dir=ckpt_output_dir,
                config=eval_config,
                test_scms=test_scms if 'test_scms' in locals() else None,
                methods=['grpo', 'random', 'learning', 'oracle']  # Compare against all baselines
            )
            
            duration = (time.time() - start_time) / 60
            print(f"\n‚úÖ Evaluation completed in {duration:.1f} minutes")
            
            # Store results
            evaluation_results[ckpt.name] = {
                'results': comparison_results,
                'optimization_direction': ckpt.optimization_config.direction,
                'checkpoint_metadata': ckpt.to_dict(),
                'duration_minutes': duration
            }
            
            # Quick summary - FIXED: use method_metrics instead of method_results
            print("\nüìä Quick Summary:")
            for method_name, metrics in comparison_results.method_metrics.items():
                improvement = metrics.mean_improvement
                # Format based on optimization direction
                if ckpt.optimization_config.is_minimizing:
                    display_val = f"{-improvement:.4f} (lower is better)"
                else:
                    display_val = f"{improvement:.4f} (higher is better)"
                print(f"  {method_name}: {display_val}")
                
        except Exception as e:
            print(f"\n‚ùå Evaluation failed: {e}")
            import traceback
            traceback.print_exc()
            evaluation_results[ckpt.name] = {'error': str(e)}

total_duration = (time.time() - evaluation_start_time) / 60

print(f"\n{'='*60}")
print(f"‚úÖ All evaluations complete!")
print(f"  Total duration: {total_duration:.1f} minutes")
print(f"  Successful: {sum(1 for r in evaluation_results.values() if 'error' not in r)}/{len(evaluation_results)}")
print(f"\nüìÅ Results saved to: {output_dir}")

üèÅ Running Evaluation
üìÅ Output directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/results/evaluation_phase2_active_learning_20250728_163305

Evaluation parameters:
  Mode: PHASE2_ACTIVE_LEARNING
  Test SCMs: 10
  Runs per method: 3
  Intervention budget: 50

üöÄ Running Phase 2: Active Learning with GRPO Policy


[2025-07-28 16:33:08,126][src.causal_bayes_opt.training.grpo_policy_loader][INFO] - Loaded variable-agnostic GRPO policy
[2025-07-28 16:33:08,126][src.causal_bayes_opt.training.grpo_policy_loader][INFO] - Loaded GRPO policy from /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_full_minimize_20250728_163157/policy_params.pkl
[2025-07-28 16:33:08,127][src.causal_bayes_opt.training.grpo_policy_loader][INFO] - Policy config: {'architecture': {'hidden_dim': 128, 'num_layers': 2, 'num_heads': 4, 'key_size': 32, 'widening_factor': 4, 'dropout': 0.1, 'policy_intermediate_dim': None, 'variable_agnostic': True}, 'state_config': {'max_history_size': 100, 'num_channels': 5, 'standardize_values': True, 'include_temporal_features': True, 'use_global_standardization': True}, 'grpo_config': {'group_size': 64, 'interventions_per_state': 8, 'clip_ratio': 0.2, 'entropy_coeff': 0.1, 'kl_penalty_coeff': 0.0, 'max_grad_norm': 1.0, 'scale_rewards': True}


Loading GRPO policy from: grpo_full_minimize_20250728_163157
‚úÖ Loaded GRPO policy successfully
  Variables: None
  Target: None

SCM 1/10
  Run 1/3... üß† Progressive Learning Demo
üß† Using learnable surrogate model
üéØ Using pretrained acquisition policy

‚ùå Phase 2 evaluation failed: Unable to retrieve parameter 'w' for module 'EnrichedAcquisitionPolicyNetwork/EnrichedAttentionEncoder/~_project_enriched_input_with_roles/RoleBasedProjection/target_projection' All parameters must be created as part of `init`.


Traceback (most recent call last):
  File "/var/folders/2f/7z7glsfj1fd22nlr6wj56z5w0000gn/T/ipykernel_79305/1762222707.py", line 85, in <module>
    result = run_progressive_learning_demo_with_scm(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/examples/complete_workflow_demo.py", line 185, in run_progressive_learning_demo_with_scm
    intervention = intervention_fn(key=keys[step])
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/src/causal_bayes_opt/training/grpo_policy_loader.py", line 257, in select_intervention
    policy_output = loaded_policy.apply_fn(
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/harellidar/Library/Caches/pypoetry/virtualenvs/causal-bayes-opt-9Aj1r1ec-py3.12/lib/python3.12/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **k

NotebookError: Phase 2 evaluation failed: Unable to retrieve parameter 'w' for module 'EnrichedAcquisitionPolicyNetwork/EnrichedAttentionEncoder/~_project_enriched_input_with_roles/RoleBasedProjection/target_projection' All parameters must be created as part of `init`.

## 6. Generate Visualizations

In [None]:
"""
Cell 6: Generate visualizations with trajectory plots

Create comprehensive plots showing method performance over time.
"""

print("üìä Generating Visualizations")
print("=" * 60)

# Check if we have results to visualize
valid_results = {k: v for k, v in evaluation_results.items() if 'error' not in v}

if not valid_results:
    print("‚ùå No valid results to visualize")
else:
    if EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
        # Phase 2 specific visualizations
        print("\nüìà Phase 2 Active Learning Visualizations")
        
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        # Create figure with subplots
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Phase 2: GRPO + Active Learning Results', fontsize=16)
        
        # Extract results
        grpo_active_results = evaluation_results.get('phase2_grpo_active', {}).get('results', {})
        random_active_results = evaluation_results.get('phase2_random_active', {}).get('results', {})
        oracle_active_results = evaluation_results.get('phase2_oracle_active', {}).get('results', {})
        
        # 1. Target Value Trajectory
        ax = axes[0, 0]
        ax.set_title('Target Value Over Time')
        ax.set_xlabel('Interventions')
        ax.set_ylabel('Target Value')
        
        # Plot trajectories for each method
        for method_name, method_results, color in [
            ('GRPO + Active', grpo_active_results, 'blue'),
            ('Random + Active', random_active_results, 'orange'),
            ('Oracle + Active', oracle_active_results, 'green')
        ]:
            all_trajectories = []
            for scm_results in method_results.values():
                for result in scm_results:
                    if 'target_progress' in result:
                        all_trajectories.append(result['target_progress'])
            
            if all_trajectories:
                # Average across runs
                import numpy as np
                max_len = max(len(t) for t in all_trajectories)
                padded = np.array([t + [t[-1]]*(max_len-len(t)) for t in all_trajectories])
                mean_traj = np.mean(padded, axis=0)
                std_traj = np.std(padded, axis=0)
                
                x = np.arange(len(mean_traj))
                ax.plot(x, mean_traj, label=method_name, linewidth=2, color=color)
                ax.fill_between(x, mean_traj - std_traj, mean_traj + std_traj, alpha=0.2, color=color)
        
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 2. Structure Learning (F1 Score)
        ax = axes[0, 1]
        ax.set_title('Structure Learning (F1 Score)')
        ax.set_xlabel('Interventions')
        ax.set_ylabel('F1 Score')
        
        for method_name, method_results, color in [
            ('GRPO + Active', grpo_active_results, 'blue'),
            ('Random + Active', random_active_results, 'orange'),
            ('Oracle + Active', oracle_active_results, 'green')
        ]:
            f1_trajectories = []
            for scm_results in method_results.values():
                for result in scm_results:
                    if 'learning_history' in result:
                        # Extract F1 scores from learning history
                        f1_scores = []
                        for step in result['learning_history']:
                            if 'f1_score' in step:
                                f1_scores.append(step['f1_score'])
                        if f1_scores:
                            f1_trajectories.append(f1_scores)
            
            if f1_trajectories:
                # Average across runs
                max_len = max(len(t) for t in f1_trajectories)
                padded = np.array([t + [t[-1]]*(max_len-len(t)) for t in f1_trajectories])
                mean_f1 = np.mean(padded, axis=0)
                std_f1 = np.std(padded, axis=0)
                
                x = np.arange(len(mean_f1))
                ax.plot(x, mean_f1, label=method_name, linewidth=2, color=color)
                ax.fill_between(x, mean_f1 - std_f1, mean_f1 + std_f1, alpha=0.2, color=color)
        
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim([0, 1])
        
        # 3. Final Performance Comparison
        ax = axes[1, 0]
        ax.set_title('Final Performance Comparison')
        
        # Collect final values
        method_final_values = {}
        for method_key, method_data, display_name in [
            ('phase2_grpo_active', grpo_active_results, 'GRPO + Active'),
            ('phase2_random_active', random_active_results, 'Random + Active'),
            ('phase2_oracle_active', oracle_active_results, 'Oracle + Active')
        ]:
            final_values = []
            for scm_results in method_data.values():
                for result in scm_results:
                    # Extract final value from target_progress if available
                    if 'target_progress' in result and result['target_progress']:
                        final_values.append(result['target_progress'][-1])
                    elif 'intervention_mean' in result:
                        # Handle intervention_mean which might be a JAX array
                        val = result['intervention_mean']
                        if hasattr(val, 'item'):
                            final_values.append(float(val.item()))
                        elif isinstance(val, (int, float)):
                            final_values.append(float(val))
            if final_values:
                method_final_values[display_name] = final_values
        
        # Also add Phase 1 GRPO if available
        if 'phase1_grpo_bootstrap' in evaluation_results:
            phase1_results = evaluation_results['phase1_grpo_bootstrap'].get('results', {})
            if hasattr(phase1_results, 'method_metrics'):
                for method, metrics in phase1_results.method_metrics.items():
                    if 'grpo' in method.lower():
                        method_final_values['GRPO + Bootstrap'] = [metrics.mean_final_value]
        
        # Create box plot
        if method_final_values:
            labels = list(method_final_values.keys())
            values = list(method_final_values.values())
            positions = range(len(labels))
            
            bp = ax.boxplot(values, positions=positions, labels=labels, patch_artist=True)
            
            # Color boxes
            colors = ['blue', 'orange', 'green', 'red'][:len(labels)]
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
            
            ax.set_ylabel('Final Target Value')
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add sample size annotations
            for i, (label, vals) in enumerate(method_final_values.items()):
                ax.text(i, ax.get_ylim()[1] * 0.95, f'n={len(vals)}', 
                        ha='center', va='top', fontsize=10)
        
        # 4. Structure Learning Summary
        ax = axes[1, 1]
        ax.set_title('Structure Learning Performance')
        
        # Calculate average final F1 scores
        method_f1_scores = {}
        for method_name, method_results, display_name in [
            ('GRPO + Active', grpo_active_results, 'GRPO + Active'),
            ('Random + Active', random_active_results, 'Random + Active'),
            ('Oracle + Active', oracle_active_results, 'Oracle + Active')
        ]:
            f1_scores = []
            for scm_results in method_results.values():
                for result in scm_results:
                    if 'structure_learning_metrics' in result:
                        final_f1 = result['structure_learning_metrics'].get('final_f1', 0)
                        f1_scores.append(final_f1)
                    elif 'learning_history' in result and result['learning_history']:
                        # Try to get from last step
                        last_step = result['learning_history'][-1]
                        if 'f1_score' in last_step:
                            f1_scores.append(last_step['f1_score'])
            
            if f1_scores:
                method_f1_scores[display_name] = np.mean(f1_scores)
        
        # Create bar plot
        if method_f1_scores:
            methods = list(method_f1_scores.keys())
            scores = list(method_f1_scores.values())
            
            bars = ax.bar(methods, scores, alpha=0.7, color=['blue', 'orange', 'green'][:len(methods)])
            ax.set_ylabel('Average Final F1 Score')
            ax.set_ylim([0, 1])
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add value labels on bars
            for bar, score in zip(bars, scores):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{score:.3f}', ha='center', va='bottom')
        
        plt.tight_layout()
        
        # Save the plot
        plot_path = output_dir / "phase2_comparison.png"
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Saved Phase 2 comparison plot: {plot_path}")
        
        # Display the plot
        display(Image(filename=str(plot_path)))
        
        # Print summary statistics
        print("\nüìä Phase 2 Summary Statistics:")
        print("-" * 60)
        print(f"{'Method':<25} {'Mean Final Value':>15} {'Std Dev':>15} {'F1 Score':>15}")
        print("-" * 60)
        
        for method_name, final_vals in method_final_values.items():
            if final_vals:
                mean_val = np.mean(final_vals)
                std_val = np.std(final_vals)
                f1_val = method_f1_scores.get(method_name, 'N/A')
                
                print(f"{method_name:<25} {mean_val:>15.4f} {std_val:>15.4f} ", end="")
                if isinstance(f1_val, float):
                    print(f"{f1_val:>15.3f}")
                else:
                    print(f"{'N/A':>15}")
        
        print("-" * 60)
        
        # Add intervention diversity analysis
        print("\nüìä Intervention Diversity Analysis:")
        print("-" * 60)
        
        for method_name, method_results in [
            ('GRPO + Active', grpo_active_results),
            ('Random + Active', random_active_results),
            ('Oracle + Active', oracle_active_results)
        ]:
            intervention_counts = {}
            total_interventions = 0
            
            for scm_results in method_results.values():
                for result in scm_results:
                    if 'intervention_counts' in result:
                        for var, count in result['intervention_counts'].items():
                            if var != 'observational':
                                intervention_counts[var] = intervention_counts.get(var, 0) + count
                                total_interventions += count
            
            if total_interventions > 0:
                print(f"\n{method_name}:")
                for var, count in sorted(intervention_counts.items()):
                    percentage = (count / total_interventions) * 100
                    print(f"  {var}: {count} ({percentage:.1f}%)")
        
        print("-" * 60)
        
    else:
        # Original visualization code for other modes
        # Process results for each checkpoint
        for ckpt_name, eval_data in valid_results.items():
            print(f"\nüìà Visualizations for: {ckpt_name}")
            
            comparison_results = eval_data['results']
            opt_direction = eval_data['optimization_direction']
            
            # The plots should already be generated by run_evaluation
            # Let's check what was created
            ckpt_output_dir = output_dir / ckpt_name
            plot_files = list(ckpt_output_dir.glob("*.png"))
            plot_files.extend(list((ckpt_output_dir / "plots").glob("*.png")))
            
            if plot_files:
                print(f"Found {len(plot_files)} plots:")
                for plot_file in plot_files:
                    print(f"  - {plot_file.name}")
                    # Display the plot
                    try:
                        display(Image(filename=str(plot_file)))
                    except:
                        print(f"    (Could not display {plot_file.name})")
            
            # Also create a summary table - FIXED: use method_metrics instead of method_results
            print(f"\nüìä Performance Summary for {ckpt_name}:")
            print("-" * 80)
            print(f"{'Method':<30} {'Improvement':>15} {'Final Value':>15} {'Success Rate':>15}")
            print("-" * 80)
            
            for method_name, metrics in comparison_results.method_metrics.items():
                improvement = metrics.mean_improvement
                final_value = metrics.mean_final_value
                success_rate = metrics.success_rate if hasattr(metrics, 'success_rate') else metrics.n_successful / metrics.n_runs if metrics.n_runs > 0 else 0.0
                
                if opt_direction == "MINIMIZE":
                    # For minimization, negative improvement is good
                    improvement_str = f"{-improvement:>15.4f} ‚Üì"
                else:
                    improvement_str = f"{improvement:>15.4f} ‚Üë"
                
                print(f"{method_name:<30} {improvement_str} {final_value:>15.4f} {success_rate:>14.1%}")
            
            print("-" * 80)
            
            # Statistical significance - check if it exists
            if hasattr(comparison_results, 'statistical_tests') and comparison_results.statistical_tests:
                print("\nüîç Statistical Significance:")
                for test_name, test_result in comparison_results.statistical_tests.items():
                    if isinstance(test_result, dict):
                        p_value = test_result.get('p_value', 1.0)
                        significant = p_value < 0.05
                        print(f"  {test_name}: p={p_value:.4f} {'‚úÖ' if significant else '‚ùå'}")
    
    print(f"\n‚úÖ Visualization complete!")
    print(f"üìÅ All plots saved to: {output_dir}")

## 7. Export Results and Summary

In [None]:
"""
Cell 7: Export results and generate summary

Export evaluation results with proper handling for different result types.
"""

print("üìä Exporting Results")
print("=" * 60)

# Create a helper function to safely convert results
def convert_result_to_dict(result):
    """Safely convert any result type to dictionary."""
    if isinstance(result, dict):
        return result
    elif hasattr(result, 'to_dict'):
        return result.to_dict()
    elif hasattr(result, '__dict__'):
        # Convert object attributes to dict
        return {k: v for k, v in result.__dict__.items() if not k.startswith('_')}
    else:
        # Return as-is if we can't convert
        return result

# Fix export for all types of results
export_results = {}
for ckpt_name, result_data in evaluation_results.items():
    if 'error' in result_data:
        # Keep error results as-is
        export_results[ckpt_name] = result_data
    elif 'phase2_' in ckpt_name or 'phase1_' in ckpt_name:
        # Phase 2 and Phase 1 results are already dicts
        export_results[ckpt_name] = result_data
    else:
        # Standard results - handle ComparisonResults object
        results = result_data.get('results')
        
        # Convert ComparisonResults to dict if needed
        if results and hasattr(results, 'method_metrics'):
            # This is a ComparisonResults object
            results_dict = {
                'method_metrics': {},
                'scm_results': getattr(results, 'scm_results', {}),
                'statistical_tests': getattr(results, 'statistical_tests', {}),
                'metadata': getattr(results, 'metadata', {})
            }
            
            # Convert method metrics
            for method_name, metrics in results.method_metrics.items():
                if hasattr(metrics, '__dict__'):
                    results_dict['method_metrics'][method_name] = {
                        k: v for k, v in metrics.__dict__.items() 
                        if not k.startswith('_')
                    }
                else:
                    results_dict['method_metrics'][method_name] = metrics
            
            results = results_dict
        
        export_results[ckpt_name] = {
            'results': convert_result_to_dict(results),
            'optimization_direction': result_data.get('optimization_direction'),
            'checkpoint_metadata': result_data.get('checkpoint_metadata'),
            'duration_minutes': result_data.get('duration_minutes')
        }

# Prepare export data
export_data = {
    'evaluation_config': {
        'mode': EVALUATION_MODE,
        'num_test_scms': NUM_TEST_SCMS,
        'runs_per_method': RUNS_PER_METHOD,
        'intervention_budget': INTERVENTION_BUDGET,
        'random_seed': RANDOM_SEED,
        'timestamp': timestamp
    },
    'checkpoints_evaluated': [
        {
            'name': ckpt.name,
            'optimization_direction': ckpt.optimization_config.direction,
            'path': str(ckpt.path),
            'training_mode': ckpt.training_config.get('mode', 'unknown')
        }
        for ckpt in selected_checkpoints
    ],
    'results': export_results,
    'duration_minutes': total_duration if 'total_duration' in locals() else 0
}

# Save JSON results with custom encoder for numpy arrays
import numpy as np

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if hasattr(obj, 'item'):
            return obj.item()
        return super().default(obj)

json_path = output_dir / "evaluation_results.json"

try:
    with open(json_path, 'w') as f:
        json.dump(export_data, f, indent=2, cls=NumpyEncoder)
    print(f"‚úÖ Results saved to: {json_path}")
except Exception as e:
    print(f"‚ö†Ô∏è Error saving JSON: {e}")
    # Try to save a simplified version
    try:
        simplified_data = {
            'evaluation_config': export_data['evaluation_config'],
            'checkpoints_evaluated': export_data['checkpoints_evaluated'],
            'error': 'Full results could not be serialized',
            'duration_minutes': export_data['duration_minutes']
        }
        with open(json_path, 'w') as f:
            json.dump(simplified_data, f, indent=2)
        print(f"‚úÖ Saved simplified results to: {json_path}")
    except:
        print(f"‚ùå Could not save results to JSON")

# Generate text summary
summary_path = output_dir / "evaluation_summary.txt"
with open(summary_path, 'w') as f:
    f.write(f"GRPO Evaluation Summary\n")
    f.write(f"="*60 + "\n\n")
    
    f.write(f"Evaluation Mode: {EVALUATION_MODE}\n")
    f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M')}\n")
    f.write(f"Duration: {total_duration:.1f} minutes\n\n" if 'total_duration' in locals() else "\n")
    
    f.write(f"Configuration:\n")
    f.write(f"  Test SCMs: {NUM_TEST_SCMS}\n")
    f.write(f"  Runs per method: {RUNS_PER_METHOD}\n")
    f.write(f"  Intervention budget: {INTERVENTION_BUDGET}\n")
    f.write(f"  Random seed: {RANDOM_SEED}\n\n")
    
    f.write(f"Checkpoints Evaluated:\n")
    for ckpt in selected_checkpoints:
        f.write(f"  - {ckpt.name} ({ckpt.optimization_config.direction})\n")
    
    f.write(f"\nKey Findings:\n")
    f.write("-" * 60 + "\n")
    
    # Summarize results based on evaluation mode
    if EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
        # Phase 2 specific summary
        f.write("\nPhase 2 Active Learning Results:\n\n")
        
        # Extract method final values if available
        if 'method_final_values' in locals():
            for method_name, values in method_final_values.items():
                if values:
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    f.write(f"{method_name}:\n")
                    f.write(f"  Mean final value: {mean_val:.4f} ¬± {std_val:.4f}\n")
                    f.write(f"  Number of runs: {len(values)}\n\n")
        
        # Add structure learning summary if available
        if 'method_f1_scores' in locals():
            f.write("\nStructure Learning Performance:\n")
            for method, score in method_f1_scores.items():
                f.write(f"  {method}: F1 = {score:.3f}\n")
    
    else:
        # Standard evaluation summary
        for ckpt_name, result_data in export_results.items():
            if 'error' not in result_data and 'results' in result_data:
                f.write(f"\n{ckpt_name}:\n")
                
                results = result_data['results']
                if isinstance(results, dict) and 'method_metrics' in results:
                    # Sort methods by performance
                    method_metrics = results['method_metrics']
                    is_minimizing = result_data.get('optimization_direction') == 'MINIMIZE'
                    
                    sorted_methods = sorted(
                        method_metrics.items(),
                        key=lambda x: x[1].get('mean_final_value', float('inf')),
                        reverse=not is_minimizing
                    )
                    
                    for rank, (method_name, metrics) in enumerate(sorted_methods[:5], 1):
                        final_value = metrics.get('mean_final_value', 'N/A')
                        marker = "*" if 'grpo' in method_name.lower() else " "
                        f.write(f"  {rank}. {marker} {method_name}: {final_value}\n")
    
    f.write(f"\n\nOutput Files:\n")
    f.write(f"  Results JSON: {json_path.name}\n")
    f.write(f"  Summary: {summary_path.name}\n")
    
    # List all plots
    plot_locations = []
    
    # Check for plots in main directory
    main_plots = list(output_dir.glob("*.png"))
    for plot in main_plots:
        f.write(f"  Plot: {plot.name}\n")
        plot_locations.append(plot)
    
    # Check for plots in checkpoint subdirectories
    if 'selected_checkpoints' in locals():
        for ckpt in selected_checkpoints:
            ckpt_plot_dir = output_dir / ckpt.name / "plots"
            if ckpt_plot_dir.exists():
                ckpt_plots = list(ckpt_plot_dir.glob("*.png"))
                for plot in ckpt_plots:
                    f.write(f"  Plot ({ckpt.name}): {plot.relative_to(output_dir)}\n")
                    plot_locations.append(plot)

print(f"‚úÖ Summary saved to: {summary_path}")

# Display final summary
print(f"\nüéâ Evaluation Complete!")
print(f"Mode: {EVALUATION_MODE}")
print(f"Checkpoints: {len(selected_checkpoints)}")
print(f"Duration: {total_duration:.1f} minutes" if 'total_duration' in locals() else "")
print(f"\nOutput directory: {output_dir}")

# Show plot locations clearly
if 'plot_locations' in locals() and plot_locations:
    print(f"\nüìä Generated Plots ({len(plot_locations)} total):")
    for plot in plot_locations[:5]:  # Show first 5
        print(f"  - {plot.relative_to(output_dir)}")
    if len(plot_locations) > 5:
        print(f"  ... and {len(plot_locations) - 5} more")
else:
    print(f"\n‚ö†Ô∏è No plots were found in the output directory.")

print(f"\nüìä Key Insights:")
if EVALUATION_MODE == "SINGLE_CHECKPOINT":
    print(f"- Evaluated {selected_checkpoints[0].name} with {selected_checkpoints[0].optimization_config.direction} objective")
    print(f"- Compare against baselines to see if training improved performance")
    print(f"- Check trajectory plots to see learning progress over time")
elif EVALUATION_MODE == "COMPARE_OBJECTIVES":
    print(f"- Compared {len(selected_checkpoints)} checkpoints with different optimization directions")
    print(f"- Minimization policies optimize for lower target values")
    print(f"- Maximization policies optimize for higher target values")
    print(f"- Check plots to see which direction performs better")
elif EVALUATION_MODE == "PHASE2_ACTIVE_LEARNING":
    print(f"- Evaluated GRPO policy with active learning surrogate")
    print(f"- Compared GRPO+Active vs Random+Active vs Oracle+Active")
    print(f"- Oracle baseline shows theoretical best performance")
    print(f"- Check if GRPO guidance improves structure learning (F1 scores)")

print(f"\nüí° Next steps:")
print(f"1. Review the trajectory plots for learning progress")
print(f"2. Check the summary plots for method comparison")
print(f"3. Read detailed results in {json_path.name}")
print(f"4. Consider running with FULL mode for more thorough evaluation")

## Summary

**What we've accomplished:**
1. ‚úÖ Loaded checkpoints with optimization direction metadata
2. ‚úÖ Ran evaluation with proper metric handling
3. ‚úÖ Generated visualizations that show min/max correctly
4. ‚úÖ Exported comprehensive results

**Key improvements over original notebook:**
- Checkpoint-first approach - no need to run training
- Auto-detects optimization direction from metadata
- Correctly displays minimization vs maximization results
- All cells are independent and can be re-run
- No silent failures - explicit errors throughout

**Understanding the results:**
- For MINIMIZE checkpoints: Lower target values are better
- For MAXIMIZE checkpoints: Higher target values are better
- The plots show "(‚Üì better)" or "(‚Üë better)" to clarify
- PARENT_SCALE baseline uses minimization by default