# PhotonSim SIREN Ray Generation Validation
This notebook validates ray generation using the newly trained PhotonSim SIREN model.

In [None]:
import sys
import os
from pathlib import Path

# Add parent directory to path
sys.path.append('..')
sys.path.append('../tools')

# Add training modules
training_path = Path('../siren/training')
sys.path.append(str(training_path))

import numpy as np
import jax
import jax.numpy as jnp
from jax import random

# Import PhotonSim training modules
from inference import SIRENPredictor

# Import tools
from tools.siren import SIREN
from tools.table import Table
from tools.simulation import create_siren_grid
from tools.generate import generate_random_cone_vectors, normalize, photonsim_differentiable_get_rays
from tools.simulation import create_photonsim_siren_grid

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

from functools import partial

In [None]:
model_base_path = Path('../notebooks/output/photonsim_siren_training/trained_model/photonsim_siren')
photonsim_predictor = SIRENPredictor(model_base_path)
table_data = create_photonsim_siren_grid(photonsim_predictor, 500)

# Set up simulation parameters (same as original)
origin = jnp.array([0.5, 0.0, -0.5])
direction = jnp.array([1.0, -1.0, 0.2])
Nphot = 1_000_000
key = random.PRNGKey(0)
energy = 500
model_params = photonsim_predictor.params

In [None]:
def calculate_opening_angles(ray_vectors, direction):
    """Calculate opening angles between ray vectors and a reference direction."""
    # Normalize the direction vector
    direction_norm = direction / jnp.linalg.norm(direction)
    
    # Normalize each ray vector
    ray_vectors_norm = ray_vectors / jnp.linalg.norm(ray_vectors, axis=1)[:, None]
    
    # Calculate dot product between normalized vectors
    cos_theta = jnp.dot(ray_vectors_norm, direction_norm)
    
    # Calculate opening angles in radians using arccos
    angles = jnp.arccos(jnp.clip(cos_theta, -1.0, 1.0))
    
    return angles

    
# Generate rays for comparison energies using actual PhotonSim ranges
fig, axes = plt.subplots(4, 5, figsize=(20, 16))
# Create 20 energy values distributed across the PhotonSim training range
energies = np.linspace(200, 1000, 20)  # 20 evenly spaced energies from 200 to 1000 MeV

for i, energy in enumerate(energies):
    row = i // 5  # Integer division to get row index
    col = i % 5   # Modulo to get column index
    
    print(f"Generating rays for energy {energy:.0f} MeV using PhotonSim ranges...")
    
    # Use the updated PhotonSim function (no table_data needed)
    ray_vectors, ray_origins, photon_weights = photonsim_differentiable_get_rays(origin, direction, energy, Nphot, table_data, model_params, key)
    ranges = jnp.linalg.norm(ray_origins - origin, axis=1)
    angles = calculate_opening_angles(ray_vectors, direction)
    
    # Calculate num_seeds used for this energy
    num_seeds = jnp.int32(energy * 11.136 -720.3)
    
    h = axes[row, col].hist2d(
        ranges, 
        angles,
        weights=photon_weights.squeeze(), 
        bins=[200, 200], 
        cmap='gnuplot',
        norm=LogNorm(vmin=1),
        range=[[0, 6], [0, 3.14]]
    )
    axes[row, col].set_ylabel('Angle (radians)')
    axes[row, col].set_xlabel('Distance to Origin (m)')
    axes[row, col].set_title(f'PhotonSim SIREN\nEnergy: {energy:.0f} MeV\nSeeds: {num_seeds:,}')

fig.patch.set_facecolor('white')
fig.suptitle('PhotonSim SIREN Ray Generation - Actual Training Ranges', fontsize=16)
fig.tight_layout()
plt.show()
print("✅ 20 energy comparison completed using actual PhotonSim ranges")