In [None]:
from pathlib import Path
import numpy as np

import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
from rtal.data.detector import Detector
from rtal.data.particle import Particles

## Load demo data

In [None]:
dataset_folder = Path('demo_dataset')
fnames = sorted(list(dataset_folder.glob('*npz')))

In [None]:
sample_idx = 1
fname = fnames[sample_idx]
with np.load(fname) as handle:
    detector_start     = handle['detector_start']
    detector_curr      = handle['detector_curr']
    readout_start      = handle['readout_start'] 
    readout_curr       = handle['readout_curr'] 
    particle_vertex    = handle['particle_vertex']
    particle_direction = handle['particle_direction']

print(f'detector start:\n{detector_start}\n')
print(f'detector misaligned:\n{np.array2string(detector_curr, sign="+", formatter={"float_kind":lambda x: f"{x:.8f}"})}\n')
print(f'readout shape: {readout_start.shape}\n')
print(f'readout start: \n{np.array2string(readout_start[:, :5], precision=4)}\n')
print(f'readout misaligned: \n{np.array2string(readout_curr[:, :5], precision=4)}\n')
print(f'particle vertex:\n{np.array2string(particle_vertex[:5], precision=4)}\n')
print(f'particle direction:\n{np.array2string(particle_direction[:5], precision=4)}\n')

## Visualize

In [None]:
def get_corners(detector, state):

    corner_basis = np.array([[ 1,  1], 
                             [-1,  1], 
                             [-1, -1], 
                             [ 1, -1]])
    corner_basis *= np.array([[detector._RANGE_X, detector._RANGE_Y]])
    
    return detector.center[state] + \
            np.matmul(corner_basis, 
                      np.array([detector.local_x[state], detector.local_y[state]]))

def get_trace(detecor, state, color='lightblue', opacity=.7):
    corners = get_corners(detector, state)
    
    loop = np.vstack([corners, corners[:1]])
    x, y, z = loop.T
    
    surface_trace = go.Mesh3d(x=x, y=y, z=z,
                              color=color,
                              opacity=opacity,
                              # triangular faces of the mesh
                              i=[0, 0],
                              j=[1, 2],
                              k=[2, 3])

    axes_scale = .9
    displacement = axes_scale * detector._RANGE_X * detector.local_x[state]
    
    x, y, z = np.vstack([detector.center[state] - displacement, 
                         detector.center[state] + displacement]).T
    
    local_x_trace = go.Scatter3d(x=x, y=y, z=z,
                                 mode='lines',
                                 showlegend=False,
                                 line=dict(color='gray', width=2))

    displacement = axes_scale * detector._RANGE_Y * detector.local_y[state]
    x, y, z = np.vstack([detector.center[state] - displacement, 
                         detector.center[state] + displacement]).T
    
    local_y_trace = go.Scatter3d(x=x, y=y, z=z,
                                 mode='lines',
                                 showlegend=False,
                                 line=dict(color='gray', width=2))
    return [surface_trace, local_x_trace, local_y_trace]

In [None]:
SCENE = dict(aspectmode  = 'manual', 
             aspectratio = dict(x=1, y=1.85, z=1), 
             xaxis       = dict(title='X', range=[-10, 10]),
             yaxis       = dict(title='Y', range=[-2,  35]),
             zaxis       = dict(title='Z', range=[-10, 10]))

In [None]:
def get_ray_traces(detectors, readout, state, num_rays=None, color='blue', random_seed=None, name=None):

    np.random.seed(random_seed)
    
    vertex = particle_vertex[np.newaxis, :]
    
    if num_rays is not None:
        ray_count = readout.shape[1]
        assert num_rays <= ray_count
        selected_ray_indices = np.random.choice(ray_count, num_rays, replace=False)
        readout = readout[:, selected_ray_indices]
        vertex = vertex[:, selected_ray_indices]
    
    global_points = np.stack([detector.to_global(readout, state) for detector, readout in zip(detectors, readout)])

    Xs, Ys, Zs = np.concatenate([vertex, global_points], axis=0).T

    ray_traces = []

    for ray_id, (x, y, z) in enumerate(zip(Xs, Ys, Zs)):
        if (ray_id == 0) and (name is not None):
            ray_trace = go.Scatter3d(x=x, y=y, z=z,
                                     mode='lines',
                                     line=dict(color=color, width=2),
                                     showlegend=True, 
                                     name=name)
        else:
            ray_trace = go.Scatter3d(x=x, y=y, z=z,
                                     mode='lines',
                                     line=dict(color=color, width=2),
                                     showlegend=False) 
        ray_traces.append(ray_trace)
    
    return ray_traces

## Construct detectors

In [None]:
detectors = []
for param_start, param_curr in zip(detector_start, detector_curr):
    parameter_dict = {'start': param_start, 'curr': param_curr}
    detectors.append(Detector.from_dict(parameter_dict))

In [None]:
num_rays = 10

state = 'start'
readout = readout_start if state == 'start' else readout_curr
ray_traces_start = get_ray_traces(detectors, 
                                  readout, 
                                  state=state, 
                                  num_rays=num_rays, 
                                  color='blue', 
                                  random_seed=2025, 
                                  name=f'Sample {sample_idx}, Rays recovered from starting state')

state = 'curr'
readout = readout_start if state == 'start' else readout_curr
ray_traces_curr = get_ray_traces(detectors, 
                                 readout, 
                                 state=state, 
                                 num_rays=num_rays, 
                                 color='red', 
                                 random_seed=2025, 
                                 name=f'Sample {sample_idx}, Rays recovered from current (misaligned) state')

legend = dict(x = 0.0, 
              y = -0.07,
              orientation = "h",
              xanchor     = "left",
              yanchor     = "top")

In [None]:
fig = go.Figure()

for detector in detectors:
    traces = get_trace(detector, 'start')
    fig.add_traces(traces)

    traces = get_trace(detector, 'curr', color='pink')
    fig.add_traces(traces)

fig.add_traces(ray_traces_start)

fig.update_layout(width  = 600, 
                  height = 500, 
                  scene  = SCENE, 
                  legend = legend)

fig.show()
fig.write_image(f'plots/plot_3d_{sample_idx}_start.png')
fig.write_html(f'plots/plot_3d_{sample_idx}_start.html')

In [None]:
fig = go.Figure()

for detector in detectors:
    traces = get_trace(detector, 'start')
    fig.add_traces(traces)

    traces = get_trace(detector, 'curr', color='pink')
    fig.add_traces(traces)

# fig.add_traces(ray_traces_start)
fig.add_traces(ray_traces_curr)

fig.update_layout(width  = 600, 
                  height = 500, 
                  scene  = SCENE, 
                  legend = legend)

fig.show()
fig.write_image(f'plots/plot_3d_{sample_idx}_curr.png')
fig.write_html(f'plots/plot_3d_{sample_idx}_curr.html')

In [None]:
def load(fname):
    with np.load(fname) as handle:
        detector_start     = handle['detector_start']
        detector_curr      = handle['detector_curr']
        readout_start      = handle['readout_start'] 
        readout_curr       = handle['readout_curr'] 
        particle_vertex    = handle['particle_vertex']
        particle_direction = handle['particle_direction']

    detectors = []
    for detector_idx, (start, curr) in enumerate(zip(detector_start, detector_curr)):
        detectors.append(Detector.from_dict({'start': start, 'curr': curr}))

    particles = Particles(vertex    = particle_vertex, 
                          direction = particle_direction)

    return detectors, particles, readout_start, readout_curr

In [None]:
def plot_readout(detectors, 
                 particles         = None, 
                 readout_start     = None, 
                 readout_curr      = None, 
                 calculate_readout = False, 
                 rounded           = False, 
                 title             = None, 
                 return_fig        = False):
    
    """
    plot readouts
    """
    
    subplot_titles = list(f"detector {detector_idx}" 
                          for detector_idx in range(len(detectors)))
    
    fig = make_subplots(rows               = 1, 
                        cols               = len(detectors), 
                        shared_yaxes       = True, 
                        subplot_titles     = subplot_titles,
                        horizontal_spacing = 0.03)
    
    for detector_idx, detector in enumerate(detectors):

        readout_range = detector._RANGE_X / detector._PITCH_X
        
        # Define the coordinates of the square
        square_x = readout_range * np.array([-1, 1, 1, -1, -1])
        square_y = readout_range * np.array([-1, -1, 1, 1, -1])
        
        # Add the square as a blue line
        fig.add_trace(go.Scatter(x          = square_x,
                                 y          = square_y,
                                 mode       = 'lines',
                                 line       = dict(color='black', width=2),
                                 name       = "Detector", 
                                 showlegend = detector_idx == 0), 
                      row=1, 
                      col=detector_idx + 1)
        
        # Add grid horizontal and verticle lines
        for i in range(-int(readout_range) + 1, int(readout_range)):
            fig.add_trace(go.Scatter(x          = [-readout_range, readout_range],
                                     y          = [i, i],
                                     mode       = 'lines',
                                     line       = dict(color='lightgray', width=.5),
                                     showlegend = False),
                          row=1, 
                          col=detector_idx + 1)
            fig.add_trace(go.Scatter(x          = [i, i],
                                     y          = [-readout_range, readout_range],
                                     mode       = 'lines',
                                     line       = dict(color='lightgray', width=.5),
                                     showlegend = False), 
                          row=1, 
                          col=detector_idx + 1)
        
        # Plot readout
        if calculate_readout:
            scatter_x_start, scatter_y_start = detector.get_readout(particles, 'start', rounded=rounded)[0].T
            scatter_x_curr,  scatter_y_curr  = detector.get_readout(particles, 'curr',  rounded=rounded)[0].T
        else:
            scatter_x_start, scatter_y_start = readout_start[detector_idx].T
            scatter_x_curr,  scatter_y_curr  = readout_curr[detector_idx].T
        
        # Correct readout
        fig.add_trace(go.Scatter(x          = scatter_x_start,
                                 y          = scatter_y_start,
                                 mode       = 'markers',
                                 marker     = dict(size=4, color='blue'), 
                                 name       = "Correct readout", 
                                 showlegend = detector_idx == 0),
                      row=1, 
                      col=detector_idx + 1)
        
        # Misaligned readout
        fig.add_trace(go.Scatter(x          = scatter_x_curr,
                                 y          = scatter_y_curr,
                                 mode       = 'markers',
                                 marker     = dict(size=4, color='red'), 
                                 name       = "misaligned readout", 
                                 showlegend = detector_idx == 0),
                      row=1, 
                      col=detector_idx + 1)
    
        fig.update_xaxes(range=np.array([-1.2, 1.2]) * readout_range, row=1, col=detector_idx + 1, scaleanchor="y")
        fig.update_yaxes(range=np.array([-1.2, 1.2]) * readout_range, row=1, col=detector_idx + 1, scaleanchor="x")
        
    # Customize the layout
    if title is None:
        title = "Readout from the detectors"
    fig.update_layout(title      = title,
                      showlegend = True,
                      height     = 500, 
                      width      = 1100, 
                      legend     = dict(x = 0.0, 
                                        y = -0.07,
                                        orientation = "h",
                                        xanchor     = "left",
                                        yanchor     = "top"))
        
    fig.show()
    
    if return_fig:
        return fig

In [None]:
detectors, particles, readout_start, readout_curr = load(fnames[sample_idx])

fig = plot_readout(detectors, 
                   particles, 
                   readout_start, 
                   readout_curr, 
                   calculate_readout=False, 
                   title=f'Detector readout sample {sample_idx}, from saved readout', 
                   return_fig=True)

fig.write_image(f'plots/plot_2d_{sample_idx}_from-readout.png')
fig.write_html(f'plots/plot_2d_{sample_idx}_from-readout.html')

fig = plot_readout(detectors, 
                   particles, 
                   calculate_readout=True, 
                   rounded=False, 
                   title=f'Detector readout sample {sample_idx}, calculated no rounding', 
                   return_fig=True)
fig.write_image(f'plots/plot_2d_{sample_idx}_calculated-no-rounding.png')
fig.write_html(f'plots/plot_2d_{sample_idx}_calculated-no-rounding.html')