In [None]:
# Track Parameter 1D Scans - Gradients and Loss Analysis
# This notebook performs 1D scans for all track parameters (energy, pos x, pos y, pos z, theta, phi)
# plotting both gradients and loss values to understand the optimization landscape

import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import numpy as np
import h5py
import matplotlib.pyplot as plt
from pathlib import Path

# Import necessary modules
from tools.simulation import setup_event_simulator
from tools.geometry import generate_detector
from tools.utils import read_event_file, analyze_loaded_particle
from tools.optimization import create_multi_objective_optimizer

print("JAX devices:", jax.devices())
print("JAX version:", jax.__version__)

In [None]:
# Configuration and setup
json_filename = '../config/IWCD_geom_config.json'
Nphot = 500_000  # Higher for better gradient stability

# Generate detector geometry
detector = generate_detector(json_filename)
detector_points = jnp.array(detector.all_points)
NUM_DETECTORS = len(detector_points)

print(f"Detector has {NUM_DETECTORS} PMTs")

# Setup event simulator
simulate_event = setup_event_simulator(json_filename, Nphot, K=5, is_data=False, temperature=0.5)

# Detector parameters
detector_params = (
    jnp.array(100),          # scatter_length
    jnp.array(0.05),         # reflection_rate
    jnp.array(100000.),      # absorption_length
    jnp.array(0.001)         # gumbel_softmax_temp
)

# Optimization parameters
energy_lr = 2.0
spatial_lr = 0.1
lambda_time = 1.0
tau = 0.01

In [None]:
# Load data event
data_filename = 'output/event_0.h5'
if not Path(data_filename).exists():
    print(f"Data file {data_filename} not found. Please create it first.")
else:
    with h5py.File(data_filename, 'r') as f:
        data_charges = np.array(f['Q'])[0]
        data_times = np.array(f['T'])[0]
        true_mom = np.array(f['P'])[0]
        true_vtx = np.array(f['V'])[0]
        pdg_code = np.array(f['PDG'])[0]
    
    true_particle_info = analyze_loaded_particle(true_mom, true_vtx, pdg_code)
    true_event_data = (data_charges, data_times)
    
    print(f"True energy: {true_particle_info['kinetic_energy']:.2f} MeV")
    print(f"True position: [{true_particle_info['vertex'][0]:.3f}, {true_particle_info['vertex'][1]:.3f}, {true_particle_info['vertex'][2]:.3f}] m")
    print(f"True angles: theta={np.degrees(true_particle_info['theta_rad']):.1f}°, phi={np.degrees(true_particle_info['phi_rad']):.1f}°")
    print(f"Total charge: {np.sum(data_charges):.2f}")
    print(f"PMTs with signal: {np.sum(data_charges > 0)}")

In [None]:
# Create multi-objective optimizer
energy_grad_fn, spatial_grad_fn, energy_optimizer, spatial_optimizer = create_multi_objective_optimizer(
    simulate_event=simulate_event,
    detector_points=detector_points,
    detector_params=detector_params,
    energy_lr=energy_lr,
    spatial_lr=spatial_lr,
    lambda_time=lambda_time,
    tau=tau
)

print("Multi-objective optimizer created successfully!")

# Define baseline parameters (starting point for scans)
baseline_energy = true_particle_info['kinetic_energy']
baseline_position = jnp.array(true_particle_info['vertex'])
baseline_angles = jnp.array([true_particle_info['theta_rad'], true_particle_info['phi_rad']])
baseline_params = (jnp.array(baseline_energy), baseline_position, baseline_angles)

print(f"Baseline parameters:")
print(f"  Energy: {baseline_energy:.2f} MeV")
print(f"  Position: [{baseline_position[0]:.3f}, {baseline_position[1]:.3f}, {baseline_position[2]:.3f}] m")
print(f"  Angles: [{np.degrees(baseline_angles[0]):.1f}°, {np.degrees(baseline_angles[1]):.1f}°]")

In [None]:
# Define scan ranges for each parameter
points = 121
scan_ranges = {
    'energy': np.linspace(baseline_energy * 0.5, 1000, points),
    'pos_x': np.linspace(baseline_position[0] - 3.0, baseline_position[0] + 3.0, points),
    'pos_y': np.linspace(baseline_position[1] - 3.0, baseline_position[1] + 3.0, points),
    'pos_z': np.linspace(baseline_position[2] - 3.0, baseline_position[2] + 3.0, points),
    'theta': np.linspace(max(0.1, baseline_angles[0] - 0.5), min(np.pi - 0.1, baseline_angles[0] + 0.5), points),
    'phi': np.linspace(baseline_angles[1] - 0.5, baseline_angles[1] + 0.5, points)
}

print("Scan ranges defined:")
for param, values in scan_ranges.items():
    print(f"  {param}: [{values[0]:.3f}, {values[-1]:.3f}] with {len(values)} points")

# Fixed random key for consistent results
fixed_key = jax.random.PRNGKey(42)

In [None]:
# Function to create parameter variations
def create_param_variation(param_name, param_value, baseline_params):
    """
    Create parameter tuple with one parameter varied from baseline.
    """
    energy, position, angles = baseline_params
    
    if param_name == 'energy':
        return (jnp.array(param_value), position, angles)
    elif param_name == 'pos_x':
        new_pos = position.at[0].set(param_value)
        return (energy, new_pos, angles)
    elif param_name == 'pos_y':
        new_pos = position.at[1].set(param_value)
        return (energy, new_pos, angles)
    elif param_name == 'pos_z':
        new_pos = position.at[2].set(param_value)
        return (energy, new_pos, angles)
    elif param_name == 'theta':
        new_angles = angles.at[0].set(param_value)
        return (energy, position, new_angles)
    elif param_name == 'phi':
        new_angles = angles.at[1].set(param_value)
        return (energy, position, new_angles)
    else:
        raise ValueError(f"Unknown parameter: {param_name}")

# Function to extract specific gradient component
def extract_gradient_component(grads, param_name):
    """
    Extract the gradient component for a specific parameter.
    """
    energy_grad, position_grad, angles_grad = grads
    
    if param_name == 'energy':
        return float(energy_grad)
    elif param_name == 'pos_x':
        return float(position_grad[0])
    elif param_name == 'pos_y':
        return float(position_grad[1])
    elif param_name == 'pos_z':
        return float(position_grad[2])
    elif param_name == 'theta':
        return float(angles_grad[0])
    elif param_name == 'phi':
        return float(angles_grad[1])
    else:
        raise ValueError(f"Unknown parameter: {param_name}")

print("Helper functions defined successfully!")

In [None]:
# Perform 1D scans for energy loss
def perform_energy_scan(param_name, param_values):
    """
    Perform 1D scan for energy loss and gradients.
    """
    energy_losses = []
    energy_grads = []
    
    print(f"Scanning {param_name}...")
    
    for i, param_value in enumerate(param_values):
        # Create parameter variation
        params = create_param_variation(param_name, param_value, baseline_params)
        
        # Compute energy loss and gradient
        try:
            loss, grads = energy_grad_fn(params, true_event_data, fixed_key)
            grad_component = extract_gradient_component(grads, param_name)
            
            energy_losses.append(float(loss))
            energy_grads.append(grad_component)
            
        except Exception as e:
            print(f"Error at {param_name}={param_value:.3f}: {e}")
            energy_losses.append(np.nan)
            energy_grads.append(np.nan)
        
        if (i + 1) % 10 == 0:
            print(f"  Progress: {i+1}/{len(param_values)}")
    
    return np.array(energy_losses), np.array(energy_grads)

# Perform 1D scans for spatial loss
def perform_spatial_scan(param_name, param_values):
    """
    Perform 1D scan for spatial loss and gradients.
    """
    spatial_losses = []
    spatial_grads = []
    
    print(f"Scanning {param_name} (spatial)...")
    
    for i, param_value in enumerate(param_values):
        # Create parameter variation
        params = create_param_variation(param_name, param_value, baseline_params)
        
        # Compute spatial loss and gradient
        try:
            loss, grads = spatial_grad_fn(params, true_event_data, fixed_key)
            grad_component = extract_gradient_component(grads, param_name)
            
            spatial_losses.append(float(loss))
            spatial_grads.append(grad_component)
            
        except Exception as e:
            print(f"Error at {param_name}={param_value:.3f}: {e}")
            spatial_losses.append(np.nan)
            spatial_grads.append(np.nan)
        
        if (i + 1) % 10 == 0:
            print(f"  Progress: {i+1}/{len(param_values)}")
    
    return np.array(spatial_losses), np.array(spatial_grads)

print("Scan functions defined successfully!")

In [None]:
# Perform scans for energy parameter
print("=" * 60)
print("SCANNING ENERGY PARAMETER")
print("=" * 60)

energy_values = scan_ranges['energy']
energy_losses_energy, energy_grads_energy = perform_energy_scan('energy', energy_values)
spatial_losses_energy, spatial_grads_energy = perform_spatial_scan('energy', energy_values)

print(f"Energy scan completed!")
print(f"Energy loss range: [{np.nanmin(energy_losses_energy):.6f}, {np.nanmax(energy_losses_energy):.6f}]")
print(f"Energy grad range: [{np.nanmin(energy_grads_energy):.8f}, {np.nanmax(energy_grads_energy):.8f}]")
print(f"Spatial loss range: [{np.nanmin(spatial_losses_energy):.6f}, {np.nanmax(spatial_losses_energy):.6f}]")
print(f"Spatial grad range: [{np.nanmin(spatial_grads_energy):.8f}, {np.nanmax(spatial_grads_energy):.8f}]")

In [None]:
# Plot energy parameter scan results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Energy loss vs energy
ax1.plot(energy_values, energy_losses_energy, 'b-', linewidth=2, label='Energy Loss')
ax1.axvline(baseline_energy, color='r', linestyle='--', alpha=0.7, label='True Value')
ax1.set_xlabel('Energy (MeV)')
ax1.set_ylabel('Energy Loss')
ax1.set_title('Energy Loss vs Energy')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Energy gradient vs energy
ax2.plot(energy_values, energy_grads_energy, 'g-', linewidth=2, label='Energy Gradient')
ax2.axvline(baseline_energy, color='r', linestyle='--', alpha=0.7, label='True Value')
ax2.axhline(0, color='k', linestyle='-', alpha=0.3)
ax2.set_xlabel('Energy (MeV)')
ax2.set_ylabel('Energy Gradient')
ax2.set_title('Energy Gradient vs Energy')
ax2.grid(True, alpha=0.3)
ax2.legend()

# Spatial loss vs energy
ax3.plot(energy_values, spatial_losses_energy, 'purple', linewidth=2, label='Spatial Loss')
ax3.axvline(baseline_energy, color='r', linestyle='--', alpha=0.7, label='True Value')
ax3.set_xlabel('Energy (MeV)')
ax3.set_ylabel('Spatial Loss')
ax3.set_title('Spatial Loss vs Energy')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Spatial gradient vs energy
ax4.plot(energy_values, spatial_grads_energy, 'orange', linewidth=2, label='Spatial Gradient')
ax4.axvline(baseline_energy, color='r', linestyle='--', alpha=0.7, label='True Value')
ax4.axhline(0, color='k', linestyle='-', alpha=0.3)
ax4.set_xlabel('Energy (MeV)')
ax4.set_ylabel('Spatial Gradient (Energy Component)')
ax4.set_title('Spatial Gradient vs Energy')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.show()

In [None]:
# Perform scans for position X parameter
print("=" * 60)
print("SCANNING POSITION X PARAMETER")
print("=" * 60)

pos_x_values = scan_ranges['pos_x']
energy_losses_pos_x, energy_grads_pos_x = perform_energy_scan('pos_x', pos_x_values)
spatial_losses_pos_x, spatial_grads_pos_x = perform_spatial_scan('pos_x', pos_x_values)

print(f"Position X scan completed!")
print(f"Energy loss range: [{np.nanmin(energy_losses_pos_x):.6f}, {np.nanmax(energy_losses_pos_x):.6f}]")
print(f"Energy grad range: [{np.nanmin(energy_grads_pos_x):.8f}, {np.nanmax(energy_grads_pos_x):.8f}]")
print(f"Spatial loss range: [{np.nanmin(spatial_losses_pos_x):.6f}, {np.nanmax(spatial_losses_pos_x):.6f}]")
print(f"Spatial grad range: [{np.nanmin(spatial_grads_pos_x):.8f}, {np.nanmax(spatial_grads_pos_x):.8f}]")

In [None]:
# Plot position X parameter scan results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Energy loss vs position X
ax1.plot(pos_x_values, energy_losses_pos_x, 'b-', linewidth=2, label='Energy Loss')
ax1.axvline(baseline_position[0], color='r', linestyle='--', alpha=0.7, label='True Value')
ax1.set_xlabel('Position X (m)')
ax1.set_ylabel('Energy Loss')
ax1.set_title('Energy Loss vs Position X')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Energy gradient vs position X
ax2.plot(pos_x_values, energy_grads_pos_x, 'g-', linewidth=2, label='Energy Gradient')
ax2.axvline(baseline_position[0], color='r', linestyle='--', alpha=0.7, label='True Value')
ax2.axhline(0, color='k', linestyle='-', alpha=0.3)
ax2.set_xlabel('Position X (m)')
ax2.set_ylabel('Energy Gradient (Pos X Component)')
ax2.set_title('Energy Gradient vs Position X')
ax2.grid(True, alpha=0.3)
ax2.legend()

# Spatial loss vs position X
ax3.plot(pos_x_values, spatial_losses_pos_x, 'purple', linewidth=2, label='Spatial Loss')
ax3.axvline(baseline_position[0], color='r', linestyle='--', alpha=0.7, label='True Value')
ax3.set_xlabel('Position X (m)')
ax3.set_ylabel('Spatial Loss')
ax3.set_title('Spatial Loss vs Position X')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Spatial gradient vs position X
ax4.plot(pos_x_values, spatial_grads_pos_x, 'orange', linewidth=2, label='Spatial Gradient')
ax4.axvline(baseline_position[0], color='r', linestyle='--', alpha=0.7, label='True Value')
ax4.axhline(0, color='k', linestyle='-', alpha=0.3)
ax4.set_xlabel('Position X (m)')
ax4.set_ylabel('Spatial Gradient (Pos X Component)')
ax4.set_title('Spatial Gradient vs Position X')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.show()

In [None]:
# Perform scans for position Y parameter
print("=" * 60)
print("SCANNING POSITION Y PARAMETER")
print("=" * 60)

pos_y_values = scan_ranges['pos_y']
energy_losses_pos_y, energy_grads_pos_y = perform_energy_scan('pos_y', pos_y_values)
spatial_losses_pos_y, spatial_grads_pos_y = perform_spatial_scan('pos_y', pos_y_values)

print(f"Position Y scan completed!")
print(f"Energy loss range: [{np.nanmin(energy_losses_pos_y):.6f}, {np.nanmax(energy_losses_pos_y):.6f}]")
print(f"Energy grad range: [{np.nanmin(energy_grads_pos_y):.8f}, {np.nanmax(energy_grads_pos_y):.8f}]")
print(f"Spatial loss range: [{np.nanmin(spatial_losses_pos_y):.6f}, {np.nanmax(spatial_losses_pos_y):.6f}]")
print(f"Spatial grad range: [{np.nanmin(spatial_grads_pos_y):.8f}, {np.nanmax(spatial_grads_pos_y):.8f}]")

In [None]:
# Plot position Y parameter scan results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Energy loss vs position Y
ax1.plot(pos_y_values, energy_losses_pos_y, 'b-', linewidth=2, label='Energy Loss')
ax1.axvline(baseline_position[1], color='r', linestyle='--', alpha=0.7, label='True Value')
ax1.set_xlabel('Position Y (m)')
ax1.set_ylabel('Energy Loss')
ax1.set_title('Energy Loss vs Position Y')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Energy gradient vs position Y
ax2.plot(pos_y_values, energy_grads_pos_y, 'g-', linewidth=2, label='Energy Gradient')
ax2.axvline(baseline_position[1], color='r', linestyle='--', alpha=0.7, label='True Value')
ax2.axhline(0, color='k', linestyle='-', alpha=0.3)
ax2.set_xlabel('Position Y (m)')
ax2.set_ylabel('Energy Gradient (Pos Y Component)')
ax2.set_title('Energy Gradient vs Position Y')
ax2.grid(True, alpha=0.3)
ax2.legend()

# Spatial loss vs position Y
ax3.plot(pos_y_values, spatial_losses_pos_y, 'purple', linewidth=2, label='Spatial Loss')
ax3.axvline(baseline_position[1], color='r', linestyle='--', alpha=0.7, label='True Value')
ax3.set_xlabel('Position Y (m)')
ax3.set_ylabel('Spatial Loss')
ax3.set_title('Spatial Loss vs Position Y')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Spatial gradient vs position Y
ax4.plot(pos_y_values, spatial_grads_pos_y, 'orange', linewidth=2, label='Spatial Gradient')
ax4.axvline(baseline_position[1], color='r', linestyle='--', alpha=0.7, label='True Value')
ax4.axhline(0, color='k', linestyle='-', alpha=0.3)
ax4.set_xlabel('Position Y (m)')
ax4.set_ylabel('Spatial Gradient (Pos Y Component)')
ax4.set_title('Spatial Gradient vs Position Y')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.show()

In [None]:
# Perform scans for position Z parameter
print("=" * 60)
print("SCANNING POSITION Z PARAMETER")
print("=" * 60)

pos_z_values = scan_ranges['pos_z']
energy_losses_pos_z, energy_grads_pos_z = perform_energy_scan('pos_z', pos_z_values)
spatial_losses_pos_z, spatial_grads_pos_z = perform_spatial_scan('pos_z', pos_z_values)

print(f"Position Z scan completed!")
print(f"Energy loss range: [{np.nanmin(energy_losses_pos_z):.6f}, {np.nanmax(energy_losses_pos_z):.6f}]")
print(f"Energy grad range: [{np.nanmin(energy_grads_pos_z):.8f}, {np.nanmax(energy_grads_pos_z):.8f}]")
print(f"Spatial loss range: [{np.nanmin(spatial_losses_pos_z):.6f}, {np.nanmax(spatial_losses_pos_z):.6f}]")
print(f"Spatial grad range: [{np.nanmin(spatial_grads_pos_z):.8f}, {np.nanmax(spatial_grads_pos_z):.8f}]")

In [None]:
# Plot position Z parameter scan results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Energy loss vs position Z
ax1.plot(pos_z_values, energy_losses_pos_z, 'b-', linewidth=2, label='Energy Loss')
ax1.axvline(baseline_position[2], color='r', linestyle='--', alpha=0.7, label='True Value')
ax1.set_xlabel('Position Z (m)')
ax1.set_ylabel('Energy Loss')
ax1.set_title('Energy Loss vs Position Z')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Energy gradient vs position Z
ax2.plot(pos_z_values, energy_grads_pos_z, 'g-', linewidth=2, label='Energy Gradient')
ax2.axvline(baseline_position[2], color='r', linestyle='--', alpha=0.7, label='True Value')
ax2.axhline(0, color='k', linestyle='-', alpha=0.3)
ax2.set_xlabel('Position Z (m)')
ax2.set_ylabel('Energy Gradient (Pos Z Component)')
ax2.set_title('Energy Gradient vs Position Z')
ax2.grid(True, alpha=0.3)
ax2.legend()

# Spatial loss vs position Z
ax3.plot(pos_z_values, spatial_losses_pos_z, 'purple', linewidth=2, label='Spatial Loss')
ax3.axvline(baseline_position[2], color='r', linestyle='--', alpha=0.7, label='True Value')
ax3.set_xlabel('Position Z (m)')
ax3.set_ylabel('Spatial Loss')
ax3.set_title('Spatial Loss vs Position Z')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Spatial gradient vs position Z
ax4.plot(pos_z_values, spatial_grads_pos_z, 'orange', linewidth=2, label='Spatial Gradient')
ax4.axvline(baseline_position[2], color='r', linestyle='--', alpha=0.7, label='True Value')
ax4.axhline(0, color='k', linestyle='-', alpha=0.3)
ax4.set_xlabel('Position Z (m)')
ax4.set_ylabel('Spatial Gradient (Pos Z Component)')
ax4.set_title('Spatial Gradient vs Position Z')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.show()

In [None]:
# Perform scans for theta parameter
print("=" * 60)
print("SCANNING THETA PARAMETER")
print("=" * 60)

theta_values = scan_ranges['theta']
energy_losses_theta, energy_grads_theta = perform_energy_scan('theta', theta_values)
spatial_losses_theta, spatial_grads_theta = perform_spatial_scan('theta', theta_values)

print(f"Theta scan completed!")
print(f"Energy loss range: [{np.nanmin(energy_losses_theta):.6f}, {np.nanmax(energy_losses_theta):.6f}]")
print(f"Energy grad range: [{np.nanmin(energy_grads_theta):.8f}, {np.nanmax(energy_grads_theta):.8f}]")
print(f"Spatial loss range: [{np.nanmin(spatial_losses_theta):.6f}, {np.nanmax(spatial_losses_theta):.6f}]")
print(f"Spatial grad range: [{np.nanmin(spatial_grads_theta):.8f}, {np.nanmax(spatial_grads_theta):.8f}]")

In [None]:
# Plot theta parameter scan results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Convert theta to degrees for plotting
theta_degrees = np.degrees(theta_values)
baseline_theta_degrees = np.degrees(baseline_angles[0])

# Energy loss vs theta
ax1.plot(theta_degrees, energy_losses_theta, 'b-', linewidth=2, label='Energy Loss')
ax1.axvline(baseline_theta_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax1.set_xlabel('Theta (degrees)')
ax1.set_ylabel('Energy Loss')
ax1.set_title('Energy Loss vs Theta')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Energy gradient vs theta
ax2.plot(theta_degrees, energy_grads_theta, 'g-', linewidth=2, label='Energy Gradient')
ax2.axvline(baseline_theta_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax2.axhline(0, color='k', linestyle='-', alpha=0.3)
ax2.set_xlabel('Theta (degrees)')
ax2.set_ylabel('Energy Gradient (Theta Component)')
ax2.set_title('Energy Gradient vs Theta')
ax2.grid(True, alpha=0.3)
ax2.legend()

# Spatial loss vs theta
ax3.plot(theta_degrees, spatial_losses_theta, 'purple', linewidth=2, label='Spatial Loss')
ax3.axvline(baseline_theta_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax3.set_xlabel('Theta (degrees)')
ax3.set_ylabel('Spatial Loss')
ax3.set_title('Spatial Loss vs Theta')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Spatial gradient vs theta
ax4.plot(theta_degrees, spatial_grads_theta, 'orange', linewidth=2, label='Spatial Gradient')
ax4.axvline(baseline_theta_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax4.axhline(0, color='k', linestyle='-', alpha=0.3)
ax4.set_xlabel('Theta (degrees)')
ax4.set_ylabel('Spatial Gradient (Theta Component)')
ax4.set_title('Spatial Gradient vs Theta')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.show()

In [None]:
# Perform scans for phi parameter
print("=" * 60)
print("SCANNING PHI PARAMETER")
print("=" * 60)

phi_values = scan_ranges['phi']
energy_losses_phi, energy_grads_phi = perform_energy_scan('phi', phi_values)
spatial_losses_phi, spatial_grads_phi = perform_spatial_scan('phi', phi_values)

print(f"Phi scan completed!")
print(f"Energy loss range: [{np.nanmin(energy_losses_phi):.6f}, {np.nanmax(energy_losses_phi):.6f}]")
print(f"Energy grad range: [{np.nanmin(energy_grads_phi):.8f}, {np.nanmax(energy_grads_phi):.8f}]")
print(f"Spatial loss range: [{np.nanmin(spatial_losses_phi):.6f}, {np.nanmax(spatial_losses_phi):.6f}]")
print(f"Spatial grad range: [{np.nanmin(spatial_grads_phi):.8f}, {np.nanmax(spatial_grads_phi):.8f}]")

In [None]:
# Plot phi parameter scan results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Convert phi to degrees for plotting
phi_degrees = np.degrees(phi_values)
baseline_phi_degrees = np.degrees(baseline_angles[1])

# Energy loss vs phi
ax1.plot(phi_degrees, energy_losses_phi, 'b-', linewidth=2, label='Energy Loss')
ax1.axvline(baseline_phi_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax1.set_xlabel('Phi (degrees)')
ax1.set_ylabel('Energy Loss')
ax1.set_title('Energy Loss vs Phi')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Energy gradient vs phi
ax2.plot(phi_degrees, energy_grads_phi, 'g-', linewidth=2, label='Energy Gradient')
ax2.axvline(baseline_phi_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax2.axhline(0, color='k', linestyle='-', alpha=0.3)
ax2.set_xlabel('Phi (degrees)')
ax2.set_ylabel('Energy Gradient (Phi Component)')
ax2.set_title('Energy Gradient vs Phi')
ax2.grid(True, alpha=0.3)
ax2.legend()

# Spatial loss vs phi
ax3.plot(phi_degrees, spatial_losses_phi, 'purple', linewidth=2, label='Spatial Loss')
ax3.axvline(baseline_phi_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax3.set_xlabel('Phi (degrees)')
ax3.set_ylabel('Spatial Loss')
ax3.set_title('Spatial Loss vs Phi')
ax3.grid(True, alpha=0.3)
ax3.legend()

# Spatial gradient vs phi
ax4.plot(phi_degrees, spatial_grads_phi, 'orange', linewidth=2, label='Spatial Gradient')
ax4.axvline(baseline_phi_degrees, color='r', linestyle='--', alpha=0.7, label='True Value')
ax4.axhline(0, color='k', linestyle='-', alpha=0.3)
ax4.set_xlabel('Phi (degrees)')
ax4.set_ylabel('Spatial Gradient (Phi Component)')
ax4.set_title('Spatial Gradient vs Phi')
ax4.grid(True, alpha=0.3)
ax4.legend()

plt.tight_layout()
plt.show()

In [None]:
# Summary analysis of all parameters
print("=" * 80)
print("SUMMARY ANALYSIS OF ALL TRACK PARAMETERS")
print("=" * 80)

# Collect gradient statistics for each parameter
all_scans = {
    'energy': (energy_grads_energy, spatial_grads_energy),
    'pos_x': (energy_grads_pos_x, spatial_grads_pos_x),
    'pos_y': (energy_grads_pos_y, spatial_grads_pos_y),
    'pos_z': (energy_grads_pos_z, spatial_grads_pos_z),
    'theta': (energy_grads_theta, spatial_grads_theta),
    'phi': (energy_grads_phi, spatial_grads_phi)
}

print("\nGradient Statistics:")
print(f"{'Parameter':<10} {'Energy Grad':<25} {'Spatial Grad':<25}")
print(f"{'':10} {'Range':<12} {'Max Abs':<12} {'Range':<12} {'Max Abs':<12}")
print("-" * 80)

for param_name, (energy_grads, spatial_grads) in all_scans.items():
    # Remove NaN values
    energy_grads_clean = energy_grads[~np.isnan(energy_grads)]
    spatial_grads_clean = spatial_grads[~np.isnan(spatial_grads)]
    
    if len(energy_grads_clean) > 0:
        energy_range = np.max(energy_grads_clean) - np.min(energy_grads_clean)
        energy_max_abs = np.max(np.abs(energy_grads_clean))
    else:
        energy_range = 0.0
        energy_max_abs = 0.0
    
    if len(spatial_grads_clean) > 0:
        spatial_range = np.max(spatial_grads_clean) - np.min(spatial_grads_clean)
        spatial_max_abs = np.max(np.abs(spatial_grads_clean))
    else:
        spatial_range = 0.0
        spatial_max_abs = 0.0
    
    print(f"{param_name:<10} {energy_range:<12.6f} {energy_max_abs:<12.6f} {spatial_range:<12.6f} {spatial_max_abs:<12.6f}")

print("\n" + "=" * 80)

In [None]:
# Create overview plot showing gradient magnitudes for all parameters
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Extract maximum absolute gradients for comparison
param_names = list(all_scans.keys())
energy_max_grads = []
spatial_max_grads = []

for param_name, (energy_grads, spatial_grads) in all_scans.items():
    energy_grads_clean = energy_grads[~np.isnan(energy_grads)]
    spatial_grads_clean = spatial_grads[~np.isnan(spatial_grads)]
    
    energy_max_grads.append(np.max(np.abs(energy_grads_clean)) if len(energy_grads_clean) > 0 else 0.0)
    spatial_max_grads.append(np.max(np.abs(spatial_grads_clean)) if len(spatial_grads_clean) > 0 else 0.0)

# Plot energy gradient magnitudes
bars1 = ax1.bar(param_names, energy_max_grads, color='skyblue', alpha=0.7)
ax1.set_ylabel('Maximum Absolute Energy Gradient')
ax1.set_title('Energy Gradient Magnitudes by Parameter')
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3)
plt.setp(ax1.get_xticklabels(), rotation=45, ha='right')

# Add value labels on bars
for bar, value in zip(bars1, energy_max_grads):
    if value > 0:
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{value:.2e}',
                ha='center', va='bottom', fontsize=9, rotation=90)

# Plot spatial gradient magnitudes
bars2 = ax2.bar(param_names, spatial_max_grads, color='lightcoral', alpha=0.7)
ax2.set_ylabel('Maximum Absolute Spatial Gradient')
ax2.set_title('Spatial Gradient Magnitudes by Parameter')
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)
plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')

# Add value labels on bars
for bar, value in zip(bars2, spatial_max_grads):
    if value > 0:
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{value:.2e}',
                ha='center', va='bottom', fontsize=9, rotation=90)

plt.tight_layout()
plt.show()

print("\nOverview plot completed!")
print("This analysis shows the gradient landscape for all track parameters.")
print("Parameters with larger gradients will be easier to optimize.")