In [1]:
from datetime import time

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import sys
sys.path.append('..')

In [2]:
from tools.geometry import generate_detector
from tools.utils import generate_random_params
from tools.utils import load_single_event, save_single_event
import jax
import jax.numpy as jnp
from tools.simulation import setup_event_simulator

# Generate and save a single event
key = jax.random.PRNGKey(6)

detector_params = (
    jnp.array(50),            # scattering_length
    jnp.array(0.00),         # reflection_rate
    jnp.array(999.),         # absorption_length
    jnp.array(0.001)         # gumbel_softmax_temp
)


track_params = (
    jnp.array(800.0, dtype=jnp.float32),              # energy 
    jnp.array([0.0, 0.0, 0.0], dtype=jnp.float32),    # position
    jnp.array([jnp.pi/3, jnp.pi/4], dtype=jnp.float32)  # angles (theta, phi)
)

In [3]:
detector_names = ['EOS', 'WCTE', 'IWCD', 'SK', 'HK', 'JUNO', 'TAO']

In [None]:
generate_event = True

if generate_event:
    for name in detector_names:
        json_filename = f'../config/{name}_geom_config.json'
        detector = generate_detector(json_filename)
        detector_points = jnp.array(detector.all_points)
        Nphot = 5_000_000
        temperature = 0.0
        generate_event = False
        deector_type='Sphere' if name == 'TAO' or name == 'JUNO' else 'Cylinder'    
        simulator = setup_event_simulator(json_filename, Nphot, temperature=temperature, K=2, is_data=False, is_calibration=False, detector_type=deector_type, max_detectors_per_cell=10)
        single_event = jax.lax.stop_gradient(simulator(track_params, detector_params, key))


        source_params = (
            jnp.array([0.0, 0.0, 0.0], dtype=jnp.float32),
            jnp.array(1.0, dtype=jnp.float32)
        )
        
        # simulator = setup_event_simulator(json_filename, Nphot, temperature=temperature, K=1, is_data=False, is_calibration=True, detector_type=deector_type, max_detectors_per_cell=10)
        # single_event = jax.lax.stop_gradient(simulator(source_params, detector_params, key))
        
        save_single_event(single_event, track_params, detector_params, filename=f'../events/{name}_event_data.h5', calibration_mode=False)

In [None]:
def visualize_3D_event_for_detector(name, colorscale='viridis', surface_color='gray'):
    _, _, indices, charges, times = load_single_event(f'../events/{name}_event_data.h5', None, calibration_mode=False)
    json_filename = f'../config/{name}_geom_config.json'
    detector = generate_detector(json_filename)
    figname = f'figures/{name}_3D_evt_display.pdf'
    detector.visualize_event_data_plotly_discs(indices, charges, times, show_all_detectors=True, log_scale=True, show_colorbar=False, dark_theme=False, plot_time=False, colorscale=colorscale, surface_color=surface_color, figname=figname)

In [None]:
def check_missing_sensors(name):
    _, _, indices, charges, times = load_single_event(f'../events/{name}_event_data.h5', None, calibration_mode=False)
    json_filename = f'../config/{name}_geom_config.json'
    detector = generate_detector(json_filename)
    if len(detector.all_points) == len(indices):
        print('Success!')
    else:
        print('Problems?')
        print(len(detector.all_points), len(indices))

for name in detector_names:
    check_missing_sensors(name)

In [None]:
cmap = 'inferno'
visualize_3D_event_for_detector('SK', colorscale=cmap, surface_color='dimgray')
visualize_3D_event_for_detector('JUNO', colorscale=cmap, surface_color='dimgray')
visualize_3D_event_for_detector('IWCD', colorscale=cmap, surface_color='dimgray')
visualize_3D_event_for_detector('TAO', colorscale=cmap, surface_color='dimgray')

In [None]:
import matplotlib.pyplot as plt

name = 'SK'
json_filename = f'../config/{name}_geom_config.json'
detector = generate_detector(json_filename)

detector = generate_detector(json_filename)
detector_points = jnp.array(detector.all_points)
photosensor_radius = detector.S_radius
sphere_radius = detector.r

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.scatter(detector.all_points[:,0],detector.all_points[:,1],detector.all_points[:,2], s=0.05)