# Imports

In [None]:
import glob
import json
import os
import pickle
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pylandau
from itables import init_notebook_mode
from matplotlib.colors import LinearSegmentedColormap, LogNorm, to_rgba
from matplotlib.ticker import AutoMinorLocator, MaxNLocator, ScalarFormatter
from scipy.optimize import curve_fit
from sklearn.cluster import DBSCAN, KMeans
from skspatial.objects import Cylinder, Line, Plane, Point, Triangle
from tqdm.auto import tqdm

init_notebook_mode(all_interactive=True)

# Parameters

README 

- `non_track_keys = 5` has to be updated if new keys are added to an event in metrics

In [None]:
metrics_file = "combined/metrics.pkl"
# To load all filter parameters from saved JSON if it is not None and exists. Will overwrite all variables.
filter_file = None  # "combined/filter_parameters_27786.json"

In [None]:
# Save options
save_figures = True

# Plotting options
individual_plots = [f"20230706_191437_7932"]  # [f"{i}" for i in np.arange(1, 101, 1)]
show_figures = True
label_font_size = 16
tick_font_size = 16
title_font_size = 18

# Light variable to consider
light_variable = "integral"

# Units for plot labels
q_unit = "e"  # After applying charge_gain
xy_unit = "mm"
z_unit = "mm"
dh_unit = "?" if z_unit != xy_unit else xy_unit
time_unit = "ns"
light_unit = "p.e." if light_variable == "integral" else "p.e./time bin"

# Conversion factors
# charge_gain = 245  # mV to e
detector_z = 300
detector_x = 128
detector_y = 160
quadrant_size = 32  # One SiPM + LArPix cell

# Filters for post processing if not using filter parameters file
score_cutoff = -1.0
max_score = 1.0
min_track_length = 160
max_track_length = np.inf
max_tracks = 1
max_light = np.inf
min_light = 0
max_z = np.inf

# Other
non_track_keys = 5

In [None]:
file_label = "_".join(metrics_file.split(".")[0].split("/")[:-1])
if file_label == "":
    file_label = "combined"

# Functions

## Helpers

In [None]:
def get_track_stats(metrics, empty_ratio_lims=(0, 1), min_entries=2):
    track_dQdx = []
    track_length = []
    track_score = []
    track_z = []
    track_points = []
    events = []

    empty_count = 0
    short_count = 0
    for event, entry in metrics.items():
        for track, values in entry.items():
            if isinstance(track, str) or track <= 0:
                continue

            dQ = values["dQ"]
            dx = values["dx"]
            non_zero_mask = np.where(dQ > 0)[0]

            if len(dQ[non_zero_mask]) < min_entries:
                short_count += 1
                continue

            empty_ratio = sum(dQ[non_zero_mask[0] : non_zero_mask[-1] + 1] == 0) / (
                non_zero_mask[-1] - non_zero_mask[0] + 1
            )

            if empty_ratio > empty_ratio_lims[1] or empty_ratio < empty_ratio_lims[0]:
                empty_count += 1
                continue

            dQdx = dQ[non_zero_mask[0] : non_zero_mask[-1] + 1] / dx
            x_range = np.arange(0, len(dQdx) * dx, dx)[: len(dQdx)]
            position = [
                values["Fit_line"].to_point(t=-(len(dQ) / 2) * dx + t * dx - dx / 2)
                for t in range(len(dQ))
            ]
            position = position[non_zero_mask[0] : non_zero_mask[-1] + 1]

            track_dQdx.append(pd.Series(dQdx, index=x_range, name="dQdx"))
            track_points.append(pd.Series(position, index=x_range, name="position"))
            track_length.append(values["Fit_norm"])
            track_score.append(values["RANSAC_score"])
            track_z.append(values["Fit_line"].point[2])
            events.append(event)

    print(f"Tracks with dead area outside {empty_ratio_lims} interval: {empty_count}")
    print(f"Tracks with less than {min_entries} entries: {short_count}")

    track_dQdx = pd.Series(track_dQdx)
    track_points = pd.Series(track_points)
    track_length = pd.Series(track_length)
    track_score = pd.Series(track_score)
    track_z = pd.Series(track_z)
    events = pd.Series(events)

    mask = (
        track_dQdx.apply(lambda x: x.notna().all())
        * track_length.notna()
        * track_score.notna()
        * track_z.notna()
    )

    print(f"Remaining tracks: {sum(mask)}")

    track_dQdx = track_dQdx[mask]
    track_points = track_points[mask]
    track_length = track_length[mask]
    track_score = track_score[mask]
    track_z = track_z[mask]
    events = events[mask]

    df = pd.DataFrame(
        [track_dQdx, track_points, track_length, track_score, track_z, events],
        index=[
            "track_dQdx",
            "track_points",
            "track_length",
            "track_score",
            "track_z",
            "event",
        ],
    ).T

    return df

In [None]:
def max_std(array, ax=None, array_max=None, min_count_ratio=0.9, max_std_ratio=0.5):
    max_std = array.std()
    max_count = len(array)
    if array_max is None:
        array_max = np.percentile(array, 99 + (min_count_ratio >= 0.99))

    std = []
    count = []
    x_range = range(int(min(array)), (int(array_max) + 1), 1)
    for i in x_range:
        cut = array[array < i]
        std.append(cut.std())
        count.append(len(cut))

    std = np.array(std)
    count = np.array(count)
    condition = ((count / max_count).round(3) >= min_count_ratio) & (
        (std / max_std).round(3) <= max_std_ratio
    )
    vline = x_range[
        (
            np.where(condition)[0][-1]
            if np.any(condition)
            else (count / max_count > min_count_ratio).argmax()
        )
    ]

    print(
        "Max STD ratio",
        max_std_ratio,
        "limited to",
        min_count_ratio * 100,
        "% of events:",
        vline,
        "\n",
    )
    if ax is not None:
        ax.plot(std / max_std, label="STD ratio")
        ax.plot(count / max_count, label="Event count ratio")
        ax.axvline(vline, ls="--", c="r", label=f"{min_count_ratio*100}% of events")
        ax.legend()
        ax.tick_params(
            axis="both", direction="inout", which="major", top=True, right=True
        )
        ax.xaxis.set_minor_locator(AutoMinorLocator())
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax.yaxis.set_minor_locator(AutoMinorLocator())
        ax.grid(alpha=0.25)

    return vline

In [None]:
def cluster_hot_bins(min_n_ratio, n, x_edges, y_edges, scale=(1, 1), eps=8):
    n[np.isnan(n)] = 0
    min_n = min_n_ratio * n.max()
    filtered_n2 = n[n >= min_n]
    bin_centers_x = 0.5 * (x_edges[1:] + x_edges[:-1])
    bin_centers_y = 0.5 * (y_edges[1:] + y_edges[:-1])
    filtered_centers_x, filtered_centers_y = np.array(
        np.meshgrid(bin_centers_x, bin_centers_y)
    )
    filtered_centers_x = filtered_centers_x[n.T > min_n]
    filtered_centers_y = filtered_centers_y[n.T > min_n]
    dbscan = DBSCAN(
        eps=eps, min_samples=int(np.sqrt(len(filtered_centers_y))), metric="chebyshev"
    ).fit(np.c_[filtered_centers_x / scale[0], filtered_centers_y / scale[1]])
    return filtered_centers_x, filtered_centers_y, dbscan.labels_

In [None]:
def filter_metrics(
    metrics,
    min_score=0.5,
    max_score=np.inf,
    min_track_length=160,
    max_track_length=np.inf,
    max_tracks=1,
    min_light=0,
    max_light=100,
    max_z=300,
):
    print(f"min_score = {min_score}")
    print(f"max_score = {max_score}")
    print(f"min_track_length = {min_track_length}")
    print(f"max_track_length = {max_track_length}")
    print(f"max_tracks = {max_tracks}")
    print(f"min_light = {min_light}")
    print(f"max_light = {max_light}")
    print(f"max_z = {max_z}")

    filtered_metrics = {}

    for event_idx, metric in metrics.items():
        if (
            len(metric) <= max_tracks + non_track_keys
            and metric["Total_light"] <= max_light
            and metric["Total_light"] >= min_light
        ):
            candidate_metric = {
                track_idx: values
                for track_idx, values in metric.items()
                if isinstance(track_idx, str)
                or (
                    track_idx > 0
                    and values["RANSAC_score"] >= min_score
                    and values["RANSAC_score"] <= max_score
                    and values["Fit_norm"] >= min_track_length
                    and values["Fit_norm"] <= max_track_length
                    and values["Fit_line"].point[2] < max_z
                )
            }
            if (
                len(candidate_metric) <= max_tracks + non_track_keys
                and len(candidate_metric) > non_track_keys
            ):
                filtered_metrics[event_idx] = candidate_metric

    print(f"{len(filtered_metrics)} metrics remaining")

    with open(
        f"{file_label}/filter_parameters_{len(filtered_metrics)}.json", "w+"
    ) as f:
        json.dump(
            {
                "min_score": min_score,
                "max_score": max_score,
                "min_track_length": min_track_length,
                "max_track_length": max_track_length,
                "max_tracks": max_tracks,
                "min_light": min_light,
                "max_light": max_light,
                "max_z": max_z,
            },
            f,
        )

    return filtered_metrics

In [None]:
def combine_metrics():
    combined_metrics = {}

    for file in tqdm(glob.glob("**/*metrics*.pkl"), leave=True):
        folder = file.split("/")[0]
        tqdm.write(folder)
        with open(file, "rb") as f:
            metric = pickle.load(f)
            for key, value in tqdm(metric.items(), leave=False):
                combined_metrics[f"{folder}_{key}"] = value

    if not os.path.exists("combined"):
        os.makedirs("combined")

    with open("combined/metrics.pkl", "wb") as o:
        pickle.dump(combined_metrics, o)

    print("Done")

    return combined_metrics

## Plotting

In [None]:
class OOMFormatter(ScalarFormatter):
    def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True):
        self.oom = order
        self.fformat = fformat
        ScalarFormatter.__init__(self, useOffset=offset, useMathText=mathText)

    def _set_order_of_magnitude(self):
        self.orderOfMagnitude = self.oom

    def _set_format(self, vmin=None, vmax=None):
        self.format = self.fformat
        if self._useMathText:
            self.format = r"$\mathdefault{%s}$" % self.format


def set_common_ax_options(ax=None, cbar=None):
    if ax is not None:
        ax.tick_params(
            axis="both",
            direction="inout",
            which="major",
            top=True,
            right=True,
            labelsize=tick_font_size,
        )
        ax.set_axisbelow(True)
        ax.grid(alpha=0.25)
        ax.set_title(ax.get_title(), fontsize=title_font_size)
        ax.set_ylabel(ax.get_ylabel(), fontsize=label_font_size)
        ax.set_xlabel(ax.get_xlabel(), fontsize=label_font_size)
        if hasattr(ax, "get_zlabel"):
            ax.set_zlabel(ax.get_zlabel(), fontsize=label_font_size)

        if not ax.get_xscale() == "log":
            ax.xaxis.set_minor_locator(AutoMinorLocator())
            if ax.get_xlim()[1] > 1.1:
                ax.xaxis.set_major_locator(MaxNLocator(integer=(ax.get_xlim()[1] > 2)))
                if ax.get_xlim()[1] > 1e3:
                    ax.xaxis.set_major_formatter(OOMFormatter(3, "%1.1f"))

        if not ax.get_yscale() == "log":
            ax.yaxis.set_minor_locator(AutoMinorLocator())
            if ax.get_ylim()[1] > 1.1:
                ax.yaxis.set_major_locator(MaxNLocator(integer=(ax.get_ylim()[1] > 2)))
                if ax.get_ylim()[1] > 1e3:
                    ax.yaxis.set_major_formatter(OOMFormatter(3, "%1.1f"))

    if cbar is not None:
        cbar.ax.tick_params(labelsize=tick_font_size)
        cbar.set_label(cbar.ax.get_ylabel(), fontsize=label_font_size)

### Tracks

In [None]:
# Plot dQ versus X
def plot_dQ(dQ_array, event_idx, track_idx, dh, interpolate=False):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax_twinx = ax.twinx()

    fig.suptitle(
        rf"Event {event_idx} - Track {track_idx} - $dx = {round(dh,2)}$ {dh_unit}"
    )

    mean_dQ = np.mean(dQ_array[dQ_array > 0])
    non_zero_indices = np.where(dQ_array > 0)[0]

    # Check if there are non-zero values in dQ_array
    if non_zero_indices.size > 0:
        # Find the first non-zero index and add 2 indices before it
        first_index = max(non_zero_indices[0] - 2, 0)

        # Find the last non-zero index and add 2 indices after it
        last_index = min(non_zero_indices[-1] + 2, len(dQ_array))

        new_dQ_array = dQ_array.copy()[first_index:last_index]

        if interpolate:
            new_dQ_array[1:-1] = np.where(
                new_dQ_array[1:-1] == 0,
                mean_dQ,
                new_dQ_array[1:-1],
            )

        dQ_array = new_dQ_array

    ax.axhline(
        mean_dQ / dh,
        ls="--",
        c="red",
        label=rf"Mean = ${round(mean_dQ/dh,2)}$ {q_unit} {dh_unit}$^{{-1}}$",
        lw=1,
    )
    x_range = np.arange(0, len(dQ_array) * dh, dh)[: len(dQ_array)]

    ax.step(x_range, dQ_array / dh, where="mid")
    ax.set_xlabel(rf"$x$ [{dh_unit}]")
    ax.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]")

    ax_twinx.step(x_range, np.cumsum(dQ_array), color="C1", where="mid")
    ax_twinx.set_ylabel(f"Q [{q_unit}]")

    for axes in [ax, ax_twinx]:
        set_common_ax_options(ax=axes)

    h1, l1 = ax.get_legend_handles_labels()
    ax_twinx.legend(h1, l1, loc="lower center")

    ax.legend(loc="lower center")

    fig.tight_layout()
    if save_figures:
        os.makedirs(f"{file_label}/{event_idx}", exist_ok=True)
        fig.savefig(
            f"{file_label}/{event_idx}/dQ_E{event_idx}_T{track_idx}_{round(dh,2)}.pdf",
            dpi=300,
            bbox_inches="tight",
        )

In [None]:
def plot_track_stats(
    metrics,
    limit_xrange=True,
    min_score=0.5,
    empty_ratio_lims=(0, 1),
    min_entries=2,
    lognorm=True,
    profile=False,
    bins=[40, 40],
):
    df = get_track_stats(
        metrics, empty_ratio_lims=empty_ratio_lims, min_entries=min_entries
    )

    track_dQdx = df["track_dQdx"]
    track_length = df["track_length"].astype(float)
    track_score = df["track_score"].astype(float)
    track_z = df["track_z"].astype(float)
    track_cv_dQdx = track_dQdx.apply(lambda x: x.std() / x.mean()).astype(float)
    track_mean_dQdx = track_dQdx.apply(lambda x: x.mean()).astype(float)

    score_mask = (track_score >= min_score).to_numpy()
    score_bool = (1 - score_mask).sum() > 0

    print(f"Tracks with score < {min_score}: {len(track_dQdx)-sum(score_mask)}")

    print(f"Remaining tracks: {sum(score_mask)}")

    dQdx_series = pd.concat(track_dQdx.to_list())
    dQdx_series = dQdx_series[dQdx_series > 0].dropna().sort_index()
    cut_dQdx_series = pd.concat(track_dQdx[score_mask].to_list())
    cut_dQdx_series = cut_dQdx_series[cut_dQdx_series > 0].dropna().sort_index()

    print("dQ/dx stats:")
    display(dQdx_series.describe())

    # 1D histograms
    fig1 = plt.figure(figsize=(14, 6))

    ax11 = fig1.add_subplot(121)
    ax12 = fig1.add_subplot(122)

    limit = np.percentile(dQdx_series.values, 99) if limit_xrange else np.inf

    n_all11, bins_all11, patches_all11 = ax11.hist(
        dQdx_series[dQdx_series <= limit].values, bins=bins[0], label="All tracks"
    )

    n_all12, bins_all12, patches_all12 = ax12.hist(
        track_length, bins=bins[0], label="All tracks"
    )

    if score_bool:
        n11, edges11, patches11 = ax11.hist(
            cut_dQdx_series[cut_dQdx_series <= limit].values,
            bins=bins_all11,
            label=rf"Score $\geq {min_score}$",
        )
        ax12.hist(
            track_length[score_mask],
            bins=bins_all12,
            label=rf"Score $\geq {min_score}$",
        )

    bin_centers_all11 = (bins_all11[1:] + bins_all11[:-1]) / 2
    p0 = (
        np.median(cut_dQdx_series),
        np.std(bin_centers_all11),
        np.std(bin_centers_all11),
        sum(n_all11),
    )

    popt, pcov = curve_fit(
        pylandau.langau,
        bin_centers_all11,
        n_all11,
        absolute_sigma=True,
        p0=p0,
        bounds=(0, np.inf),
    )

    ax11.plot(
        fit_x := np.linspace(bins_all11[0], bins_all11[-1], 1000),
        pylandau.langau(fit_x, *popt),
        "r-",
        label=r"fit: $\mu$=%5.1f, $\eta$=%5.1f, $\sigma$=%5.1f, A=%5.1f" % tuple(popt),
    )

    ax11.set_xlabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_font_size)
    ax11.set_title(f"{len(track_dQdx)} tracks", fontsize=title_font_size)
    ax12.set_title(f"{len(track_length)} tracks", fontsize=title_font_size)

    # 2D histograms
    def hist2d(x, y, ax, bins, lognorm, fit="Log", profile=False):
        if profile:
            hist, x_edges, y_edges = np.histogram2d(x, y, bins=bins)

            y_means = [
                np.mean(y[(x >= x_edges[i]) & (x < x_edges[i + 1])])
                for i in range(len(x_edges) - 1)
            ]
            y_stds = [
                np.std(y[(x >= x_edges[i]) & (x < x_edges[i + 1])])
                for i in range(len(x_edges) - 1)
            ]
            x_values = (x_edges[1:] + x_edges[:-1]) / 2
            bin_widths = [
                (x_edges[i + 1] - x_edges[i]) / 2 for i in range(len(x_edges) - 1)
            ]
            ax.errorbar(x_values, y_means, yerr=y_stds, xerr=bin_widths, fmt="o")

        else:
            hist2d = ax.hist2d(
                x,
                y,
                bins=bins,
                cmin=1,
                norm=LogNorm() if lognorm else None,
            )
        if fit == "Log":
            x_fit = np.log(x)
        elif fit == "Linear":
            x_fit = x
        else:
            return

        try:
            fit_p = np.polyfit(x_fit, y, 1)
        except:
            if fit == "Log":
                x_fit = x  # Try linear fit as a fallback
                fit = "Linear"
            elif fit == "Linear":
                x_fit = np.log(x)  # Try log fit as a fallback
                fit = "Log"
            try:
                fit_p = np.polyfit(x_fit, y, 1)
            except:
                return

        p = np.poly1d(fit_p)
        x_plot = np.arange(min(x), max(x), 1)

        if fit == "Log":
            y_plot = p(np.log(x_plot))
        else:
            y_plot = p(x_plot)

        ax.plot(x_plot, y_plot, c="salmon", ls="-", label=f"{fit} fit")

    fig2 = plt.figure(figsize=(14, 6))
    ax21 = fig2.add_subplot(121)
    ax22 = fig2.add_subplot(122)

    fig2.suptitle(f"{len(track_dQdx)} tracks", fontsize=title_font_size)
    ax21.set_ylabel(
        rf"Mean $dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_font_size
    )
    ax21.set_title("Mean dQ/dx vs. Track length", fontsize=title_font_size)
    ax22.set_ylabel(rf"$dQ/dx$ CV", fontsize=label_font_size)
    ax22.set_title("dQ/dx CV vs. Track length", fontsize=title_font_size)

    hist2d21 = hist2d(
        track_length, track_mean_dQdx, ax21, bins, lognorm, fit="Log", profile=profile
    )

    hist2d22 = hist2d(
        track_length, track_cv_dQdx, ax22, bins, lognorm, fit="Linear", profile=profile
    )

    fig4 = plt.figure(figsize=(7, 6))
    ax4 = fig4.add_subplot(111)
    ax4.set_ylabel(f"Fit score")
    ax4.set_title("Fit score vs. Track length")

    hist2d4 = hist2d(
        track_length,
        track_score,
        ax4,
        [bins[0], 40],
        lognorm,
        fit="Log",
        profile=profile,
    )

    fig5 = plt.figure(figsize=(7 + 7 * score_bool, 6))
    ax51 = fig5.add_subplot(111 + 10 * score_bool)
    ax51.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_font_size)
    ax51.set_xlabel(rf"Residual range [{dh_unit}]", fontsize=label_font_size)
    ax51.set_title(rf"{len(track_dQdx)} tracks", fontsize=title_font_size)

    hist2d(
        dQdx_series.index,
        dQdx_series,
        ax51,
        bins,
        lognorm,
        fit="Linear",
        profile=profile,
    )

    fig6 = plt.figure(figsize=(7 + 7 * score_bool, 6))
    ax61 = fig6.add_subplot(111 + 10 * score_bool)
    ax61.set_ylabel(
        rf"Mean $dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_font_size
    )
    ax61.set_xlabel(rf"Mean anode distance [{z_unit}]", fontsize=label_font_size)
    ax61.set_title(rf"{len(track_z)} tracks", fontsize=title_font_size)

    hist2d(track_z, track_mean_dQdx, ax61, bins, lognorm, fit="Linear", profile=profile)

    # fig7 = plt.figure(figsize=(7 + 7 * score_bool, 6))
    # ax71 = fig7.add_subplot(111 + 10 * score_bool)
    # ax71.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_size)
    # ax71.set_xlabel(rf"Anode distance [{z_unit}]", fontsize=label_size)
    # ax71.set_title(rf"{len(track_z)} tracks", fontsize=title_size)

    # dq_z_series = pd.concat(dq_z_list)
    # dq_z_series = dq_z_series[dq_z_series > 0].dropna().sort_index()

    # hist2d(
    #     dq_z_series.index,
    #     dq_z_series,
    #     ax71,
    #     bins,
    #     lognorm,
    #     fit="Linear",
    #     profile=profile,
    # )

    # def exp_decay(x, tau, init):
    #     return init * np.exp(-x / tau)

    # popt, pcov = curve_fit(
    #     exp_decay,
    #     track_z,
    #     track_mean_dQdx,
    #     p0=[23, 5000],
    # )

    # plt.plot(
    #     track_z,
    #     exp_decay(track_z, *popt),
    #     "r-",
    #     label="fit: tau=%5.3f, init=%5.3f" % tuple(popt),
    # )

    # print(popt)

    axes = [ax11, ax12, ax21, ax22, ax4, ax51, ax61]  # , ax71]
    figs = [fig1, fig2, fig4, fig5, fig6]  # , fig7]

    if score_bool:
        # 2D histograms after RANSAC score cut
        fig3 = plt.figure(figsize=(14, 6))
        ax31 = fig3.add_subplot(121)
        ax32 = fig3.add_subplot(122)
        ax31.set_ylabel(rf"Mean $dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]")
        ax31.set_title(rf"Mean dQ/dx vs. Track length")
        ax32.set_ylabel(rf"$dQ/dx$ CV")
        ax32.set_title(rf"dQ/dx CV vs. Track length")
        fig3.suptitle(
            rf"Fit score $\geq {min_score}$ ({round(sum(score_mask)/len(score_mask)*100)}% of tracks)",
            fontsize=title_font_size,
        )

        figs.append(fig3)
        axes.extend([ax31, ax32])

        hist2d31 = hist2d(
            track_length[score_mask],
            track_mean_dQdx[score_mask],
            ax31,
            bins,
            lognorm,
            fit="Log",
            profile=profile,
        )

        hist2d32 = hist2d(
            track_length[score_mask],
            track_cv_dQdx[score_mask],
            ax32,
            bins,
            lognorm,
            fit="Linear",
            profile=profile,
        )

        ax52 = fig5.add_subplot(122)
        axes.append(ax52)
        ax52.set_ylabel(
            rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_font_size
        )
        ax52.set_xlabel(rf"Residual range [{dh_unit}]", fontsize=label_font_size)
        ax52.set_title(
            rf"Fit score $\geq {min_score}$ ({round(sum(score_mask)/len(score_mask)*100)}% of tracks)",
            fontsize=title_font_size,
        )
        fig5.suptitle("dQ/dx vs. Residual range", fontsize=title_font_size)

        hist2d(
            cut_dQdx_series.index,
            cut_dQdx_series,
            ax52,
            bins,
            lognorm,
            fit="Linear",
            profile=profile,
        )

        ax62 = fig6.add_subplot(122)
        axes.append(ax62)
        ax62.set_ylabel(
            rf"Mean $dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_font_size
        )
        ax62.set_xlabel(rf"Mean anode distance [{z_unit}]", fontsize=label_font_size)
        ax62.set_title(
            rf"Fit score $\geq {min_score}$ ({round(sum(score_mask)/len(score_mask)*100)}% of tracks)",
            fontsize=title_font_size,
        )
        fig6.suptitle("Mean dQ/dx vs. Mean anode distance", fontsize=title_font_size)

        hist2d(
            track_z[score_mask],
            track_mean_dQdx[score_mask],
            ax62,
            bins,
            lognorm,
            fit="Linear",
            profile=profile,
        )

        # ax72 = fig7.add_subplot(122)
        # axes.append(ax72)
        # ax72.set_ylabel(rf"$dQ/dx$ [{q_unit} {dh_unit}$^{{-1}}$]", fontsize=label_size)
        # ax72.set_xlabel(rf"Anode distance [{z_unit}]", fontsize=label_size)
        # ax72.set_title(
        #     rf"Fit score $\geq {min_score}$ ({round(sum(score_mask)/len(score_mask)*100)}% of tracks)", fontsize=title_size
        # )
        # fig7.suptitle("dQ/dx vs. Anode distance", fontsize=title_size)

        # cut_dq_z_series = pd.concat(
        #     [series for i, series in enumerate(dq_z_list) if score_mask[i]]
        # )
        # cut_dq_z_series = cut_dq_z_series[cut_dq_z_series > 0].dropna().sort_index()

        # hist2d(
        #     cut_dq_z_series.index,
        #     cut_dq_z_series,
        #     ax72,
        #     bins,
        #     lognorm,
        #     fit="Linear",
        #     profile=profile,
        # )

    max_track_legth = np.sqrt(detector_x**2 + detector_y**2 + detector_z**2)
    max_track_legth_xy = np.sqrt(detector_x**2 + detector_y**2)
    print("Max possible track length", round(max_track_legth, 2), "mm")
    print("Max possible track length on xy plane", round(max_track_legth_xy, 2), "mm")
    print("Max possible vertical track length", detector_y, "mm")

    for ax in axes:
        if ax == ax11 or ax == ax12:
            ax.set_ylabel("Counts")
        if ax != ax11:
            if not (
                ax == ax51
                or ax == ax61
                # or ax == ax71
                or (score_bool and (ax == ax52 or ax == ax62))  # or ax == ax72))
            ):
                ax.set_xlabel(f"Track length [{dh_unit}]")
            if max(track_length) > detector_y:
                ax.axvline(detector_y, c="g", ls="--", label="Max vertical length")
            if max(track_length) > max_track_legth_xy:
                ax.axvline(
                    max_track_legth_xy, c="orange", ls="--", label=r"Max length in $xy$"
                )
            if max(track_length) > max_track_legth:
                ax.axvline(max_track_legth, c="r", ls="--", label="Max length")

            if ax != ax12:
                if limit_xrange:
                    xlim = ax.get_xlim()
                    ax.set_xlim(xlim[0], min(max_track_legth + 10, xlim[1]))

                cbar = ax.get_figure().colorbar(ax.collections[0])
                cbar.set_label("Counts" + (" [log]" if lognorm else ""))
                set_common_ax_options(cbar=cbar)
        if not (not score_bool and ax == ax11):
            ax.legend(loc="lower right" if ax == ax4 else "upper right")

        set_common_ax_options(ax=ax)

    for fig in figs:
        fig.tight_layout()

    if save_figures:
        entries = len(track_dQdx)
        fig1.savefig(
            f"{file_label}/track_stats_1D_hist_{file_label}_{entries}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig2.savefig(
            f"{file_label}/track_stats_2D_hist_{file_label}_{entries}{'_profile' if profile else ''}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig4.savefig(
            f"{file_label}/track_stats_score_{file_label}_{entries}{'_profile' if profile else ''}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig5.savefig(
            f"{file_label}/track_stats_dQdx_{file_label}_{entries}{'_profile' if profile else ''}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig6.savefig(
            f"{file_label}/track_stats_dQdx_z_{file_label}_{entries}{'_profile' if profile else ''}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        # fig7.savefig(
        #     f"{file_label}/track_stats_dQ_z_{file_label}_{entries}{'_profile' if profile else ''}.pdf",
        #     dpi=300,
        #     bbox_inches="tight",
        # )
        if score_bool:
            fig3.savefig(
                f"{file_label}/track_stats_2D_hist_cut_{file_label}_{entries}{'_profile' if profile else ''}.pdf",
                dpi=300,
                bbox_inches="tight",
            )
    return df

### Tracks and light

In [None]:
def plot_light_geo_stats(
    metrics,
    limit_xrange=False,
    light_max=None,
    min_count_ratio=0.99,
    max_std_ratio=0.2,
    single_track=True,
    lognorm=True,
):
    sipm_distance = []
    sipm_angle = []
    sipm_light = []

    for metric in metrics.values():
        if single_track and len(metric.keys()) > 1 + non_track_keys:
            continue
        for track_idx, values in metric.items():
            if not isinstance(track_idx, str) and track_idx > 0:
                sipms = values["SiPM"]
                for light in sipms.values():
                    sipm_distance.append(light["distance"])
                    sipm_angle.append(light["angle"])
                    sipm_light.append(light[light_variable])

    sipm_distance = np.array(sipm_distance)
    sipm_angle = np.array(sipm_angle)
    sipm_light = np.array(sipm_light)

    max_distance = np.sqrt(detector_x**2 + detector_y**2 + detector_z**2)
    print("Max possible distance to track", round(max_distance, 2), "mm")
    print("Drift distance", detector_z, "mm")

    sipm_distance = sipm_distance[~np.isnan(sipm_light) & (sipm_light > 0)]
    sipm_angle = sipm_angle[~np.isnan(sipm_light) & (sipm_light > 0)]
    sipm_light = sipm_light[~np.isnan(sipm_light) & (sipm_light > 0)]

    fig1 = plt.figure(figsize=(7, 6))
    ax1 = fig1.add_subplot(111)

    vline = max_std(
        sipm_light,
        ax1,
        array_max=light_max,
        max_std_ratio=max_std_ratio,
        min_count_ratio=min_count_ratio,
    )
    bins = vline

    ax1.set_xlabel("Max light integral")
    ax1.set_ylabel("Normalized value")

    fig1.suptitle("Light integral distribution")

    sipm_distance = sipm_distance[(sipm_light <= vline)]
    sipm_angle = sipm_angle[(sipm_light <= vline)]
    sipm_light = sipm_light[(sipm_light <= vline)]
    sipm_angle = np.degrees(sipm_angle)

    fig2, ax2 = plt.subplots(1, 1, figsize=(7, 6))
    axes = np.array(ax2)

    n2, x_edges2, y_edges2, image2 = ax2.hist2d(
        sipm_distance,
        sipm_angle,
        weights=abs(sipm_light),
        bins=bins,
        cmin=1,
        norm=LogNorm() if lognorm else None,
    )

    def triangle_calc(height, base):
        # Calculate the angle θ
        return np.degrees(2 * np.arctan((base / 2) / height))

    def inverse_triangle_calc(angle, height):
        # Calculate the base
        return 2 * height * np.tan(np.radians(angle / 2))

    filtered_centers_x2, filtered_centers_y2, cluster_labels = cluster_hot_bins(
        0.35, n2, x_edges2, y_edges2, scale=(3, 1), eps=8
    )

    # Create a LinearSegmentedColormap from the gradient
    salmon_cmap = LinearSegmentedColormap.from_list(
        "salmon_cmap",
        [
            to_rgba("darkred"),
            to_rgba("salmon"),
        ],
        N=np.unique(cluster_labels).size - 1,
    )

    x2 = np.arange(min(sipm_distance), max(sipm_distance), 1)
    for cluster_label in np.unique(cluster_labels):
        if cluster_label == -1:
            continue
        fit2 = parameters, cov = curve_fit(
            triangle_calc,
            filtered_centers_x2[cluster_labels == cluster_label],
            filtered_centers_y2[cluster_labels == cluster_label],
            p0=[
                inverse_triangle_calc(sipm_angle.mean(), sipm_distance.mean()),
            ],
        )
        print(f"Fit mean track length: {parameters[0]}")

        ax2.plot(
            x2,
            triangle_calc(x2, *parameters),
            ls="-",
            c=salmon_cmap(cluster_label),
            label=rf"Fit: {parameters[0]:.0f}{dh_unit} track length",
        )
    ax2.set_ylabel(f"SiPM opening angle to track centre [deg]")
    cbar2 = plt.colorbar(image2)
    cbar2.set_label(rf"Light {light_variable} [{light_unit} - log]")
    set_common_ax_options(cbar=cbar2)

    fig2.suptitle(f"SiPM level light distribution - {len(sipm_light)} entries")

    fig3, axes3 = plt.subplots(1, 2, figsize=(14, 6))
    axes = np.append(axes, axes3)

    hist30 = axes3[0].hist2d(
        sipm_distance,
        sipm_light,
        bins=bins,
        cmin=1,
        norm=LogNorm() if lognorm else None,
    )
    axes3[0].set_ylabel(f"Light_{light_variable} [{light_unit}]")
    cbar30 = plt.colorbar(hist30[3])
    cbar30.set_label(rf"Counts [Log]")
    set_common_ax_options(cbar=cbar30)

    hist31 = axes3[1].hist2d(
        sipm_angle, sipm_light, bins=bins, cmin=1, norm=LogNorm() if lognorm else None
    )
    axes3[1].set_xlabel(f"SiPM opening angle to track [deg]")
    axes3[1].set_ylabel(f"Light {light_variable} [{light_unit}]")
    cbar31 = plt.colorbar(hist31[3])
    cbar31.set_label(rf"Counts [Log]")
    set_common_ax_options(cbar=cbar31)

    fig3.suptitle(f"SiPM level light distribution - {len(sipm_light)} entries")

    for ax in axes:
        set_common_ax_options(ax=ax)
        if ax == ax2 or ax == axes3[0]:
            if limit_xrange:
                xlim = ax.get_xlim()
                ax.set_xlim(xlim[0], min(max_distance + 10, xlim[1]))
            if max(sipm_distance) > detector_z:
                ax.axvline(detector_z, c="orange", ls="--", label="Drift distance")
            if max(sipm_distance) > max_distance:
                ax.axvline(max_distance, c="r", ls="--", label="Max distance")

            ax.set_xlabel(f"Distance from track centre [{dh_unit}]")

            ax.legend()

    for fig in [fig1, fig2, fig3]:
        fig.tight_layout()

    if save_figures:
        entries = len(sipm_light)
        fig1.savefig(
            f"{file_label}/light_geo_optimization_{file_label}_{entries}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig2.savefig(
            f"{file_label}/light_geo_2D_hist_{file_label}_{entries}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig3.savefig(
            f"{file_label}/light_geo_1D_hist_{file_label}_{entries}.pdf",
            dpi=300,
            bbox_inches="tight",
        )

In [None]:
def plot_light_fit_stats(metrics):
    cosine_df = pd.DataFrame(columns=["cosine", "threshold", "Light", "Charge"])
    for event, metric in metrics.items():
        if "Fit_line" not in metric["SiPM"]:
            continue
        light_track = metric["SiPM"]["Fit_line"]
        if len(metric.keys()) == non_track_keys + 1:
            for idx, track in metric.items():
                if isinstance(idx, str):
                    continue
                charge_track = track["Fit_line"]
                cross = charge_track.direction.cross(light_track.direction)
                cosine = abs(
                    charge_track.direction.cosine_similarity(light_track.direction)
                )
                cosine_df.loc[event] = [
                    cosine,
                    metric["SiPM"]["Fit_threshold"],
                    charge_track.direction,
                    light_track.direction,
                ]
    entries = len(cosine_df)
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(111)
    for i in range(0, min(int(cosine_df["threshold"].max()), 80), 10):
        data = cosine_df[cosine_df["threshold"] > i]
        if len(data) > 0.005 * len(cosine_df):
            data.hist(
                "cosine",
                ax=ax,
                bins=np.linspace(0, 1, 11),
                label=f"Threshold: {i} {light_unit} - {len(cosine_df[cosine_df['threshold']>i])} entries",
            )
    ax.set(
        title="Cosine similarity between charge and light tracks",
        xlabel="Cosine similarity",
        ylabel="Counts",
    )
    set_common_ax_options(ax=ax)
    ax.legend()
    fig.tight_layout()
    if save_figures:
        fig.savefig(
            f"{file_label}/light_fit_{file_label}_{entries}.pdf",
            dpi=300,
            bbox_inches="tight",
        )

In [None]:
def plot_voxel_data(metrics, bins=50, log=(False, False, False), lognorm=False):
    z = []
    q = []
    l = []
    for i, metric in metrics.items():
        # if not metric["SiPM"]:
        #     continue
        for key, sipm in metric["SiPM"].items():
            if isinstance(key, tuple):
                q.append(sipm["charge_q"])
                z.append(sipm["charge_z"])
                l.append(sipm["integral"])

    z = np.array(z)
    q = np.array(q)
    l = np.array(l)

    max_light = max_std(
        l,
        ax=None,
        min_count_ratio=0.98,
        max_std_ratio=0.1,
    )

    max_charge = np.percentile(q, 99)
    max_z = np.percentile(z, 99)

    mask = (l < max_light) & (l > 0) & (q < max_charge) & (q > 0) & (z < max_z)

    z = z[mask]
    q = q[mask]
    l = l[mask]

    if log[0]:
        bins_z = np.exp(np.linspace(0, np.log(max(z)), bins))
    else:
        bins_z = bins
    if log[1]:
        bins_q = np.exp(np.linspace(np.log(min(q)), np.log(max(q)), bins))
    else:
        bins_q = bins
    if log[2]:
        bins_l = np.exp(np.linspace(np.log(min(l)), np.log(max(l)), bins))
    else:
        bins_l = bins

    fig1 = plt.figure(figsize=(10, 6))
    ax1 = fig1.add_subplot(111)

    hist = ax1.hist2d(
        z, l, bins=[bins_z, bins_l], cmin=1, norm=LogNorm() if lognorm else None
    )
    cbar1 = plt.colorbar(hist[3])
    ax1.set_title("Light vs. z distance")
    cbar1.set_label(rf"Counts - log" if lognorm else rf"Counts")
    set_common_ax_options(cbar=cbar1)

    # def fit_func(x, a, b):
    #     return np.exp(-(x - a) / b)

    # params, cov = curve_fit(fit_func, z, l)

    # print("Exponential fit:", params)

    # x = np.linspace(min(z), max(z), 1000)
    # ax1.plot(x, fit_func(x, *params), c="r", ls="--", label="Exponential fit")
    # ax1.legend()

    fig2, axes2 = plt.subplots(3, 1, figsize=(10, 18))

    figs = [fig1, fig2]
    axes = [ax1, *axes2]

    hist21 = axes2[0].hist2d(
        z, q, bins=[bins_z, bins_q], cmin=1, norm=LogNorm() if lognorm else None
    )
    cbar21 = plt.colorbar(hist21[3])
    axes2[0].set_title(rf"Charge vs. Anode distance")
    cbar21.set_label(rf"Counts - log" if lognorm else rf"Counts")
    set_common_ax_options(cbar=cbar21)

    hist22 = axes2[1].hist2d(
        q, l, bins=[bins_q, bins_l], cmin=1, norm=LogNorm() if lognorm else None
    )
    cbar22 = plt.colorbar(hist22[3])
    axes2[1].set(
        title=rf"Light vs. Charge",
        xlabel=rf"Charge [{q_unit} - log]" if log[1] else rf"Charge [{q_unit}]",
        xscale="log" if log[1] else "linear",
    )
    cbar22.set_label(rf"Counts - log" if lognorm else rf"Counts")
    set_common_ax_options(cbar=cbar22)

    hist23 = axes2[2].hist2d(
        z,
        q,
        weights=l,
        bins=[bins_z, bins_q],
        cmin=1,
        norm=LogNorm() if lognorm else None,
    )
    cbar23 = plt.colorbar(hist23[3])
    axes2[2].set_title(rf"Charge vs. Anode distance with light weights")
    cbar23.set_label(
        (
            rf"Light {light_variable} [{light_unit} - log]"
            if lognorm
            else rf"Light {light_variable} [{light_unit}]"
        )
    )
    set_common_ax_options(cbar=cbar23)

    for idx, ax in enumerate(axes):
        if not idx == 2:
            ax.set_xlabel(
                f"Anode distance [{z_unit} - log]"
                if log[0]
                else f"Anode distance [{z_unit}]"
            )
            ax.set_xscale("log" if log[0] else "linear")
        if idx % 2 == 0:
            ax.set_ylabel(
                rf"Light {light_variable} [{light_unit} - log]"
                if log[2]
                else rf"Light {light_variable} [{light_unit}]"
            )
            ax.set_yscale("log" if log[2] else "linear")
        else:
            ax.set_ylabel(
                rf"Charge [{q_unit} - log] " if log[1] else rf"Charge [{q_unit}]"
            )
            ax.set_yscale("log" if log[1] else "linear")

        set_common_ax_options(ax=ax)

    for fig in figs:
        fig.tight_layout()

    if save_figures:
        events = sum(mask)
        fig1.savefig(
            f"{file_label}/voxel_light_vs_z_{file_label}_{events}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig2.savefig(
            f"{file_label}/voxel_charge_vs_z_hist_{file_label}_{events}.pdf",
            dpi=300,
            bbox_inches="tight",
        )

### Light

In [None]:
def light_vs_charge(
    metrics,
    light_max=None,
    min_count_ratio=0.99,
    max_std_ratio=0.5,
    clusters=None,
    bin_density=1,
    log=(True, False),
    p0=True,
):
    if isinstance(log, bool):
        log = [log, log]

    light_array = []
    charge_array = []
    for event, metric in metrics.items():
        light_array.append(metric["Total_light"])
        charge_array.append(metric["Total_charge"])

    light_array = np.array(light_array)
    charge_array = np.array(charge_array)

    mask = (charge_array > 0) & (light_array > 0)
    charge_array = charge_array[mask]
    light_array = light_array[mask]

    fig1 = plt.figure(figsize=(8, 6))
    ax1 = fig1.add_subplot(111)
    ax1.set_xlabel("Max light integral")
    ax1.set_ylabel("Normalized value")
    fig1.suptitle("Light integral distribution")
    vline = max_std(
        light_array,
        ax1,
        array_max=light_max,
        min_count_ratio=min_count_ratio,
        max_std_ratio=max_std_ratio,
    )
    bins = int(vline / 20)

    charge_array = charge_array[(light_array <= vline)]
    light_array = light_array[(light_array <= vline)]

    def hist2d(x, y, ax, bins, log):
        if log:
            log_bins_x = np.exp(np.linspace(np.log(min(x) - 1), np.log(max(x)), bins))
            log_bins_y = np.exp(np.linspace(np.log(min(y)), np.log(max(y)), bins))
            bins = [log_bins_x, log_bins_y]
            ax.set_xscale("log")
            ax.set_yscale("log")

        n, x_edges, y_edges, image = ax.hist2d(x, y, bins=bins, cmin=1)

        # fit peak with curve_fit
        # @latexify.function(use_math_symbols=True)
        def fit_function(xy, amplitude, xo, yo, sigma_x, sigma_y):
            x, y = xy
            gauss = (
                amplitude
                * np.exp(-0.5 * ((x - xo) / sigma_x) ** 2)
                / (sigma_x * np.sqrt(2 * np.pi))
                * np.exp(-0.5 * ((y - yo) / sigma_y) ** 2)
                / (sigma_y * np.sqrt(2 * np.pi))
            )
            return gauss

        try:
            bin_peaks = n.ravel(order="F")
            bin_peaks[np.isnan(bin_peaks)] = 0
            x_bin_centers = 0.5 * (x_edges[1:] + x_edges[:-1])
            y_bin_centers = 0.5 * (y_edges[1:] + y_edges[:-1])
            x_bin_centers, y_bin_centers = np.array(
                np.meshgrid(x_bin_centers, y_bin_centers)
            )
            x_bin_centers = x_bin_centers.ravel()
            y_bin_centers = y_bin_centers.ravel()

            parameters, cov_matrix = curve_fit(
                fit_function,
                (
                    x_bin_centers / max(x_bin_centers),
                    y_bin_centers / max(y_bin_centers),
                ),
                bin_peaks,
                bounds=(0, np.inf),
            )
            plot_mesh = np.array(
                np.meshgrid(
                    np.linspace(min(x) / max(x), 1, len(x_bin_centers) * 5),
                    np.linspace(min(y) / max(y), 1, len(y_bin_centers) * 5),
                )
            )

            z_plot = fit_function(plot_mesh, *parameters)

            # print(latexify.get_latex(fit_function))
            print("Parameters:")
            print(
                "\n".join(
                    [
                        f"{name}: {value}"
                        for name, value in zip(
                            [
                                "amplitude",
                                "mu_x",
                                "mu_y",
                                "sigma_x",
                                "sigma_y",
                                "theta",
                            ],
                            parameters,
                        )
                    ]
                ),
                "\n",
            )
            contour = ax.contour(
                plot_mesh[0] * max(x),
                plot_mesh[1] * max(y),
                z_plot,
                norm="log",
                cmap="autumn",
                linewidths=1,
                levels=[
                    fit_function(
                        (
                            parameters[1] - 3 * parameters[3],
                            parameters[2] - 3 * parameters[4],
                        ),
                        *parameters,
                    ),
                    fit_function(
                        (
                            parameters[1] - 2 * parameters[3],
                            parameters[2] - 2 * parameters[4],
                        ),
                        *parameters,
                    ),
                    fit_function(
                        (
                            parameters[1] - 1 * parameters[3],
                            parameters[2] - 1 * parameters[4],
                        ),
                        *parameters,
                    ),
                ],
            )
            fmt = {}
            strs = [r"$3\sigma$", r"$2\sigma$", r"$1\sigma$"]
            for l, s in zip(contour.levels, strs):
                fmt[l] = s

            ax.clabel(contour, contour.levels, inline=True, fmt=fmt, fontsize=10)
        except:
            print("Fit failed\n")

        ax.set_xlabel(f"Total charge [{q_unit}{' - Log' if log else ''}]")
        ax.set_ylabel(
            f"Total Light {light_variable} [{light_unit}{' - Log' if log else ''}]"
        )
        cbar = plt.colorbar(image)
        cbar.set_label(rf"Counts")
        set_common_ax_options(ax=ax)

        return n, x_edges, y_edges, image

    def hist1d(array, ax, bin_density, log, p0):

        upper_bound = np.percentile(array, 95)
        array = array[array < upper_bound]
        if log:
            ax.set_xscale("log")
            bins = np.exp(
                np.linspace(
                    np.log(min(array) - 1), np.log(upper_bound), int(50 * bin_density)
                )
            )
        else:
            bins = np.linspace(min(array), upper_bound, int(50 * bin_density))

        n, edges, patches = ax.hist(
            array, bins=bins, fill=False, ec="C0", histtype="bar"
        )

        peak_y = n.max()
        peak_x = edges[n.argmax() : n.argmax() + 1].mean()
        mean = array.mean()
        median = np.median(array)

        set_common_ax_options(ax=ax)
        if not log:
            bin_centers = 0.5 * (edges[1:] + edges[:-1])
            bin_centers = bin_centers[n > 0]
            bin_peaks = n[n > 0]

            try:
                if p0 is True:
                    p0 = (
                        peak_x,
                        np.std(array) / 20,
                        np.std(array) / 10,
                        peak_y,
                    )

                parameters, cov_matrix = curve_fit(
                    pylandau.langau,
                    bin_centers,
                    bin_peaks,
                    absolute_sigma=True,
                    p0=p0,
                    bounds=(0, np.inf),
                )

                x_plot = np.linspace(min(array), max(array), len(bins) * 10)
                y_plot = pylandau.langau(x_plot, *parameters)

                # print(latexify.get_latex(fit_function))
                print("Parameters:")
                print(
                    "\n".join(
                        [
                            f"{name}: {value}"
                            for name, value in zip(
                                ["mu", "eta", "sigma", "A"], parameters
                            )
                        ]
                    ),
                    "\n",
                )
                ax.plot(
                    x_plot,
                    y_plot,
                    "m",
                    label=rf"Fit ($\mu={parameters[1]*min(bin_centers):.2f}$)",
                )
            except:
                print("Fit failed\n")

        ax.axvline(peak_x, c="r", ls="--", label=f"Peak: {peak_x:.2f}")
        ax.axvline(median, c="orange", ls="--", label=f"Median: {median:.2f}")
        ax.axvline(mean, c="g", ls="--", label=f"Mean: {mean:.2f}")

        ax.set_ylabel(f"Counts")
        # ax.set_xlim(min(array) - 2, edges3[(edges3 < upper_bound).argmin()])
        ax.set_ylim(0, peak_y * 1.1)
        ax.legend()

    fig2 = plt.figure(figsize=(8, 6))
    ax2 = plt.subplot(111)
    n2d, xedges2d, yedges2d, image2d = hist2d(
        charge_array, light_array, ax2, bins, log[0]
    )
    fig2.suptitle(f"Event level Light vs. Charge - {len(charge_array)} events")

    fig3 = plt.figure(figsize=(8, 6))
    ax3 = plt.subplot(111)
    ratio = charge_array / light_array
    hist1d(ratio, ax3, bin_density, log[1], p0)
    ax3.set_xlabel(
        f"Event total charge / Light [{q_unit}/{light_unit}{' - Log' if log[1] else ''}]"
    )
    fig3.suptitle(f"Event level Charge vs. Light - {len(charge_array)} events")

    fig4, axes4 = plt.subplots(1, 2, figsize=(14, 6))

    hist1d(charge_array, axes4[0], bin_density, log[1], p0)
    axes4[0].set_xlabel(f"Event total charge [{q_unit}{' - Log' if log[1] else ''}]")
    hist1d(light_array, axes4[1], bin_density, log[1], p0)
    axes4[1].set_xlabel(f"Event total Light [{light_unit}{' - Log' if log[1] else ''}]")
    fig4.suptitle(f"Event level Charge and Light - {len(charge_array)} events")

    figs = [fig1, fig2, fig3, fig4]
    if clusters is None:
        if log[0]:
            temp_x, temp_y, cluster_labels = cluster_hot_bins(
                0.3, n2d, np.log(xedges2d), np.log(yedges2d), eps=1
            )
        else:
            temp_x, temp_y, cluster_labels = cluster_hot_bins(
                0.3,
                n2d,
                xedges2d,
                yedges2d,
                eps=2,
                scale=(np.diff(xedges2d).mean(), np.diff(yedges2d).mean()),
            )
        clusters = np.unique(cluster_labels).size - 1

    if clusters > 1:
        data = pd.DataFrame(np.log(charge_array), np.log(light_array))
        kmeans = KMeans(n_clusters=clusters)
        kmeans.fit(data)
        labels = kmeans.predict(data) + 1
        populations = np.unique(labels)

        fig22, axes22 = plt.subplots(
            len(populations), 1, figsize=(8, 6 * len(populations))
        )

        fig32, axes32 = plt.subplots(
            len(populations), 1, figsize=(10, 6 * len(populations))
        )

        figs.extend([fig22, fig32])

        for idx, label in enumerate(populations):
            hist2d(
                charge_array[labels == label],
                light_array[labels == label],
                axes22[idx],
                bins,
                log[0],
            )
            axes22[idx].set_title(f"Population {label} - {sum(labels == label)} events")

            print(f"population {label}:")

            ratio = charge_array[labels == label] / light_array[labels == label]
            hist1d(ratio, axes32[idx], bin_density, log[1], p0)
            axes32[idx].set_xlabel(
                f"Event total charge / Light [{q_unit}/{light_unit}{' - Log' if log[1] else ''}]"
            )
            axes32[idx].set_title(f"Population {label} - {sum(labels == label)} events")

    for fig in figs:
        fig.tight_layout()

    if save_figures:
        events = len(ratio)
        fig1.savefig(
            f"{file_label}/light_vs_charge_optmization_{file_label}_{events}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig2.savefig(
            f"{file_label}/light_vs_charge_2D_hist_{file_label}_{events}.pdf",
            dpi=300,
            bbox_inches="tight",
        )
        fig3.savefig(
            f"{file_label}/light_vs_charge_ratio_{file_label}_{events}.pdf",
            dpi=300,
            bbox_inches="tight",
        )

# File loading

In [None]:
# Load metrics from pickle file
if not os.path.isfile(metrics_file):
    metrics = combine_metrics()
else:
    with open(metrics_file, "rb") as f:
        metrics = pickle.load(f)

In [None]:
cached_metrics = metrics
if filter_file is not None and os.path.isfile(filter_file):
    with open(filter_file, "r") as f:
        filter_settings = json.load(f)
        metrics = filter_metrics(
            metrics,
            **filter_settings,
        )
        globals().update(filter_settings)
else:
    metrics = filter_metrics(
        metrics,
        min_score=score_cutoff,
        max_score=max_score,
        min_track_length=min_track_length,
        max_track_length=max_track_length,
        max_tracks=max_tracks,
        max_light=max_light,
        min_light=min_light,
        max_z=max_z,
    )

# Analysis

In [None]:
metrics

## dQ/dx

### Statistical plots

In [None]:
plot_track_stats(
    metrics,
    limit_xrange=True,
    empty_ratio_lims=(0.0, 1),
    lognorm=False,
    min_entries=2,
    min_score=0.5,
    bins=[40, 40],
    profile=False,
)
if show_figures:
    plt.show()
else:
    plt.close("all")

In [None]:
plot_track_stats(
    metrics,
    limit_xrange=True,
    empty_ratio_lims=(0.0, 1),
    lognorm=True,
    min_entries=2,
    min_score=0.5,
    bins=[40, 40],
    profile=True,
)
plt.close("all")

### Individual plots

In [None]:
for event_idx in tqdm(individual_plots, leave=False):
    if event_idx in metrics:
        for track_idx, values in metrics[event_idx].items():
            if not isinstance(track_idx, str) and track_idx > 0:
                dQ_array = values["dQ"]
                dh = values["dx"]
                plot_dQ(dQ_array, event_idx, track_idx, dh, interpolate=False)

                if show_figures:
                    plt.show()
                else:
                    plt.close("all")

## Light to track Geometry

In [None]:
warnings.filterwarnings("ignore", category=Warning, module="numpy")

In [None]:
plot_light_geo_stats(
    metrics,
    single_track=True,
    limit_xrange=True,
    lognorm=True,
)
if show_figures:
    plt.show()
else:
    plt.close("all")

In [None]:
warnings.filterwarnings("default", category=Warning)

## Total Light vs. Charge

In [None]:
warnings.filterwarnings("ignore", category=Warning, module="numpy")

In [None]:
light_vs_charge(
    metrics,
    clusters=1,
    bin_density=1,
    log=(True, False),
    p0=True,
)

if show_figures:
    plt.show()
else:
    plt.close("all")

In [None]:
# Reset the warning filter (optional)
warnings.filterwarnings("default", category=Warning)

## Voxelized Light vs. Charge

In [None]:
warnings.filterwarnings("ignore", category=Warning, module="numpy")

In [None]:
plot_voxel_data(metrics, log=(False, False, False), lognorm=False)  # (Z, Charge, Light)

if show_figures:
    plt.show()
else:
    plt.close("all")

In [None]:
warnings.filterwarnings("default", category=Warning)

## Light fit vs. Charge fit

In [None]:
plot_light_fit_stats(
    metrics,
)

# Other

## Track angles

In [None]:
# Cos^2 of track angles
def plot_track_angles(metrics):
    cos_x = []
    cos_y = []
    cos_z = []
    vectors = []
    for idx, metric in metrics.items():
        for track_idx, track in metric.items():
            if type(track) is dict:
                if "RANSAC_score" in track and track["RANSAC_score"] < 0.5:
                    continue
                if "Fit_norm" in track and track["Fit_norm"] < 1:
                    continue
                if "Fit_line" in track:
                    cos_x.append(
                        track["Fit_line"].direction.cosine_similarity([1, 0, 0])
                    )
                    cos_y.append(
                        track["Fit_line"].direction.cosine_similarity([0, 1, 0])
                    )
                    cos_z.append(
                        track["Fit_line"].direction.cosine_similarity([0, 0, 1])
                    )
                    vectors.append(track["Fit_line"].direction.to_array())

    vectors = np.array(vectors)
    cos_x = np.array(cos_x)
    cos_y = np.array(cos_y)
    cos_z = np.array(cos_z)

    fig, ax = plt.subplots(2, 3, figsize=(18, 12))

    ax[0, 0].hist(vectors[:, 0], bins=20)
    ax[0, 0].set_xlabel("X vector component")
    ax[0, 1].hist(vectors[:, 1], bins=20)
    ax[0, 1].set_xlabel("Y vector component")
    ax[0, 2].hist(vectors[:, 2], bins=20)
    ax[0, 2].set_xlabel("Z vector component")

    ax[1, 0].hist(abs(cos_x), bins=20)
    ax[1, 0].set_xlabel("Cosine similarity to x-axis")
    ax[1, 1].hist(abs(cos_y), bins=20)
    ax[1, 1].set_xlabel("Cosine similarity to y-axis")
    ax[1, 2].hist(abs(cos_z), bins=20)
    ax[1, 2].set_xlabel("Cosine similarity to z-axis")

    return fig, ax

In [None]:
plot_track_angles(metrics)
plt.show()

## Heat map

In [None]:
def sectorize_dqdx(metrics, bin_size=(32, 32)):
    temp_df = get_track_stats(metrics)
    temp_df = temp_df[temp_df["track_score"] > 0.5]
    dQdx_df = temp_df.explode("track_dQdx")
    dQdx_df["position"] = temp_df["track_points"].explode()
    dQdx_df["x"] = dQdx_df["position"].apply(lambda x: x[0])
    dQdx_df["y"] = dQdx_df["position"].apply(lambda x: x[1])
    dQdx_df["z"] = dQdx_df["position"].apply(lambda x: x[2])
    dQdx_df = dQdx_df[(dQdx_df["track_dQdx"] > 0)].drop("position", axis=1)

    # Define the bin edges for a range of coordinates
    x_bins = np.arange(
        -detector_x / 2 - bin_size[0], detector_x / 2 + 2 * bin_size[0], bin_size[0]
    )
    y_bins = np.arange(
        -detector_y / 2 - bin_size[1], detector_y / 2 + 2 * bin_size[1], bin_size[1]
    )

    # Cut the data into bins
    dQdx_df["x_bin"] = pd.cut(dQdx_df["x"], bins=x_bins, labels=False)
    dQdx_df["y_bin"] = pd.cut(dQdx_df["y"], bins=y_bins, labels=False)
    last_x = dQdx_df["x_bin"].max()
    last_y = dQdx_df["y_bin"].max()
    dQdx_df["x_bin"] = dQdx_df["x_bin"].apply(
        lambda x: 1 if x == 0 else last_x - 1 if x == last_x else x
    )
    dQdx_df["y_bin"] = dQdx_df["y_bin"].apply(
        lambda x: 1 if x == 0 else last_y - 1 if x == last_y else x
    )

    # Create a DataFrame for all possible bin combinations
    all_bins = pd.DataFrame(
        [(x, y) for x in dQdx_df["x_bin"].unique() for y in dQdx_df["y_bin"].unique()],
        columns=["x_bin", "y_bin"],
    )

    # Merge the actual data with the placeholder DataFrame
    dQdx_df = pd.merge(
        all_bins, dQdx_df.reset_index(), on=["x_bin", "y_bin"], how="left"
    ).fillna(0)

    # Combine the bins into a single sector identifier
    dQdx_df["sector"] = dQdx_df.apply(lambda row: (row["x_bin"], row["y_bin"]), axis=1)

    return dQdx_df

In [None]:
dQdx_df = sectorize_dqdx(metrics, bin_size=(32, 32))

In [None]:
plt.scatter(dQdx_df["x"], dQdx_df["y"], c=dQdx_df["track_dQdx"], s=0.1)
plt.xticks(np.arange(-detector_x / 2, detector_x / 2 + 32, 32))
plt.yticks(np.arange(-detector_y / 2, detector_y / 2 + 32, 32))
plt.grid()
plt.colorbar()

In [None]:
counts = dQdx_df.pivot_table(
    index="y_bin", columns="x_bin", values="track_dQdx", fill_value=0, aggfunc="count"
)
counts_cut = dQdx_df[
    (1500 < dQdx_df["track_dQdx"]) & (dQdx_df["track_dQdx"] < 3000)
].pivot_table(
    index="y_bin", columns="x_bin", values="track_dQdx", fill_value=0, aggfunc="count"
)

plt.pcolormesh(counts_cut / counts)
cbar = plt.colorbar(label="Ratio per sector")
plt.gca().set_aspect("equal", adjustable="box")
plt.xticks(counts.columns)
plt.yticks(counts.index)
plt.xlabel("X bin")
plt.ylabel("Y bin")
plt.title("Ratio of dQdx counts per sector")
plt.show()

In [None]:
t = dQdx_df.groupby("event")["track_dQdx"].mean()
t[((t < 3000) & (1500 < t))].index

In [None]:
fig, ax = plt.subplots(
    dQdx_df["y_bin"].nunique(), dQdx_df["x_bin"].nunique(), figsize=(20, 20)
)
for sector in dQdx_df["sector"].unique():
    x = int(sector[0]) - 1
    y = 4 - int(sector[1]) + 1
    dQdx_df[dQdx_df["sector"] == sector]["track_dQdx"].hist(
        ax=ax[y, x], bins=np.arange(0, 12e3, 400)
    )

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
metrics["20230706_203024_267"][1]["Fit_line"].plot_2d(
    ax,
    t_1=-metrics["20230706_203024_267"][1]["Fit_norm"] / 2,
    t_2=metrics["20230706_203024_267"][1]["Fit_norm"] / 2,
)
dQdx_df[dQdx_df["event"] == "20230706_203024_267"].plot.scatter(ax=ax, x="x", y="y")

In [None]:
dQdx_df["event"]