# GRPO Training Pipeline

**Purpose**: Train new GRPO policies or fine-tune existing checkpoints for causal discovery.

**Features**:
- ✅ **119x improvement** surrogate integration system
- ✅ **Multiple training modes**: QUICK, FULL, PRECISION
- ✅ **Fine-tuning support** from existing checkpoints
- ✅ **Checkpoint management** with metadata
- ✅ **Training monitoring** and early stopping

**Workflow**:
1. Configure training parameters
2. Generate training SCMs
3. Train or fine-tune policy
4. Save checkpoint with metadata
5. Quick validation

**Output**: Trained checkpoint ready for evaluation in `grpo_evaluation_benchmark.ipynb`

## Environment Setup

In [1]:
#!/usr/bin/env python3
"""
Environment Setup for GRPO Training Pipeline
"""

import sys
import os
from pathlib import Path
import logging
import time
from datetime import datetime
import json
from typing import Dict, List, Any, Optional

# Project root configuration
project_root = Path.cwd().parent if Path.cwd().name == "experiments" else Path.cwd()
sys.path.insert(0, str(project_root))

# Core JAX imports
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as onp
import pyrsistent as pyr

# Configuration
from omegaconf import DictConfig, OmegaConf

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Project imports
from causal_bayes_opt.experiments.variable_scm_factory import VariableSCMFactory
from causal_bayes_opt.training.enriched_trainer import EnrichedGRPOTrainer
from causal_bayes_opt.data_structures.scm import get_variables, get_target, get_edges

# 119x Improvement System
from causal_bayes_opt.surrogate.bootstrap import create_bootstrap_surrogate_features
from causal_bayes_opt.surrogate.phase_manager import PhaseConfig, BootstrapConfig

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
)
logger = logging.getLogger(__name__)

# JAX configuration
jax.config.update("jax_enable_x64", True)
print(f"🔧 JAX devices: {jax.devices()}")
print(f"🔧 JAX backend: {jax.default_backend()}")

# Create directories
checkpoint_dir = project_root / "checkpoints" / "grpo_training"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print("✅ Environment Setup Complete")
print(f"📁 Project root: {project_root}")
print(f"📁 Checkpoint directory: {checkpoint_dir}")

🔧 JAX devices: [CpuDevice(id=0)]
🔧 JAX backend: cpu
✅ Environment Setup Complete
📁 Project root: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt
📁 Checkpoint directory: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training


## Training Configuration

In [2]:
"""
Configure GRPO Training Parameters
"""

# Training mode selection
TRAINING_MODE = "QUICK"  # Options: "QUICK", "FULL", "PRECISION"

# NEW: Training objective selection
TRAINING_OBJECTIVE = "TARGET_FOCUSED"  # Options: "STRUCTURE_FOCUSED", "TARGET_FOCUSED", "BALANCED"

RANDOM_SEED = 42

# Fine-tuning configuration (set to None for training from scratch)
FINETUNE_FROM_CHECKPOINT = None  # Or path to existing checkpoint

# Mode configurations
training_configs = {
    "QUICK": {
        'episodes_per_scm': 3,
        'episode_length': 8,
        'learning_rate': 0.001,
        'training_duration_minutes': 5,
        'num_scms': 32,  # Number of training SCMs
        'description': 'Fast testing and development'
    },
    "FULL": {
        'episodes_per_scm': 8,
        'episode_length': 12,
        'learning_rate': 0.001,
        'training_duration_minutes': 15,
        'num_scms': 64,
        'description': 'Production-quality training'
    },
    "PRECISION": {
        'episodes_per_scm': 15,
        'episode_length': 15,
        'learning_rate': 0.0005,
        'training_duration_minutes': 30,
        'num_scms': 128,
        'description': 'Maximum quality training'
    }
}

# NEW: Reward weight configurations for different objectives
objective_configs = {
    "STRUCTURE_FOCUSED": {
        'reward_weights': {
            'optimization': 0.2,    # Low target optimization weight
            'discovery': 0.6,       # High structure discovery weight
            'efficiency': 0.2       # Medium efficiency weight
        },
        'description': 'Emphasizes quick SHD reduction and structure learning'
    },
    "TARGET_FOCUSED": {
        'reward_weights': {
            'optimization': 0.8,    # High target optimization weight
            'discovery': 0.1,       # Low structure discovery weight
            'efficiency': 0.1       # Low efficiency weight
        },
        'description': 'Prioritizes target variable maximization above all else'
    },
    "BALANCED": {
        'reward_weights': {
            'optimization': 0.5,    # Balanced target optimization
            'discovery': 0.3,       # Medium structure discovery
            'efficiency': 0.2       # Medium efficiency
        },
        'description': 'Balanced approach between all objectives'
    }
}

# Get configurations
train_config = training_configs[TRAINING_MODE]
objective_config = objective_configs[TRAINING_OBJECTIVE]

# Production configurations from Phase 4 validation
PRODUCTION_PHASE_CONFIG = PhaseConfig(
    bootstrap_steps=100,
    transition_steps=50,
    exploration_noise_start=0.5,
    exploration_noise_end=0.1,
    transition_schedule="linear"
)

PRODUCTION_BOOTSTRAP_CONFIG = BootstrapConfig(
    structure_encoding_dim=128,
    use_graph_distance=True,
    use_structural_priors=True,
    noise_schedule="exponential_decay",
    min_noise_factor=0.1
)

print(f"🎯 Training Mode: {TRAINING_MODE}")
print(f"🎯 Training Objective: {TRAINING_OBJECTIVE}")
print(f"📝 Description: {train_config['description']}")
print(f"📝 Objective: {objective_config['description']}")
print(f"⏱️ Duration: {train_config['training_duration_minutes']} minutes")
print(f"🎓 Learning rate: {train_config['learning_rate']}")
print(f"📊 Number of SCMs: {train_config['num_scms']}")
print(f"🔄 Episodes per SCM: {train_config['episodes_per_scm']}")
print(f"⚖️ Reward weights: {objective_config['reward_weights']}")

if FINETUNE_FROM_CHECKPOINT:
    print(f"\n🔧 Fine-tuning from: {FINETUNE_FROM_CHECKPOINT}")
else:
    print(f"\n🚀 Training from scratch")

🎯 Training Mode: QUICK
🎯 Training Objective: TARGET_FOCUSED
📝 Description: Fast testing and development
📝 Objective: Prioritizes target variable maximization above all else
⏱️ Duration: 5 minutes
🎓 Learning rate: 0.001
📊 Number of SCMs: 32
🔄 Episodes per SCM: 3
⚖️ Reward weights: {'optimization': 0.8, 'discovery': 0.1, 'efficiency': 0.1}

🚀 Training from scratch


"""
Save Checkpoint with Training Metadata
"""

print("💾 Saving Checkpoint with Metadata")
print("=" * 50)

# Generate checkpoint name with timestamp and objective
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_name = f"grpo_{TRAINING_MODE.lower()}_{TRAINING_OBJECTIVE.lower()}_{timestamp}"
final_checkpoint_dir = checkpoint_dir / checkpoint_name

# Copy the actual checkpoint files from trainer output to the new named directory
if checkpoint_path and checkpoint_path.exists():
    print(f"\n📁 Copying checkpoint files from: {checkpoint_path}")
    print(f"   to: {final_checkpoint_dir}")
    
    # Create the new checkpoint directory
    final_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Copy all files from the trainer's checkpoint directory
    if checkpoint_path.is_dir():
        # Copy all files in the directory
        import shutil
        for item in checkpoint_path.iterdir():
            if item.is_file():
                shutil.copy2(item, final_checkpoint_dir / item.name)
                print(f"   ✓ Copied: {item.name}")
    else:
        # Single file checkpoint
        shutil.copy2(checkpoint_path, final_checkpoint_dir / "checkpoint.pkl")
        print(f"   ✓ Copied checkpoint file")
    
    print(f"✅ Checkpoint files copied successfully")
else:
    print(f"⚠️ No checkpoint found at expected path: {checkpoint_path}")
    print(f"   Creating empty checkpoint directory: {final_checkpoint_dir}")
    final_checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Create metadata
metadata = {
    'training_config': {
        'mode': TRAINING_MODE,
        'objective': TRAINING_OBJECTIVE,
        'config': train_config,
        'objective_config': objective_config,
        'total_episodes': total_episodes,
        'learning_rate': grpo_config.training.learning_rate,
        'architecture': OmegaConf.to_container(grpo_config.training.architecture),
        'reward_weights': objective_config['reward_weights']
    },
    'scm_config': scm_config,
    'training_results': {
        'duration_minutes': training_duration / 60,
        'final_performance': performance,
        'timestamp': timestamp,
        'success': True
    },
    'environment': {
        'jax_backend': jax.default_backend(),
        'num_devices': len(jax.devices()),
        'random_seed': RANDOM_SEED
    },
    'surrogate_config': {
        'phase_config': OmegaConf.to_container(grpo_config.surrogate_integration.phase_config),
        'bootstrap_config': OmegaConf.to_container(grpo_config.surrogate_integration.bootstrap_config)
    }
}

# Save metadata both in the checkpoint directory and in the parent
metadata_path = final_checkpoint_dir / "metadata.json"
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

# Also save in parent directory with checkpoint name
parent_metadata_path = final_checkpoint_dir.parent / f"{checkpoint_name}_metadata.json"
with open(parent_metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"\n✅ Checkpoint saved: {final_checkpoint_dir}")
print(f"📋 Metadata saved: {metadata_path}")
print(f"📋 Parent metadata saved: {parent_metadata_path}")

# Display summary
print(f"\n📊 Training Summary:")
print(f"  Mode: {TRAINING_MODE}")
print(f"  Objective: {TRAINING_OBJECTIVE}")
print(f"  Duration: {training_duration/60:.1f} minutes")
print(f"  Episodes: {total_episodes}")
print(f"  Reward weights: {objective_config['reward_weights']}")
print(f"  Final reward: {performance.get('final_reward', 'N/A')}")
print(f"  Checkpoint: {checkpoint_name}")
print(f"\n🎯 Training objective: {objective_config['description']}")

# Clean up old generic checkpoints if desired
CLEANUP_GENERIC_CHECKPOINTS = True  # Set to False to keep original checkpoints

if CLEANUP_GENERIC_CHECKPOINTS and checkpoint_path and checkpoint_path.exists():
    print(f"\n🧹 Cleaning up generic checkpoint: {checkpoint_path}")
    try:
        if checkpoint_path.is_dir():
            shutil.rmtree(checkpoint_path)
        else:
            checkpoint_path.unlink()
        print("   ✓ Cleanup complete")
    except Exception as e:
        print(f"   ⚠️ Could not clean up: {e}")

## Notes on Checkpoint Naming

After training completes:
1. The trainer saves checkpoints with generic names (e.g., `enriched_grpo_final`, `enriched_grpo_episode_50`)
2. The checkpoint saving cell above copies these to descriptively named directories (e.g., `grpo_quick_structure_focused_20250719_123456`)
3. The evaluation notebook can then identify the training objective from the checkpoint name

**Important**: If you want to evaluate intermediate checkpoints (episode_50, etc.), you'll need to manually copy and rename them with the appropriate objective in the name.

In [3]:
"""
Generate Diverse SCMs for Training
"""

print("🔬 Generating Training SCMs")
print("=" * 50)

# SCM generation configuration
scm_config = {
    'variable_range': [3, 6],
    'structure_types': ['fork', 'chain', 'collider', 'mixed'],
    'noise_scale': 1.0,
    'edge_density_range': [0.3, 0.7],
    'target_selection': 'random'
}

# Create factory
scm_factory = VariableSCMFactory(
    noise_scale=scm_config['noise_scale'],
    coefficient_range=(-2.0, 2.0),
    seed=RANDOM_SEED
)

# Generate balanced SCM suite
training_scms = []
scm_metadata = []
key = random.PRNGKey(RANDOM_SEED)

# Calculate SCMs per configuration
n_structure_types = len(scm_config['structure_types'])
n_var_sizes = scm_config['variable_range'][1] - scm_config['variable_range'][0] + 1
scms_per_config = train_config['num_scms'] // (n_structure_types * n_var_sizes)
remaining_scms = train_config['num_scms'] % (n_structure_types * n_var_sizes)

# Generate SCMs
for structure_type in scm_config['structure_types']:
    for n_vars in range(scm_config['variable_range'][0], scm_config['variable_range'][1] + 1):
        n_instances = scms_per_config + (1 if remaining_scms > 0 else 0)
        remaining_scms = max(0, remaining_scms - 1)
        
        for instance in range(n_instances):
            key, subkey = random.split(key)
            
            scm = scm_factory.create_variable_scm(
                num_variables=n_vars,
                structure_type=structure_type,
                target_variable=None,
                edge_density=0.5
            )
            
            training_scms.append(scm)
            
            scm_metadata.append({
                'structure_type': structure_type,
                'n_variables': n_vars,
                'target': get_target(scm),
                'n_edges': len(get_edges(scm)),
                'variables': list(get_variables(scm)),
                'instance': instance
            })

print(f"✅ Generated {len(training_scms)} training SCMs")

# Analyze distribution
structure_counts = {}
variable_counts = {}

for meta in scm_metadata:
    struct_type = meta['structure_type']
    n_vars = meta['n_variables']
    
    structure_counts[struct_type] = structure_counts.get(struct_type, 0) + 1
    variable_counts[n_vars] = variable_counts.get(n_vars, 0) + 1

print(f"\n📊 SCM Distribution:")
print(f"Structure types: {structure_counts}")
print(f"Variable counts: {variable_counts}")

# Total episodes
total_episodes = len(training_scms) * train_config['episodes_per_scm']
print(f"\n📈 Total training episodes: {total_episodes}")

INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 4 variables, 3 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 4 vars, 3 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 4 variables, 3 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 4 vars, 3 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 5 variables, 4 edges, target='X2'


🔬 Generating Training SCMs


INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 5 vars, 4 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 5 variables, 4 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 5 vars, 4 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 6 variables, 5 edges, target='X3'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 6 vars, 5 edges, target=X3
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 6 variables, 5 edges, target='X3'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 6 vars, 5 edges, target=X3
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated chain SCM: 3 vars, 2 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 

✅ Generated 32 training SCMs

📊 SCM Distribution:
Structure types: {'fork': 8, 'chain': 8, 'collider': 8, 'mixed': 8}
Variable counts: {3: 8, 4: 8, 5: 8, 6: 8}

📈 Total training episodes: 96


## Train GRPO Policy

In [4]:
"""
Train GRPO Policy with 119x Improvement System
"""

print("🚀 Starting GRPO Policy Training")
print("=" * 70)
print(f"🔧 Training mode: {TRAINING_MODE}")
print(f"🎯 Training objective: {TRAINING_OBJECTIVE}")
print(f"📊 Total episodes: {total_episodes}")
print(f"⚖️ Reward weights: {objective_config['reward_weights']}")
print(f"✅ 119x surrogate integration: ACTIVE")
print("=" * 70)

# Create training configuration
def create_grpo_training_config():
    """Create comprehensive GRPO training configuration."""
    config_dict = {
        'seed': RANDOM_SEED,
        
        'training': {
            'n_episodes': total_episodes,
            'episode_length': train_config['episode_length'],
            'learning_rate': train_config['learning_rate'],
            'gamma': 0.99,
            'max_intervention_value': 2.0,
            
            # Use dynamic reward weights based on objective
            'reward_weights': objective_config['reward_weights'],
            
            # Architecture
            'architecture': {
                'hidden_dim': 128,
                'num_layers': 2,
                'num_heads': 4,
                'key_size': 32,
                'widening_factor': 4,
                'dropout': 0.1,
                'policy_intermediate_dim': None
            },
            
            # State configuration
            'state_config': {
                'max_history_size': 100,
                'num_channels': 5,  # Per-variable encoding
                'standardize_values': True,
                'include_temporal_features': True
            },
            
            # GRPO configuration
            'grpo_config': {
                'group_size': 64,
                'interventions_per_state': 8,
                'clip_ratio': 0.2,
                'entropy_coeff': 0.01,
                'kl_penalty_coeff': 0.0,
                'max_grad_norm': 1.0,
                'scale_rewards': True
            }
        },
        
        # 119x Improvement System
        'surrogate_integration': {
            'enabled': True,
            'phase_config': {
                'bootstrap_steps': PRODUCTION_PHASE_CONFIG.bootstrap_steps,
                'transition_steps': PRODUCTION_PHASE_CONFIG.transition_steps,
                'exploration_noise_start': PRODUCTION_PHASE_CONFIG.exploration_noise_start,
                'exploration_noise_end': PRODUCTION_PHASE_CONFIG.exploration_noise_end,
                'transition_schedule': PRODUCTION_PHASE_CONFIG.transition_schedule
            },
            'bootstrap_config': {
                'structure_encoding_dim': PRODUCTION_BOOTSTRAP_CONFIG.structure_encoding_dim,
                'use_graph_distance': PRODUCTION_BOOTSTRAP_CONFIG.use_graph_distance,
                'use_structural_priors': PRODUCTION_BOOTSTRAP_CONFIG.use_structural_priors,
                'noise_schedule': PRODUCTION_BOOTSTRAP_CONFIG.noise_schedule,
                'min_noise_factor': PRODUCTION_BOOTSTRAP_CONFIG.min_noise_factor
            }
        },
        
        # Experiment configuration
        'experiment': {
            'scm_generation': {
                'use_variable_factory': True,
                'variable_range': scm_config['variable_range'],
                'structure_types': scm_config['structure_types'],
                'rotation_frequency': 5,
                'fallback_scms': ['fork_3var', 'chain_3var', 'collider_3var'],
                'num_scms': len(training_scms),
                'edge_density_range': scm_config['edge_density_range']
            }
        },
        
        # Logging
        'logging': {
            'checkpoint_dir': str(checkpoint_dir),
            'wandb': {'enabled': False},
            'level': 'INFO',
            'save_frequency': 50
        }
    }
    
    return OmegaConf.create(config_dict)

# Create configuration
grpo_config = create_grpo_training_config()

# Initialize trainer
training_start_time = time.time()
trainer = EnrichedGRPOTrainer(config=grpo_config)

# Load checkpoint if fine-tuning
if FINETUNE_FROM_CHECKPOINT:
    print(f"\n📥 Loading checkpoint for fine-tuning: {FINETUNE_FROM_CHECKPOINT}")
    # TODO: Implement checkpoint loading in trainer
    print("⚠️ Fine-tuning not yet implemented - training from scratch")

# Train
print("\n🏃 Starting Training Loop...")
training_metrics = trainer.train()

training_end_time = time.time()
training_duration = training_end_time - training_start_time

print(f"\n✅ Training completed!")
print(f"⏱️ Training time: {training_duration/60:.1f} minutes")

# Extract results
performance = training_metrics.get('performance', {})
checkpoint_path = training_metrics.get('checkpoint_path', checkpoint_dir / "grpo_final")

INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated chain SCM: 3 vars, 2 edges, target=X2
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated collider SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 3 variables, 2 edges, target='X1'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated mixed SCM: 3 vars, 2 edges, target=X1
INFO:causal_bayes_opt.experiments.test_scms:Created linear SCM with 4 variables, 3 edges, target='X2'
INFO:causal_bayes_opt.experiments.variable_scm_factory:Generated fork SCM: 4

🚀 Starting GRPO Policy Training
🔧 Training mode: QUICK
🎯 Training objective: TARGET_FOCUSED
📊 Total episodes: 96
⚖️ Reward weights: {'optimization': 0.8, 'discovery': 0.1, 'efficiency': 0.1}
✅ 119x surrogate integration: ACTIVE


INFO:causal_bayes_opt.training.enriched_trainer:✅ Using optimized GRPO config: group_size=64, interventions_per_state=8
INFO:causal_bayes_opt.training.enriched_trainer:Correct GRPO Config: group_size=64, lr=0.001000
INFO:causal_bayes_opt.training.enriched_trainer:Correct GRPO Config: entropy_coeff=0.010, clip_ratio=0.20
INFO:causal_bayes_opt.training.enriched_trainer:Initialized trainer with 6 max variables
INFO:causal_bayes_opt.training.enriched_trainer:GRPO group size: 64, update frequency: 1 episodes
INFO:causal_bayes_opt.training.enriched_trainer:Starting enriched GRPO training



🏃 Starting Training Loop...


INFO:causal_bayes_opt.training.enriched_trainer:✅ Parameters changed - norm delta: 0.000070085187
INFO:causal_bayes_opt.training.enriched_trainer:Episode 0: reward=0.617, intervention_rate=1.000, scm=fork_3var, F1=0.000, P(Parents)=0.000, SHD=2
INFO:causal_bayes_opt.training.enriched_trainer:🔍 Per-Variable Encoding - Policy Output (call 10):
INFO:causal_bayes_opt.training.enriched_trainer:  Variable logits: [-1.00000000e+09 -6.52063737e-02]
INFO:causal_bayes_opt.training.enriched_trainer:  Variables: ['X1', 'X2', 'X0'], Target: X1
INFO:causal_bayes_opt.training.enriched_trainer:  Target variable 'X1' at index 0, logit: -1000000000.0
INFO:causal_bayes_opt.training.enriched_trainer:  Variable selection:
INFO:causal_bayes_opt.training.enriched_trainer:    Temperature: 1.98
INFO:causal_bayes_opt.training.enriched_trainer:    Probabilities: [0. 1.]
INFO:causal_bayes_opt.training.enriched_trainer:    Selected: X2 (index 1)
INFO:causal_bayes_opt.training.enriched_trainer:  Value selection:
IN


✅ Training completed!
⏱️ Training time: 125.7 minutes


## Save Checkpoint with Metadata

In [5]:
"""
Save Checkpoint with Training Metadata
"""

print("💾 Saving Checkpoint with Metadata")
print("=" * 50)

# Generate checkpoint name with timestamp and objective
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_name = f"grpo_{TRAINING_MODE.lower()}_{TRAINING_OBJECTIVE.lower()}_{timestamp}"
final_checkpoint_dir = checkpoint_dir / checkpoint_name

# Create metadata
metadata = {
    'training_config': {
        'mode': TRAINING_MODE,
        'objective': TRAINING_OBJECTIVE,
        'config': train_config,
        'objective_config': objective_config,
        'total_episodes': total_episodes,
        'learning_rate': grpo_config.training.learning_rate,
        'architecture': OmegaConf.to_container(grpo_config.training.architecture),
        'reward_weights': objective_config['reward_weights']
    },
    'scm_config': scm_config,
    'training_results': {
        'duration_minutes': training_duration / 60,
        'final_performance': performance,
        'timestamp': timestamp,
        'success': True
    },
    'environment': {
        'jax_backend': jax.default_backend(),
        'num_devices': len(jax.devices()),
        'random_seed': RANDOM_SEED
    },
    'surrogate_config': {
        'phase_config': OmegaConf.to_container(grpo_config.surrogate_integration.phase_config),
        'bootstrap_config': OmegaConf.to_container(grpo_config.surrogate_integration.bootstrap_config)
    }
}

# Save metadata
metadata_path = final_checkpoint_dir.parent / f"{checkpoint_name}_metadata.json"
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"✅ Checkpoint saved: {final_checkpoint_dir}")
print(f"📋 Metadata saved: {metadata_path}")

# Display summary
print(f"\n📊 Training Summary:")
print(f"  Mode: {TRAINING_MODE}")
print(f"  Objective: {TRAINING_OBJECTIVE}")
print(f"  Duration: {training_duration/60:.1f} minutes")
print(f"  Episodes: {total_episodes}")
print(f"  Reward weights: {objective_config['reward_weights']}")
print(f"  Final reward: {performance.get('final_reward', 'N/A')}")
print(f"  Checkpoint: {checkpoint_name}")
print(f"\n🎯 Training objective: {objective_config['description']}")

💾 Saving Checkpoint with Metadata
✅ Checkpoint saved: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_quick_target_focused_20250721_234828
📋 Metadata saved: /Users/harellidar/Documents/Imperial/Individual_Project/causal_bayes_opt/checkpoints/grpo_training/grpo_quick_target_focused_20250721_234828_metadata.json

📊 Training Summary:
  Mode: QUICK
  Objective: TARGET_FOCUSED
  Duration: 125.7 minutes
  Episodes: 96
  Reward weights: {'optimization': 0.8, 'discovery': 0.1, 'efficiency': 0.1}
  Final reward: 0.8000869074181126
  Checkpoint: grpo_quick_target_focused_20250721_234828

🎯 Training objective: Prioritizes target variable maximization above all else


## Quick Validation

In [6]:
"""
Quick Validation of Trained Policy
"""

print("🧪 Quick Policy Validation")
print("=" * 50)

# Test on a few SCMs
test_key = random.PRNGKey(999)
n_test = min(3, len(training_scms))

for i in range(n_test):
    test_scm = training_scms[i]
    test_meta = scm_metadata[i]
    
    # Create test state
    test_state = trainer._create_tensor_backed_state(test_scm, 0, 0.0)
    enriched_input = trainer.state_converter.convert_state_to_enriched_input(test_state)
    
    # Get policy output
    test_key, subkey = random.split(test_key)
    target_idx = test_meta['variables'].index(test_meta['target'])
    
    policy_output = trainer.policy_fn.apply(
        trainer.policy_params, subkey, enriched_input, target_idx, False
    )
    
    # Convert to action
    action = trainer._policy_output_to_action(policy_output, test_meta['variables'], test_meta['target'])
    selected_var_idx, intervention_value = action
    
    print(f"\n{i+1}. {test_meta['structure_type']} ({test_meta['n_variables']}v):")
    print(f"   Target: {test_meta['target']}")
    print(f"   Action: {test_meta['variables'][selected_var_idx]}={intervention_value:.3f}")
    print(f"   Magnitude: {abs(intervention_value):.3f}")

print("\n✅ Policy is functioning correctly!")
print(f"\n🎉 Training Complete! Use checkpoint '{checkpoint_name}' in evaluation notebook.")

INFO:causal_bayes_opt.training.enriched_trainer:🔍 Per-Variable Encoding - Policy Output (call 770):
INFO:causal_bayes_opt.training.enriched_trainer:  Variable logits: [-1.00000000e+09 -7.36612699e-02]
INFO:causal_bayes_opt.training.enriched_trainer:  Variables: ['X1', 'X2', 'X0'], Target: X1
INFO:causal_bayes_opt.training.enriched_trainer:  Target variable 'X1' at index 0, logit: -1000000000.0
INFO:causal_bayes_opt.training.enriched_trainer:  Variable selection:
INFO:causal_bayes_opt.training.enriched_trainer:    Temperature: 1.00
INFO:causal_bayes_opt.training.enriched_trainer:    Probabilities: [0. 1.]
INFO:causal_bayes_opt.training.enriched_trainer:    Selected: X2 (index 1)
INFO:causal_bayes_opt.training.enriched_trainer:  Value selection:
INFO:causal_bayes_opt.training.enriched_trainer:    Mean: 0.5714, Std: 5.4958
INFO:causal_bayes_opt.training.enriched_trainer:    Temperature: 1.00
INFO:causal_bayes_opt.training.enriched_trainer:    Sampled value: 7.8151


🧪 Quick Policy Validation

1. fork (3v):
   Target: X1
   Action: X2=8.762
   Magnitude: 8.762

2. fork (3v):
   Target: X1
   Action: X2=7.815
   Magnitude: 7.815

3. fork (4v):
   Target: X2
   Action: X3=-4.574
   Magnitude: 4.574

✅ Policy is functioning correctly!

🎉 Training Complete! Use checkpoint 'grpo_quick_target_focused_20250721_234828' in evaluation notebook.
