# Examining Physics Using the Interpolation Method

This notebook demonstrates the final phase of Spyral, solving. To run this notebook, the EstimatePhase of Spyral *must* have been run on the data. Now that we have generated our clusters and estimated the physical observables of interest we are ready to initiate the solving phase of the analysis, where we attempt to extract the exact physics observables by fitting solutions of the equations of motion to the data. This works by pre-generating a bunch of solutions to the ODE's and then interpolation on these solutions to try and fit the data. It has the advantage of being very fast; the ODE's only ever need to be solved once, and then all the remaining calculation is just simple bilinear interpolation. For more information on the method, see the Spyral [documentation](https://attpc.github.io/Spyral)

## Setup
First let's take care of all of our imports.

In [None]:
from spyral.core.cluster import Cluster
from spyral.interpolate.track_interpolator import create_interpolator

# Pick one of these import lines to uncomment to use as your solver
# By default we chose the L-BFGS-B
from spyral.solvers.solver_interp import fit_model_interp, Guess, interpolate_trajectory
# from spyral.solvers.solver_interp_leastsq import fit_model_interp, Guess, interpolate_trajectory
from spyral.phases.interp_solver_phase import DEFAULT_PID_XAXIS, DEFAULT_PID_YAXIS

from spyral.core.run_stacks import form_run_string
from spyral import SolverParameters, DetectorParameters, InterpSolverPhase

from spyral_utils.nuclear import NuclearDataMap
from spyral_utils.nuclear.particle_id import deserialize_particle_id

import polars as pl
import h5py as h5
from pathlib import Path
import matplotlib.pyplot as plt

%matplotlib widget

Now with all of our code imported we will setup the configuration

In [None]:
# Set some parameters
workspace_path = Path("/path/to/your/workspace/")

solver_params = SolverParameters(
    gas_data_path=Path("/path/to/your/gas.json"),
    particle_id_filename=Path("/path/to/your/pid.json"),
    ic_min_val=900.0,
    ic_max_val=1300.0,
    n_time_steps=1000,
    interp_ke_min=0.1,
    interp_ke_max=70.0,
    interp_ke_bins=700,
    interp_polar_min=5.0,
    interp_polar_max=85.0,
    interp_polar_bins=160,
    fit_vertex_rho=True,
    fit_vertex_phi=True,
    fit_azimuthal=True,
    fit_method="lbfsgsb" # has no impact here, we do this ourselves at the imports
)

det_params = DetectorParameters(
    magnetic_field=2.85,
    electric_field=45000.0,
    detector_length=1000.0,
    beam_region_radius=25.0,
    micromegas_time_bucket=10.0,
    window_time_bucket=560.0,
    get_frequency=6.25,
    garfield_file_path=Path("/path/to/some/garfield.txt"),
    do_garfield_correction=False,
)

cluster_path = workspace_path / "Cluster" # this may change if you add custom phases
estimate_path = workspace_path / "Estimation" # this may change if you add custom phases

Now we need to load our interpolation mesh. Note that if you don't have one created, the below cell will create it for you and this can take some time.

In [None]:
nuc_map = NuclearDataMap()
pid = deserialize_particle_id(solver_params.particle_id_filename, nuc_map)
if pid is None:
    raise Exception("Particle ID error!")
pid_xaxis = DEFAULT_PID_XAXIS
pid_yaxis = DEFAULT_PID_YAXIS
if not pid.cut.is_default_x_axis() and not pid.cut.is_default_y_axis():
    pid_xaxis = pid.cut.get_x_axis()
    pid_yaxis = pid.cut.get_y_axis()
solver = InterpSolverPhase(solver_params, det_params)
success = solver.create_assets(workspace_path)
if not success:
    raise Exception("Could not setup interpolation mesh!")
tracks = create_interpolator(solver.track_path)

Now we'll load a specific run file (both the estimation result and clustering result) and make an iterator to run through the events in the data. Here we also apply the particle ID gate to the dataset.

In [None]:
run_number = 16
cluster_file_path = cluster_path / f"{form_run_string(run_number)}.h5"
cluster_file = h5.File(cluster_file_path, "r")
estimate_file_path = estimate_path / f"{form_run_string(run_number)}.parquet"
estimate_df = pl.scan_parquet(estimate_file_path)
estimate_gated = estimate_df.filter(pl.struct([pid_xaxis, pid_yaxis]).map_batches(pid.cut.is_cols_inside)).collect().to_dict()
cluster_group = cluster_file["cluster"]
nrows = len(estimate_gated["event"])
row_iter = iter(range(nrows))
print(f"Number of rows: {nrows}")

## Analysis

Re-running the cells below will walk through the events in the dataset in order, as long as you don't re-run the cells above.

Now we'll load the next event from the dataset. You can always use a hardcoded value to select a specific event.

In [None]:
plt.close()
row = None
# Can always override with a hardcoded row if needed
# row = 1
if row is None:
    try:
        row = next(row_iter)
    except StopIteration:
        raise Exception("You ran out of rows for this run! Open a new run")
    

print(f"Row: {row}")
event = estimate_gated['event'][row]
cluster_index = estimate_gated['cluster_index'][row]
print(f"Event: {event}")
print(f"Cluster index: {cluster_index}")
event_group = cluster_group[f"event_{event}"]
local_cluster = event_group[f"cluster_{cluster_index}"]
print(f'Direction: {estimate_gated["direction"][row]}')
cluster = Cluster(event, local_cluster.attrs["label"], local_cluster["cloud"][:].copy())

With our cluster and estimated observables loaded, we are ready to fit to the data. We setup our Guess object from our estimates and then pass that along to the fit_model function. Sometimes this will return None when a given trajectory has estimates that are outside the interpolation table (these typically correspond to bad events). If this happens a error will occur. Simply re-run the notebook until the a good event is randomly selected. Note that the first time you run this block it might take a couple of seconds. This is because the interpolation method uses a just-in-time compiler (jit) to speed up the calculations. The first time you call the code, the code gets compiled (resulting in a slowdown). But everytime the code is called after that, the compiled program is used, resulting in enormus performance gains.

In [None]:
guess = Guess(
    estimate_gated["polar"][row],
    estimate_gated["azimuthal"][row],
    estimate_gated["brho"][row],
    estimate_gated["vertex_x"][row],
    estimate_gated["vertex_y"][row],
    estimate_gated["vertex_z"][row]
)
print(guess)
result = fit_model_interp(cluster, guess, pid.nucleus, tracks, det_params, solver_params)
if result is None:
    print("Guess outside of interpolation range!")
best_fit_trajectory = interpolate_trajectory(result, tracks, pid.nucleus)

If a good event was chosen, you should see above a print out of the fit results. Key values are the chi-square (which should be small) and the variable values, which are the fitted observables. Also important are the correlations, which tell you if any of the parameters are co-dependent. If two parameters have a correlation of 1.0 they are basically degenerate to the fitter, which is very bad.

We can also plot the results of the fit against the data to vizualize the performance

In [None]:
fig, axs = plt.subplot_mosaic(
    """
    AB
    AC
    """,
    constrained_layout=True,
    figsize=(16,8)
)

axs["A"].scatter(cluster.data[:, 0]*0.001, cluster.data[:, 1]*0.001, s=3, label="Cluster")
axs["A"].scatter(best_fit_trajectory[:, 0], best_fit_trajectory[:, 1], s=3, label="Best Fit")
axs["A"].scatter([result["vertex_x"]], [result["vertex_y"]], s=3, label="Best Fit Vertex")
axs["A"].set_xlabel("X (m)")
axs["A"].set_ylabel("Y (m)")
axs["A"].set_xlim(-0.3, 0.3)
axs["A"].set_ylim(-0.3, 0.3)
axs["B"].scatter(cluster.data[:, 2]*0.001, cluster.data[:, 1]*0.001, s=3, label="Cluster")
axs["B"].scatter(best_fit_trajectory[:, 2], best_fit_trajectory[:, 1], s=3, label="Best Fit")
axs["B"].scatter([result["vertex_z"]], [result["vertex_y"]], s=3, label="Best Fit Vertex")
axs["B"].set_xlabel("Z (m)")
axs["B"].set_ylabel("Y (m)")
axs["B"].set_xlim(.0, 1.0)
axs["B"].set_ylim(-0.3, 0.3)
axs["C"].scatter(cluster.data[:, 2]*0.001, cluster.data[:, 0]*0.001, s=3, label="Cluster")
axs["C"].scatter(best_fit_trajectory[:, 2], best_fit_trajectory[:, 0], s=3, label="Best Fit")
axs["C"].scatter([result["vertex_z"]], [result["vertex_x"]], s=3, label="Best Fit Vertex")
axs["C"].set_xlabel("Z (m)")
axs["C"].set_ylabel("X (m)")
axs["C"].set_xlim(.0, 1.0)
axs["C"].set_ylim(-0.3, 0.3)
axs["A"].legend()
axs["B"].legend()
axs["C"].legend()

Hopefully you see a nice fit to the data! If the fit looks bad, there are several things to check. First is the particle ID gate; if the wrong particle group is selected, the fit will fail spectacularly. Another is the coarse-ness of the interpolation scheme. If there are too few bins in the polar angle or the particle kinetic energy, the interpolation may not generate good values. Finally, it is also good to make sure that the target gas is correctly defined with the right pressure and chemistry.

To walk through more events, you can re-run the cells under the Analysis heading, and you will walk through the events in the order they were written.