In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import numpy as np
import os
import random

from tools.generate import read_photon_data_from_photonsim, generate_events_from_photonsim
from tools.simulation import setup_event_simulator
from tools.geometry import generate_detector
from tools.visualization import create_detector_display, create_detector_comparison_display

json_filename = '../config/HK_geom_config.json'

detector = generate_detector(json_filename)
Nphot = 5_000_000

simulate_event = setup_event_simulator(json_filename, Nphot, K=5, is_data=True, temperature=0.)

In [None]:
photonsim_file = "../muons_50_with_photons.root"

detector_params = (
    jnp.array(100),          # scatter_length
    jnp.array(0.05),         # reflection_rate
    jnp.array(100000.),      # absorption_length
    jnp.array(0.001)         # gumbel_softmax_temp
)

In [None]:
saved_files = generate_events_from_photonsim(
    event_simulator=simulate_event,
    root_file_path=photonsim_file,
    detector_params=detector_params,
    output_dir='output/',
    n_events=5,
    batch_size=5
)

In [None]:
filename = 'output/event_0.h5'

In [None]:
from tools.utils import read_event_file

In [None]:
read_event_file(filename)

In [None]:
import h5py
with h5py.File(filename, 'r') as f:
    loaded_charges = np.array(f['Q'])[0]
    loaded_times = np.array(f['T'])

loaded_indices = list(range(len(q[0])))
_ = print(np.shape(loaded_indices)), print(np.shape(loaded_charges)), print(np.shape(loaded_times))

In [None]:
detector_display = create_detector_display(json_filename, sparse=False)
detector_display(loaded_charges, loaded_times, file_name=None, plot_time=False, log_scale=True)
#detector_display(loaded_indices, loaded_charges, loaded_times, file_name=None, plot_time=True)

In [None]:
from tools.visualization import create_detector_display
from tools.utils import load_single_event, save_single_event, full_to_sparse, sparse_to_full, print_particle_params, print_detector_params

event_location = '../events/test_event_data.h5'

from tools.geometry import generate_detector
import jax.numpy as jnp
default_json_filename='../config/HK_geom_config.json'
detector = generate_detector(default_json_filename)
detector_points = jnp.array(detector.all_points)
NUM_DETECTORS = len(detector_points)

#loaded_trk_params, loaded_detector_params, loaded_indices, _, _= load_single_event(event_location, NUM_DETECTORS, calibration_mode=False)
loaded_trk_params, loaded_detector_params, loaded_indices, loaded_charges, loaded_times= load_single_event(event_location, NUM_DETECTORS, calibration_mode=False)

In [None]:
print(np.shape(loaded_indices)), print(np.shape(loaded_charges)), print(np.shape(loaded_times))

In [None]:
loaded_times