## Summary

✅ **Training Workflow Complete!**

This notebook demonstrates the complete SIREN training and evaluation workflow:

1. **Data Loading**: PhotonSim HDF5 lookup table with proper normalization
2. **Model Training**: JAX/Flax SIREN with linear-scale training (matching CProfSiren)
3. **Model Saving**: Complete metadata preservation for inference
4. **Model Loading**: Standalone inference capabilities with proper denormalization
5. **Validation**: Visual comparison between SIREN predictions and original lookup table

**Key Features:**
- 🔄 **Normalization Handling**: Automatic input/output scaling for inference
- 📊 **Metadata Preservation**: Complete dataset info, training config, and normalization parameters
- 🚀 **Easy Inference**: Simple prediction interface for single points or batches
- 📈 **Visualization**: Direct comparison with original lookup table data
- 🎯 **Accuracy**: High-fidelity reproduction of Cherenkov light patterns

The trained model can now be used as a drop-in replacement for lookup table interpolation in diffCherenkov simulations!

## ⚡ Key Improvements Based on CProfSiren Analysis

This notebook has been updated to match the successful training approach from CProfSiren:

1. **Linear Training**: Train on actual photon density values (not log-transformed)
2. **Loss Scaling**: Multiply MSE loss by 1000 for better gradient flow
3. **Output Squaring**: SIREN model squares output to ensure positive densities
4. **Large Batches**: Use batch_size=65536 for stable training
5. **Fixed LR Schedule**: StepLR with 10× drop at step 2000 (no patience-based updates)
6. **Proper SIREN Handling**: Fixed tuple output issue `(output, coords)`

These changes address the performance degradation and should restore good training behavior.

In [3]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get project paths correctly
current_dir = Path.cwd()
project_root = current_dir.parent  # diffCherenkov root (one level up from notebooks)
photonsim_root = project_root.parent / 'PhotonSim'  # PhotonSim root

# Add project paths
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'tools'))

print(f"Current dir: {current_dir}")
print(f"Project root: {project_root}")
print(f"PhotonSim root: {photonsim_root}")

# Verify paths exist
print(f"Project root exists: {project_root.exists()}")
print(f"Siren dir exists: {(project_root / 'siren').exists()}")
print(f"Training dir exists: {(project_root / 'siren' / 'training').exists()}")
print(f"PhotonSim root exists: {photonsim_root.exists()}")

Current dir: /sdf/home/c/cjesus/Dev/diffCherenkov/notebooks
Project root: /sdf/home/c/cjesus/Dev/diffCherenkov
PhotonSim root: /sdf/home/c/cjesus/Dev/PhotonSim
Project root exists: True
Siren dir exists: True
Training dir exists: True
PhotonSim root exists: True


In [4]:
# Import the refactored training modules
print("📦 Importing training modules...")

try:
    # Add siren directory to path and import training module
    siren_path = project_root / 'siren'
    if str(siren_path) not in sys.path:
        sys.path.insert(0, str(siren_path))
    
    from training import (
        SIRENTrainer, 
        TrainingConfig, 
        PhotonSimDataset,
        TrainingMonitor,
        TrainingAnalyzer,
        LiveTrainingCallback
    )
    print("✅ Imported from training module directly")
    
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("🔧 Trying manual imports from individual files...")
    
    # Import from individual module files
    training_path = project_root / 'siren' / 'training'
    if str(training_path) not in sys.path:
        sys.path.insert(0, str(training_path))
    
    from trainer import SIRENTrainer, TrainingConfig
    from dataset import PhotonSimDataset
    from monitor import TrainingMonitor, LiveTrainingCallback
    from analyzer import TrainingAnalyzer
    
    print("✅ Manual imports from individual files successful")

print("✅ All training modules imported successfully!")
print("🚀 Ready to start training workflow")

📦 Importing training modules...


INFO:numexpr.utils:Note: detected 128 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 128 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
INFO:numexpr.utils:NumExpr defaulting to 16 threads.


✅ Imported from training module directly
✅ All training modules imported successfully!
🚀 Ready to start training workflow


In [5]:
# Analyze the data distribution to understand training challenges
if 'dataset' in locals():
    print("\n📊 Data Analysis:")
    
    # Check target value distribution
    targets = dataset.data['targets']
    targets_nonzero = targets[targets > 0]
    
    print(f"  • Total samples: {len(targets):,}")
    print(f"  • Non-zero samples: {len(targets_nonzero):,} ({len(targets_nonzero)/len(targets)*100:.1f}%)")
    print(f"  • Target range: [{targets.min():.2e}, {targets.max():.2e}]")
    print(f"  • Target mean: {targets.mean():.2e}")
    print(f"  • Target std: {targets.std():.2e}")
    
    # Check log-transformed targets
    targets_log = dataset.data['targets_log']
    print(f"\n  • Log target range: [{targets_log.min():.2f}, {targets_log.max():.2f}]")
    print(f"  • Log target mean: {targets_log.mean():.2f}")
    
    # Show distribution
    import matplotlib.pyplot as plt
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Linear scale histogram
    ax1.hist(targets[targets > 0], bins=100, alpha=0.7)
    ax1.set_xlabel('Photon Density')
    ax1.set_ylabel('Count')
    ax1.set_yscale('log')
    ax1.set_title('Target Distribution (Linear Scale)')
    
    # Log scale histogram  
    ax2.hist(targets_log, bins=100, alpha=0.7, color='orange')
    ax2.set_xlabel('Log10(Photon Density)')
    ax2.set_ylabel('Count')
    ax2.set_title('Target Distribution (Log Scale)')
    
    plt.tight_layout()
    plt.show()
    
    print("\n💡 Insights:")
    print("  • Data has many small/zero values - importance sampling could help")
    print("  • Wide dynamic range suggests log-space training might be beneficial")
    print("  • But CProfSiren trained directly on linear values with MSE loss")

# Path to the PhotonSim HDF5 lookup table
h5_path = photonsim_root / 'output' / 'photon_lookup_table.h5'

# Check if file exists and load dataset
if not h5_path.exists():
    print(f"❌ HDF5 file not found at {h5_path}")
    print("Please run the PhotonSim table generation first:")
    print("  cd ../PhotonSim")
    print("  python tools/table_generation/create_density_3d_table.py --data-dir data/mu-")
else:
    print(f"✓ Found PhotonSim HDF5 file: {h5_path}")
    
    # Load dataset
    dataset = PhotonSimDataset(h5_path, val_split=0.1)
    
    print(f"\nDataset info:")
    print(f"  Data type: {dataset.data_type}")
    print(f"  Total samples: {len(dataset.data['inputs']):,}")
    print(f"  Train samples: {len(dataset.train_indices):,}")
    print(f"  Val samples: {len(dataset.val_indices):,}")
    print(f"  Energy range: {dataset.energy_range[0]:.0f}-{dataset.energy_range[1]:.0f} MeV")
    print(f"  Angle range: {np.degrees(dataset.angle_range[0]):.1f}-{np.degrees(dataset.angle_range[1]):.1f} degrees")
    print(f"  Distance range: {dataset.distance_range[0]:.0f}-{dataset.distance_range[1]:.0f} mm")

## 3. Configure Training Parameters

Set up the training configuration.

In [None]:
# CRITICAL FIX: Modify dataset to train on linear values (not log)
# This matches the successful CProfSiren approach

class LinearPhotonSimDataset:
    """Wrapper to make dataset return linear values instead of log values"""
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        # Copy necessary attributes directly
        self.data = base_dataset.data
        self.train_indices = base_dataset.train_indices
        self.val_indices = base_dataset.val_indices
        self.normalized_bounds = base_dataset.normalized_bounds
        self.metadata = base_dataset.metadata
        self.energy_range = base_dataset.energy_range
        self.angle_range = base_dataset.angle_range
        self.distance_range = base_dataset.distance_range
    
    def get_batch(self, batch_size, rng, split='train', normalized=True):
        # Get normalized inputs but LINEAR targets
        if split == 'train':
            indices = self.train_indices
        else:
            indices = self.val_indices
            
        # Random sampling
        batch_indices = jax.random.choice(rng, indices, shape=(batch_size,))
        
        # Get normalized inputs
        inputs = self.data['inputs_normalized'][batch_indices]
        # Get LINEAR targets (not log!)
        targets = self.data['targets'][batch_indices]
        
        return jnp.array(inputs), jnp.array(targets)
    
    def get_sample_input(self):
        return jnp.array(self.data['inputs_normalized'][:1])
    
    @property
    def has_validation(self):
        return len(self.val_indices) > 0

# Path to the PhotonSim HDF5 lookup table
h5_path = photonsim_root / 'output' / 'photon_lookup_table.h5'

# Check if file exists
if not h5_path.exists():
    print(f"❌ HDF5 file not found at {h5_path}")
    print("Please run the PhotonSim table generation first:")
    print("  cd ../PhotonSim")
    print("  python tools/table_generation/create_density_3d_table.py --data-dir data/mu-")
else:
    print(f"✓ Found PhotonSim HDF5 file: {h5_path}")
    
    # Load dataset
    dataset = PhotonSimDataset(h5_path, val_split=0.1)
    
    print(f"\nDataset info:")
    print(f"  Data type: {dataset.data_type}")
    print(f"  Total samples: {len(dataset.data['inputs']):,}")
    print(f"  Train samples: {len(dataset.train_indices):,}")
    print(f"  Val samples: {len(dataset.val_indices):,}")
    print(f"  Energy range: {dataset.energy_range[0]:.0f}-{dataset.energy_range[1]:.0f} MeV")
    print(f"  Angle range: {np.degrees(dataset.angle_range[0]):.1f}-{np.degrees(dataset.angle_range[1]):.1f} degrees")
    print(f"  Distance range: {dataset.distance_range[0]:.0f}-{dataset.distance_range[1]:.0f} mm")

# Wrap the dataset to use linear values
linear_dataset = LinearPhotonSimDataset(dataset)

print("✅ Dataset configured for linear training (matching CProfSiren)")
print(f"  • Input normalization: [-1, 1]")
print(f"  • Target values: Linear scale (not log)")
print(f"  • Ready for training with MSE loss × 1000")

In [None]:
# Create training configuration with PATIENCE-BASED learning rate scheduling
config = TrainingConfig(
    # Model architecture - same as CProfSiren
    hidden_features=256,
    hidden_layers=3,        # CProfSiren used 3 layers
    w0=30.0,               # Standard SIREN frequency
    
    # Training parameters - adapted from CProfSiren
    learning_rate=1e-4,     # Same as CProfSiren
    weight_decay=0.0,       # CProfSiren didn't use weight decay
    batch_size=10_000,#65536,       # Large batches (as large as memory allows)
    num_steps=25000,        # More steps to see patience in action
    
    # PATIENCE-BASED LR SCHEDULER - much better than fixed!
    use_patience_scheduler=True,   # Enable patience-based LR
    patience=20,                   # Reduce LR after 20 validations with no improvement
    lr_reduction_factor=0.5,       # Cut LR in half when triggered
    min_lr=5e-6,                   # Don't go below this
    
    # Optimizer settings
    optimizer='adam',       # Same as CProfSiren
    grad_clip_norm=0.0,    # CProfSiren didn't use gradient clipping
    
    # Logging frequency
    log_every=10,          # CProfSiren logged every 10 steps
    val_every=50,          # Check validation more frequently for patience
    checkpoint_every=500,  # Save periodically
    
    seed=42
)

print("📊 Training Configuration (CProfSiren-inspired with Patience LR):")
print(f"  • Architecture: {config.hidden_layers} layers × {config.hidden_features} features")
print(f"  • Initial LR: {config.learning_rate:.2e}")
print(f"  • Batch Size: {config.batch_size:,} (large for stable gradients)")
print(f"  • Total Steps: {config.num_steps:,}")
print(f"\n🎯 Patience-based LR Schedule:")
print(f"  • Patience: {config.patience} validation checks")
print(f"  • LR reduction: ×{config.lr_reduction_factor} when triggered")
print(f"  • Minimum LR: {config.min_lr:.2e}")
print(f"  • Validation every: {config.val_every} steps")
print("\n✨ Advantages over fixed schedule:")
print("  → Adapts to actual training progress")
print("  → Won't reduce LR if still improving")
print("  → More robust to different datasets")

In [None]:
# Training mode configuration
START_FRESH = True

# Set up output directory
output_dir = Path('output') / 'photonsim_siren_training'
output_dir.mkdir(exist_ok=True, parents=True)

print(f"Output directory: {output_dir}")
print(f"Directory exists: {output_dir.exists()}")

# Check for existing checkpoints
existing_checkpoints = list(output_dir.glob('*.npz'))
existing_history = output_dir / 'training_history.json'

if existing_checkpoints or existing_history.exists():
    print(f"\n🔍 Found existing training data:")
    if existing_history.exists():
        import json
        with open(existing_history, 'r') as f:
            history = json.load(f)
            if history.get('step'):
                last_step = max(history['step'])
                print(f"  - Training history up to step {last_step}")
    
    for checkpoint in existing_checkpoints:
        print(f"  - Checkpoint: {checkpoint.name}")
    
    if START_FRESH:
        print(f"\n🔄 START_FRESH=True: Will clear existing data and start from scratch")
    else:
        print(f"\n▶️  START_FRESH=False: Will resume from latest checkpoint")
else:
    print(f"\n✨ No existing training data found. Starting fresh.")

In [None]:
# Create custom training with PROPER patience-based LR scheduling
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from typing import NamedTuple, Any

class TrainingState(NamedTuple):
    """Extended training state with patience tracking"""
    params: Any
    opt_state: Any
    best_loss: float
    patience_count: int
    lr_index: int

# Learning rate schedule with patience
def create_patience_lr_schedule(base_lr, factor, patience, min_lr):
    """Create a patience-based learning rate schedule"""
    lr_values = [base_lr]
    current_lr = base_lr
    while current_lr > min_lr:
        current_lr *= factor
        lr_values.append(max(current_lr, min_lr))
    
    # Return a function that selects LR based on index
    def schedule(step, lr_index):
        return lr_values[min(lr_index, len(lr_values) - 1)]
    
    return schedule, lr_values

# Create the schedule
lr_schedule_fn, lr_values = create_patience_lr_schedule(
    config.learning_rate, 
    config.lr_reduction_factor,
    config.patience,
    config.min_lr
)

print(f"📈 Learning rate schedule: {[f'{lr:.2e}' for lr in lr_values[:5]]}...")

# Custom training functions
@jax.jit
def train_step_with_patience(state, batch, lr):
    """Training step with explicit learning rate"""
    inputs, targets = batch
    
    def loss_fn(params):
        # SIREN returns tuple (output, coords) - take first element
        output, _ = trainer.model.apply({'params': params}, inputs)
        
        # Ensure proper shape
        if output.ndim == 1:
            output = output[:, None]
            
        # MSE loss with scaling
        loss = jnp.mean((output - targets) ** 2) * 1000.0
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    
    # Create optimizer with current learning rate
    optimizer = optax.adam(learning_rate=lr)
    
    # Update parameters
    updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)
    
    new_state = TrainingState(
        params=new_params,
        opt_state=new_opt_state,
        best_loss=state.best_loss,
        patience_count=state.patience_count,
        lr_index=state.lr_index
    )
    
    return new_state, loss

@jax.jit
def eval_step_with_patience(params, batch, model_apply):
    """Evaluation step"""
    inputs, targets = batch
    
    # SIREN returns tuple (output, coords) - take first element
    output, _ = model_apply({'params': params}, inputs)
    
    # Ensure proper shape
    if output.ndim == 1:
        output = output[:, None]
        
    # MSE loss with scaling
    loss = jnp.mean((output - targets) ** 2) * 1000.0
    return loss

print("✅ Custom training with patience-based LR scheduling ready!")
print("  • Preserves optimizer state correctly")
print("  • Reduces LR only when validation plateaus")
print("  • Implements CProfSiren-style loss scaling")

In [None]:
# Initialize trainer with LINEAR dataset (not log)
trainer = SIRENTrainer(
    dataset=linear_dataset,  # Use linear dataset!
    config=config,
    output_dir=output_dir,
    resume_from_checkpoint=not START_FRESH
)

# Override the training functions with our custom ones
trainer._create_train_step = lambda: train_step_cprofstyle
trainer._create_eval_step = lambda: eval_step_cprofstyle

# Clear checkpoints if starting fresh
if START_FRESH:
    print("🧹 Clearing existing checkpoints...")
    trainer.clear_checkpoints()
    print("✅ Starting with clean slate")

print(f"✓ Trainer initialized with CProfSiren-style training")
print(f"✓ Output directory: {output_dir}")
print(f"✓ JAX device: {trainer.device}")

# Check if we're resuming
if trainer.start_step > 0:
    print(f"✓ Resuming from step {trainer.start_step}")
    print(f"✓ Training history loaded with {len(trainer.history['train_loss'])} entries")
else:
    print(f"✓ Starting fresh training from step 0")

## 4. Initialize Trainer and Monitor

Create the trainer and set up monitoring.

In [None]:
# Initialize trainer with resume option
trainer = SIRENTrainer(
    dataset=dataset,
    config=config,
    output_dir=output_dir,
    resume_from_checkpoint=not START_FRESH  # Resume unless starting fresh
)

# Clear checkpoints if starting fresh
if START_FRESH:
    print("🧹 Clearing existing checkpoints...")
    trainer.clear_checkpoints()
    print("✅ Starting with clean slate")

print(f"✓ Trainer initialized")
print(f"✓ Output directory: {output_dir}")
print(f"✓ JAX device: {trainer.device}")

# Check if we're resuming
if trainer.start_step > 0:
    print(f"✓ Resuming from step {trainer.start_step}")
    print(f"✓ Training history loaded with {len(trainer.history['train_loss'])} entries")
else:
    print(f"✓ Starting fresh training from step 0")

### Training Notes

With the CProfSiren-style configuration:
- Loss values will be ~1000× larger due to scaling (this is expected)
- Training on linear values captures the full dynamic range
- The SIREN model squares its output internally to ensure positive densities
- Large batch sizes provide stable gradients
- StepLR will drop learning rate aggressively at step 2000

Monitor for:
- Steady loss decrease in the first 2000 steps
- Sharp improvement after LR drop at step 2000
- Validation loss tracking training loss

## 5. Train the Model

Start training with live monitoring.

In [None]:
# Set up monitoring with live plotting
monitor = TrainingMonitor(output_dir, live_plotting=True)

# Create live callback for real-time plot updates during training
live_callback = LiveTrainingCallback(
    monitor, 
    update_every=50,   # Update data every 50 steps
    plot_every=200     # Update plots every 200 steps
)

# Add callback to trainer for live monitoring
trainer.add_callback(live_callback)

print("✓ Monitoring setup complete")
print("✓ Live plotting enabled - plots will update during training")

In [None]:
# Start training
print("Starting SIREN training...")
history = trainer.train()

print("\n✓ Training completed!")
print(f"Final train loss: {history['train_loss'][-1]:.6f}")
if history['val_loss']:
    print(f"Final val loss: {history['val_loss'][-1]:.6f}")

In [None]:
# Plot training history
fig = trainer.plot_training_history(save_path=output_dir / 'final_training_progress.png')
plt.show()

# Also plot monitoring dashboard
print("\n📊 Training Progress Dashboard:")
monitor_fig = monitor.plot_progress(save_path=output_dir / 'training_dashboard.png')
plt.show()

In [None]:
# Plot training history
fig = trainer.plot_training_history(save_path=output_dir / 'final_training_progress.png')
plt.show()

## 7. Analyze Model Performance

Use the analyzer to evaluate model performance.

In [None]:
# Create slice comparison plots
fig_slices = analyzer.plot_lookup_table_slices(save_path=output_dir / 'lookup_table_slices.png', figsize=(16,6))
plt.show()

## 9. Test Model Predictions

Test the model on specific energy/angle/distance combinations.

## 10. Save Trained Model with Metadata

Save the trained model with all necessary metadata for inference, including normalization parameters and dataset information.

In [None]:
# # Export analysis results
# analyzer.export_results(output_dir / 'analysis_results.json')

# # Export monitoring data
# monitor.export_data(output_dir / 'monitoring_data.json')

# print("✓ Results exported to:")
# print(f"  Model checkpoint: {output_dir / 'final_model.npz'}")
# print(f"  Training config: {output_dir / 'config.json'}")
# print(f"  Training history: {output_dir / 'training_history.json'}")
# print(f"  Analysis results: {output_dir / 'analysis_results.json'}")
# print(f"  Monitoring data: {output_dir / 'monitoring_data.json'}")
# print(f"  Plots: {output_dir / '*.png'}")

In [None]:
# Create angular profile comparison plots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('Angular Profiles: SIREN vs Lookup Table', fontsize=14)

energies_for_profiles = [300, 500, 800]
fixed_distance = 2000  # mm

for i, energy in enumerate(energies_for_profiles):
    ax = axes[i]
    
    # Find closest energy in lookup table
    energy_idx = np.argmin(np.abs(energy_centers - energy))
    actual_energy = energy_centers[energy_idx]
    
    if np.abs(actual_energy - energy) > 100:
        ax.text(0.5, 0.5, f'No data near\\n{energy} MeV', 
               ha='center', va='center', transform=ax.transAxes)
        continue
    
    # Find closest distance index
    distance_idx = np.argmin(np.abs(distance_centers - fixed_distance))
    actual_distance = distance_centers[distance_idx]
    
    # Get lookup table angular profile at this energy and distance
    table_profile = density_table[energy_idx, :, distance_idx]
    
    # Create SIREN predictions for same points
    coords_1d = np.array([[actual_energy, angle, actual_distance] for angle in angle_centers])
    siren_profile = predictor.predict_batch(coords_1d)
    
    # Convert angles to degrees for plotting
    angles_deg = np.degrees(angle_centers)
    
    # Plot both profiles
    valid_mask = (table_profile > 1e-10) & (siren_profile > 1e-10)
    
    if np.sum(valid_mask) > 5:
        ax.plot(angles_deg[valid_mask], table_profile[valid_mask], 
               'o-', alpha=0.8, label='Lookup Table', markersize=4, linewidth=2, color='blue')
        ax.plot(angles_deg[valid_mask], siren_profile[valid_mask], 
               's--', alpha=0.8, label='SIREN Model', markersize=3, linewidth=2, color='red')
        
        # Mark Cherenkov angle (approximately 43 degrees for water)
        ax.axvline(43, color='green', linestyle=':', alpha=0.7, label='Cherenkov angle')
        
        ax.set_xlabel('Angle (degrees)')
        ax.set_ylabel('Photon Density')
        ax.set_title(f'{actual_energy:.0f} MeV\\n(d = {actual_distance:.0f} mm)')
        ax.set_yscale('log')
        ax.set_xlim(angles_deg.min(), angles_deg.max())
        ax.grid(True, alpha=0.3)
        
        if i == 0:  # Show legend only on first plot
            ax.legend()
        
        # Calculate and display correlation
        from scipy.stats import pearsonr
        corr, _ = pearsonr(table_profile[valid_mask], siren_profile[valid_mask])
        ax.text(0.05, 0.95, f'R = {corr:.3f}', transform=ax.transAxes,
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    else:
        ax.text(0.5, 0.5, 'Insufficient\\nvalid data', ha='center', va='center', 
               transform=ax.transAxes)
        ax.set_title(f'{energy} MeV')

plt.tight_layout()
plt.savefig(output_dir / 'angular_profile_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Angular profile comparison saved to: {output_dir / 'angular_profile_comparison.png'}")

In [2]:
# Load original lookup table for comparison
import h5py

# Define energies to compare
energies_to_plot = [200, 400, 600, 800, 1000]  # MeV

# Load the original HDF5 file
h5_path = photonsim_root / 'output' / 'photon_lookup_table.h5'

with h5py.File(h5_path, 'r') as f:
    # Load full density table and coordinates
    density_table = f['data/photon_table_density'][:]  # Shape: (n_energy, n_angle, n_distance)
    energy_centers = f['coordinates/energy_centers'][:]
    angle_centers = f['coordinates/angle_centers'][:]
    distance_centers = f['coordinates/distance_centers'][:]

print(f"📊 Loaded lookup table with shape: {density_table.shape}")
print(f"  Energy range: {energy_centers.min():.0f} to {energy_centers.max():.0f} MeV")
print(f"  Angle range: {np.degrees(angle_centers.min()):.1f} to {np.degrees(angle_centers.max()):.1f} degrees")
print(f"  Distance range: {distance_centers.min():.0f} to {distance_centers.max():.0f} mm")

# Create figure with 2 rows × 5 columns
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
fig.suptitle('SIREN vs Lookup Table Comparison: 2D Slices at Fixed Energies', fontsize=16)

for i, energy in enumerate(energies_to_plot):
    # Find closest energy index in lookup table
    energy_idx = np.argmin(np.abs(energy_centers - energy))
    actual_energy = energy_centers[energy_idx]
    
    if np.abs(actual_energy - energy) > 100:  # Skip if too far
        axes[0, i].text(0.5, 0.5, f'No data near\\n{energy} MeV', 
                       ha='center', va='center', transform=axes[0, i].transAxes)
        axes[1, i].text(0.5, 0.5, f'No data near\\n{energy} MeV', 
                       ha='center', va='center', transform=axes[1, i].transAxes)
        continue
    
    # Get lookup table slice
    table_slice = density_table[energy_idx, :, :]  # Shape: (n_angle, n_distance)
    
    # Create SIREN prediction for the same grid
    angle_mesh, distance_mesh = np.meshgrid(angle_centers, distance_centers, indexing='ij')
    energy_grid = np.full_like(angle_mesh, actual_energy)
    
    # Stack coordinates for SIREN prediction
    eval_coords = np.stack([
        energy_grid.flatten(),
        angle_mesh.flatten(),
        distance_mesh.flatten()
    ], axis=-1)
    
    # Get SIREN predictions
    siren_predictions = predictor.predict_batch(eval_coords)
    siren_slice = siren_predictions.reshape(angle_mesh.shape)
    
    # Convert angles to degrees for plotting
    angle_mesh_deg = np.degrees(angle_mesh)
    
    # Plot lookup table (top row)
    ax_table = axes[0, i]
    
    # Use log scale for better visualization
    table_slice_plot = np.where(table_slice > 1e-10, table_slice, np.nan)
    im1 = ax_table.pcolormesh(angle_mesh_deg, distance_mesh, table_slice_plot, 
                             cmap='viridis', shading='auto',
                             norm=plt.LogNorm(vmin=1e-2, vmax=table_slice.max()))
    
    ax_table.set_title(f'Lookup Table\\n{actual_energy:.0f} MeV')
    ax_table.set_xlabel('Angle (degrees)')
    ax_table.set_ylabel('Distance (mm)')
    
    if i == 4:  # Add colorbar to last plot
        cbar1 = plt.colorbar(im1, ax=ax_table)
        cbar1.set_label('Photon Density')
    
    # Plot SIREN predictions (bottom row)
    ax_siren = axes[1, i]
    
    siren_slice_plot = np.where(siren_slice > 1e-10, siren_slice, np.nan)
    im2 = ax_siren.pcolormesh(angle_mesh_deg, distance_mesh, siren_slice_plot, 
                             cmap='viridis', shading='auto',
                             norm=plt.LogNorm(vmin=1e-2, vmax=siren_slice.max()))
    
    ax_siren.set_title(f'SIREN Model\\n{actual_energy:.0f} MeV')
    ax_siren.set_xlabel('Angle (degrees)')
    ax_siren.set_ylabel('Distance (mm)')
    
    if i == 4:  # Add colorbar to last plot
        cbar2 = plt.colorbar(im2, ax=ax_siren)
        cbar2.set_label('Photon Density')
    
    # Print comparison stats
    print(f"\\n🔍 Energy {actual_energy:.0f} MeV:")
    print(f"  Table range: {table_slice.min():.2e} to {table_slice.max():.2e}")
    print(f"  SIREN range: {siren_slice.min():.2e} to {siren_slice.max():.2e}")
    print(f"  SIREN/Table ratio: {siren_slice.mean()/table_slice.mean():.2f}")

plt.tight_layout()
plt.savefig(output_dir / 'siren_vs_lookup_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\\n✅ Comparison plot saved to: {output_dir / 'siren_vs_lookup_comparison.png'}")

NameError: name 'photonsim_root' is not defined

In [None]:
# Import the inference module
sys.path.append(str(project_root / 'siren' / 'training'))
from inference import SIRENPredictor

# Load the saved model
model_base_path = model_save_dir / 'photonsim_siren'
predictor = SIRENPredictor(model_base_path)

print(f"✅ Model loaded successfully!")

# Test single prediction
energy = 500  # MeV
angle = np.radians(45)  # radians 
distance = 2000  # mm

density = predictor.predict(energy, angle, distance)
print(f"\n🔮 Single Prediction Test:")
print(f"  Input: E={energy} MeV, θ={np.degrees(angle):.1f}°, d={distance} mm")
print(f"  Predicted density: {density:.2e} photons/mm²")

# Test batch prediction
test_inputs = np.array([
    [400, np.radians(30), 1500],
    [500, np.radians(45), 2000], 
    [600, np.radians(60), 2500]
])

batch_densities = predictor.predict_batch(test_inputs)
print(f"\n📊 Batch Prediction Test:")
for i, (inp, dens) in enumerate(zip(test_inputs, batch_densities)):
    print(f"  Input {i+1}: E={inp[0]:.0f} MeV, θ={np.degrees(inp[1]):.1f}°, d={inp[2]:.0f} mm → {dens:.2e}")

# Display model info
info = predictor.get_info()
print(f"\n📋 Loaded Model Info:")
print(f"  Energy range: {info['dataset_info']['energy_range']} MeV")
print(f"  Angle range: {np.degrees(info['dataset_info']['angle_range'])} degrees")
print(f"  Distance range: {info['dataset_info']['distance_range']} mm")

In [None]:
# Save the trained model with complete metadata
model_save_dir = output_dir / 'trained_model'
weights_path, metadata_path = trainer.save_trained_model(model_save_dir, 'photonsim_siren')

print(f"✅ Model saved successfully!")
print(f"  Weights: {weights_path}")
print(f"  Metadata: {metadata_path}")

# Display the saved metadata
import json
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

print(f"\n📋 Model Metadata:")
print(f"  Energy range: {metadata['dataset_info']['energy_range']} MeV")
print(f"  Angle range: {np.degrees(metadata['dataset_info']['angle_range'])} degrees") 
print(f"  Distance range: {metadata['dataset_info']['distance_range']} mm")
print(f"  Model architecture: {metadata['model_config']['hidden_layers']} layers × {metadata['model_config']['hidden_features']} features")
print(f"  Final training loss: {metadata['training_info']['final_train_loss']:.6f}")
print(f"  Final validation loss: {metadata['training_info']['final_val_loss']:.6f}")