In [1]:
import uproot
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [93]:
df_columns = ['bin', 'Truth', 'CKF', 'GNN']

interests = {
    "Tag No_of_tracks": {
        "hist_setup": np.linspace(0, 20, 21),
        "branches": ["TagTrk2_track_No_truth", "TagTrk2_track_No", "gnn_tag_no_tracks"],
        "hists": [np.zeros(len(np.linspace(0, 20, 21)) - 1, dtype=int)] * 3,
        "df": pd.DataFrame(columns=df_columns),
    },
    "Rec No_of_tracks": {
        "hist_setup": np.linspace(0, 20, 21),
        "branches": ["RecTrk2_track_No_truth", "RecTrk2_track_No", "gnn_rec_no_tracks"],
        "hists": [np.zeros(len(np.linspace(0, 20, 21)) - 1, dtype=int)] * 3,
        "df": pd.DataFrame(columns=df_columns),
    }
}

int_style = {
    "Truth": {
        "color": "black",
    },
    "CKF": {
        "color": "red",
        "marker": "None",
        "fillcolor": "rgba(0,100,80,0.4)",
    },
    "GNN": {
        "color": "blue",
        "marker": "None",
        "fillcolor": "rgba(0,176,246,0.4)",
    },
    "CKF/Truth": {
        "color": "red",
        "marker": "circle",
    },
    "GNN/Truth": {
        "color": "blue",
        "marker": "diamond",
    },
}



In [94]:
for df in tqdm(uproot.iterate(
        "/Users/avencast/PycharmProjects/trkgnn/workspace/test/merged.root:dp",
        library="pd",
        step_size="100 MB",
        branches=[br for brs in [interests[interest]["branches"] for interest in interests] for br in brs],
        cut=None,
)):
    # Calculate histograms and accumulate
    for interest in interests:
        for i, branch in enumerate(interests[interest]["branches"]):
            interests[interest]["hists"][i] += np.histogram(df[branch], bins=interests[interest]["hist_setup"])[0]

# convert to dataframe
for interest in interests:
    interests[interest]["df"] = pd.DataFrame({
        "bin": 0.5 * (interests[interest]["hist_setup"][:-1] + interests[interest]["hist_setup"][1:]),
        "Truth": interests[interest]["hists"][0],
        "CKF": interests[interest]["hists"][1],
        "GNN": interests[interest]["hists"][2],
    })
    interests[interest]["df"] = interests[interest]["df"].loc[
        (interests[interest]["df"]['Truth'] != 0)
        & (interests[interest]["df"]['CKF'] != 0)
        & (interests[interest]["df"]['GNN'] != 0)
    ]


1it [00:00,  2.09it/s]


In [95]:
def plot_comparison(
        x, y0, y1, y2,
        labels=('Truth', 'CKF', 'GNN'),
        y_label: str = 'Counts',
        ratio_threshold: float = 5.0,
        annotation_text: str = None,
):
    # Compute errors
    y0_error = np.sqrt(y0)
    y1_error = np.sqrt(y1)
    y2_error = np.sqrt(y2)
    # Compute difference and errors
    ratio1 = y1 / y0
    ratio2 = y2 / y0
    ratio1_error = ratio1 * np.sqrt((y1_error / y1) ** 2 + (y0_error / y0) ** 2)
    ratio2_error = ratio2 * np.sqrt((y2_error / y2) ** 2 + (y0_error / y0) ** 2)

    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.0,
        row_heights=[0.75, 0.25]
    )

    hit_ratio = len(x[ratio1 > ratio_threshold]) + len(x[ratio2 > ratio_threshold]) > 0
    ratio_min = min(ratio1.min(), ratio2.min())
    ratio_max = max(ratio1.max(), ratio2.max()) if hit_ratio else ratio_threshold

    fig.add_trace(
        go.Scatter(
            x=x, y=y0, name=labels[0], line=dict(color=int_style[labels[0]]["color"]),
            error_y=dict(type='data', array=y0_error, visible=True)
        ), row=1, col=1)
    fig.add_trace(
        go.Scatter(
            x=x, y=y1, name=labels[1],
            line=dict(color=int_style[labels[1]]["color"]),
            fillcolor=int_style[labels[1]]["fillcolor"],
            fill='tozeroy',
            error_y=dict(type='data', array=y1_error, visible=True)
        ), row=1, col=1)
    fig.add_trace(
        go.Scatter(
            x=x, y=y2, name=labels[2],
            line=dict(color=int_style[labels[2]]["color"]),
            fillcolor=int_style[labels[2]]["fillcolor"],
            fill='tozeroy',
            error_y=dict(type='data', array=y1_error, visible=True)
        ), row=1, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio1 <= ratio_threshold], y=ratio1[ratio1 <= ratio_threshold],
        name=f"{labels[1]}/{labels[0]}",
        mode='markers',
        marker=dict(
            size=10,  # Setting marker size
            color=int_style[f"{labels[1]}/{labels[0]}"]["color"],
            symbol=int_style[f"{labels[1]}/{labels[0]}"]["marker"]
        ),
        error_y=dict(
            type='data',
            array=ratio1_error[ratio1 <= ratio_threshold],
            color=int_style[f"{labels[1]}/{labels[0]}"]["color"],
            thickness=1.5,
            width=3)
    ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio2 <= ratio_threshold], y=ratio2[ratio2 <= ratio_threshold],
        name=f"{labels[2]}/{labels[0]}",
        mode='markers',
        marker=dict(
            size=10,  # Setting marker size
            color=int_style[f"{labels[2]}/{labels[0]}"]["color"],
            symbol=int_style[f"{labels[2]}/{labels[0]}"]["marker"]
        ),
        error_y=dict(
            type='data',
            array=ratio2_error[ratio2 <= ratio_threshold],
            color=int_style[f"{labels[2]}/{labels[0]}"]["color"],
            thickness=1.5,
            width=3)
    ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio1 > ratio_threshold],
        y=ratio1[ratio1 > ratio_threshold].where(ratio1 <= ratio_threshold, ratio_threshold),
        name=f"{labels[1]}/{labels[0]} >={ratio_threshold}",
        mode='markers',
        marker_color=int_style[f"{labels[1]}/{labels[0]}"]["color"],
        marker_symbol='arrow-up-open', marker_size=12, marker_line_width=2,
    ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=x[ratio2 > ratio_threshold],
        y=ratio2[ratio2 > ratio_threshold].where(ratio2 <= ratio_threshold, ratio_threshold),
        name=f"{labels[2]}/{labels[0]} >={ratio_threshold}",
        mode='markers',
        marker_color=int_style[f"{labels[2]}/{labels[0]}"]["color"],
        marker_symbol='arrow-up-open', marker_size=12, marker_line_width=2,
    ), row=2, col=1)

    fig.add_hline(y=1.0, line_width=1, line_dash="dash", line_color="grey", row=2, col=1)

    y_axis_attr = dict(linecolor="#666666", zerolinecolor='rgba(0,0,0,0)', linewidth=2, mirror=True)
    fig.update_yaxes(**y_axis_attr, title_text=y_label, type="log", row=1, col=1)
    fig.update_yaxes(**y_axis_attr, title_text="Rec./Truth", row=2, col=1,
                     range=[ratio_min - (ratio_max - ratio_min) * 0.35, ratio_max * 1.05])

    x_axis_attr = dict(
        linecolor="#666666", gridcolor='#d9d9d9', zerolinecolor='rgba(0,0,0,0)', linewidth=2,
        showline=True, showgrid=False
    )
    fig.update_xaxes(**x_axis_attr, mirror=True, row=1, col=1)
    fig.update_xaxes(**x_axis_attr, mirror=False, row=2, col=1)

    # annotation
    if annotation_text is not None:
        fig.add_annotation(
            text=f'{annotation_text}', showarrow=False, xref='paper', x=0.15,
            yref='paper', y=0.97,
            font=dict(size=18, family='Arial'),
        )

    fig.update_layout(
        legend=dict(
            orientation="h",
            yanchor="top",
            y=1.08,
            xanchor="right",
            x=1,
            font=dict(size=14),
        ),
        width=800,
        height=700,
        # paper_bgcolor='rgba(0,0,0,0)',
        # plot_bgcolor='rgba(0,0,0,0)',
    )

    fig.show()


In [96]:
for interest in interests:
    # print(interests[interest]['df'])
    plot_comparison(
        x=interests[interest]["df"]["bin"],
        y0=interests[interest]["df"]["Truth"],
        y1=interests[interest]["df"]["CKF"],
        y2=interests[interest]["df"]["GNN"],
        annotation_text=interest,
    )

In [195]:
interests["Tag No_of_tracks"]["df"]

           A             B
0  [1, 2, 3]  [10, 20, 30]
1  [4, 5, 6]  [40, 50, 60]
2  [7, 8, 9]  [70, 80, 90]
