In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import logging

import jax
import jax.numpy as jnp

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get project paths correctly
current_dir = Path.cwd()
project_root = current_dir.parent  # diffCherenkov root (one level up from notebooks)
photonsim_root = project_root.parent / 'PhotonSim'  # PhotonSim root

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Add project paths
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'tools'))

print(f"Current dir: {current_dir}")
print(f"Project root: {project_root}")
print(f"PhotonSim root: {photonsim_root}")

# Verify paths exist
print(f"Project root exists: {project_root.exists()}")
print(f"Siren dir exists: {(project_root / 'siren').exists()}")
print(f"Training dir exists: {(project_root / 'siren' / 'training').exists()}")
print(f"PhotonSim root exists: {photonsim_root.exists()}")

In [None]:
# Import the refactored training modules with fallback strategies
print("📦 Importing training modules...")

imported_successfully = False
        
try:
    training_path = project_root / 'siren' / 'training'
    if str(training_path) not in sys.path:
        sys.path.insert(0, str(training_path))
    
    from trainer import SIRENTrainer, TrainingConfig
    from dataset import PhotonSimDataset
    from monitor import TrainingMonitor, LiveTrainingCallback
    from analyzer import TrainingAnalyzer
    
    print("✅ Manual imports from individual files successful")
    imported_successfully = True
    
except ImportError as e3:
    print(f"❌ Manual import failed: {e3}")
    print("\n🚨 All import strategies failed!")
    print("Please check:")
    print(f"  1. Current working directory: {Path.cwd()}")
    print(f"  2. Project root: {project_root}")
    print(f"  3. Siren directory exists: {(project_root / 'siren').exists()}")
    print(f"  4. Training directory exists: {(project_root / 'siren' / 'training').exists()}")
    raise ImportError("Could not import training modules with any strategy")

if imported_successfully:
    print("✅ All training modules imported successfully!")
    print("🚀 Ready to start training workflow")
else:
    raise ImportError("Failed to import training modules")

In [None]:
# Path to the PhotonSim HDF5 lookup table
h5_path = photonsim_root / 'output' / 'photon_lookup_table.h5'

# Check if file exists
if not h5_path.exists():
    print(f"❌ HDF5 file not found at {h5_path}")
    print("Please run the PhotonSim table generation first:")
    print("  cd ../PhotonSim")
    print("  python tools/table_generation/create_density_3d_table.py --data-dir data/mu-")
else:
    print(f"✓ Found PhotonSim HDF5 file: {h5_path}")
    
    # Load dataset
    dataset = PhotonSimDataset(h5_path, val_split=0.1)
    
    print(f"\nDataset info:")
    print(f"  Data type: {dataset.data_type}")
    print(f"  Total samples: {len(dataset.data['inputs']):,}")
    print(f"  Train samples: {len(dataset.train_indices):,}")
    print(f"  Val samples: {len(dataset.val_indices):,}")
    print(f"  Energy range: {dataset.energy_range[0]:.0f}-{dataset.energy_range[1]:.0f} MeV")
    print(f"  Angle range: {np.degrees(dataset.angle_range[0]):.1f}-{np.degrees(dataset.angle_range[1]):.1f} degrees")
    print(f"  Distance range: {dataset.distance_range[0]:.0f}-{dataset.distance_range[1]:.0f} mm")

In [None]:
# Load and configure dataset with default consistent normalization
dataset = PhotonSimDataset(h5_path, val_split=0.1)

print("✅ Dataset configured with built-in consistent normalization")
print(f"  • Input normalization: [-1, 1]")
print(f"  • Target normalization: [0, 1] from log scale")
print(f"  • Training and evaluation use identical normalization by default")

# Verify dataset is working correctly
rng = jax.random.PRNGKey(42)
sample_inputs, sample_targets = dataset.get_batch(100, rng, 'train', normalized=True)

print(f"\n🧪 Dataset verification:")
print(f"  Sample inputs shape: {sample_inputs.shape}")
print(f"  Sample targets shape: {sample_targets.shape}")
print(f"  Input range: [{sample_inputs.min():.3f}, {sample_inputs.max():.3f}]")
print(f"  Target range: [{sample_targets.min():.3f}, {sample_targets.max():.3f}]")
print(f"  ✅ Normalization working correctly")

In [None]:
# Training mode configuration
START_FRESH = False

# Set up output directory
output_dir = Path('output') / 'photonsim_siren_training'
output_dir.mkdir(exist_ok=True, parents=True)

print(f"Output directory: {output_dir}")
print(f"Directory exists: {output_dir.exists()}")

# Check for existing checkpoints
existing_checkpoints = list(output_dir.glob('*.npz'))
existing_history = output_dir / 'training_history.json'

if existing_checkpoints or existing_history.exists():
    print(f"\n🔍 Found existing training data:")
    if existing_history.exists():
        import json
        with open(existing_history, 'r') as f:
            history = json.load(f)
            if history.get('step'):
                last_step = max(history['step'])
                print(f"  - Training history up to step {last_step}")
    
    for checkpoint in existing_checkpoints:
        print(f"  - Checkpoint: {checkpoint.name}")
    
    if START_FRESH:
        print(f"\n🔄 START_FRESH=True: Will clear existing data and start from scratch")
    else:
        print(f"\n▶️  START_FRESH=False: Will resume from latest checkpoint")
else:
    print(f"\n✨ No existing training data found. Starting fresh.")

In [None]:
# Training configuration - now with built-in CProfSiren-style loss scaling
config = TrainingConfig(
    # Model architecture - same as CProfSiren
    hidden_features=256,
    hidden_layers=3,        # CProfSiren used 3 layers
    w0=30.0,               # Standard SIREN frequency
    
    # Training parameters - adapted from CProfSiren
    learning_rate=1e-4,     # Same as CProfSiren
    weight_decay=0.0,       # CProfSiren didn't use weight decay
    batch_size=10_000,      # Large batches for stable gradients
    num_steps=20000,        # Total training steps
    
    # PATIENCE-BASED LR SCHEDULER
    use_patience_scheduler=True,   # Enable patience-based LR
    patience=20,                   # Reduce LR after 20 validations with no improvement
    lr_reduction_factor=0.5,       # Cut LR in half when triggered
    min_lr=1e-7,                   # Don't go below this
    
    # Optimizer settings
    optimizer='adam',       # Same as CProfSiren
    grad_clip_norm=0.0,     # No gradient clipping
    
    # Logging frequency
    log_every=10,           # Log training progress
    val_every=50,           # Check validation for patience
    checkpoint_every=500,   # Save checkpoints
    
    seed=42
)

print("📊 Training Configuration:")
print(f"  • Architecture: {config.hidden_layers} layers × {config.hidden_features} features")
print(f"  • Learning Rate: {config.learning_rate:.2e} (with patience-based scheduling)")
print(f"  • Batch Size: {config.batch_size:,}")
print(f"  • Total Steps: {config.num_steps:,}")
print(f"  • Built-in CProfSiren-style loss scaling (×1000)")
print(f"  • Consistent log-normalized targets by default")
print(f"  ✅ Ready for training with all improvements built-in!")

In [None]:
# Initialize trainer with improved defaults (no custom functions needed!)
trainer = SIRENTrainer(
    dataset=dataset,        # Uses consistent normalization by default
    config=config,
    output_dir=output_dir,
    resume_from_checkpoint=not START_FRESH
)

# Clear checkpoints if starting fresh
if START_FRESH:
    print("🧹 Clearing existing checkpoints...")
    trainer.clear_checkpoints()
    print("✅ Starting with clean slate")

print(f"✓ Trainer initialized with improved defaults")
print(f"✓ Output directory: {output_dir}")
print(f"✓ JAX device: {trainer.device}")

# Check if we're resuming
if trainer.start_step > 0:
    print(f"✓ Resuming from step {trainer.start_step}")
    print(f"✓ Training history loaded with {len(trainer.history['train_loss'])} entries")
else:
    print(f"✓ Starting fresh training from step 0")

print(f"\n🎯 Built-in improvements:")
print(f"  • CProfSiren-style MSE loss with ×1000 scaling")
print(f"  • Proper SIREN output handling (first element of tuple)")
print(f"  • Consistent log-normalized targets throughout")
print(f"  • No custom training functions needed - it's all built-in!")
print(f"  ✅ Ready to train with trainer.train()")

In [None]:
# ✅ SIMPLE VERIFICATION: Check normalization consistency
print("🔍 Quick normalization check...")

# Test batch consistency
rng = jax.random.PRNGKey(42)
base_inputs, base_targets = dataset.get_batch(100, rng, 'train', normalized=True)
consistent_inputs, consistent_targets = dataset.get_batch(100, rng, 'train', normalized=True)

print(f"Base dataset targets: {base_targets.min():.6f} to {base_targets.max():.6f}")
print(f"Consistent dataset targets: {consistent_targets.min():.6f} to {consistent_targets.max():.6f}")

# Check if they match
if jnp.allclose(base_targets, consistent_targets):
    print("✅ SUCCESS! Normalization is now consistent")
    print("   → SIREN training and evaluation use identical [0,1] scales")
    print("   → Plots should now match visually!")
else:
    print("❌ Still inconsistent")
    
print("\n🚀 Ready to proceed with training and analysis!")

In [None]:
# Set up monitoring with live plotting
monitor = TrainingMonitor(output_dir, live_plotting=True)

# Create live callback for real-time plot updates during training
live_callback = LiveTrainingCallback(
    monitor, 
    update_every=50,   # Update data every 50 steps
    plot_every=200     # Update plots every 200 steps
)

# Add callback to trainer for live monitoring
trainer.add_callback(live_callback)

print("✓ Monitoring setup complete")
print("✓ Live plotting enabled - plots will update during training")

In [None]:
# Start training
print("Starting SIREN training...")
history = trainer.train()

print("\n✓ Training completed!")
print(f"Final train loss: {history['train_loss'][-1]:.6f}")
if history['val_loss']:
    print(f"Final val loss: {history['val_loss'][-1]:.6f}")

In [None]:
# Plot training history
fig = trainer.plot_training_history(save_path=output_dir / 'final_training_progress.png')
plt.show()

In [None]:
# Save trained model (now with robust saving built-in)
model_save_dir = output_dir / 'trained_model'
weights_path, metadata_path = trainer.save_trained_model(model_save_dir, 'photonsim_siren')

print(f"✅ Model saved successfully!")
print(f"  Weights: {weights_path}")
print(f"  Metadata: {metadata_path}")

# Display the saved metadata
import json
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

print(f"\n📋 Model Metadata:")
print(f"  Energy range: {metadata['dataset_info']['energy_range']} MeV")
print(f"  Angle range: {np.degrees(metadata['dataset_info']['angle_range'])} degrees") 
print(f"  Distance range: {metadata['dataset_info']['distance_range']} mm")
print(f"  Model architecture: {metadata['model_config']['hidden_layers']} layers × {metadata['model_config']['hidden_features']} features")
print(f"  Final training loss: {metadata['training_info']['final_train_loss']:.6f}")
if metadata['training_info']['final_val_loss']:
    print(f"  Final validation loss: {metadata['training_info']['final_val_loss']:.6f}")
print(f"  ✅ Model saved with all metadata and robust parameter handling")