# Track Reconstruction from Data Events

This notebook reconstructs track parameters from simulated data events by combining:
- Event loading from `generate_muon_data_events.ipynb`
- Optimization techniques from `track_optimization.ipynb`

The goal is to find the best track parameters (energy, position, direction) that reproduce the observed charge and timing patterns.

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import numpy as np
import h5py
import matplotlib.pyplot as plt
from pathlib import Path

# Import necessary modules
from tools.simulation import setup_event_simulator
from tools.geometry import generate_detector
from tools.visualization import create_detector_display, create_detector_comparison_display
from tools.utils import read_event_file, extract_particle_properties, analyze_loaded_particle
from tools.optimization import create_multi_objective_optimizer

## Setup Detector and Simulator

In [None]:
# Configuration
json_filename = '../config/IWCD_geom_config.json'
Nphot = 100_000  # Reduced for faster reconstruction

# Generate detector geometry
detector = generate_detector(json_filename)
detector_points = jnp.array(detector.all_points)
NUM_DETECTORS = len(detector_points)

print(f"Detector has {NUM_DETECTORS} PMTs")

# Setup event simulator for reconstruction
simulate_event = setup_event_simulator(json_filename, Nphot, K=1, is_data=False, temperature=0.)

# Detector parameters for simulation
detector_params = (
    jnp.array(100),          # scatter_length
    jnp.array(0.05),         # reflection_rate
    jnp.array(100000.),      # absorption_length
    jnp.array(0.001)         # gumbel_softmax_temp
)

## Load Data Event

In [None]:
# Load the data event to reconstruct
data_filename = 'output/event_0.h5'

# Check if file exists, if not create some test data
if not Path(data_filename).exists():
    print(f"Data file {data_filename} not found.")
    print("Please run generate_muon_data_events.ipynb first to create test data.")
    # You could also generate a test event here if needed
else:
    print(f"Loading data from {data_filename}")
    
    # Read the event file to see its structure
    data_dict = read_event_file(data_filename, verbose=True)

In [None]:
# Extract data from the loaded event
with h5py.File(data_filename, 'r') as f:
    data_charges = np.array(f['Q'])[0]  # Shape: (N_detectors,)
    data_times = np.array(f['T'])[0]    # Shape: (N_detectors,)
    true_mom = np.array(f['P'])[0]      # True momentum for comparison
    true_vtx = np.array(f['V'])[0]      # True vertex for comparison
    pdg_code = np.array(f['PDG'])[0]    # Particle type

# Extract true parameters for comparison
true_particle_info = analyze_loaded_particle(true_mom, true_vtx, pdg_code)

print("\nTrue Track Parameters:")
print(f"Energy: {true_particle_info['kinetic_energy']:.2f} MeV")
print(f"Theta: {true_particle_info['theta_deg']:.2f}°")
print(f"Phi: {true_particle_info['phi_deg']:.2f}°")
print(f"Vertex: [{true_vtx[0]:.3f}, {true_vtx[1]:.3f}, {true_vtx[2]:.3f}] m")
print(f"Total charge: {np.sum(data_charges):.2f}")
print(f"PMTs with signal: {np.sum(data_charges > 0)}")

## Visualize Data Event

In [None]:
# Create visualization
detector_display = create_detector_display(json_filename, sparse=False)

# Display charge pattern
detector_display(data_charges, data_times, 
                file_name='figures/data_event_charge.pdf', 
                plot_time=False, log_scale=True)

# Display time pattern
detector_display(data_charges, data_times, 
                file_name='figures/data_event_time.pdf', 
                plot_time=True, log_scale=True)

## Set up Optimization for Track Reconstruction

In [None]:
# Import the multi-objective optimizer from optimization module
from tools.optimization import create_multi_objective_optimizer

# Set up multi-objective optimization
print("Setting up multi-objective optimization...")

# Optimization parameters
energy_lr = 2.0      # Learning rate for energy
spatial_lr = 0.1     # Learning rate for spatial/angular parameters  
lambda_time = 0.0    # Time weight (0 to focus on charge only)
tau = 0.01           # Temperature for soft assignments

# Create the multi-objective optimizer
energy_grad_fn, spatial_grad_fn, energy_optimizer, spatial_optimizer = create_multi_objective_optimizer(
    simulate_event=simulate_event,
    detector_points=detector_points,
    detector_params=detector_params,
    energy_lr=energy_lr,
    spatial_lr=spatial_lr,
    lambda_time=lambda_time,
    tau=tau
)

print(f"Energy learning rate: {energy_lr}")
print(f"Spatial learning rate: {spatial_lr}")
print(f"Lambda time: {lambda_time}")
print(f"Tau (soft assignment temperature): {tau}")

## Run Track Reconstruction

In [None]:
# Import initial guess methods
from tools.optimization import grid_scan_initial_guess_vectorized, generate_random_initial_guess

# Choose initial guess method
initial_guess_method = 'grid_scan'  # Options: 'grid_scan' or 'random'

# Prepare true event data in the format expected by the optimizer
true_event_data = (data_charges, data_times)

# Generate initial guess using the selected method
event_key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(event_key)

print(f"Generating initial guess using {initial_guess_method} method...")

if initial_guess_method == 'grid_scan':
    # Grid scan over angles and energy to find best initial guess
    initial_params = grid_scan_initial_guess_vectorized(
        spatial_grad_fn, energy_grad_fn, true_event_data, subkey
    )
elif initial_guess_method == 'random':
    # Random initial guess within detector bounds
    initial_params = generate_random_initial_guess(
        spatial_grad_fn, energy_grad_fn, true_event_data, subkey, detector_points
    )
else:
    # Fallback to fixed initial guess
    initial_energy = 500.0
    initial_vertex = jnp.array([0.0, 0.0, 0.0])
    initial_angles = jnp.array([jnp.pi/4, 0.0])
    initial_params = (initial_energy, initial_vertex, initial_angles)

# Extract initial parameters
initial_energy, initial_vertex, initial_angles = initial_params

print("\nInitial guess parameters:")
print(f"Energy: {initial_energy:.2f} MeV")
print(f"Vertex: [{initial_vertex[0]:.3f}, {initial_vertex[1]:.3f}, {initial_vertex[2]:.3f}] m")
print(f"Theta: {np.degrees(initial_angles[0]):.2f}°")
print(f"Phi: {np.degrees(initial_angles[1]):.2f}°")

print(f"\nTrue event statistics:")
print(f"Total charge: {np.sum(data_charges):.2f}")
print(f"PMTs with signal: {np.sum(data_charges > 0)}")

In [None]:
# Set up optimization parameters
current_params = initial_params
position_scale = 2.0  # Scale factor for position updates

# Storage for optimization history
loss_history = []
energy_loss_history = []
spatial_loss_history = []
param_history = []

# Optimization parameters (same as track_optimization)
n_iterations = 400
patience = 250
min_improvement = 1e-6

print(f"\nStarting multi-objective optimization with {n_iterations} iterations...")
print(f"Position scale: {position_scale}")
print(f"Patience: {patience}")

In [None]:
# Run multi-objective optimization loop
from tools.optimization import run_multi_objective_optimization

# Use a fixed random key for reproducibility
event_key = jax.random.PRNGKey(43)

print("Starting optimization...")

# Run the optimization
optimized_params, loss_history_dict, param_history_dict = run_multi_objective_optimization(
    params=current_params,
    energy_grad_fn=energy_grad_fn,
    spatial_grad_fn=spatial_grad_fn,
    energy_optimizer=energy_optimizer,
    spatial_optimizer=spatial_optimizer,
    event_key=event_key,
    true_event_data=true_event_data,
    n_iterations=n_iterations,
    position_scale=position_scale,
    patience=patience
)

# Extract final parameters
current_energy, current_vertex, current_angles = optimized_params

# Extract loss histories
loss_history = loss_history_dict['total']
energy_loss_history = loss_history_dict['energy']
spatial_loss_history = loss_history_dict['spatial']

print(f"\nOptimization completed after {len(loss_history)} iterations")
if len(loss_history) > 0:
    print(f"Final total loss: {loss_history[-1]:.6f}")
    print(f"Final energy loss: {energy_loss_history[-1]:.6f}")
    print(f"Final spatial loss: {spatial_loss_history[-1]:.6f}")
else:
    print("WARNING: No iterations completed successfully!")

# Store param history for plotting
param_history = []
for i in range(len(param_history_dict['energy'])):
    params = (
        param_history_dict['energy'][i],
        jnp.array([
            param_history_dict['position_x'][i],
            param_history_dict['position_y'][i],
            param_history_dict['position_z'][i]
        ]),
        jnp.array([
            param_history_dict['theta'][i],
            param_history_dict['phi'][i]
        ])
    )
    param_history.append(params)

## Compare Reconstructed vs True Parameters

In [None]:
# Final reconstructed parameters
reconstructed_energy = float(current_energy)
reconstructed_vertex = np.array(current_vertex)
reconstructed_angles = np.array(current_angles)
reconstructed_theta_deg = np.degrees(reconstructed_angles[0])
reconstructed_phi_deg = np.degrees(reconstructed_angles[1])

print("\n" + "="*60)
print("TRACK RECONSTRUCTION RESULTS")
print("="*60)

print("\nTrue Parameters:")
print(f"  Energy:     {true_particle_info['kinetic_energy']:.2f} MeV")
print(f"  Theta:      {true_particle_info['theta_deg']:.2f}°")
print(f"  Phi:        {true_particle_info['phi_deg']:.2f}°")
print(f"  Vertex:     [{true_vtx[0]:.3f}, {true_vtx[1]:.3f}, {true_vtx[2]:.3f}] m")

print("\nReconstructed Parameters:")
print(f"  Energy:     {reconstructed_energy:.2f} MeV")
print(f"  Theta:      {reconstructed_theta_deg:.2f}°")
print(f"  Phi:        {reconstructed_phi_deg:.2f}°")
print(f"  Vertex:     [{reconstructed_vertex[0]:.3f}, {reconstructed_vertex[1]:.3f}, {reconstructed_vertex[2]:.3f}] m")

print("\nReconstruction Errors:")
energy_error = abs(reconstructed_energy - true_particle_info['kinetic_energy'])
theta_error = abs(reconstructed_theta_deg - true_particle_info['theta_deg'])
phi_error = abs(reconstructed_phi_deg - true_particle_info['phi_deg'])
vertex_error = np.linalg.norm(reconstructed_vertex - true_vtx)

print(f"  Energy:     {energy_error:.2f} MeV ({energy_error/true_particle_info['kinetic_energy']*100:.1f}%)")
print(f"  Theta:      {theta_error:.2f}°")
print(f"  Phi:        {phi_error:.2f}°")
print(f"  Vertex:     {vertex_error:.3f} m")

## Generate Reconstructed Event and Compare

In [None]:
# Generate event using reconstructed parameters
final_track_params = (reconstructed_energy, reconstructed_vertex, reconstructed_angles)
key = jax.random.PRNGKey(42)
reconstructed_charges, reconstructed_times = simulate_event(final_track_params, detector_params, key)

print(f"Data event - Total charge: {np.sum(data_charges):.2f}, PMTs hit: {np.sum(data_charges > 0)}")
print(f"Reconstructed - Total charge: {np.sum(reconstructed_charges):.2f}, PMTs hit: {np.sum(reconstructed_charges > 0)}")

# Calculate reconstruction quality metrics
charge_correlation = np.corrcoef(data_charges, reconstructed_charges)[0, 1]
charge_rmse = np.sqrt(np.mean((data_charges - reconstructed_charges)**2))

print(f"\nCharge pattern correlation: {charge_correlation:.4f}")
print(f"Charge RMSE: {charge_rmse:.4f}")

## Visualize Comparison

In [None]:
# Create comparison display
detector_comparison = create_detector_comparison_display(json_filename, sparse=False)

# Compare charge patterns
detector_comparison(
    (data_charges, data_times),  # true_data as tuple
    (reconstructed_charges, reconstructed_times),  # sim_data as tuple
    file_name='figures/reconstruction_charge_comparison.pdf',
    plot_time=False
)

# Compare time patterns  
detector_comparison(
    (data_charges, data_times),  # true_data as tuple
    (reconstructed_charges, reconstructed_times),  # sim_data as tuple
    file_name='figures/reconstruction_time_comparison.pdf',
    plot_time=True,
    align_time=True  # This will align the mean times
)

In [None]:
# Plot convergence history
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Extract history from param_history
energy_history = [params[0] for params in param_history]
vertex_history = [params[1] for params in param_history]
angle_history = [params[2] for params in param_history]

# Total loss history
axes[0, 0].plot(loss_history, label='Total')
axes[0, 0].plot(energy_loss_history, label='Energy', alpha=0.7)
axes[0, 0].plot(spatial_loss_history, label='Spatial', alpha=0.7)
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Convergence')
axes[0, 0].set_yscale('log')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Energy history
axes[0, 1].plot(energy_history, label='Reconstructed')
axes[0, 1].axhline(true_particle_info['kinetic_energy'], color='red', linestyle='--', label='True')
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Energy (MeV)')
axes[0, 1].set_title('Energy Convergence')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Vertex distance from true
vertex_distances = [np.linalg.norm(v - true_vtx) for v in vertex_history]
axes[0, 2].plot(vertex_distances)
axes[0, 2].set_xlabel('Iteration')
axes[0, 2].set_ylabel('Distance from True Vertex (m)')
axes[0, 2].set_title('Vertex Error Convergence')
axes[0, 2].grid(True, alpha=0.3)

# Angle history
theta_history = [np.degrees(angles[0]) for angles in angle_history]
phi_history = [np.degrees(angles[1]) for angles in angle_history]

axes[1, 0].plot(theta_history, label='Theta (Reconstructed)')
axes[1, 0].axhline(true_particle_info['theta_deg'], color='red', linestyle='--', label='Theta (True)')
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Angle (degrees)')
axes[1, 0].set_title('Theta Convergence')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(phi_history, label='Phi (Reconstructed)')
axes[1, 1].axhline(true_particle_info['phi_deg'], color='red', linestyle='--', label='Phi (True)')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Angle (degrees)')
axes[1, 1].set_title('Phi Convergence')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Angular error
theta_errors = [abs(t - true_particle_info['theta_deg']) for t in theta_history]
phi_errors = [abs(p - true_particle_info['phi_deg']) for p in phi_history]
axes[1, 2].plot(theta_errors, label='Theta Error')
axes[1, 2].plot(phi_errors, label='Phi Error')
axes[1, 2].set_xlabel('Iteration')
axes[1, 2].set_ylabel('Angular Error (degrees)')
axes[1, 2].set_title('Angular Error Convergence')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/reconstruction_convergence.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot convergence history
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss history
axes[0, 0].plot(loss_history)
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Convergence')
axes[0, 0].set_yscale('log')
axes[0, 0].grid(True, alpha=0.3)

# Energy history
axes[0, 1].plot(energy_history, label='Reconstructed')
axes[0, 1].axhline(true_particle_info['kinetic_energy'], color='red', linestyle='--', label='True')
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Energy (MeV)')
axes[0, 1].set_title('Energy Convergence')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Angle history
theta_history = [np.degrees(angles[0]) for angles in angle_history]
phi_history = [np.degrees(angles[1]) for angles in angle_history]

axes[1, 0].plot(theta_history, label='Theta (Reconstructed)')
axes[1, 0].axhline(true_particle_info['theta_deg'], color='red', linestyle='--', label='Theta (True)')
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Angle (degrees)')
axes[1, 0].set_title('Theta Convergence')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(phi_history, label='Phi (Reconstructed)')
axes[1, 1].axhline(true_particle_info['phi_deg'], color='red', linestyle='--', label='Phi (True)')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Angle (degrees)')
axes[1, 1].set_title('Phi Convergence')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/reconstruction_convergence.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot vertex position convergence
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

vertex_x = [v[0] for v in vertex_history]
vertex_y = [v[1] for v in vertex_history]
vertex_z = [v[2] for v in vertex_history]

axes[0].plot(vertex_x, label='Reconstructed X')
axes[0].axhline(true_vtx[0], color='red', linestyle='--', label='True X')
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('X Position (m)')
axes[0].set_title('X Position Convergence')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(vertex_y, label='Reconstructed Y')
axes[1].axhline(true_vtx[1], color='red', linestyle='--', label='True Y')
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Y Position (m)')
axes[1].set_title('Y Position Convergence')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(vertex_z, label='Reconstructed Z')
axes[2].axhline(true_vtx[2], color='red', linestyle='--', label='True Z')
axes[2].set_xlabel('Iteration')
axes[2].set_ylabel('Z Position (m)')
axes[2].set_title('Z Position Convergence')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/vertex_reconstruction_convergence.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
np.unique(reconstructed_charges)