In [134]:
import uproot
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import awkward

import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [203]:
df_columns = ['bin', 'Truth', 'CKF', 'GNN']

interests = {
    "Tagging Region, No. of tracks": {
        "hist_setup": np.linspace(0, 20, 21),
        "branches": ["TagTrk2_track_No_truth", "TagTrk2_track_No", "gnn_tag_no_tracks"],
        "type": "plain",
        "x_label": "No. of tracks",
    },
    "Recoil Region, No. of tracks": {
        "hist_setup": np.linspace(0, 20, 21),
        "branches": ["RecTrk2_track_No_truth", "RecTrk2_track_No", "gnn_rec_no_tracks"],
        "type": "plain",
        "x_label": "No. of tracks",
    },
    "Before Target, Incident Energy": {
        "hist_setup": np.linspace(5000, 9000, 50),
        "branches": ["TagTrk2_pp_truth_fin", "TagTrk2_pp", "gnn_tag_pf"],
        "cuts": ["TagTrk2_track_No_truth == 1", "TagTrk2_track_No == 1", "gnn_tag_no_tracks == 1"],
        "type": "plain",
        "x_label": r"${P_{i}^{e}~\text{[MeV]}}$",
    },
    "After Target, Recoil Energy": {
        "hist_setup": np.linspace(5000, 9000, 50),
        "branches": ["RecTrk2_pp_truth_ini", "RecTrk2_pp", "gnn_rec_pi"],
        "cuts": ["RecTrk2_track_No_truth == 1", "RecTrk2_track_No == 1", "gnn_rec_no_tracks == 1"],
        "type": "plain",
        "x_label": r"${P_{f}^{e}~\text{[MeV]}}$",
    },
    "Before ECAL, Incident Energy": {
        "hist_setup": np.linspace(5000, 9000, 50),
        "branches": ["RecTrk2_pp_truth_fin", "RecTrk2_pp", "gnn_rec_pf"],
        "cuts": ["RecTrk2_track_No_truth == 1", "RecTrk2_track_No == 1", "gnn_rec_no_tracks == 1"],
        "type": "plain",
        "x_label": r"${P_{\text{ECAL}}^{e}~\text{[MeV]}}$",
    },
    "Missing Momentum": {
        "hist_setup": np.linspace(0, 9000, 100),
        "branches_up": ["TagTrk2_pp_truth_fin", "TagTrk2_pp", "gnn_tag_pf"],
        "branches_down": ["RecTrk2_pp_truth_ini", "RecTrk2_pp", "gnn_rec_pi"],
        "cuts": [
            "TagTrk2_track_No_truth == 1 & RecTrk2_track_No_truth == 1",
            "TagTrk2_track_No == 1 & RecTrk2_track_No == 1",
            "gnn_tag_no_tracks == 1 & gnn_rec_no_tracks == 1",
        ],
        "type": "difference",
        "x_label": r"${P_{i}^{e} - P_{f}^{e}~\text{[MeV]}}$",
    },
    "Resolution: Tagging Region": {
        "hist_setup": np.linspace(-10, 10, 50),
        "branch_base": "TagTrk2_pp_truth_fin",
        "branches": ["TagTrk2_pp", "gnn_tag_pf"],
        "cuts": ["TagTrk2_track_No == 1", "gnn_tag_no_tracks == 1"],
        "type": "resolution",
        "x_label": r"${\sigma P_{i}^{e}~\text{[MeV]}}$",
    }
}

# Initialize histograms
for interest in interests:
    interests[interest]['df'] = pd.DataFrame(columns=df_columns)
    if interests[interest]["type"] in {"plain", "resolution"}:
        interests[interest]["hists"] = [
            np.zeros_like(interests[interest]["hist_setup"][:-1]) for _ in interests[interest]["branches"]
        ]
        if "cuts" not in interests[interest]:
            interests[interest]["cuts"] = [f"{br} == {br}" for br in interests[interest]["branches"]]
    if interests[interest]["type"] == "difference":
        interests[interest]['branches'] = [*interests[interest]['branches_up'], *interests[interest]['branches_down']]
        interests[interest]["hists"] = [
            np.zeros_like(interests[interest]["hist_setup"][:-1]) for _ in interests[interest]["branches_up"]
        ]
        if "cuts" not in interests[interest]:
            interests[interest]["cuts"] = [f"{br} == {br}" for br in interests[interest]["branches_up"]]

int_style = {
    "Truth": {
        "color": "rgba(0,0,0,0.8)",
    },
    "CKF": {
        "color": "rgba(0,100,80,0.6)",
        "marker": "None",
        "fillcolor": "rgba(0,100,80,0.4)",
    },
    "GNN": {
        "color": "rgba(0,176,246,0.6)",
        "marker": "None",
        "fillcolor": "rgba(0,176,246,0.4)",
    },
    "CKF/Truth": {
        "color": "rgba(0,100,80,0.4)",
        "marker": "circle",
    },
    "GNN/Truth": {
        "color": "rgba(0,176,246,0.4)",
        "marker": "diamond",
    },
}



In [204]:

# Calculate the total number of entries/events in all files
# Open the file
with uproot.open("/Users/avencast/PycharmProjects/trkgnn/workspace/test/merged.root") as file:
    # Get the number of entries in the tree
    total_entries = file["dp"].num_entries

# Create a tqdm object with the total number of entries
pbar = tqdm(total=total_entries, unit="events", desc="Processing", initial=0, position=0, leave=True)

for df in uproot.iterate(
        "/Users/avencast/PycharmProjects/trkgnn/workspace/test/merged.root:dp",
        library="pd",
        step_size="500 KB",
        filter_name=[br for brs in [interests[interest]["branches"] for interest in interests] for br in brs],
        cut=None,
):
    def filter_list(x: pd.Series):
        if x.dtype == 'awkward':
            return x.map(lambda x: x[0] if len(x) > 0 else np.nan)
        return x


    df = df.apply(filter_list)

    # Calculate histograms and accumulate
    for interest in interests:
        if interests[interest]["type"] == "plain":
            for i, (branch, cut) in enumerate(zip(interests[interest]["branches"], interests[interest]["cuts"])):
                df_br = df.query(cut)[branch]
                interests[interest]["hists"][i] += np.histogram(df_br, bins=interests[interest]["hist_setup"])[0]
        if interests[interest]["type"] == "difference":
            for i, (br_up, br_down, cut) in enumerate(
                    zip(interests[interest]["branches_up"],
                        interests[interest]["branches_down"],
                        interests[interest]["cuts"])):
                df_br = df.query(cut)[br_up] - df.query(cut)[br_down]
                interests[interest]["hists"][i] += np.histogram(df_br, bins=interests[interest]["hist_setup"])[0]
        if interests[interest]["type"] == "resolution":
            for i, (branch, cut) in enumerate(zip(interests[interest]["branches"], interests[interest]["cuts"])):
                df_cut = df.query(cut)
                diff = (df_cut[branch] - df_cut[interests[interest]["branch_base"]]) / df_cut[
                    interests[interest]["branch_base"]] * 100
                interests[interest]["hists"][i] += np.histogram(diff, bins=interests[interest]["hist_setup"])[0]

    pbar.update(len(df))
    # break

# Close the progress bar
pbar.close()

# convert to dataframe
for interest in interests:
    if interests[interest]["type"] in {"plain", "difference"}:
        interests[interest]["df"] = pd.DataFrame({
            "bin": interests[interest]["hist_setup"][:-1],
            "Truth": interests[interest]["hists"][0],
            "CKF": interests[interest]["hists"][1],
            "GNN": interests[interest]["hists"][2],
        })
        interests[interest]["df"] = interests[interest]["df"].loc[
            (interests[interest]["df"]['Truth'] != 0)
            | (interests[interest]["df"]['CKF'] != 0)
            | (interests[interest]["df"]['GNN'] != 0)
            ]
    if interests[interest]["type"] == "resolution":
        interests[interest]["df"] = pd.DataFrame({
            "bin": interests[interest]["hist_setup"][:-1],
            "CKF/Truth": interests[interest]["hists"][0],
            "GNN/Truth": interests[interest]["hists"][1],
        })
        interests[interest]["df"] = interests[interest]["df"].loc[
            (interests[interest]["df"]['CKF/Truth'] != 0)
            | (interests[interest]["df"]['GNN/Truth'] != 0)
            ]


Processing: 100%|██████████| 50000/50000 [00:00<00:00, 66655.99events/s]


In [205]:
def plot_comparison(
        x, y1, y2,
        y0=None,
        labels=('Truth', 'CKF', 'GNN'),
        x_label: str = 'x',
        y_label: str = 'Counts',
        ratio_threshold: float = 5.0,
        annotation_text: str = None,
):
    # Compute errors
    y1_error = np.sqrt(y1)
    y2_error = np.sqrt(y2)
    # Compute difference and errors
    if y0 is not None:
        y0_error = np.sqrt(y0)
        ratio1 = y1 / y0
        ratio2 = y2 / y0
        ratio1_error = ratio1 * np.sqrt((y1_error / y1) ** 2 + (y0_error / y0) ** 2)
        ratio2_error = ratio2 * np.sqrt((y2_error / y2) ** 2 + (y0_error / y0) ** 2)
    else:
        ratio1 = y2 / y1
        ratio2 = ratio1
        ratio1_error = ratio1 * np.sqrt((y2_error / y2) ** 2 + (y1_error / y1) ** 2)
        ratio2_error = ratio1_error

    y_max = max((y0 + y0_error).max() if y0 is not None else 1, (y1 + y1_error).max(), (y2 + y2_error).max())

    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.0,
        row_heights=[0.75, 0.25]
    )

    hit_ratio = len(x[ratio1 > ratio_threshold]) + len(x[ratio2 > ratio_threshold]) > 0
    ratio_min = min(ratio1.min(), ratio2.min())
    ratio_max = max(ratio1.max(), ratio2.max()) if not hit_ratio else ratio_threshold

    fig.add_trace(
        go.Scatter(
            x=x, y=y0, name=labels[0], line=dict(color=int_style[labels[0]]["color"]),
            error_y=dict(type='data', array=y0_error, visible=True)
        ), row=1, col=1)
    fig.add_trace(
        go.Scatter(
            x=x, y=y1, name=labels[1],
            line=dict(color=int_style[labels[1]]["color"]),
            fillcolor=int_style[labels[1]]["fillcolor"],
            # fill='tozeroy',
            error_y=dict(type='data', array=y1_error, visible=True)
        ), row=1, col=1)
    fig.add_trace(
        go.Scatter(
            x=x, y=y2, name=labels[2],
            line=dict(color=int_style[labels[2]]["color"]),
            fillcolor=int_style[labels[2]]["fillcolor"],
            # fill='tozeroy',
            error_y=dict(type='data', array=y1_error, visible=True)
        ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio1 <= ratio_threshold], y=ratio1[ratio1 <= ratio_threshold],
        name=f"{labels[1]}/{labels[0]}",
        mode='markers',
        marker=dict(
            size=10,  # Setting marker size
            color=int_style[f"{labels[1]}/{labels[0]}"]["color"],
            symbol=int_style[f"{labels[1]}/{labels[0]}"]["marker"]
        ),
        error_y=dict(
            type='data',
            array=ratio1_error[ratio1 <= ratio_threshold],
            color=int_style[f"{labels[1]}/{labels[0]}"]["color"],
            thickness=1.5,
            width=3)
    ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio2 <= ratio_threshold], y=ratio2[ratio2 <= ratio_threshold],
        name=f"{labels[2]}/{labels[0]}",
        mode='markers',
        marker=dict(
            size=10,  # Setting marker size
            color=int_style[f"{labels[2]}/{labels[0]}"]["color"],
            symbol=int_style[f"{labels[2]}/{labels[0]}"]["marker"]
        ),
        error_y=dict(
            type='data',
            array=ratio2_error[ratio2 <= ratio_threshold],
            color=int_style[f"{labels[2]}/{labels[0]}"]["color"],
            thickness=1.5,
            width=3)
    ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio1 > ratio_threshold],
        y=ratio1[ratio1 > ratio_threshold].where(ratio1 <= ratio_threshold, ratio_threshold),
        name=f"{labels[1]}/{labels[0]} >={ratio_threshold}",
        mode='markers',
        marker_color=int_style[f"{labels[1]}/{labels[0]}"]["color"],
        marker_symbol='arrow-up-open', marker_size=12, marker_line_width=2,
        showlegend=False,
    ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio2 > ratio_threshold],
        y=ratio2[ratio2 > ratio_threshold].where(ratio2 <= ratio_threshold, ratio_threshold),
        name=f"{labels[2]}/{labels[0]} >={ratio_threshold}",
        mode='markers',
        marker_color=int_style[f"{labels[2]}/{labels[0]}"]["color"],
        marker_symbol='arrow-up-open', marker_size=12, marker_line_width=2,
        showlegend=False,
    ), row=2, col=1)

    fig.add_hline(y=1.0, line_width=1, line_dash="dash", line_color="grey", row=2, col=1)

    y_axis_attr = dict(linecolor="#666666", zerolinecolor='rgba(0,0,0,0)', linewidth=2, mirror=True)
    fig.update_yaxes(
        **y_axis_attr, title_text=y_label, type="log", row=1, col=1,
        range=[0.1, np.log10(y_max) + 2.5]
    )
    fig.update_yaxes(
        **y_axis_attr, title_text="Rec./Truth", row=2, col=1,
        range=[0, ratio_max]
    )

    x_axis_attr = dict(
        linecolor="#666666", gridcolor='#d9d9d9', zerolinecolor='rgba(0,0,0,0)', linewidth=2,
        showline=True, showgrid=False
    )
    fig.update_xaxes(**x_axis_attr, mirror=True, row=1, col=1)
    fig.update_xaxes(**x_axis_attr, mirror=False, row=2, col=1, title_text=x_label)

    # annotation
    x_base, y_base = 0.05, 0.97
    fig.add_annotation(
        text=r'<i><b>DarkSHINE<b><i>', showarrow=False, xref='paper', x=x_base, yref='paper', y=y_base,
        font=dict(size=34, family='Cambria'),
    )
    fig.add_annotation(
        text=r'Simulation', showarrow=False, xref='paper', x=x_base + 0.405, yref='paper', y=y_base - 0.0037,
        font=dict(size=31, family='Cambria'),
    )

    y_base -= 0.075
    fig.add_annotation(
        text=r'$\Large{E^{e}_{0} = 8~\text{GeV},~10^{14}~\text{EOT}}$', showarrow=False, xref='paper', x=x_base,
        yref='paper', y=y_base,
    )

    y_base -= 0.062
    if annotation_text is not None:
        fig.add_annotation(
            text=f'{annotation_text}', showarrow=False, xref='paper', x=x_base,
            yref='paper', y=y_base,
            font=dict(size=20, family='Cambria'),
        )

    fig.update_layout(
        legend=dict(
            orientation="v",
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.98,
            font=dict(size=14),
        ),
        width=800,
        height=800,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
    )

    fig.show()


In [206]:
for interest in interests:
    # print(interests[interest]['df'])
    if interests[interest]['type'] in {'plain', 'difference'}:
        plot_comparison(
            x=interests[interest]["df"]["bin"],
            y0=interests[interest]["df"]["Truth"],
            y1=interests[interest]["df"]["CKF"],
            y2=interests[interest]["df"]["GNN"],
            x_label=interests[interest]["x_label"],
            annotation_text=interest,
            ratio_threshold=2.0,
        )

In [69]:
for df in uproot.iterate(
        "/Users/avencast/PycharmProjects/trkgnn/workspace/test/merged.root:dp",
        library="pd",
        step_size="50 KB",
        filter_name=[br for brs in [interests[interest]["branches"] for interest in interests] for br in brs],
        cut=None,
):
    # def filter_list(x:pd.Series):
    #     print(x)
    #
    #
    # df = df.apply(filter_list)
    break
