# Imports

In [None]:
import pickle
import warnings
import json

import awkward as ak
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import uproot
from itables import init_notebook_mode
from matplotlib.colors import LogNorm
from matplotlib.ticker import AutoMinorLocator, MaxNLocator
from scipy import stats
from skimage.measure import LineModelND, ransac
from sklearn.cluster import DBSCAN
from sklearn.linear_model import RANSACRegressor
from skspatial.objects import Cylinder, Line, Triangle, Point
from tqdm.auto import tqdm

init_notebook_mode(all_interactive=True)

# Parameters

In [None]:
# File paths
file_label = "20230706_191437"  # label for files generated by this notebook
light_file = "../Data/rwf_0cd913fa_20230706_191437.data_deco_2.root"
charge_file = "../Data/evd_self_trigger-packets-2023_07_06_19_14_CEST_validated.root"
sipm_map_file = "sipm_sn_ch_to_xy.json"

# Plotting options
individual_plots = np.arange(1, 21, 1)
show_figures = True

# Events to process
event_list = None
# event_list = np.array([30, 50])
# event_list = np.arange(0, 101)

# Noisy Pixels
channel_disable_list = [7]

# Light variable to consider
light_variable = "integral"

# Units for plot labels
q_unit = "ke"  # After applying charge_gain
xy_unit = "mm"
z_unit = "mm"
dh_unit = "?" if z_unit != xy_unit else xy_unit
time_unit = "ns"
light_unit = "p.e."

# Conversion factors
charge_gain = 245  # mV to ke
detector_z = 300
detector_x = 128
detector_y = 160

# DBSCAN parameters for clustering
min_samples = 2
xy_epsilon = 8  # 8 ideal
z_epsilon = 8  # 8 ideal

# RANSAC parameters for line fitting
residual_threshold = 6  # 6 ideal
max_trials = 1000

# Force parameters for cylinder
force_dh = None
force_dr = None

# Functions

## Parameter calculators

In [None]:
# SiPMs mapping
sipm_map = None


def sipm_to_xy(sn, ch):
    global sipm_map
    if sipm_map is None:
        with open(sipm_map_file, "r") as f:
            sipm_map = json.load(f)

    xy = sipm_map.get(str(sn), {}).get(str(ch), None)
    if xy is None:
        return None
    else:
        x = xy[0] + 64
        y = xy[1] - 16
        return (x, y)


# Check if SiPMs on anode area
def get_sipm_mask(sn, ch):
    xy = sipm_to_xy(sn, ch)
    # return True
    if xy is None:
        return False
    else:
        return (
            xy[0] > -detector_x / 2
            and xy[0] < detector_x / 2
            and xy[1] < detector_y / 2
            and xy[1] > -detector_y / 2
        )

In [None]:
# Cylinder parameters for dQ/dx
def get_dh(unit_vector):
    if force_dh is not None:
        return force_dh

    dl_vector = np.array([xy_epsilon, xy_epsilon, z_epsilon])
    min_dh = np.linalg.norm(dl_vector) / 4
    max_dh = 2 * np.linalg.norm(dl_vector)
    dl_projection = abs(np.dot(unit_vector, dl_vector))
    dh = min(max(dl_projection, min_dh), max_dh)

    return dh


def get_dr(standard_deviation):
    if force_dr is not None:
        return force_dr

    dl_vector = np.array([xy_epsilon, xy_epsilon, z_epsilon])
    min_dr = np.linalg.norm(dl_vector) / 4
    dr = max(standard_deviation, min_dr)

    return dr

## Uproot functions

In [None]:
# Uproot
def load_charge(file_name, events=None):
    with uproot.open(file_name) as f:
        charge_df = f["events"].arrays(library="pd").set_index("eventID")
        if events is not None:
            charge_df = charge_df.loc[events]

    return charge_df


def load_light(file_name, deco=True, events=None, mask=True, keep_rwf=False):
    light_df = pd.DataFrame()
    with uproot.open(file_name) as f:
        if deco:
            tree = f["decowave"]
        else:
            tree = f["rwf_array"]

        for idx, arrays in enumerate(tree.iterate(library="np")):
            df = pd.DataFrame.from_dict(arrays, orient="index").T
            df.dropna()
            if events is not None:
                df = df[df["event"].isin(events)]

            if mask:
                df = df[
                    df[["sn", "ch"]].apply(lambda x: get_sipm_mask(x[0], x[1]), axis=1)
                ]

            if df.empty:
                continue

            df[["x", "y"]] = df[["sn", "ch"]].apply(
                lambda x: pd.Series(sipm_to_xy(x[0], x[1])), axis=1
            )

            if deco:
                df["rwf"] = df["decwfm"]

            df["peak"] = df["rwf"].apply(lambda x: max(x) - min(x))
            df["integral"] = df["rwf"].apply(lambda x: np.trapz(x[x.argmax()-10:x.argmax()+10]))

            columns = ["event", "tai_ns", "sn", "ch", "peak", "integral", "x", "y"]
            if keep_rwf:
                columns.append("rwf")

            df = df[columns]
            light_df = pd.concat([light_df, df], ignore_index=True)

    return light_df

## Data handling

In [None]:
match_dict = {}


def match_events(charge_df, light_df, window=10):
    match_dict = {}

    charge_events = charge_df[["event_unix_ts", "event_start_t"]].drop_duplicates()
    light_events = light_df[["tai_ns", "event"]].drop_duplicates()

    for event, row in tqdm(charge_events.iterrows(), total=len(charge_events)):
        charge_ts = (float(row["event_unix_ts"]) * 1e6) + (
            float(row["event_start_t"]) * 1e-1
        )
        light_matches = light_events.where(
            abs(light_events["tai_ns"].astype(float) * 1e-3 - 36000000 - charge_ts)
            <= window
        ).dropna()

        if not light_matches.empty:
            if event in match_dict:
                match_dict[event].append(light_matches.index)
            else:
                match_dict[event] = list(light_matches.index)

    return match_dict

In [None]:
# Create a list of fake data
def generate_dead_area(z_range):
    # Dead area on chips 44, 54, detector_x/2
    fake_x1, fake_y1, fake_z1 = np.meshgrid(
        np.linspace(36, 60, 6),
        np.concatenate(
            [
                np.linspace(-76, -52, 6),
                np.linspace(-44, -20, 6),
                np.linspace(-12, 12, 6),
            ]
        ),
        z_range,
    )

    # Dead area on chip 33
    fake_x2, fake_y2, fake_z2 = np.meshgrid(
        np.linspace(4, 28, 6), np.linspace(20, 44, 6), z_range
    )

    # Dead area on chip 42
    fake_x3, fake_y3, fake_z3 = np.meshgrid(
        np.linspace(-28, -4, 6), np.linspace(-12, 12, 6), z_range
    )
    mask = fake_x3 + (fake_y3 + 16) <= 0
    fake_x3, fake_y3, fake_z3 = fake_x3[mask], fake_y3[mask], fake_z3[mask]

    fake_x4, fake_y4, fake_z4 = np.meshgrid([-14], [2], z_range)

    # Dead area on SiPMs
    fake_x5 = []
    fake_y5 = []
    fake_z5 = []
    for k in range(4):
        for l in range(5):
            if (k == 3 and l < 3) or (k == 2 and l == 3) or (k == 1 and l == 2):
                continue

            temp_x, temp_y, temp_z = np.meshgrid(
                np.array([-50, -46]) + 32 * k,
                np.array([-66, -62]) + 32 * l,
                z_range,
            )

            fake_x5.extend(temp_x)
            fake_y5.extend(temp_y)
            fake_z5.extend(temp_z)

    fake_x5 = np.array(fake_x5)
    fake_y5 = np.array(fake_y5)
    fake_z5 = np.array(fake_z5)

    # Concatenate all the fake data
    fake_x = np.concatenate(
        [
            fake_x1.flatten(),
            fake_x2.flatten(),
            fake_x3.flatten(),
            fake_x4.flatten(),
            fake_x5.flatten(),
        ]
    )
    fake_y = np.concatenate(
        [
            fake_y1.flatten(),
            fake_y2.flatten(),
            fake_y3.flatten(),
            fake_y4.flatten(),
            fake_y5.flatten(),
        ]
    )
    fake_z = np.concatenate(
        [
            fake_z1.flatten(),
            fake_z2.flatten(),
            fake_z3.flatten(),
            fake_z4.flatten(),
            fake_z5.flatten(),
        ]
    )

    fake_data = np.c_[fake_x, fake_y, fake_z]

    return fake_data

In [None]:
# Apply DBSCAN clustering
def cluster(hitArray):
    # First stage clustering
    z_intervals = []
    first_stage = DBSCAN(eps=xy_epsilon, min_samples=min_samples).fit(hitArray[:, 0:2])
    for label in first_stage.labels_:
        if label > -1:
            mask = first_stage.labels_ == label
            z = hitArray[mask, 2]
            z_intervals.append((min(z), max(z)))

    # Sort the intervals based on their start points
    sorted_intervals = sorted(z_intervals, key=lambda interval: interval[0])

    # Initialize a list to store the intervals representing the empty space
    empty_space_ranges = []

    # Iterate through the sorted intervals to find the gaps
    for i in range(len(sorted_intervals) - 1):
        current_interval = sorted_intervals[i]
        next_interval = sorted_intervals[i + 1]

        # Calculate the gap between the current interval and the next interval
        gap_start = current_interval[1]
        gap_end = next_interval[0]

        # Check if there is a gap (empty space) between intervals
        if gap_end > gap_start and gap_end < gap_start + 40:
            empty_space_ranges.append(np.arange(gap_start, gap_end, z_epsilon))

    if not empty_space_ranges:
        empty_space_ranges.append(
            np.arange(
                np.mean(hitArray[:, 2]) - np.std(hitArray[:, 2]),
                np.mean(hitArray[:, 2]) + np.std(hitArray[:, 2]),
                z_epsilon,
            )
        )

    z_range = np.concatenate(empty_space_ranges)

    # Create a list of holes
    fake_data = generate_dead_area(z_range)
    fake_data_count = len(fake_data)

    # Second stage clustering
    # Combine fake to true data
    second_stage_data = np.concatenate([hitArray, fake_data])
    second_stage = DBSCAN(eps=xy_epsilon, min_samples=1).fit(second_stage_data[:, 0:2])

    # Third stage clustering
    # Create a new array with z and labels
    third_stage_z = np.c_[second_stage.labels_ * 1e3, second_stage_data[:, 2]]
    flag = second_stage.labels_ > -1
    third_stage_data = third_stage_z[flag].copy()
    third_stage = DBSCAN(
        eps=z_epsilon, min_samples=min_samples, metric="chebyshev"
    ).fit(third_stage_data)

    # Remove fake data
    # Shift labels by 1 so that negative values are reserved for outliers
    labels = third_stage.labels_[:-fake_data_count] + 1

    return labels


# Apply Ransac Fit
def ransacFit(hitArray, weightArray=None):
    if weightArray is not None:
        estimator = RANSACRegressor(
            min_samples=min_samples,
            max_trials=max_trials,
            residual_threshold=residual_threshold,
        )
        inliers = estimator.fit(
            hitArray[:, 0:2],
            hitArray[:, 2],
            sample_weight=weightArray,
        ).inlier_mask_

        score = estimator.score(hitArray[:, 0:2], hitArray[:, 2])
    else:
        model_robust, inliers = ransac(
            hitArray,
            LineModelND,
            min_samples=min_samples,
            residual_threshold=residual_threshold,
            max_trials=max_trials,
        )

        score = model_robust.score(hitArray)

    outliers = inliers == False
    return inliers, outliers, score


# Apply best line fit
def lineFit(hitArray):
    line = Line.best_fit(hitArray)
    residuals = []
    for point in hitArray:
        distance = line.distance_point(point)
        residuals.append(distance)

    # Convert residuals to a numpy array
    residuals = np.array(residuals)

    # Calculate chi-squared
    chi_squared = np.sum(residuals**2)

    return line, chi_squared


# Calculate dQ/dx from a line fit
def dqdx(hitArray, q, line_fit, dh, dr, h, ax=None):
    # Cylinder steps for dQ/dx
    steps = np.arange(-3 * dh, h + 2 * dh, dh)

    # Mask of points that have been accounted for
    counted = np.zeros(len(q), dtype=bool)

    # Array of dQ values for each step
    dq_i = np.zeros(len(steps), dtype=float)

    for step_idx, step in enumerate(steps):
        cylinder_fit = Cylinder(
            line_fit.to_point(h / 2 - step),
            -line_fit.direction.unit() * dh,
            dr,
        )
        if ax is not None:
            cylinder_fit.plot_3d(ax)

        for point_idx, point in enumerate(hitArray):
            if not counted[point_idx] and cylinder_fit.is_point_within(point):
                counted[point_idx] = True
                dq_i[step_idx] += q[point_idx]

    return dq_i


# Fit clusters with Ransac method
def fitClusters(
    hitArray, q, labels, ax2d=None, ax3d=None, plot_cyl=False, refit_outliers=True
):
    metrics = {}
    # Fit clusters
    idx = 0
    condition = lambda: idx < len(np.unique(labels))
    while condition():
        label = np.unique(labels)[idx]
        mask = labels == label
        if label > 0 and mask.sum() > min_samples:
            xyz_c = hitArray[mask]
            x_c, y_c, z_c = xyz_c[:, 0], xyz_c[:, 1], xyz_c[:, 2]
            q_c = np.array(q)[mask]

            norm = np.linalg.norm(
                [
                    x_c.max() - x_c.min(),
                    y_c.max() - y_c.min(),
                    z_c.max() - z_c.min(),
                ]
            )

            # Fit the model
            inliers, outliers, score = ransacFit(xyz_c, q_c - min(q_c) + 1)

            # Refit outliers
            level_1 = np.where(mask)[0]
            level_2 = np.where(outliers)[0]
            level_3 = level_1[level_2]

            if refit_outliers and sum(outliers) > min_samples:
                outlier_labels = cluster(xyz_c[outliers])
                last_label = max(labels) + 1
                # Assign positive labels to clustered outliers and negative labels to unlclustered outliers
                for i, j in enumerate(level_3):
                    labels[j] = (outlier_labels[i] + last_label) * (
                        1 if outlier_labels[i] > 0 else -1
                    )
            else:
                # Assign negative labels to outliers
                for j in level_3:
                    labels[j] = -labels[j]

            if sum(inliers) > min_samples:
                line_fit, chi_squared = lineFit(xyz_c[inliers])
                # Degrees of freedom (number of points - number of parameters in the line fit)
                degrees_of_freedom = sum(inliers) - 1
                # Calculate the reduced chi squared
                reduced_chi_squared = chi_squared / degrees_of_freedom
                # Calculate the p-value (assuming chi-squared distribution)
                p_value = 1 - stats.chi2.cdf(chi_squared, degrees_of_freedom)

                if ax2d is not None:
                    # 2D plot
                    line_fit.plot_2d(
                        ax2d,
                        t_1=-norm / 2,
                        t_2=norm / 2,
                        c="red",
                        label=f"Track {label}",
                        zorder=10,
                    )
                if ax3d is not None:
                    # 3D plot
                    line_fit.plot_3d(
                        ax3d,
                        t_1=-norm / 2,
                        t_2=norm / 2,
                        c="red",
                        label=f"Track {label}",
                    )

                # Calculate dQ/dx
                dh = get_dh(line_fit.direction)
                dr = get_dr(np.sqrt(reduced_chi_squared))

                dq_i = dqdx(
                    xyz_c[inliers],
                    q_c[inliers],
                    line_fit,
                    dh=dh,
                    dr=dr,
                    h=norm,
                    ax=ax3d if ax3d is not None and plot_cyl else None,
                )

                q_eff = dq_i.sum() / q_c[inliers].sum()
                if dq_i.sum() != 0:
                    dq = dq_i
                else:
                    dq = 0

                metrics[label] = {
                    "Fit_line": line_fit,
                    "Fit_norm": norm,
                    "Fit_p_value": p_value,
                    "RANSAC_score": score,
                    "q_eff": q_eff,
                    "dQ": dq,
                    "dx": dh,
                }

        idx = np.unique(labels).tolist().index(label) + 1

    metrics["Fit_labels"] = labels

    return metrics

In [None]:
def light_geometry(track_line, track_norm, sipm_df, light_variable="integral"):
    metrics = {"Light_distance": {}, "Light_angle": {}, f"Light_{light_variable}": {}}
    point1 = track_line.to_point(-track_norm / 2)
    point2 = track_line.to_point(track_norm / 2)
    centre = track_line.point

    iterate_df = sipm_df.dropna().copy()
    for row, sipm in iterate_df.iterrows():
        sipm_idx = (sipm["sn"], sipm["ch"])
        point3 = Point([sipm["x"], sipm["y"], 0])
        triangle = Triangle(point1, point2, point3)
        angle = triangle.angle("C")
        distance = point3.distance_point(centre)
        metrics["Light_distance"][sipm_idx] = distance
        metrics["Light_angle"][sipm_idx] = angle
        metrics[f"Light_{light_variable}"][sipm_idx] = sipm[light_variable]

    return metrics

## Plotting

In [None]:
def create_axes(event_idx, charge):
    fig = plt.figure(figsize=(14, 6))
    ax3d = fig.add_subplot(121, projection="3d")
    ax2d = fig.add_subplot(122)
    fig.suptitle(f"Event {event_idx} - Charge = {charge} {q_unit}")
    grid_color = plt.rcParams["grid.color"]

    # Draw dead areas
    for i in range(4):
        j = 0
        if i == 3:
            j = 1
        ax2d.plot(
            np.arange(32, 64 + 1, 1) - 32 * j,
            np.arange(-80, -48 + 1, 1) + 32 * i,
            c=grid_color,
            lw=1,
        )
        ax2d.plot(
            np.arange(64, 32 - 1, -1) - 32 * j,
            np.arange(-80, -48 + 1, 1) + 32 * i,
            c=grid_color,
            lw=1,
        )

    # ax2d.plot(np.linspace(-31,0),np.linspace(8,-16), c=grid_color,lw = 1)
    ax2d.plot(np.linspace(-32, 0), np.linspace(16, -16), c=grid_color, lw=1, zorder=-1)
    ax2d.plot(np.linspace(-32, -16), np.linspace(-16, 0), c=grid_color, lw=1, zorder=-1)

    # Adjust axes
    for ax in [ax3d, ax2d]:
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlim([-detector_x / 2, detector_x / 2])
        ax.set_ylim([-detector_y / 2, detector_y / 2])
        ax.set_xlabel(f"x [{xy_unit}]")
        ax.set_ylabel(f"y [{xy_unit}]")
        ax.set_xticks(np.linspace(-detector_x / 2, detector_x / 2, 5))
        ax.set_yticks(np.linspace(-detector_y / 2, detector_y / 2, 6))
        ax.grid()

    ax2d.xaxis.set_minor_locator(AutoMinorLocator(8))
    ax2d.yaxis.set_minor_locator(AutoMinorLocator(8))
    ax2d.tick_params(axis="both", which="both", right=True, top=True)

    # Adjust z-axis
    ax3d.set_zlabel(f"z [{z_unit}]")
    # ax3d.zaxis.set_major_locator(MaxNLocator(integer=True))

    return fig, (ax2d, ax3d)


def event_display(
    event_idx,
    charge_x_array,
    charge_y_array,
    charge_z_array,
    charge_time_array,
    charge_array,
    light_x_array=[],
    light_y_array=[],
    light_array=[],
    plot_cyl=False,
):
    if len(charge_x_array) < 2:
        return None

    # Plot the hits
    fig, axes = create_axes(event_idx, round(sum(charge_array)))
    ax2d = axes[0]
    ax3d = axes[1]

    # Group by x and y coordinates and sum the z values
    data2d = np.c_[charge_x_array, charge_y_array, charge_array]
    unique_points, indices = np.unique(data2d[:, :2], axis=0, return_inverse=True)
    q_sum = np.bincount(indices, weights=data2d[:, 2])

    # Plot the hits
    plot3d = ax3d.scatter(
        charge_x_array,
        charge_y_array,
        charge_z_array,
        c=charge_array,
        marker="s",
        s=30,
        vmin=q_sum.min(),
        vmax=q_sum.max(),
    )
    plot2d = ax2d.scatter(
        unique_points[:, 0],
        unique_points[:, 1],
        c=q_sum,
        marker="s",
        s=40,
        vmin=q_sum.min(),
        vmax=q_sum.max(),
    )
    cbar = plt.colorbar(plot2d)
    cbar.set_label(f"charge [{q_unit}]")

    # Create a design matrix
    xyz = np.c_[charge_x_array, charge_y_array, charge_z_array]

    # Cluster the hits
    labels = cluster(xyz)

    # Fit clusters
    metrics = fitClusters(xyz, charge_array, labels, ax2d, ax3d, plot_cyl)

    # Draw missing SiPMs
    grid_color = plt.rcParams["grid.color"]

    vertices_x = np.array([3, 3, -3, -3, 3])
    vertices_y = np.array([3, -3, -3, 3, 3])
    light_xy = zip(light_x_array, light_y_array)
    for missing_index in range(20):
        col = 48 - (missing_index % 4) * 32
        row = -64 + (missing_index // 4) * 32
        if (col, row) not in light_xy:
            ax2d.fill(col + vertices_x, vertices_y - row, c=grid_color, zorder=5)

    # Draw SiPMs
    sipm_plot = ax2d.scatter(
        light_x_array,
        light_y_array,
        c=light_array,
        marker="s",
        s=200,
        linewidths=1.5,
        edgecolors=grid_color,
        zorder=6,
    )
    sipm_cbar = plt.colorbar(sipm_plot)
    sipm_cbar.set_label(rf"Light {light_variable} [{light_unit}]")

    # ax2d.legend()
    # ax3d.legend()

    # ax3d.view_init(160, 110, -85)
    ax3d.view_init(30, 20, 100)
    # ax3d.view_init(0, 0, 0)
    # ax3d.view_init(0, 0, 90)
    fig.tight_layout()

    fig.savefig(f"event_{event_idx}.png", dpi=300)
    # plt.show()

    return metrics

In [None]:
# Plot dQ versus X
def plot_dQ(dQ_array, event_idx, track_idx, dh, interpolate=False):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax_twinx = ax.twinx()

    fig.suptitle(
        rf"Event {event_idx} - Track {track_idx} - $dx = {round(dh,2)}$ {dh_unit}"
    )

    mean_dQ = np.mean(dQ_array[dQ_array > 0])
    non_zero_indices = np.where(dQ_array > 0)[0]

    # Check if there are non-zero values in dQ_array
    if non_zero_indices.size > 0:
        # Find the first non-zero index and add 2 indices before it
        first_index = max(non_zero_indices[0] - 2, 0)

        # Find the last non-zero index and add 2 indices after it
        last_index = min(non_zero_indices[-1] + 2, len(dQ_array))

        new_dQ_array = dQ_array.copy()[first_index:last_index]

        if interpolate:
            new_dQ_array[1:-1] = np.where(
                new_dQ_array[1:-1] == 0,
                mean_dQ,
                new_dQ_array[1:-1],
            )

        dQ_array = new_dQ_array

    ax.axhline(
        mean_dQ / dh,
        ls="--",
        c="red",
        label=rf"Mean = ${round(mean_dQ/dh,2)}$ {q_unit} {dh_unit}$^{{-1}}$",
        lw=1,
    )
    x_range = np.arange(0, len(dQ_array) * dh, dh)[: len(dQ_array)]

    ax.step(x_range, dQ_array / dh, where="mid")
    ax.set_xlabel(rf"$x$ [{dh_unit}]")
    ax.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]")

    ax_twinx.step(x_range, np.cumsum(dQ_array), color="C1", where="mid")
    ax_twinx.set_ylabel(f"Q [{q_unit}]")

    for axes in [ax, ax_twinx]:
        axes.xaxis.set_minor_locator(AutoMinorLocator())
        axes.yaxis.set_minor_locator(AutoMinorLocator())
        axes.xaxis.set_major_locator(MaxNLocator(integer=True))
        axes.yaxis.set_major_locator(MaxNLocator(integer=True))
        axes.tick_params(axis="both", direction="inout", which="major", top=True)

    h1, l1 = ax.get_legend_handles_labels()
    ax_twinx.legend(h1, l1, loc="lower center")

    ax.legend(loc="lower center")

    fig.tight_layout()
    fig.savefig(f"dQ_{event_idx}_{track_idx}_{round(dh,2)}.png", dpi=300)

In [None]:
def plot_track_stats(
    metrics, limit_xrange=True, p_value_limit=0.05, empty_ratio_lims=(0, 1)
):
    track_mean_dQdx = []
    track_length = []
    track_p_value = []

    empty_count = 0
    for entry in metrics.values():
        for track, values in entry.items():
            if isinstance(track, str) or track <= 0:
                continue

            dQ = values["dQ"]
            non_zero_mask = np.where(dQ > 0)[0]

            empty_ratio = sum(dQ[non_zero_mask[0] : non_zero_mask[-1] + 1] == 0) / (
                non_zero_mask[-1] - non_zero_mask[0] + 1
            )

            if empty_ratio > empty_ratio_lims[1] or empty_ratio < empty_ratio_lims[0]:
                empty_count += 1
                continue

            dQdx = dQ / values["dx"]
            track_mean_dQdx.append(np.mean(dQdx[dQdx > 0]))
            track_length.append(values["Fit_norm"])
            track_p_value.append(values["Fit_p_value"])

    print(f"Tracks with dead area outside {empty_ratio_lims} interval: {empty_count}")

    track_mean_dQdx = pd.Series(track_mean_dQdx)
    track_length = pd.Series(track_length)
    track_p_value = pd.Series(track_p_value)
    mask = track_mean_dQdx.notna() * track_length.notna() * track_p_value.notna()

    print(f"Remaining tracks: {sum(mask)}")

    track_mean_dQdx = track_mean_dQdx[mask]
    track_length = track_length[mask]
    track_p_value = track_p_value[mask]

    p_mask = track_p_value <= p_value_limit

    print(f"Tracks with p_value > {p_value_limit}: {sum(mask)-sum(p_mask)}")

    print(f"remaining_tracks: {sum(p_mask)}")

    # 1D histograms
    fig1 = plt.figure(figsize=(14, 6))

    ax11 = fig1.add_subplot(121)
    n_all1, bins_all1, patches_all1 = ax11.hist(track_mean_dQdx, label="All tracks")
    ax11.hist(
        track_mean_dQdx[p_mask],
        bins=bins_all1,
        label=r"p_value $\leq 0.05$",
    )
    ax11.set_xlabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]")
    ax11.set_title(f"{len(track_mean_dQdx)} tracks")

    ax12 = fig1.add_subplot(122)
    n_all2, bins_all2, patches_all2 = ax12.hist(track_length, label="All tracks")
    ax12.hist(track_length[p_mask], bins=bins_all2, label=r"p_value $\leq 0.05$")
    ax12.set_title(f"{len(track_length)} tracks")

    for ax in [ax11, ax12]:
        ax.set_ylabel("Counts")

    # 2D histograms
    fig2 = plt.figure(figsize=(14, 6))
    ax21 = fig2.add_subplot(121)

    hist2d1 = ax21.hist2d(track_length, track_mean_dQdx, bins=[40, 20], norm=LogNorm())
    ax21.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]")
    ax21.set_title("dQ/dx vs. Track length")

    fit21 = np.polyfit(np.log(track_length), track_mean_dQdx, 1)
    # fit = np.polyfit(track_length, track_mean_dQdx, 2)
    p21 = np.poly1d(fit21)
    x = np.arange(min(track_length), max(track_length), 1)
    ax21.plot(x, p21(np.log(x)), c="salmon", ls="-", label="Log fit")
    # ax.plot(x, p(x), c="C3", ls="--", label="Polyfit")

    # # Plot mean
    # sum_y_values = np.sum(hist2d1[0], axis=1)
    # mean_y_values = np.sum(hist2d1[0] * hist2d1[2][:-1], axis=1) / sum_y_values

    # print(mean_y_values)
    # ax21.plot(
    #     0.5 * (hist2d1[1][:-1] + hist2d1[1][1:]),
    #     mean_y_values,
    #     c="C3",
    #     ls="--",
    #     label="Mean",
    # )

    ax22 = fig2.add_subplot(122)
    hist2d2 = ax22.hist2d(
        track_length, track_p_value, bins=[40, 20], norm=LogNorm()
    )  # , cmin=1, bins = 15)
    ax22.set_ylabel(f"Fit p_value")
    ax22.set_title("Fit p_value vs. Track length")

    # 2D histograms after p_value cut
    fig3 = plt.figure(figsize=(8, 6))
    ax3 = fig3.add_subplot(111)

    hist2d3 = ax3.hist2d(
        track_length[p_mask], track_mean_dQdx[p_mask], bins=[40, 20], norm=LogNorm()
    )  # , cmin=1, bins = 15)
    ax3.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]")

    fit3 = np.polyfit(np.log(track_length[p_mask]), track_mean_dQdx[p_mask], 1)
    # fit = np.polyfit(track_length, track_mean_dQdx, 2)
    p3 = np.poly1d(fit3)
    x = np.arange(min(track_length[p_mask]), max(track_length[p_mask]), 1)
    ax3.plot(x, p3(np.log(x)), c="salmon", ls="-", label="Log fit")
    ax3.set_title(
        rf"dQ/dx vs. Track length - Fit p_value $\leq 0.05$ ({round(sum(p_mask)/len(p_mask)*100)}% of tracks)"
    )

    for ax in [ax11, ax12, ax21, ax22, ax3]:
        ax.tick_params(
            axis="both", direction="inout", which="major", right=True, top=True
        )
        if ax != ax11:
            ax.set_xlabel(f"Track length [{dh_unit}]")
            if max(track_length) > detector_y:
                ax.axvline(detector_y, c="g", ls="--", label="Max vertical length")
            if max(track_length) > max_track_legth_xy:
                ax.axvline(
                    max_track_legth_xy, c="orange", ls="--", label=r"Max length in $xy$"
                )
            if max(track_length) > max_track_legth:
                ax.axvline(max_track_legth, c="r", ls="--", label="Max length")

            if ax != ax12:
                if limit_xrange:
                    xlim = ax.get_xlim()
                    ax.set_xlim(xlim[0], min(max_track_legth + 10, xlim[1]))
                ax.set_xlabel(f"Track length [{dh_unit}]")
                cbar = plt.colorbar(ax.collections[0])
                cbar.set_label("Counts [Log]")

        ax.legend(loc="upper right")

    for fig in [fig1, fig2, fig3]:
        fig.tight_layout()

    fig1.savefig(f"track_stats_1D_hist_{len(track_mean_dQdx)}.png", dpi=300)
    fig2.savefig(f"track_stats_2D_hist_{len(track_mean_dQdx)}.png", dpi=300)
    fig3.savefig(f"track_stats_2D_hist_cut_{len(track_mean_dQdx)}.png", dpi=300)

# File loading

In [None]:
charge_df = load_charge(charge_file)
# charge_df = pd.read_csv(f"charge_df_{file_label}.bz2")

In [None]:
# Clean up charge dataframe

# Remove events with negative charge hits and without light trigger
charge_mask = (
    charge_df["event_hits_q"].apply(tuple).explode().groupby("eventID").min() > 0
) * (charge_df["trigID"].apply(len) > 0)
charge_df = charge_df[charge_mask]

print(f"Removed events: {charge_mask.count()-charge_mask.sum()}/{charge_mask.count()}")

In [None]:
light_df = load_light(light_file, deco="deco" in light_file)
# light_df = pd.read_csv(f"light_df_{file_label}.bz2")

In [None]:
# Clean up light dataframe
match_dict = match_events(charge_df, light_df)
# with open(f"match_dict_{file_label}.json", "r") as f:
#     match_dict = json.load(f)

# Remove events without charge match
light_events = ak.flatten(match_dict.values())
light_df = light_df[light_df["event"].isin(light_events)]

charge_df = charge_df.loc[match_dict.keys()]

In [None]:
light_df[["peak","integral"]].describe()

# File saving

In [None]:
charge_df.to_csv(f"charge_df_{file_label}.bz2")
light_df.to_csv(f"light_df_{file_label}.bz2")

with open(f"match_dict_{file_label}.json", "w") as f:
    json.dump(match_dict, f)

# File verification

In [None]:
charge_df.columns

In [None]:
light_df.columns

## Histograms

In [None]:
print("Trigger time distribution")
charge_df["trig_time"].apply(np.mean).hist()

In [None]:
print(f"Event duration in {time_unit}")
charge_df["event_duration"].hist()

In [None]:
print(f"Charge per hit in {q_unit}")
(charge_df["event_q"] / charge_df["event_nhits"]).hist(bins=50)

In [None]:
print(f"Charge per hit per event in {q_unit}")
(charge_df["event_q"] / charge_df["event_nhits"]).to_frame().reset_index().plot.scatter(
    x="eventID", y=0
)

In [None]:
print(f"Event charge in {q_unit}")
charge_df["event_q"].hist()

In [None]:
print(f"Hits q in {q_unit}")
charge_df["event_hits_q"].apply(tuple).explode().hist()

In [None]:
print(f"Hits z in {z_unit}")
charge_df["event_hits_z"].apply(tuple).explode().hist()

# Event display

## Fake data map

In [None]:
def plot_fake_data(z_range):
    fake_data = generate_dead_area(z_range)
    fake_x, fake_y, fake_z = fake_data[:, 0], fake_data[:, 1], fake_data[:, 2]

    fig = plt.figure()

    ax = fig.add_subplot(111)
    ax.set_aspect("equal", adjustable="box")
    ax.scatter(fake_x, fake_y, marker="s", s=20)
    ax.set_xlim([-detector_x / 2, detector_x / 2])
    ax.set_ylim([-detector_y / 2, detector_y / 2])
    ax.set_xlabel(f"x [{xy_unit}]")
    ax.set_ylabel(f"y [{xy_unit}]")
    ax.set_xticks(np.linspace(-detector_x / 2, detector_x / 2, 5))
    ax.set_yticks(np.linspace(-detector_y / 2, detector_y / 2, 6))
    ax.xaxis.set_minor_locator(AutoMinorLocator(8))
    ax.yaxis.set_minor_locator(AutoMinorLocator(8))
    ax.grid()
    ax.tick_params(axis="both", which="both", top=True, right=True)

    fig.savefig("fake_data_map.png", dpi=300)


plot_fake_data(np.arange(-detector_z, 0, 10))

plt.show()

## Data fit

In [None]:
# Suppress the UndefinedMetricWarning
warnings.filterwarnings("ignore", category=Warning)

In [None]:
metrics = {}

if event_list is None:
    index_list = charge_df.index
else:
    index_list = charge_df.index.intersection(event_list)

for idx in tqdm(index_list):
    row = charge_df.loc[idx]

    charge_channelid_array = np.array(row["event_hits_channelid"])
    charge_x_array = np.array(row["event_hits_x"])
    charge_y_array = np.array(row["event_hits_y"])
    charge_z_array = np.array(row["event_hits_z"])
    charge_time_array = np.array(row["event_hits_ts"])
    charge_array = np.array(row["event_hits_q"])

    non_zero_mask = (charge_x_array != 0) * (
        charge_y_array != 0
    )  # Remove (0,0) entries
    noisy_channels_mask = np.isin(
        charge_channelid_array, channel_disable_list, invert=True
    )  # Disable channel 7
    mask = non_zero_mask * noisy_channels_mask  # Full hits mask

    # Apply boolean indexing to x, y, and z arrays
    charge_x_array = charge_x_array[mask]
    charge_y_array = charge_y_array[mask]
    charge_z_array = charge_z_array[mask]
    charge_time_array = charge_time_array[mask]
    charge_array = charge_array[mask] * charge_gain  # Convert mV to ke

    # SiPMs are mapped 1 to 20 in a 4x5 grid. Channel needs to be assigned to index.
    light_event = match_dict.get(idx, [idx])[0]
    light_values = light_df[light_df["event"] == light_event]
    light_x_array = light_values["x"]
    light_y_array = light_values["y"]
    light_array = light_values[light_variable]

    if len(charge_x_array) > 2:
        if idx in individual_plots:
            metrics[idx] = event_display(
                idx,
                charge_x_array,
                charge_y_array,
                charge_z_array,
                charge_time_array,
                charge_array,
                light_x_array,
                light_y_array,
                light_array,
                plot_cyl=True,
            )
            if show_figures:
                plt.show()
            else:
                plt.close()
        else:
            # Create a design matrix
            xyz = np.c_[charge_x_array, charge_y_array, charge_z_array]
            # Cluster the hits
            labels = cluster(xyz)
            # Fit clusters
            metrics[idx] = fitClusters(xyz, charge_array, labels)

        # Light metrics
        for track_idx, values in metrics[idx].items():
            if "Fit_line" not in values:
                continue
            light_metrics = light_geometry(
                track_line=values["Fit_line"],
                track_norm=values["Fit_norm"],
                sipm_df=light_values,
                light_variable=light_variable,
            )

            values["Light_distance"] = light_metrics["Light_distance"]
            values["Light_angle"] = light_metrics["Light_angle"]
            values[f"Light_{light_variable}"] = light_metrics[f"Light_{light_variable}"]

        metrics[idx][
            "Pixel_mask"
        ] = mask  # Save masks to original dataframe for reference

    # else:
    #     print(f"Event {idx}: Not enough hits to fit")

In [None]:
# Reset the warning filter (optional)
warnings.filterwarnings("default", category=Warning)

## Metrics

In [None]:
metrics

In [None]:
# Save the metrics to a pickle file
with open(f"metrics_{file_label}.pkl", "wb") as f:
    pickle.dump(metrics, f)

print(f"Metrics saved to metrics_{file_label}.pkl")

In [None]:
# # Load metrics from pickle file
# with open(f"metrics_{file_label}.pkl", "rb") as f:
#     metrics = pickle.load(f)

# dQ/dx

In [None]:
max_track_legth = np.sqrt(detector_x**2 + detector_y**2 + detector_z**2)
max_track_legth_xy = np.sqrt(detector_x**2 + detector_y**2)

print("Max possible track length", round(max_track_legth, 2), "mm")
print("Max possible track lengt on xy plpane", round(max_track_legth_xy, 2), "mm")
print("Max possible vertical track length", detector_y, "mm")

In [None]:
plot_track_stats(metrics, limit_xrange=False, empty_ratio_lims=(0.0, 1))
if show_figures:
    plt.show()
else:
    plt.close()

In [None]:
for event_idx in tqdm(individual_plots, leave=False):
    if event_idx in metrics:
        for track_idx, values in metrics[event_idx].items():
            if not isinstance(track_idx, str) and track_idx > 0:
                dQ_array = values["dQ"]
                dh = values["dx"]
                plot_dQ(dQ_array, event_idx, track_idx, dh, interpolate=False)

                if show_figures:
                    plt.show()
                else:
                    plt.close()

# Light Geometry

In [None]:
distance = []
angle = []
light = []
charge = []

for event_idx in tqdm(index_list, leave=False):
    if event_idx in metrics:
        for track_idx, values in metrics[event_idx].items():
            if not isinstance(track_idx, str) and track_idx > 0:
                distance.extend(values["Light_distance"].values())
                angle.extend(values["Light_angle"].values())
                light.extend(values[f"Light_{light_variable}"].values())
                charge.extend(np.full(len(values["Light_angle"]), sum(values["dQ"])))

distance = np.array(distance)
angle = np.array(angle)
charge = np.array(charge)
light = np.array(light)

distance = distance[~np.isnan(light)]
angle = angle[~np.isnan(light)]
charge = charge[~np.isnan(light)]
light = light[~np.isnan(light)]

# distance = distance[(light < 1) * (light > 0)]
# angle = angle[(light < 1) * (light > 0)]
# charge = charge[(light < 1) * (light > 0)]
# light = light[(light < 1) * (light > 0)]

hist1 = plt.hist2d(
    distance, np.degrees(angle), weights=abs(light), bins=[20, 20], norm=LogNorm()
)

cbar = plt.colorbar(hist1[3])
cbar.set_label(rf"Light_{light_variable} [{light_unit}]")

plt.xlabel(f"Distance from track [{dh_unit}]")
plt.ylabel(f"SiPM opening angle to track centre [deg]")

if show_figures:
    plt.show()
else:
    plt.close()

hist2 = plt.hist2d(np.degrees(angle), light, bins=[20, 20], cmin=0, norm=LogNorm())
plt.xlabel(f"SiPM opening angle to track centre [deg]")
plt.ylabel(f"Light {light_variable}")
cbar = plt.colorbar(hist2[3])
cbar.set_label(rf"Counts [Log]")
plt.show()
hist3 = plt.hist2d(distance, light, bins=[20, 20], norm=LogNorm())
plt.xlabel(f"Distance from track [{dh_unit}]")
plt.ylabel(f"Light_{light_variable}")
# plt.yscale("log")
cbar = plt.colorbar(hist3[3])
cbar.set_label(rf"Counts [Log]")
plt.show()
hist4 = plt.hist2d(charge, light, bins=[20, 20], norm=LogNorm())
plt.xlabel(f"Track total charge [{q_unit}]")
plt.ylabel(f"Light_{light_variable}")
cbar = plt.colorbar(hist4[3])
cbar.set_label(rf"Counts [Log]")
plt.show()

# Light wave forms

In [None]:
if "rwf" in light_df.columns:
    for event_idx in tqdm(individual_plots, leave=True):
        values = light_df[
            (light_df["event"] == match_dict.get(event_idx, [event_idx])[0])
        ].sort_values(by=["sn", "ch"])
        for row_idx, row in values.iterrows():
            ch_idx = row["ch"]
            sn_idx = row["sn"]
            rwf = row["rwf"]

            fig = plt.figure()
            ax = plt.subplot(111)

            fig.suptitle(f"Event {event_idx} - ADC {sn_idx} - Channel {ch_idx}")
            ax.plot(rwf[:-1])
            ax.axhline(np.median(rwf), ls="--", c="red")

            ax.set_ylabel(f"Light waveform [{light_unit}]")
            ax.set_xlabel("Time [Arbitrary units]")

            fig.tight_layout()

            if show_figures:
                plt.show()
            else:
                plt.close()