# SIREN Training on PhotonSim Data - Example Workflow

This notebook demonstrates how to use the refactored training modules to train JAXSiren on PhotonSim lookup tables.

## Overview
1. Load PhotonSim HDF5 lookup table
2. Configure training parameters
3. Choose training mode (resume from checkpoint or start fresh)
4. Set up monitoring
5. Train the model
6. Analyze results with slice visualizations
7. Test model predictions

## Features
- **Checkpoint Resume**: Automatically resumes from the latest checkpoint
- **Fresh Start Option**: Clear all checkpoints and start from scratch
- **Enhanced Visualizations**: Angular profiles and 2D comparisons for different energies
- **Comprehensive Analysis**: Model performance evaluation and error analysis

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get project paths correctly
current_dir = Path.cwd()
project_root = current_dir.parent  # diffCherenkov root (one level up from notebooks)
photonsim_root = project_root.parent / 'PhotonSim'  # PhotonSim root

# Add project paths
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'tools'))

print(f"Current dir: {current_dir}")
print(f"Project root: {project_root}")
print(f"PhotonSim root: {photonsim_root}")

# Verify paths exist
print(f"Project root exists: {project_root.exists()}")
print(f"Siren dir exists: {(project_root / 'siren').exists()}")
print(f"Training dir exists: {(project_root / 'siren' / 'training').exists()}")
print(f"PhotonSim root exists: {photonsim_root.exists()}")

In [None]:
# Import the refactored training modules with fallback strategies
print("📦 Importing training modules...")

imported_successfully = False

try:
    # Strategy 1: Try standard package import
    from siren.training import (
        SIRENTrainer, 
        TrainingConfig, 
        PhotonSimDataset,
        TrainingMonitor,
        TrainingAnalyzer,
        LiveTrainingCallback
    )
    print("✅ Imported from siren.training package")
    imported_successfully = True
    
except ImportError as e1:
    print(f"❌ Package import failed: {e1}")
    print("🔄 Trying direct module imports...")
    
    try:
        # Strategy 2: Add siren directory to path and import training module
        siren_path = project_root / 'siren'
        if str(siren_path) not in sys.path:
            sys.path.insert(0, str(siren_path))
        
        from training import (
            SIRENTrainer, 
            TrainingConfig, 
            PhotonSimDataset,
            TrainingMonitor,
            TrainingAnalyzer,
            LiveTrainingCallback
        )
        print("✅ Imported from training module directly")
        imported_successfully = True
        
    except ImportError as e2:
        print(f"❌ Direct module import failed: {e2}")
        print("🔧 Trying manual imports from individual files...")
        
        try:
            # Strategy 3: Import from individual module files
            training_path = project_root / 'siren' / 'training'
            if str(training_path) not in sys.path:
                sys.path.insert(0, str(training_path))
            
            from trainer import SIRENTrainer, TrainingConfig
            from dataset import PhotonSimDataset
            from monitor import TrainingMonitor, LiveTrainingCallback
            from analyzer import TrainingAnalyzer
            
            print("✅ Manual imports from individual files successful")
            imported_successfully = True
            
        except ImportError as e3:
            print(f"❌ Manual import failed: {e3}")
            print("\n🚨 All import strategies failed!")
            print("Please check:")
            print(f"  1. Current working directory: {Path.cwd()}")
            print(f"  2. Project root: {project_root}")
            print(f"  3. Siren directory exists: {(project_root / 'siren').exists()}")
            print(f"  4. Training directory exists: {(project_root / 'siren' / 'training').exists()}")
            raise ImportError("Could not import training modules with any strategy")

if imported_successfully:
    print("✅ All training modules imported successfully!")
    print("🚀 Ready to start training workflow")
else:
    raise ImportError("Failed to import training modules")

## 1. Load PhotonSim Data

Load the HDF5 lookup table created in PhotonSim.

In [None]:
# Path to the PhotonSim HDF5 lookup table
h5_path = photonsim_root / 'output' / 'photon_lookup_table.h5'

# Check if file exists
if not h5_path.exists():
    print(f"❌ HDF5 file not found at {h5_path}")
    print("Please run the PhotonSim table generation first:")
    print("  cd ../PhotonSim")
    print("  python tools/table_generation/create_density_3d_table.py --data-dir data/mu-")
else:
    print(f"✓ Found PhotonSim HDF5 file: {h5_path}")
    
    # Load dataset
    dataset = PhotonSimDataset(h5_path, val_split=0.1)
    
    print(f"\nDataset info:")
    print(f"  Data type: {dataset.data_type}")
    print(f"  Total samples: {len(dataset.data['inputs']):,}")
    print(f"  Train samples: {len(dataset.train_indices):,}")
    print(f"  Val samples: {len(dataset.val_indices):,}")
    print(f"  Energy range: {dataset.energy_range[0]:.0f}-{dataset.energy_range[1]:.0f} MeV")
    print(f"  Angle range: {np.degrees(dataset.angle_range[0]):.1f}-{np.degrees(dataset.angle_range[1]):.1f} degrees")
    print(f"  Distance range: {dataset.distance_range[0]:.0f}-{dataset.distance_range[1]:.0f} mm")

## 2. Configure Training Parameters

Set up the training configuration using the `TrainingConfig` dataclass.

In [None]:
# Create training configuration with PATIENCE-BASED learning rate scheduling
config = TrainingConfig(
    # Model architecture
    hidden_features=256,
    hidden_layers=4,
    w0=30.0,
    
    # Training parameters
    learning_rate=1e-4,        # Start here, let patience reduce it intelligently
    weight_decay=1e-5,         # Regularization for stability
    batch_size=8192,           # Smaller batches for better gradients
    num_steps=10000,           # Enough steps to see patience in action
    
    # PATIENCE-BASED SCHEDULER (much better than fixed intervals!)
    use_patience_scheduler=True,
    patience=500,               # Reduce LR after 10 evals with no improvement  
    lr_reduction_factor=0.5,   # Cut LR in half when triggered
    min_lr=1e-7,              # Don't go below this
    
    # Stability features
    optimizer='adamw',         # AdamW with built-in weight decay
    grad_clip_norm=1.0,       # Prevent gradient explosions
    
    # Logging and checkpointing
    log_every=50,
    val_every=100,            # Check patience every 100 steps
    checkpoint_every=1000,
    
    seed=42
)

print("🧠 FIXED Training Configuration with Patience:")
print(f"  • Initial LR: {config.learning_rate:.2e}")
print(f"  • Patience: {config.patience} evaluations")
print(f"  • LR reduction: ×{config.lr_reduction_factor} when triggered")
print(f"  • Minimum LR: {config.min_lr:.2e}")
print(f"  • Validation every: {config.val_every} steps")
print(f"  • Optimizer: {config.optimizer.upper()} with gradient clipping")
print("\n📈 How it works:")
print("  → LR stays constant while validation loss improves")
print("  → After 10 evaluations with no improvement → LR ÷ 2")
print("  → Only updates LR when needed (not every step!)")
print("  → Much faster training with adaptive LR!")

## 3. Training Mode Selection

Choose whether to resume from existing checkpoint or start fresh.

In [None]:
# Training mode configuration
START_FRESH = True  # Set to True to start from scratch, False to resume from checkpoint

# Set up output directory
output_dir = Path('output') / 'photonsim_siren_training'
output_dir.mkdir(exist_ok=True, parents=True)

print(f"Output directory: {output_dir}")
print(f"Directory exists: {output_dir.exists()}")

# Check for existing checkpoints
existing_checkpoints = list(output_dir.glob('*.npz'))
existing_history = output_dir / 'training_history.json'

if existing_checkpoints or existing_history.exists():
    print(f"\n🔍 Found existing training data:")
    if existing_history.exists():
        import json
        with open(existing_history, 'r') as f:
            history = json.load(f)
            if history.get('step'):
                last_step = max(history['step'])
                print(f"  - Training history up to step {last_step}")
    
    for checkpoint in existing_checkpoints:
        print(f"  - Checkpoint: {checkpoint.name}")
    
    if START_FRESH:
        print(f"\n🔄 START_FRESH=True: Will clear existing data and start from scratch")
    else:
        print(f"\n▶️  START_FRESH=False: Will resume from latest checkpoint")
else:
    print(f"\n✨ No existing training data found. Starting fresh.")

## 4. Initialize Trainer and Monitor

Create the trainer and set up monitoring.

In [None]:
# Initialize trainer with resume option
trainer = SIRENTrainer(
    dataset=dataset,
    config=config,
    output_dir=output_dir,
    resume_from_checkpoint=not START_FRESH  # Resume unless starting fresh
)

# Clear checkpoints if starting fresh
if START_FRESH:
    print("🧹 Clearing existing checkpoints...")
    trainer.clear_checkpoints()
    print("✅ Starting with clean slate")

print(f"✓ Trainer initialized")
print(f"✓ Output directory: {output_dir}")
print(f"✓ JAX device: {trainer.device}")

# Check if we're resuming
if trainer.start_step > 0:
    print(f"✓ Resuming from step {trainer.start_step}")
    print(f"✓ Training history loaded with {len(trainer.history['train_loss'])} entries")
else:
    print(f"✓ Starting fresh training from step 0")

In [None]:
# Set up monitoring with live plotting
monitor = TrainingMonitor(output_dir, live_plotting=True)

# Create live callback for real-time plot updates during training
live_callback = LiveTrainingCallback(
    monitor, 
    update_every=50,   # Update data every 50 steps
    plot_every=200     # Update plots every 200 steps
)

# Add callback to trainer for live monitoring
trainer.add_callback(live_callback)

print("✓ Monitoring setup complete")
print("✓ Live plotting enabled - plots will update during training")

In [None]:
# Start training
print("Starting SIREN training...")
history = trainer.train()

print("\n✓ Training completed!")
print(f"Final train loss: {history['train_loss'][-1]:.6f}")
if history['val_loss']:
    print(f"Final val loss: {history['val_loss'][-1]:.6f}")

## 5. Train the Model

Start training with live monitoring.

# Plot training history
fig = trainer.plot_training_history(save_path=output_dir / 'final_training_progress.png')
plt.show()

# Also plot monitoring dashboard
print("\n📊 Training Progress Dashboard:")
monitor_fig = monitor.plot_progress(save_path=output_dir / 'training_dashboard.png')
plt.show()

In [None]:
# Plot training history
fig = trainer.plot_training_history(save_path=output_dir / 'final_training_progress.png')
plt.show()

## 7. Analyze Model Performance

Use the analyzer to evaluate model performance.

In [None]:
# Initialize analyzer
analyzer = TrainingAnalyzer(trainer, dataset)

# Evaluate model
evaluation_results = analyzer.evaluate_model(n_samples=50000, splits=['train', 'val'])

print("\nEvaluation Results:")
for split, results in evaluation_results.items():
    metrics = results['metrics']
    print(f"\n{split.upper()} Split:")
    print(f"  R²: {metrics['r2']:.6f}")
    print(f"  RMSE: {metrics['rmse']:.6f}")
    print(f"  MAE: {metrics['mae']:.6f}")
    print(f"  Relative Error: {metrics['relative_error']:.6f}")
    print(f"  Correlation: {metrics['correlation']:.6f}")

In [None]:
# Analyze error patterns
error_analysis = analyzer.analyze_error_patterns(split='val', n_samples=20000)

print("\nError Analysis:")
for analysis_name, analysis_data in error_analysis.items():
    if 'dimension_name' in analysis_data:
        dim_name = analysis_data['dimension_name']
        dim_range = analysis_data['dimension_range']
        avg_error = np.mean(analysis_data['mean_errors'])
        print(f"  {dim_name}: range {dim_range[0]:.2f}-{dim_range[1]:.2f}, avg error: {avg_error:.6f}")

In [None]:
# Create slice comparison plots
fig_slices = analyzer.plot_lookup_table_slices(save_path=output_dir / 'lookup_table_slices.png')
plt.show()

## 8. Lookup Table Slice Comparisons

Create slice visualizations comparing lookup table expectations vs SIREN model predictions.

## 9. Test Model Predictions

Test the model on specific energy/angle/distance combinations.

In [None]:
# Export analysis results
analyzer.export_results(output_dir / 'analysis_results.json')

# Export monitoring data
monitor.export_data(output_dir / 'monitoring_data.json')

print("✓ Results exported to:")
print(f"  Model checkpoint: {output_dir / 'final_model.npz'}")
print(f"  Training config: {output_dir / 'config.json'}")
print(f"  Training history: {output_dir / 'training_history.json'}")
print(f"  Analysis results: {output_dir / 'analysis_results.json'}")
print(f"  Monitoring data: {output_dir / 'monitoring_data.json'}")
print(f"  Plots: {output_dir / '*.png'}")

In [None]:
## 11. Model Summary

Display a final summary of the trained model.

In [None]:
## 10. Save Results

Export analysis results and model for future use.

In [None]:
## 11. Model Summary

Display a final summary of the trained model.

In [None]:
# Get final summary
summary = monitor.get_summary()

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)

print(f"\nModel Architecture:")
print(f"  Hidden features: {config.hidden_features}")
print(f"  Hidden layers: {config.hidden_layers}")
print(f"  SIREN frequency (w0): {config.w0}")

print(f"\nTraining Configuration:")
print(f"  Learning rate: {config.learning_rate:.2e}")
print(f"  Batch size: {config.batch_size:,}")
print(f"  Total steps: {config.num_steps:,}")
print(f"  Weight decay: {config.weight_decay:.2e}")

print(f"\nDataset Information:")
print(f"  Total samples: {len(dataset.data['inputs']):,}")
print(f"  Training samples: {len(dataset.train_indices):,}")
print(f"  Validation samples: {len(dataset.val_indices):,}")
print(f"  Energy range: {dataset.energy_range[0]:.0f}-{dataset.energy_range[1]:.0f} MeV")

print(f"\nFinal Performance:")
if 'val' in evaluation_results:
    val_metrics = evaluation_results['val']['metrics']
    print(f"  Validation R²: {val_metrics['r2']:.6f}")
    print(f"  Validation RMSE: {val_metrics['rmse']:.6f}")
    print(f"  Validation MAE: {val_metrics['mae']:.6f}")
    print(f"  Relative Error: {val_metrics['relative_error']:.6f}")
    
print(f"\nOutput Directory: {output_dir}")
print("\n✓ Training completed successfully!")