# GRPO Training Pipeline - Modular Version

**Purpose**: Train GRPO policies with explicit optimization direction support and modular execution.

**Key Features**:
- ✅ **No silent failures** - explicit errors when things go wrong
- ✅ **Independent cells** - run any cell with checkpoint support
- ✅ **Optimization direction** - support both MINIMIZE and MAXIMIZE
- ✅ **Clean checkpoint management** - save/load with full metadata
- ✅ **Consistent with PARENT_SCALE** - correct handling of minimization objective

**Workflow**:
1. Configure training parameters and optimization direction
2. Load existing checkpoint OR initialize new training
3. Generate or load training SCMs
4. Train with appropriate reward signals
5. Save checkpoint with complete metadata
6. Quick validation with correct metrics

## 1. Setup and Configuration

In [1]:
#!/usr/bin/env python3
"""
Cell 1: Import base components and configure environment

This cell can be run independently at any time.
"""

import sys
import os
from pathlib import Path
import logging
import json
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple

# Add project root to path
project_root = Path.cwd().parent if Path.cwd().name == "experiments" else Path.cwd()
sys.path.insert(0, str(project_root))

# Import base components
from scripts.notebooks.base_components import (
    NotebookError, CheckpointManager, SCMGenerator, 
    OptimizationConfig, CheckpointMetadata, validate_environment,
    format_results_summary
)
from scripts.notebooks.config_templates import (
    create_training_config, TRAINING_MODES, OBJECTIVE_CONFIGS,
    get_quick_minimize_config, get_quick_maximize_config
)

# Core imports
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as onp
import pyrsistent as pyr

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
)
logger = logging.getLogger(__name__)

# Validate environment
try:
    env_info = validate_environment()
    print("✅ Environment Setup Complete")
    print(f"📁 Project root: {project_root}")
    print(f"🔧 JAX devices: {env_info['jax_devices']}")
    print(f"🔧 JAX backend: {env_info['jax_backend']}")
    print(f"📅 Date: {env_info['timestamp']}")
except Exception as e:
    raise NotebookError(f"Environment validation failed: {e}")

# Initialize checkpoint manager
checkpoint_dir = project_root / "checkpoints" / "grpo_training"
checkpoint_manager = CheckpointManager(checkpoint_dir)
print(f"\n📁 Checkpoint directory: {checkpoint_dir}")

INFO:2025-07-23 10:07:34,087:jax._src.xla_bridge:749: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
[2025-07-23 10:07:34,087][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


✅ Environment Setup Complete
📁 Project root: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt
🔧 JAX devices: [CpuDevice(id=0)]
🔧 JAX backend: cpu
📅 Date: 2025-07-23 10:07:34

📁 Checkpoint directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training


## 2. Training Configuration

In [2]:
"""
Cell 2: Configure training parameters

Choose training mode and optimization objective.
This cell defines what kind of training will be performed.
"""

# SELECT TRAINING CONFIGURATION
TRAINING_MODE = "QUICK"  # Options: "QUICK", "STANDARD", "FULL", "PRECISION"
OPTIMIZATION_OBJECTIVE = "TARGET_MINIMIZE"  # Options: "TARGET_MINIMIZE", "TARGET_MAXIMIZE", "STRUCTURE_FOCUSED", "BALANCED"
RANDOM_SEED = 42

# Optional: Load from existing checkpoint (set to None to train from scratch)
RESUME_FROM_CHECKPOINT = None  # Or path like "checkpoints/grpo_training/grpo_minimize_20250722_120000"

# Create configuration
try:
    config = create_training_config(
        mode=TRAINING_MODE,
        objective=OPTIMIZATION_OBJECTIVE,
        random_seed=RANDOM_SEED,
        checkpoint_dir=str(checkpoint_dir)
    )
    
    # Create optimization config
    optimization_config = OptimizationConfig(
        direction=config.optimization.direction,
        target_baseline=config.optimization.target_baseline
    )
    
    print("🎯 Training Configuration")
    print("=" * 50)
    print(f"Mode: {TRAINING_MODE} - {TRAINING_MODES[TRAINING_MODE].description}")
    print(f"Objective: {OPTIMIZATION_OBJECTIVE} - {OBJECTIVE_CONFIGS[OPTIMIZATION_OBJECTIVE].description}")
    print(f"Optimization: {optimization_config.direction}")
    print(f"Random seed: {RANDOM_SEED}")
    print(f"\nTraining parameters:")
    print(f"  Total episodes: {config.training.n_episodes}")
    print(f"  Episode length: {config.training.episode_length}")
    print(f"  Learning rate: {config.training.learning_rate}")
    print(f"  Number of SCMs: {config.experiment.scm_generation.num_scms}")  # Fixed path
    print(f"\nReward weights:")
    for component, weight in config.training.reward_weights.items():
        print(f"  {component}: {weight}")
    
    if RESUME_FROM_CHECKPOINT:
        print(f"\n🔄 Will resume from: {RESUME_FROM_CHECKPOINT}")
    else:
        print(f"\n🚀 Will train from scratch")
        
except Exception as e:
    raise NotebookError(f"Failed to create configuration: {e}")

🎯 Training Configuration
Mode: QUICK - Fast testing and development
Objective: TARGET_MINIMIZE - Minimize target variable (like PARENT_SCALE)
Optimization: MINIMIZE
Random seed: 42

Training parameters:
  Total episodes: 96
  Episode length: 8
  Learning rate: 0.001
  Number of SCMs: 32

Reward weights:
  optimization: 0.8
  discovery: 0.1
  efficiency: 0.1

🚀 Will train from scratch


## 3. Initialize or Load Training State

In [3]:
"""
Cell 3: Initialize new training or load from checkpoint

This cell handles checkpoint loading and training initialization.
Can be run independently to load a specific checkpoint.
"""

# Import trainer and related modules
try:
    from causal_bayes_opt.training.enriched_trainer import EnrichedGRPOTrainer
    from causal_bayes_opt.surrogate.bootstrap import create_bootstrap_surrogate_features
    from causal_bayes_opt.surrogate.phase_manager import PhaseConfig, BootstrapConfig
except ImportError as e:
    raise NotebookError(f"Failed to import training modules: {e}")

# Production phase configuration
PRODUCTION_PHASE_CONFIG = PhaseConfig(
    bootstrap_steps=100,
    transition_steps=50,
    exploration_noise_start=0.5,
    exploration_noise_end=0.1,
    transition_schedule="linear"
)

PRODUCTION_BOOTSTRAP_CONFIG = BootstrapConfig(
    structure_encoding_dim=128,
    use_graph_distance=True,
    use_structural_priors=True,
    noise_schedule="exponential_decay",
    min_noise_factor=0.1
)

# Update config with production settings
config.surrogate_integration = {
    'enabled': True,
    'phase_config': {
        'bootstrap_steps': PRODUCTION_PHASE_CONFIG.bootstrap_steps,
        'transition_steps': PRODUCTION_PHASE_CONFIG.transition_steps,
        'exploration_noise_start': PRODUCTION_PHASE_CONFIG.exploration_noise_start,
        'exploration_noise_end': PRODUCTION_PHASE_CONFIG.exploration_noise_end,
        'transition_schedule': PRODUCTION_PHASE_CONFIG.transition_schedule
    },
    'bootstrap_config': {
        'structure_encoding_dim': PRODUCTION_BOOTSTRAP_CONFIG.structure_encoding_dim,
        'use_graph_distance': PRODUCTION_BOOTSTRAP_CONFIG.use_graph_distance,
        'use_structural_priors': PRODUCTION_BOOTSTRAP_CONFIG.use_structural_priors,
        'noise_schedule': PRODUCTION_BOOTSTRAP_CONFIG.noise_schedule,
        'min_noise_factor': PRODUCTION_BOOTSTRAP_CONFIG.min_noise_factor
    }
}

# Initialize trainer state
trainer = None
starting_episode = 0
checkpoint_metadata = None

if RESUME_FROM_CHECKPOINT:
    print(f"📥 Loading checkpoint: {RESUME_FROM_CHECKPOINT}")
    try:
        checkpoint_data, checkpoint_metadata = checkpoint_manager.load_checkpoint(RESUME_FROM_CHECKPOINT)
        
        # Validate compatibility
        if checkpoint_metadata.optimization_config.direction != optimization_config.direction:
            raise NotebookError(
                f"Optimization direction mismatch! "
                f"Checkpoint: {checkpoint_metadata.optimization_config.direction}, "
                f"Config: {optimization_config.direction}"
            )
        
        print(f"✅ Loaded checkpoint: {checkpoint_metadata.name}")
        print(f"  Training mode: {checkpoint_metadata.training_config.get('mode', 'unknown')}")
        print(f"  Optimization: {checkpoint_metadata.optimization_config.direction}")
        print(f"  Timestamp: {checkpoint_metadata.timestamp}")
        
        # TODO: Actually load trainer state from checkpoint_data
        # For now, we'll initialize a new trainer
        trainer = EnrichedGRPOTrainer(config=config)
        starting_episode = checkpoint_metadata.training_results.get('episodes_completed', 0)
        
        print(f"  Starting from episode: {starting_episode}")
        
    except Exception as e:
        raise NotebookError(f"Failed to load checkpoint: {e}")
else:
    print("🚀 Initializing new training")
    try:
        # Add optimization direction to config
        config.optimization = optimization_config.__dict__
        
        # Initialize trainer
        trainer = EnrichedGRPOTrainer(config=config)
        
        print("✅ Trainer initialized successfully")
        print(f"  Optimization: {optimization_config.direction}")
        print(f"  Surrogate integration: {'Enabled' if config.surrogate_integration.enabled else 'Disabled'}")
        
    except Exception as e:
        raise NotebookError(f"Failed to initialize trainer: {e}")

print("\n✅ Training state ready")

[2025-07-23 10:07:36,369][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-23 10:07:36,369][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 3 vars, 2 edges, target=X1
[2025-07-23 10:07:36,383][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X2'
[2025-07-23 10:07:36,383][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated chain SCM: 3 vars, 2 edges, target=X2
[2025-07-23 10:07:36,385][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-23 10:07:36,385][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated collider SCM: 3 vars, 2 edges, target=X1


🚀 Initializing new training


[2025-07-23 10:07:36,456][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-23 10:07:36,456][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated mixed SCM: 3 vars, 2 edges, target=X1
[2025-07-23 10:07:36,471][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X2'
[2025-07-23 10:07:36,471][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 4 vars, 3 edges, target=X2
[2025-07-23 10:07:36,473][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X3'
[2025-07-23 10:07:36,474][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated chain SCM: 4 vars, 3 edges, target=X3
[2025-07-23 10:07:36,475][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X2'
[2025-07-23 10:07:36,475][causal_bayes_opt.experiments.variable_scm_factor

✅ Trainer initialized successfully
  Optimization: MINIMIZE
  Surrogate integration: Enabled

✅ Training state ready


## 4. Generate or Load Training SCMs

In [4]:
"""
Cell 4: Generate training SCMs

This cell generates the SCMs for training.
Can be run independently to regenerate SCMs.
"""

print("🔬 Generating Training SCMs")
print("=" * 50)

# Initialize SCM generator
scm_generator = SCMGenerator()

# Generate SCMs
try:
    training_scms, scm_metadata = scm_generator.generate_balanced_scms(
        num_scms=config.experiment.scm_generation.num_scms,
        variable_range=tuple(config.experiment.scm_generation.variable_range),
        structure_types=config.experiment.scm_generation.structure_types,
        seed=RANDOM_SEED
    )
    
    print(f"\n✅ Generated {len(training_scms)} training SCMs")
    
    # Analyze distribution
    distribution = scm_generator._summarize_distribution(scm_metadata)
    print(f"\n📊 SCM Distribution:")
    print(f"  Structure types: {distribution['structure_types']}")
    print(f"  Variable counts: {distribution['variable_counts']}")
    
    # Calculate total episodes
    episodes_per_scm = config.training.n_episodes // len(training_scms)
    total_episodes = len(training_scms) * episodes_per_scm
    print(f"\n📈 Training schedule:")
    print(f"  Episodes per SCM: {episodes_per_scm}")
    print(f"  Total episodes: {total_episodes}")
    
    # Store in config for trainer
    config.training.n_episodes = total_episodes
    
except Exception as e:
    raise NotebookError(f"Failed to generate SCMs: {e}")

# Optional: Save SCMs for reproducibility
scm_save_path = checkpoint_dir / "training_scms" / f"scms_{TRAINING_MODE}_{RANDOM_SEED}.json"
scm_save_path.parent.mkdir(parents=True, exist_ok=True)

# Save metadata only (SCMs are too complex to serialize directly)
with open(scm_save_path, 'w') as f:
    json.dump({
        'metadata': scm_metadata,
        'config': {
            'num_scms': len(training_scms),
            'seed': RANDOM_SEED,
            'variable_range': list(config.experiment.scm_generation.variable_range),
            'structure_types': list(config.experiment.scm_generation.structure_types)
        }
    }, f, indent=2)

print(f"\n💾 Saved SCM metadata to: {scm_save_path}")

[2025-07-23 10:07:41,778][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-23 10:07:41,779][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 3 vars, 2 edges, target=X1
[2025-07-23 10:07:41,780][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 3 variables, 2 edges, target='X1'
[2025-07-23 10:07:41,781][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 3 vars, 2 edges, target=X1
[2025-07-23 10:07:41,782][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X2'
[2025-07-23 10:07:41,782][causal_bayes_opt.experiments.variable_scm_factory][INFO] - Generated fork SCM: 4 vars, 3 edges, target=X2
[2025-07-23 10:07:41,784][causal_bayes_opt.experiments.test_scms][INFO] - Created linear SCM with 4 variables, 3 edges, target='X2'
[2025-07-23 10:07:41,784][causal_bayes_opt.experiments.variable_scm_factory]

🔬 Generating Training SCMs

✅ Generated 32 training SCMs

📊 SCM Distribution:
  Structure types: {'fork': 8, 'chain': 8, 'collider': 8, 'mixed': 8}
  Variable counts: {3: 8, 4: 8, 5: 8, 6: 8}

📈 Training schedule:
  Episodes per SCM: 3
  Total episodes: 96

💾 Saved SCM metadata to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/training_scms/scms_QUICK_42.json


## 5. Train GRPO Policy

In [5]:
"""
Cell 5: Train GRPO policy with correct optimization direction

This cell performs the actual training.
Progress is saved periodically for resumption.
"""

import time

print("🚀 Starting GRPO Policy Training")
print("=" * 70)
print(f"🔧 Training mode: {TRAINING_MODE}")
print(f"🎯 Optimization: {optimization_config.direction}")
print(f"📊 Total episodes: {total_episodes}")
print(f"⚖️ Reward weights: {config.training.reward_weights}")
print(f"✅ Surrogate integration: ACTIVE")
print("=" * 70)

# Configure trainer for optimization direction
if hasattr(trainer, 'optimization_config'):
    trainer.optimization_config = optimization_config
else:
    # Inject optimization config
    trainer.config.optimization = optimization_config.__dict__

# Set training SCMs
if hasattr(trainer, 'set_training_scms'):
    trainer.set_training_scms(training_scms)

# Training loop with explicit error handling
training_start_time = time.time()
training_metrics = {}

try:
    print("\n🏃 Starting Training Loop...")
    
    # Run training
    training_metrics = trainer.train()
    
    training_end_time = time.time()
    training_duration = training_end_time - training_start_time
    
    print(f"\n✅ Training completed!")
    print(f"⏱️ Training time: {training_duration/60:.1f} minutes")
    
    # Extract performance metrics
    performance = training_metrics.get('performance', {})
    final_reward = performance.get('final_reward', 0.0)
    
    # Convert reward to actual target value if minimizing
    if optimization_config.is_minimizing:
        final_target_value = optimization_config.convert_from_maximization(final_reward)
        print(f"\n📊 Final Results:")
        print(f"  Final reward (internal): {final_reward:.4f}")
        print(f"  Final target value: {optimization_config.format_improvement(final_target_value)}")
    else:
        print(f"\n📊 Final Results:")
        print(f"  Final target value: {optimization_config.format_improvement(final_reward)}")
    
    # Store optimization direction in metrics
    training_metrics['optimization_direction'] = optimization_config.direction
    training_metrics['duration_minutes'] = training_duration / 60
    
except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    training_metrics['interrupted'] = True
    training_metrics['duration_minutes'] = (time.time() - training_start_time) / 60
    
except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    import traceback
    traceback.print_exc()
    raise NotebookError(f"Training failed: {e}")

# Get checkpoint path
checkpoint_path = training_metrics.get('checkpoint_path', checkpoint_dir / "grpo_final")
print(f"\n📁 Checkpoint saved to: {checkpoint_path}")

[2025-07-23 10:07:41,846][causal_bayes_opt.training.enriched_trainer][INFO] - Starting enriched GRPO training


🚀 Starting GRPO Policy Training
🔧 Training mode: QUICK
🎯 Optimization: MINIMIZE
📊 Total episodes: 96
⚖️ Reward weights: {'optimization': 0.8, 'discovery': 0.1, 'efficiency': 0.1}
✅ Surrogate integration: ACTIVE

🏃 Starting Training Loop...


[2025-07-23 10:08:05,610][causal_bayes_opt.training.enriched_trainer][INFO] - ✅ Parameters changed - norm delta: 0.000023709518
[2025-07-23 10:08:05,755][causal_bayes_opt.training.enriched_trainer][INFO] - Episode 0: reward=0.617, intervention_rate=1.000, scm=fork_3var, F1=0.000, P(Parents)=0.000, SHD=2
[2025-07-23 10:08:05,861][causal_bayes_opt.training.enriched_trainer][INFO] - 🔍 Per-Variable Encoding - Policy Output (call 10):
[2025-07-23 10:08:05,862][causal_bayes_opt.training.enriched_trainer][INFO] -   Variable logits: [-7.99504128e-02 -1.00000000e+09]
[2025-07-23 10:08:05,862][causal_bayes_opt.training.enriched_trainer][INFO] -   Variables: ['X2', 'X1', 'X0'], Target: X1
[2025-07-23 10:08:05,863][causal_bayes_opt.training.enriched_trainer][INFO] -   Target variable 'X1' at index 1, logit: -1000000000.0
[2025-07-23 10:08:05,883][causal_bayes_opt.training.enriched_trainer][INFO] -   Variable selection:
[2025-07-23 10:08:05,883][causal_bayes_opt.training.enriched_trainer][INFO] -  


✅ Training completed!
⏱️ Training time: 5.2 minutes

📊 Final Results:
  Final reward (internal): 0.8106
  Final target value: -0.8106 (↓ better)

📁 Checkpoint saved to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/enriched_grpo_final


## 6. Save Checkpoint with Metadata

In [ ]:
"""
Cell 6: Save checkpoint with complete metadata

This cell saves the trained model with all necessary metadata
for later evaluation and comparison.
"""
from omegaconf import OmegaConf
print("💾 Saving Checkpoint with Metadata")
print("=" * 50)

# Generate checkpoint name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
opt_direction = optimization_config.direction.lower()
checkpoint_name = f"grpo_{TRAINING_MODE.lower()}_{opt_direction}_{timestamp}"

# Extract actual model parameters from trainer
try:
    checkpoint_data = {
        'policy_params': trainer.policy_params,
        'policy_config': {
            'architecture': OmegaConf.to_container(config.training.architecture),
            'state_config': OmegaConf.to_container(config.training.state_config),
            'grpo_config': OmegaConf.to_container(config.training.grpo_config)
        },
        'training_metrics': training_metrics,
        'optimization_config': optimization_config.__dict__
    }
    print(f"✅ Extracted model parameters from trainer")
except Exception as e:
    print(f"⚠️ Warning: Could not extract model parameters: {e}")
    checkpoint_data = None

# Create checkpoint metadata
metadata = CheckpointMetadata(
    name=checkpoint_name,
    path=checkpoint_dir / checkpoint_name,
    optimization_config=optimization_config,
    training_config={
        'mode': TRAINING_MODE,
        'objective': OPTIMIZATION_OBJECTIVE,
        'config': OmegaConf.to_container(config.training),
        'reward_weights': dict(config.training.reward_weights),
        'scm_config': OmegaConf.to_container(config.experiment.scm_generation)
    },
    training_results={
        'duration_minutes': training_metrics.get('duration_minutes', 0),
        'final_performance': performance if 'performance' in locals() else {},
        'episodes_completed': total_episodes if not training_metrics.get('interrupted', False) else starting_episode,
        'success': not training_metrics.get('interrupted', False)
    },
    timestamp=timestamp
)

# Save checkpoint
try:
    final_checkpoint_path = checkpoint_manager.save_checkpoint(
        checkpoint_data=checkpoint_data,
        metadata=metadata,
        checkpoint_name=checkpoint_name
    )
    
    print(f"\n✅ Checkpoint saved successfully!")
    print(f"📁 Location: {final_checkpoint_path}")
    print(f"📋 Name: {checkpoint_name}")
    if checkpoint_data is not None:
        print(f"📋 Includes: Model parameters, config, and metrics")
    else:
        print(f"⚠️ Warning: Only metadata saved (no model parameters)")
    
    # Display summary
    print(f"\n📊 Training Summary:")
    print(f"  Mode: {TRAINING_MODE}")
    print(f"  Optimization: {optimization_config.direction}")
    print(f"  Duration: {metadata.training_results['duration_minutes']:.1f} minutes")
    print(f"  Episodes: {metadata.training_results['episodes_completed']}")
    print(f"  Success: {'Yes' if metadata.training_results['success'] else 'No (interrupted)'}")
    
    if 'final_reward' in performance:
        final_value = performance['final_reward']
        if optimization_config.is_minimizing:
            final_value = optimization_config.convert_from_maximization(final_value)
        print(f"  Final target value: {optimization_config.format_improvement(final_value)}")
    
except Exception as e:
    raise NotebookError(f"Failed to save checkpoint: {e}")

# Store checkpoint name for easy access
TRAINED_CHECKPOINT = checkpoint_name
print(f"\n💡 Checkpoint name stored in variable: TRAINED_CHECKPOINT")
print(f"   Use this in evaluation notebook: '{TRAINED_CHECKPOINT}'")

## 7. Quick Validation

In [7]:
"""
Cell 7: Quick validation of trained policy

Test the trained policy on a few SCMs to ensure it's working correctly.
This cell can be run independently with a loaded checkpoint.
"""

print("🧪 Quick Policy Validation")
print("=" * 50)

# Ensure we have a trainer
if trainer is None:
    raise NotebookError("No trainer available. Run training or load checkpoint first.")

# Test on a few SCMs
test_key = random.PRNGKey(999)
n_test = min(3, len(training_scms))

print(f"\nTesting on {n_test} SCMs:")
print("-" * 40)

for i in range(n_test):
    test_scm = training_scms[i]
    test_meta = scm_metadata[i]
    
    try:
        # Create test state
        test_state = trainer._create_tensor_backed_state(test_scm, 0, 0.0)
        enriched_input = trainer.state_converter.convert_state_to_enriched_input(test_state)
        
        # Get policy output
        test_key, subkey = random.split(test_key)
        target_idx = test_meta['variables'].index(test_meta['target'])
        
        policy_output = trainer.policy_fn.apply(
            trainer.policy_params, subkey, enriched_input, target_idx, False
        )
        
        # Convert to action
        action = trainer._policy_output_to_action(policy_output, test_meta['variables'], test_meta['target'])
        selected_var_idx, intervention_value = action
        
        print(f"\n{i+1}. {test_meta['structure_type']} ({test_meta['n_variables']}v):")
        print(f"   Target: {test_meta['target']}")
        print(f"   Action: {test_meta['variables'][selected_var_idx]}={intervention_value:.3f}")
        print(f"   Magnitude: {abs(intervention_value):.3f}")
        
        # Check if action makes sense for optimization direction
        if optimization_config.is_minimizing:
            print(f"   Direction: {'Decrease' if intervention_value < 0 else 'Increase'} (minimizing target)")
        else:
            print(f"   Direction: {'Increase' if intervention_value > 0 else 'Decrease'} (maximizing target)")
            
    except Exception as e:
        print(f"\n{i+1}. Failed to test: {e}")
        logger.warning(f"Validation failed for SCM {i}: {e}")

print("\n✅ Validation complete!")
print(f"\n🎉 Training Complete!")
print(f"📋 Checkpoint: '{TRAINED_CHECKPOINT}'")
print(f"🔧 Optimization: {optimization_config.direction}")
print(f"\n💡 Next steps:")
print(f"   1. Use evaluation notebook to compare against baselines")
print(f"   2. Checkpoint name: '{TRAINED_CHECKPOINT}'")

[2025-07-23 10:12:52,167][causal_bayes_opt.training.enriched_trainer][INFO] - 🔍 Per-Variable Encoding - Policy Output (call 770):
[2025-07-23 10:12:52,168][causal_bayes_opt.training.enriched_trainer][INFO] -   Variable logits: [ 4.25992218e-03 -1.00000000e+09]
[2025-07-23 10:12:52,168][causal_bayes_opt.training.enriched_trainer][INFO] -   Variables: ['X2', 'X1', 'X0'], Target: X1
[2025-07-23 10:12:52,168][causal_bayes_opt.training.enriched_trainer][INFO] -   Target variable 'X1' at index 1, logit: -1000000000.0
[2025-07-23 10:12:52,171][causal_bayes_opt.training.enriched_trainer][INFO] -   Variable selection:
[2025-07-23 10:12:52,171][causal_bayes_opt.training.enriched_trainer][INFO] -     Temperature: 1.00
[2025-07-23 10:12:52,171][causal_bayes_opt.training.enriched_trainer][INFO] -     Probabilities: [1. 0.]
[2025-07-23 10:12:52,172][causal_bayes_opt.training.enriched_trainer][INFO] -     Selected: X2 (index 0)
[2025-07-23 10:12:52,172][causal_bayes_opt.training.enriched_trainer][INF

🧪 Quick Policy Validation

Testing on 3 SCMs:
----------------------------------------

1. fork (3v):
   Target: X1
   Action: X2=-11.259
   Magnitude: 11.259
   Direction: Decrease (minimizing target)

2. fork (3v):
   Target: X1
   Action: X2=-7.179
   Magnitude: 7.179
   Direction: Decrease (minimizing target)

3. fork (4v):
   Target: X2
   Action: X1=-10.808
   Magnitude: 10.808
   Direction: Decrease (minimizing target)

✅ Validation complete!

🎉 Training Complete!
📋 Checkpoint: 'grpo_quick_minimize_20250723_101252'
🔧 Optimization: MINIMIZE

💡 Next steps:
   1. Use evaluation notebook to compare against baselines
   2. Checkpoint name: 'grpo_quick_minimize_20250723_101252'


## Summary and Next Steps

**What we've accomplished:**
1. ✅ Configured training with explicit optimization direction
2. ✅ Trained GRPO policy with correct reward signals
3. ✅ Saved checkpoint with complete metadata
4. ✅ Validated policy behavior

**Key improvements over original notebook:**
- No silent failures - explicit errors when things go wrong
- Independent cells - can resume training or load checkpoints
- Optimization direction support - correctly handles minimization like PARENT_SCALE
- Clean checkpoint management - all metadata preserved

**Next steps:**
1. Use `grpo_evaluation_modular.ipynb` to evaluate this checkpoint
2. Compare minimization vs maximization policies
3. Analyze how optimization direction affects performance