In [None]:
import sys
sys.path.append('..')

from spyral.core.cluster import Cluster
from spyral.interpolate.track_interpolator import create_interpolator
from spyral.solvers.solver_interp import fit_model_interp, Guess, interpolate_trajectory
from spyral.core.run_stacks import form_run_string

from spyral_utils.nuclear import NuclearDataMap
from spyral_utils.nuclear.particle_id import deserialize_particle_id

from e20009_phases.InterpSolverPhase import InterpSolverPhase
from e20009_phases.config import SolverParameters, DetectorParameters

import polars as pl
import numpy as np
import h5py as h5
import numpy as np
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
# Load config
workspace_path = Path("/Volumes/e20009/e20009_analysis")

solver_params = SolverParameters(
    gas_data_path="/Users/attpc/Desktop/e20009_analysis/e20009_analysis/e20009_parameters/e20009_target.json",
    gain_match_factors_path="/Users/attpc/Desktop/e20009_analysis/e20009_analysis/e20009_parameters/gain_match_factors.csv",
    particle_id_filename="/Users/attpc/Desktop/e20009_analysis/e20009_analysis/e20009_parameters/pid.json",
    ic_min_val=450.0,
    ic_max_val=850.0,
    n_time_steps=1000,
    interp_ke_min=0.1,      #Lower this to 0.05?
    interp_ke_max=40.0,
    interp_ke_bins=200,
    interp_polar_min=0.1,
    interp_polar_max=179.9,
    interp_polar_bins=340,
)

det_params = DetectorParameters(
    magnetic_field=3.0,
    electric_field=60000.0,
    detector_length=1000.0,
    beam_region_radius=20.0,
    drift_velocity_path=Path(
        "/Users/attpc/Desktop/e20009_analysis/e20009_analysis/e20009_parameters/drift_velocity.csv"
    ),
    get_frequency=3.125,
    garfield_file_path=Path(
        "/Users/attpc/Desktop/e20009_analysis/e20009_analysis/e20009_parameters/e20009_efield_correction.txt"
    ),
    do_garfield_correction=False,
)

cluster_path = workspace_path / "Cluster"
estimate_path = workspace_path / "Estimation"

In [None]:
# Make interpolation mesh
nuc_map = NuclearDataMap()
pid = deserialize_particle_id(solver_params.particle_id_filename, nuc_map)
if pid is None:
    raise Exception("Particle ID error!")
solver = InterpSolverPhase(solver_params, det_params)
success = solver.create_assets(workspace_path)
if not success:
    raise Exception("Could not setup interpolation mesh!")
tracks = create_interpolator(solver.track_path)

In [None]:
# Load data
run_number = 348
cluster_file_path = cluster_path / f"{form_run_string(run_number)}.h5"
cluster_file = h5.File(cluster_file_path, "r")
estimate_file_path = estimate_path / f"{form_run_string(run_number)}.parquet"
estimate_df = pl.scan_parquet(estimate_file_path)
estimate_gated = estimate_df.filter(pl.struct(['dEdx', 'brho']).map_batches(pid.cut.is_cols_inside)).collect().to_dict()
cluster_group = cluster_file['cluster']
nrows = len(estimate_gated['event'])
row = np.random.randint(0, nrows)
#row = 6990
print(f'row: {row}')
event = estimate_gated['event'][row]
cluster_index = estimate_gated['cluster_index'][row]
print(f'event: {event}')
print(f'cluster index: {cluster_index}')
event_group = cluster_group[f'event_{event}']
local_cluster = event_group[f'cluster_{cluster_index}']
print(f'Direction: {estimate_gated["direction"][row]}')
cluster = Cluster(event, local_cluster.attrs['label'], local_cluster['cloud'][:].copy())

In [None]:
# Fit data
guess = Guess(estimate_gated['polar'][row], estimate_gated['azimuthal'][row], estimate_gated['brho'][row], estimate_gated['vertex_x'][row], estimate_gated['vertex_y'][row], estimate_gated['vertex_z'][row])
print(guess)
result = fit_model_interp(cluster, guess, pid.nucleus, tracks, det_params)
if result is None:
    print('Guess outside of interpolation range!')
best_fit_trajectory = interpolate_trajectory(result, tracks, pid.nucleus)
cluster.data[:, :3] *= 0.001

In [None]:
# Plot fit
fig = make_subplots(2, 2, subplot_titles=["XY Projection", "XZ Projection", "YZ Projection"], specs=[[{"rowspan": 2}, {}],[None, {}]])
fig.add_trace(
    go.Scatter(
        x=cluster.data[:, 0],
        y=cluster.data[:, 1],
        mode="markers", 
        marker={
            "size": 3,
            "color": "blue"
        },
        name="Data"
    ),
    row=1,
    col=1
)
fig.add_trace(
    go.Scatter(
        x=best_fit_trajectory[:, 0],
        y=best_fit_trajectory[:, 1],
        mode="markers",
        marker={
            "size": 3,
            "color": "red"
        },
        name="Fit"
    ),
    row=1,
    col=1
)
fig.add_trace(
    go.Scatter(
        x=[result["vertex_x"]],
        y=[result["vertex_y"]],
        mode="markers",
        marker={
            "color": "green",
            "size": 4
        },
        name="Fit Vertex"
    ),
    row=1,
    col=1
)

fig.add_trace(
    go.Scatter(
        x=cluster.data[:, 2],
        y=cluster.data[:, 0],
        mode="markers",
        marker={
            "size": 3,
            "color": "blue"
        },
        name="Data",
        showlegend=False
    ),
    row=1,
    col=2
)
fig.add_trace(
    go.Scatter(
        x=best_fit_trajectory[:, 2],
        y=best_fit_trajectory[:, 0],
        mode="markers",
        marker={
            "size": 3,
            "color": "red"
        },
        name="Fit",
        showlegend=False
    ),
    row=1,
    col=2
)
fig.add_trace(
    go.Scatter(
        x=[result["vertex_z"]],
        y=[result["vertex_x"]],
        mode="markers",
        marker={
            "color": "green",
            "size": 4
        },
        name="Fit Vertex",
        showlegend=False,
    ),
    row=1,
    col=2
)

fig.add_trace(
    go.Scatter(
        x=cluster.data[:, 2],
        y=cluster.data[:, 1],
        mode="markers",
        marker={
            "size": 3,
            "color": "blue"
        },
        name="Data",
        showlegend=False,
    ),
    row=2,
    col=2
)
fig.add_trace(
    go.Scatter(
        x=best_fit_trajectory[:, 2],
        y=best_fit_trajectory[:, 1],
        mode="markers",
        marker={
            "size": 3,
            "color": "red"
        },
        name="Fit",
        showlegend=False,
    ),
    row=2,
    col=2
)
fig.add_trace(
    go.Scatter(
        x=[result["vertex_z"]],
        y=[result["vertex_y"]],
        mode="markers",
        marker={
            "color": "green",
            "size": 4
        },
        name="Fit Vertex",
        showlegend=False
    ),
    row=2,
    col=2
)

fig["layout"]["xaxis1"]["title"] = "X (m)"
fig["layout"]["xaxis1"]["range"] = [-0.3, 0.3]
fig["layout"]["yaxis1"]["title"] = "Y (m)"
fig["layout"]["yaxis1"]["range"] = [-0.3, 0.3]

fig["layout"]["xaxis2"]["title"] = "Z (m)"
fig["layout"]["xaxis2"]["range"] = [0.0, 1.0]
fig["layout"]["yaxis2"]["title"] = "X (m)"
fig["layout"]["yaxis2"]["range"] = [-0.3, 0.3]

fig["layout"]["xaxis3"]["title"] = "Z (m)"
fig["layout"]["xaxis3"]["range"] = [0.0, 1.0]
fig["layout"]["yaxis3"]["title"] = "Y (m)"
fig["layout"]["yaxis3"]["range"] = [-0.3, 0.3]

fig.update_layout(
    width=1450,
    height=725
)