# GRPO Training Pipeline - Modular Version

**Purpose**: Train GRPO policies with explicit optimization direction support and modular execution.

**Key Features**:
- ‚úÖ **No silent failures** - explicit errors when things go wrong
- ‚úÖ **Independent cells** - run any cell with checkpoint support
- ‚úÖ **Optimization direction** - support both MINIMIZE and MAXIMIZE
- ‚úÖ **Clean checkpoint management** - save/load with full metadata
- ‚úÖ **Consistent with PARENT_SCALE** - correct handling of minimization objective

**Workflow**:
1. Configure training parameters and optimization direction
2. Load existing checkpoint OR initialize new training
3. Generate or load training SCMs
4. Train with appropriate reward signals
5. Save checkpoint with complete metadata
6. Quick validation with correct metrics

## 1. Setup and Configuration

In [1]:
#!/usr/bin/env python3
"""
Cell 1: Import base components and configure environment

This cell can be run independently at any time.
"""

import sys
import os
from pathlib import Path
import logging
import json
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple

# Add project root to path
project_root = Path.cwd().parent if Path.cwd().name == "experiments" else Path.cwd()
sys.path.insert(0, str(project_root))

# Import base components
from scripts.notebooks.base_components import (
    NotebookError, CheckpointManager, SCMGenerator, 
    OptimizationConfig, CheckpointMetadata, validate_environment,
    format_results_summary
)
from scripts.notebooks.config_templates import (
    create_training_config, TRAINING_MODES, OBJECTIVE_CONFIGS,
    get_quick_minimize_config, get_quick_maximize_config
)

# Import fixed configuration
from src.causal_bayes_opt.training.grpo_fixed_config import (
    create_grpo_config_with_fixes,
    create_bootstrap_phase_config,
    create_bootstrap_config,
    validate_fixed_config
)

# Core imports
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as onp
import pyrsistent as pyr

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
)
logger = logging.getLogger(__name__)

# Validate environment
try:
    env_info = validate_environment()
    print("‚úÖ Environment Setup Complete")
    print(f"üìÅ Project root: {project_root}")
    print(f"üîß JAX devices: {env_info['jax_devices']}")
    print(f"üîß JAX backend: {env_info['jax_backend']}")
    print(f"üìÖ Date: {env_info['timestamp']}")
    print("\nüîß Using GRPO configuration with collapse prevention fixes:")
    print("  - Global standardization for state enrichment")
    print("  - Increased entropy coefficient (0.1)")
    print("  - Bootstrap surrogate with structural priors")
    print("  - Adaptive reward system")
except Exception as e:
    raise NotebookError(f"Environment validation failed: {e}")

# Initialize checkpoint manager
checkpoint_dir = project_root / "checkpoints" / "grpo_training"
checkpoint_manager = CheckpointManager(checkpoint_dir)
print(f"\nüìÅ Checkpoint directory: {checkpoint_dir}")

‚úÖ Environment Setup Complete
üìÅ Project root: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt
üîß JAX devices: [CpuDevice(id=0)]
üîß JAX backend: cpu
üìÖ Date: 2025-07-28 15:47:22

üîß Using GRPO configuration with collapse prevention fixes:
  - Global standardization for state enrichment
  - Increased entropy coefficient (0.1)
  - Bootstrap surrogate with structural priors
  - Adaptive reward system

üìÅ Checkpoint directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training


## 2. Training Configuration

In [2]:
"""
Cell 2: Configure training parameters

Choose training mode and optimization objective.
This cell defines what kind of training will be performed.
"""

# SELECT TRAINING CONFIGURATION
TRAINING_MODE = "FULL"  # Options: "QUICK", "STANDARD", "FULL", "PRECISION"
OPTIMIZATION_OBJECTIVE = "TARGET_MINIMIZE"  # Options: "TARGET_MINIMIZE", "TARGET_MAXIMIZE", "STRUCTURE_FOCUSED", "BALANCED"
RANDOM_SEED = 42

# Optional: Load from existing checkpoint (set to None to train from scratch)
RESUME_FROM_CHECKPOINT = None  # Or path like "checkpoints/grpo_training/grpo_minimize_20250722_120000"

# Create configuration
try:
    config = create_training_config(
        mode=TRAINING_MODE,
        objective=OPTIMIZATION_OBJECTIVE,
        random_seed=RANDOM_SEED,
        checkpoint_dir=str(checkpoint_dir)
    )
    
    # Create optimization config
    optimization_config = OptimizationConfig(
        direction=config.optimization.direction,
        target_baseline=config.optimization.target_baseline
    )
    
    print("üéØ Training Configuration")
    print("=" * 50)
    print(f"Mode: {TRAINING_MODE} - {TRAINING_MODES[TRAINING_MODE].description}")
    print(f"Objective: {OPTIMIZATION_OBJECTIVE} - {OBJECTIVE_CONFIGS[OPTIMIZATION_OBJECTIVE].description}")
    print(f"Optimization: {optimization_config.direction}")
    print(f"Random seed: {RANDOM_SEED}")
    print(f"\nTraining parameters:")
    print(f"  Total episodes: {config.training.n_episodes}")
    print(f"  Episode length: {config.training.episode_length}")
    print(f"  Learning rate: {config.training.learning_rate}")
    print(f"  Number of SCMs: {config.experiment.scm_generation.num_scms}")  # Fixed path
    print(f"\nReward weights:")
    for component, weight in config.training.reward_weights.items():
        print(f"  {component}: {weight}")
    
    if RESUME_FROM_CHECKPOINT:
        print(f"\nüîÑ Will resume from: {RESUME_FROM_CHECKPOINT}")
    else:
        print(f"\nüöÄ Will train from scratch")
        
except Exception as e:
    raise NotebookError(f"Failed to create configuration: {e}")

üéØ Training Configuration
Mode: FULL - Production-quality training
Objective: TARGET_MINIMIZE - Minimize target variable (like PARENT_SCALE)
Optimization: MINIMIZE
Random seed: 42

Training parameters:
  Total episodes: 512
  Episode length: 12
  Learning rate: 0.001
  Number of SCMs: 64

Reward weights:
  optimization: 0.8
  discovery: 0.1
  efficiency: 0.1

üöÄ Will train from scratch


In [3]:
"""
Cell 2.5: Create GRPO configuration with collapse prevention fixes

This cell creates the enhanced GRPO configuration that prevents posterior collapse.
"""

# Create GRPO configuration with fixes
USE_COLLAPSE_FIXES = True  # Set to False to use original configuration

if USE_COLLAPSE_FIXES:
    print("üîß Creating GRPO configuration with collapse prevention fixes")
    
    # Create fixed configuration
    fixed_grpo_config = create_grpo_config_with_fixes(
        max_training_steps=50000,
        batch_size=64,
        group_size=64,
        use_bootstrap=True,
        use_adaptive_rewards=True,
        entropy_coefficient=0.1  # Increased from default 0.01
    )
    
    # Create phase configurations for bootstrap
    phase_config = create_bootstrap_phase_config(
        bootstrap_steps=100,
        transition_steps=50
    )
    
    bootstrap_config = create_bootstrap_config()
    
    # Validate configuration
    try:
        validate_fixed_config(fixed_grpo_config)
        print("‚úÖ Configuration validation passed")
    except ValueError as e:
        print(f"‚ö†Ô∏è Configuration warning: {e}")
    
    # Update the training config with fixed parameters
    if 'config' in locals():
        # Update entropy coefficient
        config.training.grpo_config.entropy_coeff = 0.1
        
        # Add state enrichment configuration
        config.training.state_enrichment = {
            'standardize_values': True,
            'use_global_standardization': True,
            'channels': ['values', 'interventions', 'target', 'parent_probs', 'recency']
        }
        
        # Add adaptive rewards configuration
        config.training.adaptive_rewards = {
            'enabled': True,
            'structure_threshold': 0.95,
            'adaptation_rate': 0.1,
            'initial_weights': {'discovery': 0.7, 'optimization': 0.3},
            'final_weights': {'discovery': 0.05, 'optimization': 0.95},
            'update_frequency': 10
        }
        
        print("\nüìä Key configuration parameters:")
        print(f"  Entropy coefficient: {config.training.grpo_config.entropy_coeff}")
        print(f"  Standardization: Global")
        print(f"  Bootstrap surrogate: Enabled")
        print(f"  Adaptive rewards: Enabled")
        print(f"  Structure threshold: 0.95")
    
else:
    print("‚ö†Ô∏è Using original configuration (may experience collapse)")
    print("  Set USE_COLLAPSE_FIXES = True to use fixed configuration")

üîß Creating GRPO configuration with collapse prevention fixes
‚úÖ Configuration validation passed

üìä Key configuration parameters:
  Entropy coefficient: 0.1
  Standardization: Global
  Bootstrap surrogate: Enabled
  Adaptive rewards: Enabled
  Structure threshold: 0.95


In [4]:
"""
Cell 2.6: Configure Early Stopping (NEW - Prevents Over-training)

This cell enables early stopping to prevent over-training on solved SCMs
and maintain a balanced exploration/exploitation distribution.
"""

print("\nüõë Configuring Early Stopping...")

# Enable early stopping to prevent over-training on solved SCMs
early_stopping_config = {
    'early_stopping_enabled': True,
    'convergence_accuracy_threshold': 0.95,  # Consider converged at 95% accuracy
    'convergence_patience': 5,                # Reduced from 10 - wait 5 episodes before declaring convergence
    'min_episodes_per_scm': 5,               # Reduced from 10 - train at least 5 episodes per SCM
    'max_episodes_per_scm': 30,              # Reduced from 50 - stop after 30 episodes even if not converged
    'reward_variance_threshold': 0.05        # Tighter threshold - was 0.1
}

# Update the configuration
config.training.update(early_stopping_config)

print("‚úì Early stopping enabled with improved parameters")
print(f"  - Convergence threshold: {early_stopping_config['convergence_accuracy_threshold']}")
print(f"  - Patience: {early_stopping_config['convergence_patience']} episodes")
print(f"  - Min episodes per SCM: {early_stopping_config['min_episodes_per_scm']}")
print(f"  - Max episodes per SCM: {early_stopping_config['max_episodes_per_scm']}")
print(f"  - Reward variance threshold: {early_stopping_config['reward_variance_threshold']}")

# Also ensure the fixed entropy coefficient is maintained
if hasattr(config.training, 'grpo_config'):
    current_entropy = config.training.grpo_config.get('entropy_coeff', 0.01)
    if current_entropy < 0.1:
        config.training.grpo_config['entropy_coeff'] = 0.1
        print(f"\n‚úì Maintained entropy coefficient: {current_entropy} ‚Üí {config.training.grpo_config['entropy_coeff']}")

# Ensure global standardization is enabled
if hasattr(config.training, 'state_config'):
    config.training.state_config['standardize_values'] = True
    config.training.state_config['use_global_standardization'] = True
    print("‚úì Global standardization verified")
else:
    # Add state config if missing
    config.training['state_config'] = {
        'standardize_values': True,
        'use_global_standardization': True
    }
    print("‚úì Added state config with global standardization")

print("\nüìä Expected behavior with early stopping (FIXED):")
print("  - Training will progress through SCMs dynamically")
print("  - Simple SCMs (3-var) may converge in 5-10 episodes")
print("  - Complex SCMs (5-6 var) may take 15-25 episodes")
print("  - Overall training distribution should be ~40-60% discovery")
print("  - Prevents posterior collapse from over-training")

print("\n‚ö†Ô∏è Note: Episode counting bug has been fixed!")
print("  Each SCM will now track episodes correctly")


üõë Configuring Early Stopping...
‚úì Early stopping enabled with improved parameters
  - Convergence threshold: 0.95
  - Patience: 5 episodes
  - Min episodes per SCM: 5
  - Max episodes per SCM: 30
  - Reward variance threshold: 0.05
‚úì Global standardization verified

üìä Expected behavior with early stopping (FIXED):
  - Training will progress through SCMs dynamically
  - Simple SCMs (3-var) may converge in 5-10 episodes
  - Complex SCMs (5-6 var) may take 15-25 episodes
  - Overall training distribution should be ~40-60% discovery
  - Prevents posterior collapse from over-training

‚ö†Ô∏è Note: Episode counting bug has been fixed!
  Each SCM will now track episodes correctly


## 3. Initialize or Load Training State

In [5]:
"""
Cell 3: Initialize new training or load from checkpoint

This cell handles checkpoint loading and training initialization.
Can be run independently to load a specific checkpoint.
"""

# Import trainer and related modules
try:
    from causal_bayes_opt.training.enriched_trainer import EnrichedGRPOTrainer
    from causal_bayes_opt.surrogate.bootstrap import create_bootstrap_surrogate_features
    from causal_bayes_opt.surrogate.phase_manager import PhaseConfig, BootstrapConfig
except ImportError as e:
    raise NotebookError(f"Failed to import training modules: {e}")

# Production phase configuration
PRODUCTION_PHASE_CONFIG = PhaseConfig(
    bootstrap_steps=100,
    transition_steps=50,
    exploration_noise_start=0.5,
    exploration_noise_end=0.1,
    transition_schedule="linear"
)

PRODUCTION_BOOTSTRAP_CONFIG = BootstrapConfig(
    structure_encoding_dim=128,
    use_graph_distance=True,
    use_structural_priors=True,
    noise_schedule="exponential_decay",
    min_noise_factor=0.1
)

# Update config with production settings
config.surrogate_integration = {
    'enabled': True,
    'phase_config': {
        'bootstrap_steps': PRODUCTION_PHASE_CONFIG.bootstrap_steps,
        'transition_steps': PRODUCTION_PHASE_CONFIG.transition_steps,
        'exploration_noise_start': PRODUCTION_PHASE_CONFIG.exploration_noise_start,
        'exploration_noise_end': PRODUCTION_PHASE_CONFIG.exploration_noise_end,
        'transition_schedule': PRODUCTION_PHASE_CONFIG.transition_schedule
    },
    'bootstrap_config': {
        'structure_encoding_dim': PRODUCTION_BOOTSTRAP_CONFIG.structure_encoding_dim,
        'use_graph_distance': PRODUCTION_BOOTSTRAP_CONFIG.use_graph_distance,
        'use_structural_priors': PRODUCTION_BOOTSTRAP_CONFIG.use_structural_priors,
        'noise_schedule': PRODUCTION_BOOTSTRAP_CONFIG.noise_schedule,
        'min_noise_factor': PRODUCTION_BOOTSTRAP_CONFIG.min_noise_factor
    }
}

# Initialize trainer state
trainer = None
starting_episode = 0
checkpoint_metadata = None

if RESUME_FROM_CHECKPOINT:
    print(f"üì• Loading checkpoint: {RESUME_FROM_CHECKPOINT}")
    try:
        checkpoint_data, checkpoint_metadata = checkpoint_manager.load_checkpoint(RESUME_FROM_CHECKPOINT)
        
        # Validate compatibility
        if checkpoint_metadata.optimization_config.direction != optimization_config.direction:
            raise NotebookError(
                f"Optimization direction mismatch! "
                f"Checkpoint: {checkpoint_metadata.optimization_config.direction}, "
                f"Config: {optimization_config.direction}"
            )
        
        print(f"‚úÖ Loaded checkpoint: {checkpoint_metadata.name}")
        print(f"  Training mode: {checkpoint_metadata.training_config.get('mode', 'unknown')}")
        print(f"  Optimization: {checkpoint_metadata.optimization_config.direction}")
        print(f"  Timestamp: {checkpoint_metadata.timestamp}")
        
        # TODO: Actually load trainer state from checkpoint_data
        # For now, we'll initialize a new trainer
        trainer = EnrichedGRPOTrainer(config=config)
        starting_episode = checkpoint_metadata.training_results.get('episodes_completed', 0)
        
        print(f"  Starting from episode: {starting_episode}")
        
    except Exception as e:
        raise NotebookError(f"Failed to load checkpoint: {e}")
else:
    print("üöÄ Initializing new training")
    try:
        # Add optimization direction to config
        config.optimization = optimization_config.__dict__
        
        # Initialize trainer
        trainer = EnrichedGRPOTrainer(config=config)
        
        print("‚úÖ Trainer initialized successfully")
        print(f"  Optimization: {optimization_config.direction}")
        print(f"  Surrogate integration: {'Enabled' if config.surrogate_integration.enabled else 'Disabled'}")
        
    except Exception as e:
        raise NotebookError(f"Failed to initialize trainer: {e}")

print("\n‚úÖ Training state ready")

üöÄ Initializing new training


INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated chain SCM: 3 vars, 2 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated collider SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated mixed SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 4 variables, 3 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 4

‚úÖ Trainer initialized successfully
  Optimization: MINIMIZE
  Surrogate integration: Enabled

‚úÖ Training state ready


## 4. Generate or Load Training SCMs

In [6]:
"""
Cell 4: Generate training SCMs

This cell generates the SCMs for training.
Can be run independently to regenerate SCMs.
"""

print("üî¨ Generating Training SCMs")
print("=" * 50)

# Initialize SCM generator
scm_generator = SCMGenerator()

# Generate SCMs
try:
    training_scms, scm_metadata = scm_generator.generate_balanced_scms(
        num_scms=config.experiment.scm_generation.num_scms,
        variable_range=tuple(config.experiment.scm_generation.variable_range),
        structure_types=config.experiment.scm_generation.structure_types,
        seed=RANDOM_SEED
    )
    
    print(f"\n‚úÖ Generated {len(training_scms)} training SCMs")
    
    # Analyze distribution
    distribution = scm_generator._summarize_distribution(scm_metadata)
    print(f"\nüìä SCM Distribution:")
    print(f"  Structure types: {distribution['structure_types']}")
    print(f"  Variable counts: {distribution['variable_counts']}")
    
    # Calculate total episodes
    episodes_per_scm = config.training.n_episodes // len(training_scms)
    total_episodes = len(training_scms) * episodes_per_scm
    print(f"\nüìà Training schedule:")
    print(f"  Episodes per SCM: {episodes_per_scm}")
    print(f"  Total episodes: {total_episodes}")
    
    # Store in config for trainer
    config.training.n_episodes = total_episodes
    
except Exception as e:
    raise NotebookError(f"Failed to generate SCMs: {e}")

# Optional: Save SCMs for reproducibility
scm_save_path = checkpoint_dir / "training_scms" / f"scms_{TRAINING_MODE}_{RANDOM_SEED}.json"
scm_save_path.parent.mkdir(parents=True, exist_ok=True)

# Save metadata only (SCMs are too complex to serialize directly)
with open(scm_save_path, 'w') as f:
    json.dump({
        'metadata': scm_metadata,
        'config': {
            'num_scms': len(training_scms),
            'seed': RANDOM_SEED,
            'variable_range': list(config.experiment.scm_generation.variable_range),
            'structure_types': list(config.experiment.scm_generation.structure_types)
        }
    }, f, indent=2)

print(f"\nüíæ Saved SCM metadata to: {scm_save_path}")

INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 4 variables, 3 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 4 vars,

üî¨ Generating Training SCMs

‚úÖ Generated 64 training SCMs

üìä SCM Distribution:
  Structure types: {'fork': 16, 'chain': 16, 'collider': 16, 'mixed': 16}
  Variable counts: {3: 16, 4: 16, 5: 16, 6: 16}

üìà Training schedule:
  Episodes per SCM: 8
  Total episodes: 512

üíæ Saved SCM metadata to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/training_scms/scms_FULL_42.json


## 5. Train GRPO Policy

In [7]:
"""
Cell 5: Train GRPO policy with correct optimization direction

This cell performs the actual training.
Progress is saved periodically for resumption.
"""

import time

print("üöÄ Starting GRPO Policy Training")
print("=" * 70)
print(f"üîß Training mode: {TRAINING_MODE}")
print(f"üéØ Optimization: {optimization_config.direction}")
print(f"üìä Total episodes: {total_episodes}")
print(f"‚öñÔ∏è Reward weights: {config.training.reward_weights}")
print(f"‚úÖ Surrogate integration: ACTIVE")
print("=" * 70)

# Configure trainer for optimization direction
if hasattr(trainer, 'optimization_config'):
    trainer.optimization_config = optimization_config
else:
    # Inject optimization config
    trainer.config.optimization = optimization_config.__dict__

# Set training SCMs
if hasattr(trainer, 'set_training_scms'):
    trainer.set_training_scms(training_scms)

# Training loop with explicit error handling
training_start_time = time.time()
training_metrics = {}

try:
    print("\nüèÉ Starting Training Loop...")
    
    # Run training
    training_metrics = trainer.train()
    
    training_end_time = time.time()
    training_duration = training_end_time - training_start_time
    
    print(f"\n‚úÖ Training completed!")
    print(f"‚è±Ô∏è Training time: {training_duration/60:.1f} minutes")
    
    # Extract performance metrics
    performance = training_metrics.get('performance', {})
    final_reward = performance.get('final_reward', 0.0)
    
    # Convert reward to actual target value if minimizing
    if optimization_config.is_minimizing:
        final_target_value = optimization_config.convert_from_maximization(final_reward)
        print(f"\nüìä Final Results:")
        print(f"  Final reward (internal): {final_reward:.4f}")
        print(f"  Final target value: {optimization_config.format_improvement(final_target_value)}")
    else:
        print(f"\nüìä Final Results:")
        print(f"  Final target value: {optimization_config.format_improvement(final_reward)}")
    
    # Store optimization direction in metrics
    training_metrics['optimization_direction'] = optimization_config.direction
    training_metrics['duration_minutes'] = training_duration / 60
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
    training_metrics['interrupted'] = True
    training_metrics['duration_minutes'] = (time.time() - training_start_time) / 60
    
except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    import traceback
    traceback.print_exc()
    raise NotebookError(f"Training failed: {e}")

# Get checkpoint path
checkpoint_path = training_metrics.get('checkpoint_path', checkpoint_dir / "grpo_final")
print(f"\nüìÅ Checkpoint saved to: {checkpoint_path}")

INFO:causal_bayes_opt.training.enriched_trainer:Starting enriched GRPO training


üöÄ Starting GRPO Policy Training
üîß Training mode: FULL
üéØ Optimization: MINIMIZE
üìä Total episodes: 512
‚öñÔ∏è Reward weights: {'optimization': 0.8, 'discovery': 0.1, 'efficiency': 0.1}
‚úÖ Surrogate integration: ACTIVE

üèÉ Starting Training Loop...


INFO:causal_bayes_opt.training.enriched_trainer:üîç Per-Variable Encoding - Policy Output (call 10):
INFO:causal_bayes_opt.training.enriched_trainer:  Variable logits: [ 0.e+00 -1.e+09]
INFO:causal_bayes_opt.training.enriched_trainer:  Variables: ['X0', 'X1', 'X2'], Target: X1
INFO:causal_bayes_opt.training.enriched_trainer:  Target variable 'X1' at index 1, logit: -1000000000.0
INFO:causal_bayes_opt.training.enriched_trainer:  Variable selection:
INFO:causal_bayes_opt.training.enriched_trainer:    Temperature: 2.00
INFO:causal_bayes_opt.training.enriched_trainer:    Probabilities: [1. 0.]
INFO:causal_bayes_opt.training.enriched_trainer:    Selected: X0 (index 0)
INFO:causal_bayes_opt.training.enriched_trainer:  Value selection:
INFO:causal_bayes_opt.training.enriched_trainer:    Mean: 0.0000, Std: 1.0000
INFO:causal_bayes_opt.training.enriched_trainer:    Temperature: 1.50
INFO:causal_bayes_opt.training.enriched_trainer:    Sampled value: -2.0168
INFO:causal_bayes_opt.training.enrich


‚úÖ Training completed!
‚è±Ô∏è Training time: 44.5 minutes

üìä Final Results:
  Final reward (internal): 0.8789
  Final target value: -0.8789 (‚Üì better)

üìÅ Checkpoint saved to: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/enriched_grpo_final


## 6. Save Checkpoint with Metadata

In [8]:
"""
Cell 6: Save checkpoint with complete metadata

This cell saves the trained model with all necessary metadata
for later evaluation and comparison.
"""
from omegaconf import OmegaConf
print("üíæ Saving Checkpoint with Metadata")
print("=" * 50)

# Generate checkpoint name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
opt_direction = optimization_config.direction.lower()
checkpoint_name = f"grpo_{TRAINING_MODE.lower()}_{opt_direction}_{timestamp}"

# Extract actual model parameters from trainer
try:
    checkpoint_data = {
        'policy_params': trainer.policy_params,
        'policy_config': {
            'architecture': OmegaConf.to_container(config.training.architecture),
            'state_config': OmegaConf.to_container(config.training.state_config),
            'grpo_config': OmegaConf.to_container(config.training.grpo_config)
        },
        'training_metrics': training_metrics,
        'optimization_config': optimization_config.__dict__
    }
    print(f"‚úÖ Extracted model parameters from trainer")
except Exception as e:
    print(f"‚ö†Ô∏è Warning: Could not extract model parameters: {e}")
    checkpoint_data = None

# Create checkpoint metadata
metadata = CheckpointMetadata(
    name=checkpoint_name,
    path=checkpoint_dir / checkpoint_name,
    optimization_config=optimization_config,
    training_config={
        'mode': TRAINING_MODE,
        'objective': OPTIMIZATION_OBJECTIVE,
        'config': OmegaConf.to_container(config.training),
        'reward_weights': dict(config.training.reward_weights),
        'scm_config': OmegaConf.to_container(config.experiment.scm_generation)
    },
    training_results={
        'duration_minutes': training_metrics.get('duration_minutes', 0),
        'final_performance': performance if 'performance' in locals() else {},
        'episodes_completed': total_episodes if not training_metrics.get('interrupted', False) else starting_episode,
        'success': not training_metrics.get('interrupted', False)
    },
    timestamp=timestamp
)

# Save checkpoint
try:
    final_checkpoint_path = checkpoint_manager.save_checkpoint(
        checkpoint_data=checkpoint_data,
        metadata=metadata,
        checkpoint_name=checkpoint_name
    )
    
    print(f"\n‚úÖ Checkpoint saved successfully!")
    print(f"üìÅ Location: {final_checkpoint_path}")
    print(f"üìã Name: {checkpoint_name}")
    if checkpoint_data is not None:
        print(f"üìã Includes: Model parameters, config, and metrics")
        print(f"üîÑ Policy params saved separately: policy_params.pkl")
    else:
        print(f"‚ö†Ô∏è Warning: Only metadata saved (no model parameters)")
    
    # Display summary
    print(f"\nüìä Training Summary:")
    print(f"  Mode: {TRAINING_MODE}")
    print(f"  Optimization: {optimization_config.direction}")
    print(f"  Duration: {metadata.training_results['duration_minutes']:.1f} minutes")
    print(f"  Episodes: {metadata.training_results['episodes_completed']}")
    print(f"  Success: {'Yes' if metadata.training_results['success'] else 'No (interrupted)'}")
    
    if 'final_reward' in performance:
        final_value = performance['final_reward']
        if optimization_config.is_minimizing:
            final_value = optimization_config.convert_from_maximization(final_value)
        print(f"  Final target value: {optimization_config.format_improvement(final_value)}")
    
except Exception as e:
    raise NotebookError(f"Failed to save checkpoint: {e}")

# Store checkpoint name for easy access
TRAINED_CHECKPOINT = checkpoint_name
print(f"\nüí° Checkpoint name stored in variable: TRAINED_CHECKPOINT")
print(f"   Use this in evaluation notebook: '{TRAINED_CHECKPOINT}'")

INFO:scripts.notebooks.base_components:Saved checkpoint data: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_full_minimize_20250728_163157/checkpoint.pkl
INFO:scripts.notebooks.base_components:Saved policy params separately: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_full_minimize_20250728_163157/policy_params.pkl
INFO:scripts.notebooks.base_components:Saved checkpoint: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_full_minimize_20250728_163157
INFO:scripts.notebooks.base_components:Optimization: MINIMIZE


üíæ Saving Checkpoint with Metadata
‚úÖ Extracted model parameters from trainer

‚úÖ Checkpoint saved successfully!
üìÅ Location: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_full_minimize_20250728_163157
üìã Name: grpo_full_minimize_20250728_163157
üìã Includes: Model parameters, config, and metrics
üîÑ Policy params saved separately: policy_params.pkl

üìä Training Summary:
  Mode: FULL
  Optimization: MINIMIZE
  Duration: 44.5 minutes
  Episodes: 512
  Success: Yes
  Final target value: -0.8789 (‚Üì better)

üí° Checkpoint name stored in variable: TRAINED_CHECKPOINT
   Use this in evaluation notebook: 'grpo_full_minimize_20250728_163157'


## Summary and Next Steps

**What we've accomplished:**
1. ‚úÖ Configured training with explicit optimization direction
2. ‚úÖ Trained GRPO policy with correct reward signals
3. ‚úÖ Saved checkpoint with complete metadata
4. ‚úÖ Validated policy behavior

**Key improvements over original notebook:**
- No silent failures - explicit errors when things go wrong
- Independent cells - can resume training or load checkpoints
- Optimization direction support - correctly handles minimization like PARENT_SCALE
- Clean checkpoint management - all metadata preserved

**Next steps:**
1. Use `grpo_evaluation_modular.ipynb` to evaluate this checkpoint
2. Compare minimization vs maximization policies
3. Analyze how optimization direction affects performance

In [9]:
"""
Cell 7: Phase 2 Active Learning Information

This trained GRPO policy can now be used for Phase 2 active learning.
"""
print("üéØ Phase 2 Active Learning - Next Steps")
print("=" * 50)
print(f"\nYour trained GRPO policy is ready for Phase 2 active learning!")
print(f"Checkpoint: {TRAINED_CHECKPOINT}")

print(f"\nüìö What is Phase 2?")
print(f"Phase 2 combines your trained GRPO policy with an active learning surrogate:")
print(f"  ‚Ä¢ Phase 1 (just completed): GRPO policy learns good intervention strategies")
print(f"  ‚Ä¢ Phase 2 (optional next): Use GRPO to guide active structure discovery")

print(f"\nüöÄ To run Phase 2 evaluation:")
print(f"1. Open grpo_evaluation_modular.ipynb")
print(f"2. Set EVALUATION_MODE = 'PHASE2_ACTIVE_LEARNING'")
print(f"3. Use checkpoint: '{TRAINED_CHECKPOINT}'")

print(f"\nüîÑ Phase 2 Benefits:")
print(f"  ‚Ä¢ Better structure learning: Active surrogate discovers true causal structure")
print(f"  ‚Ä¢ Guided exploration: GRPO policy provides intelligent intervention selection")
print(f"  ‚Ä¢ Measurable progress: Track F1/SHD improvements over time")

print(f"\nüìä Comparison Options:")
print(f"  ‚Ä¢ GRPO + Bootstrap (Phase 1 only) - what you just trained")
print(f"  ‚Ä¢ GRPO + Active Learning (Phase 2) - enhanced structure discovery")
print(f"  ‚Ä¢ Random + Active Learning - baseline for comparison")

üéØ Phase 2 Active Learning - Next Steps

Your trained GRPO policy is ready for Phase 2 active learning!
Checkpoint: grpo_full_minimize_20250728_163157

üìö What is Phase 2?
Phase 2 combines your trained GRPO policy with an active learning surrogate:
  ‚Ä¢ Phase 1 (just completed): GRPO policy learns good intervention strategies
  ‚Ä¢ Phase 2 (optional next): Use GRPO to guide active structure discovery

üöÄ To run Phase 2 evaluation:
1. Open grpo_evaluation_modular.ipynb
2. Set EVALUATION_MODE = 'PHASE2_ACTIVE_LEARNING'
3. Use checkpoint: 'grpo_full_minimize_20250728_163157'

üîÑ Phase 2 Benefits:
  ‚Ä¢ Better structure learning: Active surrogate discovers true causal structure
  ‚Ä¢ Guided exploration: GRPO policy provides intelligent intervention selection
  ‚Ä¢ Measurable progress: Track F1/SHD improvements over time

üìä Comparison Options:
  ‚Ä¢ GRPO + Bootstrap (Phase 1 only) - what you just trained
  ‚Ä¢ GRPO + Active Learning (Phase 2) - enhanced structure discovery
  ‚Ä¢