# Debug Data Pipeline: Scale Consistency Check

This notebook systematically checks the scales of:
- A) Values in the original lookup table
- B) Values sampled and used for training
- C) SIREN predictions during training
- D) Values used during plotting (analyzer)

Goal: Identify where the scale mismatch occurs

In [1]:
import sys
import numpy as np
import h5py
import jax
import jax.numpy as jnp
from pathlib import Path

# needed to avoid crash related to DNN library
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Add project paths (now we're in notebooks/ folder)
sys.path.insert(0, str(Path.cwd().parent))
sys.path.insert(0, str(Path.cwd().parent / 'siren' / 'training'))

from siren.training.dataset import PhotonSimDataset
from siren.training.trainer import SIRENTrainer, TrainingConfig

## A) Original Lookup Table Values

In [2]:
# Load original HDF5 file
data_path = Path('/sdf/home/c/cjesus/Dev/PhotonSim/output/photon_lookup_table.h5')

print("=== A) ORIGINAL LOOKUP TABLE VALUES ===")
with h5py.File(data_path, 'r') as f:
    density_table = f['data/photon_table_density'][:]
    energy_centers = f['coordinates/energy_centers'][:]
    angle_centers = f['coordinates/angle_centers'][:]
    distance_centers = f['coordinates/distance_centers'][:]

print(f"Table shape: {density_table.shape}")
print(f"Table min: {density_table.min():.6e}")
print(f"Table max: {density_table.max():.6e}")
print(f"Table mean: {density_table.mean():.6e}")
print(f"Table median: {np.median(density_table):.6e}")
print(f"Non-zero values: {np.sum(density_table > 0):,}/{density_table.size:,}")

# Sample specific test point for tracking
test_energy_idx = len(energy_centers) // 2  # 500 MeV
test_angle_idx = np.argmin(np.abs(angle_centers - np.radians(43)))  # ~43 degrees (Cherenkov)
test_dist_idx = len(distance_centers) // 2   # Middle distance

test_coords = [energy_centers[test_energy_idx], angle_centers[test_angle_idx], distance_centers[test_dist_idx]]
original_test_value = density_table[test_energy_idx, test_angle_idx, test_dist_idx]

print(f"\nTest point: E={test_coords[0]:.0f} MeV, θ={np.degrees(test_coords[1]):.1f}°, d={test_coords[2]:.0f} mm")
print(f"Original test value: {original_test_value:.6e}")

=== A) ORIGINAL LOOKUP TABLE VALUES ===
Table shape: (91, 500, 500)
Table min: 0.000000e+00
Table max: 9.871154e+02
Table mean: 1.354046e+00
Table median: 0.000000e+00
Non-zero values: 5,659,770/22,750,000

Test point: E=550 MeV, θ=43.0°, d=5010 mm
Original test value: 0.000000e+00


## B) Dataset Loading and Filtering

In [3]:
print("\n=== B) DATASET AFTER LOADING AND FILTERING ===")

# Load dataset (same as training)
dataset = PhotonSimDataset(data_path)

print(f"Filtered data shape: {dataset.data['inputs'].shape}")
print(f"Raw targets min: {dataset.data['targets'].min():.6e}")
print(f"Raw targets max: {dataset.data['targets'].max():.6e}")
print(f"Raw targets mean: {dataset.data['targets'].mean():.6e}")
print(f"Raw targets median: {np.median(dataset.data['targets']):.6e}")

# Find our test point in the filtered dataset
test_coords_array = np.array(test_coords)
distances = np.linalg.norm(dataset.data['inputs'] - test_coords_array, axis=1)
closest_idx = np.argmin(distances)
closest_coords = dataset.data['inputs'][closest_idx]
filtered_test_value = dataset.data['targets'][closest_idx, 0]

print(f"\nClosest point in dataset: E={closest_coords[0]:.0f} MeV, θ={np.degrees(closest_coords[1]):.1f}°, d={closest_coords[2]:.0f} mm")
print(f"Filtered test value: {filtered_test_value:.6e}")
print(f"Value preservation ratio: {filtered_test_value/original_test_value:.6f} (should be ~1.0)")

# Check log-normalized targets
print(f"\nLog-normalized targets min: {dataset.data['targets_log'].min():.3f}")
print(f"Log-normalized targets max: {dataset.data['targets_log'].max():.3f}")
print(f"Log-normalized targets mean: {dataset.data['targets_log'].mean():.3f}")

log_test_value = dataset.data['targets_log'][closest_idx, 0]
print(f"Test point log value: {log_test_value:.6f}")
print(f"Back to linear: {10**log_test_value:.6e}")


=== B) DATASET AFTER LOADING AND FILTERING ===
Filtered data shape: (5659770, 3)
Raw targets min: 2.533042e-03
Raw targets max: 9.871154e+02
Raw targets mean: 5.442721e+00
Raw targets median: 2.536646e-01

Closest point in dataset: E=780 MeV, θ=43.0°, d=4970 mm
Filtered test value: 7.425478e-03
Value preservation ratio: inf (should be ~1.0)

Log-normalized targets min: -2.596
Log-normalized targets max: 2.994
Log-normalized targets mean: -0.491
Test point log value: -2.129276
Back to linear: 7.425478e-03


  print(f"Value preservation ratio: {filtered_test_value/original_test_value:.6f} (should be ~1.0)")


## LinearPhotonSimDataset Wrapper

In [4]:
print("\n=== LinearPhotonSimDataset WRAPPER ===")

# Create the wrapper used in training
class LinearPhotonSimDataset:
    """Wrapper to make dataset return linear values instead of log values"""
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        # Copy necessary attributes directly
        self.data = base_dataset.data
        self.train_indices = base_dataset.train_indices
        self.val_indices = base_dataset.val_indices
        self.normalized_bounds = base_dataset.normalized_bounds
        self.metadata = base_dataset.metadata
        self.energy_range = base_dataset.energy_range
        self.angle_range = base_dataset.angle_range
        self.distance_range = base_dataset.distance_range
        self.data_type = base_dataset.data_type
        self.data_path = base_dataset.data_path
    
    def get_batch(self, batch_size, rng, split='train', normalized=True):
        # Get normalized inputs but LINEAR targets
        if split == 'train':
            indices = self.train_indices
        else:
            indices = self.val_indices
            
        # Random sampling
        batch_indices = jax.random.choice(rng, indices, shape=(batch_size,))
        
        # Get normalized inputs
        inputs = self.data['inputs_normalized'][batch_indices]
        # Get LINEAR targets (not log!)
        targets = self.data['targets'][batch_indices]
        
        return jnp.array(inputs), jnp.array(targets)
    
    def get_sample_input(self):
        """Get a sample input for model initialization."""
        return self.base_dataset.get_sample_input()
    
    def get_full_data(self, split='train', normalized=True):
        """Get full dataset for a given split."""
        return self.base_dataset.get_full_data(split=split, normalized=normalized)
    
    def denormalize_inputs(self, inputs):
        """Convert normalized inputs back to original scale."""
        return self.base_dataset.denormalize_inputs(inputs)
    
    def denormalize_targets(self, targets_log):
        """Convert log-normalized targets back to original scale."""
        return self.base_dataset.denormalize_targets(targets_log)
    
    @property
    def has_validation(self):
        """Check if dataset has validation split."""
        return self.base_dataset.has_validation

linear_dataset = LinearPhotonSimDataset(dataset)

# Test a batch from LinearPhotonSimDataset
rng = jax.random.PRNGKey(42)
test_inputs, test_targets = linear_dataset.get_batch(100, rng, split='train')

print(f"LinearDataset batch inputs shape: {test_inputs.shape}")
print(f"LinearDataset batch targets shape: {test_targets.shape}")
print(f"LinearDataset targets min: {float(test_targets.min()):.6e}")
print(f"LinearDataset targets max: {float(test_targets.max()):.6e}")
print(f"LinearDataset targets mean: {float(test_targets.mean()):.6e}")

# Check if our test point is in this batch
test_target_in_batch = float(test_targets[0, 0])  # First sample, first element
print(f"\nSample target from linear dataset: {test_target_in_batch:.6e}")
print(f"Ratio to original table scale: {test_target_in_batch/dataset.data['targets'].mean():.6f}")

# Test the wrapper methods
print(f"\nWrapper method tests:")
print(f"  get_sample_input() works: {linear_dataset.get_sample_input() is not None}")
print(f"  has_validation: {linear_dataset.has_validation}")
print(f"  data_type: {linear_dataset.data_type}")


=== LinearPhotonSimDataset WRAPPER ===
LinearDataset batch inputs shape: (100, 3)
LinearDataset batch targets shape: (100, 1)
LinearDataset targets min: 3.056105e-03
LinearDataset targets max: 2.081016e+02
LinearDataset targets mean: 7.606782e+00

Sample target from linear dataset: 1.450210e+00
Ratio to original table scale: 0.266449

Wrapper method tests:
  get_sample_input() works: True
  has_validation: True
  data_type: h5_lookup


## C) Load Trained Model and Check Training-time Predictions

In [5]:
print("\n=== C) TRAINED MODEL PREDICTIONS ===")

# Load the trained model (same as in the main notebook)
config = TrainingConfig(
    hidden_features=256,
    hidden_layers=3,
    w0=30.0,
    learning_rate=1e-4,
    weight_decay=0.0,
    batch_size=65536,
    num_steps=5000,
    use_patience_scheduler=True,
    patience=20,
    lr_reduction_factor=0.5,
    min_lr=1e-7
)

output_dir = Path('output') / 'photonsim_siren_training'

try:
    print(f"Output directory exists: {output_dir.exists()}")
    if output_dir.exists():
        checkpoint_files = list(output_dir.glob('*.npz'))
        print(f"Checkpoint files found: {len(checkpoint_files)}")
        for f in checkpoint_files[:5]:  # Show first 5
            print(f"  - {f.name}")
    
    # Initialize trainer with linear dataset
    trainer = SIRENTrainer(
        linear_dataset,
        config,
        output_dir=output_dir,
        resume_from_checkpoint=True
    )
    
    print("✅ Trainer loaded successfully")
    print(f"Trainer state is None: {trainer.state is None}")
    
    if trainer.state is not None:
        # Test SIREN predictions on the same batch
        siren_predictions = trainer.predict(test_inputs)
        
        print(f"\nSIREN predictions shape: {siren_predictions.shape}")
        print(f"SIREN predictions min: {float(siren_predictions.min()):.6e}")
        print(f"SIREN predictions max: {float(siren_predictions.max()):.6e}")
        print(f"SIREN predictions mean: {float(siren_predictions.mean()):.6e}")
        
        # Compare SIREN vs targets for the same batch
        siren_sample = float(siren_predictions[0, 0])  # First prediction
        target_sample = float(test_targets[0, 0])      # First target
        
        print(f"\nSample comparison:")
        print(f"  Target: {target_sample:.6e}")
        print(f"  SIREN:  {siren_sample:.6e}")
        print(f"  Ratio (SIREN/Target): {siren_sample/target_sample:.6f}")
        
        # Check training loss on this batch
        mse_loss = jnp.mean((siren_predictions - test_targets) ** 2)
        scaled_loss = mse_loss * 1000.0  # Same scaling as training
        
        print(f"\nLoss on this batch:")
        print(f"  MSE: {float(mse_loss):.6e}")
        print(f"  Scaled (×1000): {float(scaled_loss):.6e}")
    else:
        print("❌ Trainer state is None - checkpoint loading failed")
        siren_predictions = None
    
except Exception as e:
    print(f"❌ Could not load trainer: {e}")
    import traceback
    traceback.print_exc()
    siren_predictions = None


=== C) TRAINED MODEL PREDICTIONS ===
Output directory exists: True
Checkpoint files found: 31
  - checkpoint_step_2500.npz
  - checkpoint_step_0.npz
  - checkpoint_step_8500.npz
  - checkpoint_step_9500.npz
  - checkpoint_step_6000.npz
✅ Trainer loaded successfully
Trainer state is None: False

SIREN predictions shape: (100, 1)
SIREN predictions min: 6.966487e-05
SIREN predictions max: 9.999993e-01
SIREN predictions mean: 1.485371e-01

Sample comparison:
  Target: 1.450210e+00
  SIREN:  9.450511e-02
  Ratio (SIREN/Target): 0.065166

Loss on this batch:
  MSE: 9.511873e+02
  Scaled (×1000): 9.511872e+05


## D) Analyzer Plotting Values

In [6]:
print("\n=== D) ANALYZER PLOTTING VALUES ===")

if siren_predictions is not None:
    # Simulate what the analyzer does
    
    # 1. Load original table again (as analyzer does)
    with h5py.File(data_path, 'r') as f:
        analyzer_table = f['data/photon_table_density'][:]
        analyzer_energy_centers = f['coordinates/energy_centers'][:]
        analyzer_angle_centers = f['coordinates/angle_centers'][:]
        analyzer_distance_centers = f['coordinates/distance_centers'][:]
    
    print(f"Analyzer table range: {analyzer_table.min():.6e} to {analyzer_table.max():.6e}")
    print(f"Analyzer table mean: {analyzer_table.mean():.6e}")
    
    # 2. Pick same energy slice as our test
    energy_500_idx = np.argmin(np.abs(analyzer_energy_centers - 500))  # 500 MeV
    table_slice = analyzer_table[energy_500_idx, :, :]
    
    print(f"\n500 MeV slice from analyzer:")
    print(f"  Range: {table_slice.min():.6e} to {table_slice.max():.6e}")
    print(f"  Mean: {table_slice.mean():.6e}")
    
    # 3. Create grid for SIREN evaluation (as analyzer does)
    angle_mesh, distance_mesh = np.meshgrid(
        analyzer_angle_centers, 
        analyzer_distance_centers, 
        indexing='ij'
    )
    
    # Create coordinate grid for a small subset
    n_test = 100  # Just test a small grid
    angle_flat = angle_mesh.flatten()[:n_test]
    distance_flat = distance_mesh.flatten()[:n_test]
    energy_flat = np.full_like(angle_flat, 500.0)  # 500 MeV
    
    eval_coords = np.stack([energy_flat, angle_flat, distance_flat], axis=-1)
    
    # 4. Normalize coordinates (as analyzer does)
    input_min = dataset.normalized_bounds['input_min']
    input_max = dataset.normalized_bounds['input_max']
    eval_coords_norm = 2 * ((eval_coords - input_min) / (input_max - input_min)) - 1
    
    # 5. Get SIREN predictions on this grid
    analyzer_siren_predictions = trainer.predict(eval_coords_norm)
    
    print(f"\nAnalyzer SIREN predictions on 500 MeV grid:")
    print(f"  Range: {analyzer_siren_predictions.min():.6e} to {analyzer_siren_predictions.max():.6e}")
    print(f"  Mean: {analyzer_siren_predictions.mean():.6e}")
    
    # 6. Compare with corresponding table values
    table_subset = table_slice.flatten()[:n_test]
    
    print(f"\nCorresponding table values:")
    print(f"  Range: {table_subset.min():.6e} to {table_subset.max():.6e}")
    print(f"  Mean: {table_subset.mean():.6e}")
    
    # 7. Scale comparison
    scale_ratio = analyzer_siren_predictions.mean() / table_subset.mean()
    print(f"\nSCALE COMPARISON (SIREN/Table): {scale_ratio:.6e}")
    print(f"Scaling factor needed: {1/scale_ratio:.2f}")


=== D) ANALYZER PLOTTING VALUES ===
Analyzer table range: 0.000000e+00 to 9.871154e+02
Analyzer table mean: 1.354046e+00

500 MeV slice from analyzer:
  Range: 0.000000e+00 to 8.824773e+02
  Mean: 1.190471e+00

Analyzer SIREN predictions on 500 MeV grid:
  Range: 3.716993e-05 to 9.499058e-01
  Mean: 7.441159e-02

Corresponding table values:
  Range: 0.000000e+00 to 4.031448e+00
  Mean: 1.169120e+00

SCALE COMPARISON (SIREN/Table): 6.364752e-02
Scaling factor needed: 15.71


## Summary and Diagnosis

In [7]:
print("\n=== CRITICAL: WHAT TRAINING ACTUALLY SEES ===")
print("Checking what happens when normalized=True is passed to get_batch()")

# Test what the LinearPhotonSimDataset returns vs base dataset
rng = jax.random.PRNGKey(42)

# 1. Base dataset with normalized=True
base_inputs, base_targets = dataset.get_batch(100, rng, split='train', normalized=True)
print(f"\nBase dataset with normalized=True:")
print(f"  Inputs range: {base_inputs.min():.3f} to {base_inputs.max():.3f}")
print(f"  Targets range: {base_targets.min():.3f} to {base_targets.max():.3f}")
print(f"  Targets mean: {base_targets.mean():.3f}")

# 2. Base dataset with normalized=False  
base_inputs_raw, base_targets_raw = dataset.get_batch(100, rng, split='train', normalized=False)
print(f"\nBase dataset with normalized=False:")
print(f"  Inputs range: {base_inputs_raw.min():.1f} to {base_inputs_raw.max():.1f}")
print(f"  Targets range: {base_targets_raw.min():.3e} to {base_targets_raw.max():.3e}")
print(f"  Targets mean: {base_targets_raw.mean():.3e}")

# 3. LinearPhotonSimDataset wrapper 
linear_inputs, linear_targets = linear_dataset.get_batch(100, rng, split='train', normalized=True)
print(f"\nLinearPhotonSimDataset wrapper:")
print(f"  Inputs range: {linear_inputs.min():.3f} to {linear_inputs.max():.3f}")
print(f"  Targets range: {linear_targets.min():.3e} to {linear_targets.max():.3e}")
print(f"  Targets mean: {linear_targets.mean():.3e}")

print(f"\n🔍 KEY INSIGHT:")
if base_targets.max() <= 1.0:
    print(f"   ❌ Base dataset with normalized=True returns targets ≤ 1.0!")
    print(f"   ❌ This explains why SIREN predictions max out at 1.0")
    print(f"   ❌ Training sees targets in range [0, 1] but analyzer expects [0, 1000]")
else:
    print(f"   ✅ Base dataset with normalized=True returns targets > 1.0")

if linear_targets.max() > 1.0:
    print(f"   ❌ BUT LinearPhotonSimDataset returns targets > 1.0!")
    print(f"   ❌ Training sees larger targets but SIREN still caps at 1.0")
    print(f"   ❌ This suggests architectural limitation in SIREN model")
else:
    print(f"   ✅ LinearPhotonSimDataset also caps targets ≤ 1.0")


=== CRITICAL: WHAT TRAINING ACTUALLY SEES ===
Checking what happens when normalized=True is passed to get_batch()

Base dataset with normalized=True:
  Inputs range: -1.000 to 1.000
  Targets range: -2.515 to 2.318
  Targets mean: -0.535

Base dataset with normalized=False:
  Inputs range: 0.1 to 4490.0
  Targets range: 3.056e-03 to 2.081e+02
  Targets mean: 7.607e+00

LinearPhotonSimDataset wrapper:
  Inputs range: -1.000 to 1.000
  Targets range: 3.056e-03 to 2.081e+02
  Targets mean: 7.607e+00

🔍 KEY INSIGHT:
   ✅ Base dataset with normalized=True returns targets > 1.0
   ❌ BUT LinearPhotonSimDataset returns targets > 1.0!
   ❌ Training sees larger targets but SIREN still caps at 1.0
   ❌ This suggests architectural limitation in SIREN model


In [8]:
print("\n" + "="*60)
print("SUMMARY: DATA PIPELINE SCALE ANALYSIS")
print("="*60)

print(f"\nA) Original lookup table scale:")
print(f"   Range: {density_table.min():.2e} to {density_table.max():.2e}")
print(f"   Mean: {density_table.mean():.2e}")

print(f"\nB) Dataset after filtering:")
print(f"   Range: {dataset.data['targets'].min():.2e} to {dataset.data['targets'].max():.2e}")
print(f"   Mean: {dataset.data['targets'].mean():.2e}")
print(f"   Preservation: {dataset.data['targets'].mean()/density_table.mean():.3f}")

if 'test_targets' in locals():
    print(f"\nLinear dataset training targets:")
    print(f"   Range: {test_targets.min():.2e} to {test_targets.max():.2e}")
    print(f"   Mean: {test_targets.mean():.2e}")

if 'siren_predictions' in locals() and siren_predictions is not None:
    print(f"\nC) SIREN predictions during training:")
    print(f"   Range: {siren_predictions.min():.2e} to {siren_predictions.max():.2e}")
    print(f"   Mean: {siren_predictions.mean():.2e}")
    print(f"   SIREN/Target ratio: {siren_predictions.mean()/test_targets.mean():.3f}")

if 'analyzer_siren_predictions' in locals():
    print(f"\nD) Analyzer plotting values:")
    print(f"   Table: {table_subset.mean():.2e}")
    print(f"   SIREN: {analyzer_siren_predictions.mean():.2e}")
    print(f"   Scale mismatch: {scale_ratio:.2e}")

print(f"\n" + "="*60)
print("DIAGNOSIS:")
if 'siren_predictions' in locals() and siren_predictions is not None:
    training_ratio = siren_predictions.mean()/test_targets.mean()
    if abs(training_ratio - 1.0) < 0.1:
        print("✅ SIREN matches training targets well")
    else:
        print(f"❌ SIREN vs training targets mismatch: {training_ratio:.3f}")
        
    if 'analyzer_siren_predictions' in locals():
        if abs(scale_ratio) < 0.1:
            print("❌ MAJOR scale mismatch in analyzer plotting")
            print("   Issue: Analyzer uses different scale than training")
        else:
            print("✅ Analyzer scale seems reasonable")
            
print("="*60)


SUMMARY: DATA PIPELINE SCALE ANALYSIS

A) Original lookup table scale:
   Range: 0.00e+00 to 9.87e+02
   Mean: 1.35e+00

B) Dataset after filtering:
   Range: 2.53e-03 to 9.87e+02
   Mean: 5.44e+00
   Preservation: 4.020

Linear dataset training targets:
   Range: 3.06e-03 to 2.08e+02
   Mean: 7.61e+00

C) SIREN predictions during training:
   Range: 6.97e-05 to 1.00e+00
   Mean: 1.49e-01
   SIREN/Target ratio: 0.020

D) Analyzer plotting values:
   Table: 1.17e+00
   SIREN: 7.44e-02
   Scale mismatch: 6.36e-02

DIAGNOSIS:
❌ SIREN vs training targets mismatch: 0.020
❌ MAJOR scale mismatch in analyzer plotting
   Issue: Analyzer uses different scale than training
