In [None]:
from datetime import time

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import sys
sys.path.append('..')

In [None]:

from tools.optimization import create_multi_objective_optimizer, filter_inf_results, plot_simple_multi_event_convergence
from tools.optimization import run_multi_event_optimization, plot_single_event_comparison, plot_multi_event_convergence, plot_single_event_convergence

In [None]:
from tools.geometry import generate_detector
import jax.numpy as jnp
json_filename='../config/IWCD_geom_config.json'
detector = generate_detector(json_filename)
detector_points = jnp.array(detector.all_points)

In [None]:
import pickle
from pathlib import Path

# Configuration flag
start_fresh = True

# Define output path
output_dir = Path('output/optimization/')
output_file = output_dir / 'optimization_results.pkl'

# Check if we should run the optimization
should_run = start_fresh or not output_file.exists()

if should_run:
    print("Running optimization...")
    
    results = run_multi_event_optimization(
        N_events=50,
        Nphot=100_000,
        json_filename='../config/IWCD_geom_config.json',
        K=2,
        loss_function='multi_objective',
        energy_lr=2.0,     
        spatial_lr=0.1,
        position_scale=2.0,
        lambda_time=0.0,
        n_iterations=400,
        patience=250,
        base_seed=150,
        verbose=False,
        initial_guess_method='random' # the other option is 'grid_scan'
    )
    
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save to file
    with open(output_file, 'wb') as f:
        pickle.dump(results, f)
        
    print(f"Results saved to {output_file}")
    
else:
    print(f"Optimization results already exist at {output_file}")
    print("Set start_fresh=True to run optimization again")
    
    # Optionally load existing results
    with open(output_file, 'rb') as f:
        results = pickle.load(f)
    print("Loaded existing results")

In [None]:
with open('output/optimization/optimization_results.pkl', 'rb') as f:
    loaded_results = pickle.load(f)
new_results = filter_inf_results(loaded_results)

In [None]:
import matplotlib.pyplot as plt

plt.rcParams['text.usetex'] = False
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 8

In [None]:
output_dir = Path('figures/')
output_dir.mkdir(parents=True, exist_ok=True)
_ = plot_simple_multi_event_convergence(new_results, show_individual=True, show_statistics=True, show_histograms=True, figsize=(8,5), save_path='figures/multi_evt_opt_summary_tmp.pdf')

In [None]:
import numpy as np
from tools.optimization import plot_simple_event_convergence

n_events = np.shape(new_results['loss_histories'])[0]
for i in range(n_events):
    plot_simple_event_convergence(i, new_results, save_path='figures/opt_history_evt_'+str(i)+'.pdf', figsize=(7, 2))