In [None]:
from datetime import time
from pathlib import Path

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import sys
sys.path.append('..')
training_path = Path('../siren/training')
sys.path.append(str(training_path))

from tools.siren import *
from inference import SIRENPredictor

from tools.simulation import create_photonsim_siren_grid
from tools.generate import generate_random_cone_vectors, normalize, photonsim_differentiable_get_rays

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import sys
sys.path.append('..')

from tools.propagate import create_photon_propagator
from tools.geometry import generate_detector
from tools.utils import generate_random_point_inside_cylinder, generate_random_params
from tools.losses import compute_simplified_loss
from tools.simulation import setup_event_simulator

# Configuration
default_json_filename = '../config/IWCD_geom_config.json'
detector = generate_detector(default_json_filename)
detector_points = jnp.array(detector.all_points)

# Benchmarking parameters
K_VALUES = [1, 2, 3, 4]
N_VALUES = [100000, 250000, 500000, 750000, 1000000, 1500000, 2000000]
WARMUP_RUNS = 10
TIMING_RUNS = 100

def benchmark_simulation(is_calibration=False):
    """Benchmark simulation for different N and K values"""
    results = {k: {'N': [], 'mean_time': [], 'std_time': []} for k in K_VALUES}

    print("Benchmarking simulation performance...")

    for K in K_VALUES:
        print(f"\nK = {K}")

        for Nphot in N_VALUES:
            print(f"  N = {Nphot:,}")

            # Setup simulator
            simulate_event = jit(setup_event_simulator(
                default_json_filename, Nphot, temperature=None, K=K, is_calibration=is_calibration
            ))

            # Initial key
            key = jax.random.PRNGKey(42)

            # Warmup
            for _ in range(WARMUP_RUNS):
                key, subkey = jax.random.split(key)
                # Generate new position each time
                source_origin = generate_random_point_inside_cylinder(subkey, r=4, h=6)
                other_params = (source_origin, 1000)
                track_params = (
                    jnp.array(800.0, dtype=jnp.float32),
                    jnp.array([0.0, 0.0, 0.0], dtype=jnp.float32),
                    jnp.array([jnp.pi/3, jnp.pi/4], dtype=jnp.float32)
                )
                
                if is_calibration == False:
                    other_params = track_params
                
                detector_params = (
                    jnp.array(4.), jnp.array(0.2), jnp.array(6.), jnp.array(0.001)
                )
                
                key, subkey = jax.random.split(key)
                result = simulate_event(other_params, detector_params, subkey)
                jax.tree.map(lambda x: x.block_until_ready(), result)

            # Timing
            times = []
            for _ in range(TIMING_RUNS):
                key, subkey = jax.random.split(key)
                # Generate new position each time
                source_origin = generate_random_point_inside_cylinder(subkey, r=4, h=6)
                other_params = (source_origin, 1000)
                track_params = (
                    jnp.array(800.0, dtype=jnp.float32),
                    jnp.array([0.0, 0.0, 0.0], dtype=jnp.float32),
                    jnp.array([jnp.pi/3, jnp.pi/4], dtype=jnp.float32)
                )
                
                if is_calibration == False:
                    other_params = track_params
                
                detector_params = (
                    jnp.array(4.), jnp.array(0.2), jnp.array(6.), jnp.array(0.001)
                )
            
                key, subkey = jax.random.split(key)
                start = time.time()
                result = simulate_event(other_params, detector_params, subkey)
                jax.tree.map(lambda x: x.block_until_ready(), result)
                times.append(time.time() - start)

            results[K]['N'].append(Nphot)
            results[K]['mean_time'].append(np.mean(times))
            results[K]['std_time'].append(np.std(times))

    return results

def benchmark_gradient(is_calibration=False):
    """Benchmark gradient computation for source and detector parameters"""
    results = {k: {'N': [], 'mean_time': [], 'std_time': []} for k in K_VALUES}

    print("\n\nBenchmarking gradient computation...")

    # Generate true data once with fixed N
    Nphot_true = 100000
    key = jax.random.PRNGKey(42)

    source_origin = generate_random_point_inside_cylinder(key, r=4, h=6)
    other_params = (source_origin, 1000)
    track_params = (
        jnp.array(800.0, dtype=jnp.float32),
        jnp.array([0.0, 0.0, 0.0], dtype=jnp.float32),
        jnp.array([jnp.pi/3, jnp.pi/4], dtype=jnp.float32)
    )
    
    if is_calibration == False:
        other_params = track_params
    
    detector_params = (
        jnp.array(4.), jnp.array(0.2), jnp.array(6.), jnp.array(0.001)
    )

    # Generate true data with fixed simulator
    simulate_true = setup_event_simulator(
        default_json_filename, Nphot_true, temperature=None, K=2, is_calibration=is_calibration
    )
    key, subkey = jax.random.split(key)
    true_data = jax.lax.stop_gradient(simulate_true(other_params, detector_params, subkey))

    for K in K_VALUES:
        print(f"\nK = {K}")

        for Nphot in N_VALUES:
            print(f"  N = {Nphot:,}")

            # Setup simulator for this N
            simulate_event = setup_event_simulator(
                default_json_filename, Nphot, temperature=None, K=K, is_calibration=is_calibration
            )

            # Create loss and gradient function for both source and detector params
            @jit
            def loss_and_grad_fn(other_params, detector_params):
                def loss_fn(s_origin, d_params):
                    #source_params = (s_origin, 1000)
                    simulated_data = simulate_event(other_params, d_params, key)
                    return compute_simplified_loss(detector_points, *true_data, *simulated_data, lambda_time=0.0)
                return value_and_grad(loss_fn, argnums=(0, 1))(source_origin, detector_params)

            # Warmup
            for i in range(WARMUP_RUNS):
                # Generate new position each time
                key, subkey = jax.random.split(key)
                #test_origin = generate_random_point_inside_cylinder(subkey, r=4, h=6)

                source_origin = generate_random_point_inside_cylinder(key, r=4, h=6)
                other_params = (source_origin, 1000)
                
                if is_calibration == False:
                    other_params = generate_random_params(key)
                
                loss_val, (grad_source, grad_detector) = loss_and_grad_fn(other_params, detector_params)
                jax.tree.map(lambda x: x.block_until_ready(), (loss_val, grad_source, grad_detector))

            # Timing
            times = []
            for i in range(TIMING_RUNS):
                # Generate new position each time
                key, subkey = jax.random.split(key)
                #test_origin = generate_random_point_inside_cylinder(subkey, r=4, h=6)

                source_origin = generate_random_point_inside_cylinder(key, r=4, h=6)
                other_params = (source_origin, 1000)
                
                if is_calibration == False:
                    other_params = generate_random_params(key)
                
                start = time.time()
                loss_val, (grad_source, grad_detector) = loss_and_grad_fn(other_params, detector_params)
                jax.tree.map(lambda x: x.block_until_ready(), (loss_val, grad_source, grad_detector))
                times.append(time.time() - start)

            results[K]['N'].append(Nphot)
            results[K]['mean_time'].append(np.mean(times))
            results[K]['std_time'].append(np.std(times))

    return results

sim_results_full = benchmark_simulation(is_calibration=False)
grad_results_full = benchmark_gradient(is_calibration=False)

sim_results_calib = benchmark_simulation(is_calibration=True)
grad_results_calib = benchmark_gradient(is_calibration=True)



# Print summary
print("\n" + "="*50)
print("SUMMARY CALIBRATION")
print("="*50)
print("\nSimulation times (milliseconds):")
for K in K_VALUES:
    print(f"\nK = {K}:")
    for i, N in enumerate(sim_results_calib[K]['N']):
        mean_ms = sim_results_calib[K]['mean_time'][i] * 1000
        std_ms = sim_results_calib[K]['std_time'][i] * 1000
        print(f"  N = {N:>9,}: {mean_ms:.2f} ± {std_ms:.2f} ms")

print("\nGradient computation times (milliseconds):")
for K in K_VALUES:
    print(f"\nK = {K}:")
    for i, N in enumerate(grad_results_calib[K]['N']):
        mean_ms = grad_results_calib[K]['mean_time'][i] * 1000
        std_ms = grad_results_calib[K]['std_time'][i] * 1000
        print(f"  N = {N:>9,}: {mean_ms:.2f} ± {std_ms:.2f} ms")


print("\n" + "="*50)
print("SUMMARY FULL")
print("="*50)
print("\nSimulation times (milliseconds):")
for K in K_VALUES:
    print(f"\nK = {K}:")
    for i, N in enumerate(sim_results_full[K]['N']):
        mean_ms = sim_results_full[K]['mean_time'][i] * 1000
        std_ms = sim_results_full[K]['std_time'][i] * 1000
        print(f"  N = {N:>9,}: {mean_ms:.2f} ± {std_ms:.2f} ms")

print("\nGradient computation times (milliseconds):")
for K in K_VALUES:
    print(f"\nK = {K}:")
    for i, N in enumerate(grad_results_full[K]['N']):
        mean_ms = grad_results_full[K]['mean_time'][i] * 1000
        std_ms = grad_results_full[K]['std_time'][i] * 1000
        print(f"  N = {N:>9,}: {mean_ms:.2f} ± {std_ms:.2f} ms")

In [None]:
figures_dir = Path('figures')
figures_dir.mkdir(parents=True, exist_ok=True)

# Plot results
def plot_results(sim_results, grad_results):
    """Plot both simulation and gradient results"""

    # Use same colors for both plots
    colors = plt.cm.viridis(np.linspace(0, 1, len(K_VALUES)))

    # Simulation timing plot
    plt.figure(figsize=(6, 4))

    for i, K in enumerate(K_VALUES):
        N = np.array(sim_results[K]['N'])
        mean_time = np.array(sim_results[K]['mean_time']) * 1000  # Convert to ms
        std_time = np.array(sim_results[K]['std_time']) * 1000   # Convert to ms

        plt.plot(N, mean_time, 'o-', color=colors[i], label=f'K={K}')
        plt.fill_between(N, mean_time - std_time, mean_time + std_time, alpha=0.3, color=colors[i])

    plt.xlabel('Number of Photons (N)')
    plt.ylabel('Time (ms)')
    plt.title('Simulation Performance')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xscale('log')
    plt.tight_layout()
    plt.savefig('figures/simulation_timing.png', dpi=300)
    plt.show()

    # Gradient timing plot
    plt.figure(figsize=(6, 4))

    for i, K in enumerate(K_VALUES):
        N = np.array(grad_results[K]['N'])
        mean_time = np.array(grad_results[K]['mean_time']) * 1000  # Convert to ms
        std_time = np.array(grad_results[K]['std_time']) * 1000   # Convert to ms

        plt.plot(N, mean_time, 'o-', color=colors[i], label=f'K={K}')
        plt.fill_between(N, mean_time - std_time, mean_time + std_time, alpha=0.3, color=colors[i])

    plt.xlabel('Number of Photons (N)')
    plt.ylabel('Time (ms)')
    plt.title('Gradient Computation Performance\n(Source + Detector Parameters)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xscale('log')
    plt.tight_layout()
    plt.savefig('figures/gradient_timing.png', dpi=300)
    plt.show()

In [None]:
plot_results(sim_results_calib, grad_results_calib)

In [None]:
plot_results(sim_results_full, grad_results_full)