In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import numpy as np
import os
import random

from tools.generate import read_photon_data_from_photonsim, generate_events_from_photonsim
from tools.simulation import setup_event_simulator
from tools.geometry import generate_detector
from tools.visualization import create_detector_display, create_detector_comparison_display

json_filename = '../config/HK_geom_config.json'

detector = generate_detector(json_filename)
Nphot = 5_000_000

simulate_data_event = setup_event_simulator(json_filename, Nphot, K=5, is_data=True, temperature=0.)

In [None]:
photonsim_file = "../muons_50_with_photons.root"

detector_params = (
    jnp.array(100),          # scatter_length
    jnp.array(0.05),         # reflection_rate
    jnp.array(100000.),      # absorption_length
    jnp.array(0.001)         # gumbel_softmax_temp
)

In [None]:
saved_files = generate_events_from_photonsim(
    event_simulator=simulate_data_event,
    root_file_path=photonsim_file,
    detector_params=detector_params,
    output_dir='output/',
    n_events=5,
    batch_size=5
)

In [None]:
filename = 'output/event_0.h5'

In [None]:
from tools.utils import read_event_file

In [None]:
read_event_file(filename)

In [None]:
import h5py

from tools.utils import extract_particle_properties

with h5py.File(filename, 'r') as f:
    loaded_charges_data = np.array(f['Q'])[0]
    loaded_times_data = np.array(f['T'])[0]
    loaded_mom_data   = np.array(f['P'])[0]
    loaded_vtx_data   = np.array(f['V'])[0]
    theta_data, phi_data, energy_data = extract_particle_properties(loaded_mom_data, pdg_code=13)

#loaded_indices = list(range(len(loaded_charges)))
#_ = print(np.shape(loaded_indices)), print(np.shape(loaded_charges)), print(np.shape(loaded_times)), print(np.shape(loaded_mom)), print(np.shape(loaded_vtx))

In [None]:
detector_display = create_detector_display(json_filename, sparse=False)
detector_display(loaded_charges_data, loaded_times_data, file_name='figures/simulated_HK_muon_Q.pdf', plot_time=False, log_scale=True)
detector_display(loaded_charges_data, loaded_times_data, file_name='figures/simulated_HK_muon_T.pdf', plot_time=True, log_scale=True)

In [None]:
key = jax.random.PRNGKey(71900)
trk_params = (energy_data, jnp.array(loaded_vtx_data, dtype=jnp.float32), jnp.array([theta_data, phi_data], dtype=jnp.float32))
simulate_event = setup_event_simulator(json_filename, Nphot, K=5, is_data=False, temperature=0.)

In [None]:
from tools.utils import load_single_event, save_single_event, print_particle_params, print_detector_params#, full_to_sparse, sparse_to_full, print_particle_params, print_detector_params

event_location = '../events/test_event_data.h5'

single_event_data = jax.lax.stop_gradient(simulate_event(trk_params, detector_params, key))
save_single_event(single_event_data, trk_params, detector_params, filename=event_location, calibration_mode=False)

In [None]:
from tools.geometry import generate_detector
import jax.numpy as jnp
json_filename='../config/HK_geom_config.json'
detector = generate_detector(json_filename)
detector_points = jnp.array(detector.all_points)
NUM_DETECTORS = len(detector_points)

loaded_trk_params, loaded_detector_params, loaded_indices_pred, loaded_charges_pred, loaded_times_pred= load_single_event(event_location, NUM_DETECTORS, calibration_mode=False)
print_particle_params(loaded_trk_params), print_detector_params(loaded_detector_params)

detector_display = create_detector_display(json_filename)
detector_display(loaded_indices_pred, loaded_charges_pred, loaded_times_pred, file_name='figures/predicted_HK_muon_Q.pdf', plot_time=False, log_scale=True)
detector_display(loaded_indices_pred, loaded_charges_pred, loaded_times_pred, file_name='figures/predicted_HK_muon_T.pdf', plot_time=True, log_scale=True)

In [None]:
# import matplotlib.pyplot as plt
# _ = plt.hist(loaded_charges_data, bins=200, range=(1.1,30), alpha=0.5, label='data')
# _ = plt.hist(loaded_charges_pred, bins=200, range=(1.1,30), alpha=0.5, label='pred')
# plt.legend()