In [None]:
import os
import json
import hydra
import numpy as np
from math import isclose
import enreg.tools.general as g
import mplhep as hep
import awkward as ak
import matplotlib.pyplot as plt
from omegaconf import DictConfig
from sklearn.preprocessing import label_binarize
from matplotlib import cm
from sklearn import metrics
from sklearn.metrics import confusion_matrix, precision_score
from matplotlib.ticker import FormatStrFormatter

In [None]:
hep.style.use("CMS")

In [None]:
data_zh_model = {
    "ParticleTransformer": g.load_all_data("/home/laurits/ml-tau-en-reg/training-outputs/20240921_recoPtCut_removed_samples/v1/dm_multiclass/ParticleTransformer/zh_test.parquet").dm_multiclass,
    "LorentzNet": g.load_all_data("/home/laurits/ml-tau-en-reg/training-outputs/20240921_recoPtCut_removed_samples/v1/dm_multiclass/LorentzNet/zh_test.parquet").dm_multiclass,
    "DeepSet": g.load_all_data("/home/laurits/ml-tau-en-reg/training-outputs/20240921_recoPtCut_removed_samples/v1/dm_multiclass/DeepSet/zh_test.parquet").dm_multiclass,
    "HPS": g.load_all_data("/home/laurits/HPS_recoCut0_ntuples/zh.parquet")
}


In [None]:
data_z_model = {
    "ParticleTransformer": g.load_all_data("/home/laurits/ml-tau-en-reg/training-outputs/20240921_recoPtCut_removed_samples/v1/dm_multiclass/ParticleTransformer/z_test.parquet").dm_multiclass,
    "LorentzNet": g.load_all_data("/home/laurits/ml-tau-en-reg/training-outputs/20240921_recoPtCut_removed_samples/v1/dm_multiclass/LorentzNet/z_test.parquet").dm_multiclass,
    "DeepSet": g.load_all_data("/home/laurits/ml-tau-en-reg/training-outputs/20240921_recoPtCut_removed_samples/v1/dm_multiclass/DeepSet/z_test.parquet").dm_multiclass,
    "HPS": g.load_all_data("/home/laurits/HPS_recoCut0_ntuples/z.parquet")
}

In [None]:
output_dir = os.path.join("../outputs/20240923_decaymode_plots/")
os.makedirs(output_dir, exist_ok=True)

In [None]:
## Functions

#---------------For getting the distribution of the decay modes-----------------------------------------------------------------------------

# Function for plotting the distribution of the decay modes
def plot_decay_modes(model):
    dms = np.arange(17)
    if model == "HPS":
        x = data_zh_model[model].true_decay_mode
        y = data_zh_model[model].pred_decay_mode
    else:
        x = data_zh_model[model].target
        y = data_zh_model[model].pred
    bins=dms
    if model == "SimpleDNN":
        plt.hist([x, y], bins, label=['Actual decaymode', 'DeepSet'])
    else:
        plt.hist([x, y], bins, label=['Actual decaymode', model])
    plt.legend(loc='upper right')
    plt.title("Actual vs predicted decaymodes")
    plt.yscale('log')
    plt.xticks(dms+0.4, dms);
    plt.savefig(os.path.join(output_dir, f"dm_{model}.pdf"), bbox_inches='tight', format='pdf')
    plt.close("all")

#---------------For getting the confusion matrices-----------------------------------------------------------------------------

# Function for aggregating specific classes of the confusion matrix
def aggregate_classes(cm, classes_to_merge):
    num_classes = cm.shape[0]
    new_num_classes = num_classes - len(classes_to_merge) + 1
    new_cm = np.zeros((new_num_classes, new_num_classes))
    
    if classes_to_merge == [2, 3, 4]:
        merge_index = 2  # Index for the merged class
    elif classes_to_merge == [4, 5]:
        merge_index = 4

    # Classes that are not being merged
    classes_to_keep = [i for i in range(num_classes) if i not in classes_to_merge]

    # Mapping old indices to new indices
    index_map = {}
    new_index = 0

    # Insert classes to keep before the merge_index
    for old_index in classes_to_keep:
        if new_index == merge_index:
            new_index += 1
        index_map[old_index] = new_index
        new_index += 1

    # Map the merged classes to the merge_index
    for cls in classes_to_merge:
        index_map[cls] = merge_index

    # Aggregate the confusion matrix
    for i in range(num_classes):
        for j in range(num_classes):
            new_i = index_map[i]
            new_j = index_map[j]
            new_cm[new_i, new_j] += cm[i, j]

    return new_cm

# Function for normalizing the confusion matrix
def normalize_confusion_matrix(cm):
    column_sums = cm.sum(axis=0, keepdims=True)
    normalized_cm = cm / column_sums
    return normalized_cm

# Function for plotting the DM confusion matrices
def CM_plot(dataset):

    models = ["DeepSet", "LorentzNet", "ParticleTransformer", "HPS"]
    for model_name in models:
        if dataset == "ZH":
            if model_name == "HPS":
                y_true = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].true_decay_mode))
                y_pred = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].pred_decay_mode))
            else:
                y_true = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].target))
                y_pred = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].pred))
        elif dataset == "Z":
            if model_name == "HPS":
                y_true = g.get_reduced_decaymodes(np.array(data_z_model[model_name].true_decay_mode))
                y_pred = g.get_reduced_decaymodes(np.array(data_z_model[model_name].pred_decay_mode))
            else:
                y_true = g.get_reduced_decaymodes(np.array(data_z_model[model_name].target))
                y_pred = g.get_reduced_decaymodes(np.array(data_z_model[model_name].pred))
        missing_mask = y_pred != -1
        y_true = y_true[missing_mask]
        y_pred = y_pred[missing_mask]
        
        cm = confusion_matrix(y_true, y_pred, normalize='true')
            
        categories = [r'$h^\pm$', r'$h^\pm+\pi^0$', r'$h^\pm+\geq2\pi^0$', r'$h^\pm h^\mp h^\pm$', r'$h^\pm h^\mp h^\pm$' '\n' r'$+\geq\pi^0$', 'Rare']
    
        fig, ax = plt.subplots()
        xbins = ybins = np.arange(len(categories) + 1)
        tick_values = np.arange(len(categories)) + 0.5

        hep.hist2dplot(cm, xbins, ybins, cmap='GnBu', cbar=True, flow=None, ax=ax)

        for i in range(len(ybins) - 1):
            for j in range(len(xbins) - 1):
                bin_value = cm.T[i, j]
                ax.text(
                    float(xbins[j] + 0.5),
                    float(ybins[i] + 0.5),
                    f"{bin_value:.2f}",
                    color='k',
                    ha="center",
                    va="center",
                    fontweight="bold",
                )
        ax.set_xticks(tick_values, categories)
        ax.set_yticks(tick_values + 0.2, categories)
        ax.set_xticklabels(categories, rotation=45)
        ax.tick_params(axis='x', which='both', bottom=True, top=True, labeltop=False, labelbottom=True)
        ax.set_yticklabels(categories)
        ax.set_xlabel('Generated decay mode')
        ax.set_ylabel('Reconstructed decay mode')
        ax.set_title(f"{dataset} Confusion Matrix", y=1.05)
        plt.savefig(os.path.join(output_dir, f"{dataset}_cm_{model_name}.pdf"), bbox_inches='tight', format='pdf')
        plt.close("all")

#---------------For plotting the classification precision of the DMs-----------------------------------------------------------------------------

# Function to get precisions of the models
def get_precision(dataset):
    
    model_precisions = []
    models = ["DeepSet", "LorentzNet", "ParticleTransformer", "HPS"]
    for model_name in models:
        if dataset == "ZH":
            if model_name == "HPS":
                y_true = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].true_decay_mode))
                y_pred = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].pred_decay_mode))
            else:
                y_true = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].target))
                y_pred = g.get_reduced_decaymodes(np.array(data_zh_model[model_name].pred))
        elif dataset == "Z":
            if model_name == "HPS":
                y_true = g.get_reduced_decaymodes(np.array(data_z_model[model_name].true_decay_mode))
                y_pred = g.get_reduced_decaymodes(np.array(data_z_model[model_name].pred_decay_mode))
            else:
                y_true = g.get_reduced_decaymodes(np.array(data_z_model[model_name].target))
                y_pred = g.get_reduced_decaymodes(np.array(data_z_model[model_name].pred))
        missing_mask = y_pred != -1
        y_true = y_true[missing_mask]
        y_pred = y_pred[missing_mask]
        
        precision = precision_score(y_true, y_pred, average=None)

        model_precisions.append(precision)
    return model_precisions

# Function to plot the classification precision of the DMs
def plot_dm_prec(dataset):
    
    if dataset == "ZH":
        pr = get_precision("ZH")
    elif dataset == "Z":
        pr = get_precision("Z")

    pr_sdnn, pr_ln, pr_pt, pr_hps = pr[0], pr[1], pr[2], pr[3]
    pr_sdnn = list(pr_sdnn)
    pr_ln = list(pr_ln)
    pr_pt = list(pr_pt)
    pr_hps = list(pr_hps)
    labels = [r'$h^\pm$', r'$h^\pm+\pi^0$', r'$h^\pm+\geq2\pi^0$', r'$h^\pm h^\mp h^\pm$', r'$h^\pm h^\mp h^\pm$' '\n' r'$+\geq\pi^0$', 'Rare', 'Overall']
    PDG_ratios = np.array([0.1777, 0.4002, 0.1668, 0.1513, 0.0816, 0.0224])
    
    pr_sdnn += [np.sum(np.array(pr_sdnn) * PDG_ratios)]
    pr_ln += [np.sum(np.array(pr_ln) * PDG_ratios)]
    pr_pt += [np.sum(np.array(pr_pt) * PDG_ratios)]
    pr_hps += [np.sum(np.array(pr_hps) * PDG_ratios)]

    # Create a mapping from labels to their positions on the x-axis
    x = range(len(labels))
    
    fig, ax = plt.subplots()

    # Define small offsets for each dataset
    offsets = [-0.3, -0.1, 0.1, 0.3]
    
    # Plotting the data as points
    ax.scatter(np.array(range(len(labels))) + offsets[0], pr_sdnn, label='DeepSet', color='#ff5b5b', marker='o', s=100)
    ax.scatter(np.array(range(len(labels))) + offsets[1], pr_ln, label='LorentzNet', color='#ffc140', marker='o', s=100)
    ax.scatter(np.array(range(len(labels))) + offsets[2], pr_pt, label='ParticleTransformer', color='#89cded', marker='o', s=100)
    ax.scatter(np.array(range(len(labels))) + offsets[3], pr_hps, label='HPS', color='green', marker='o', s=100)
    
    ax.set_xlabel('Decay Modes')
    ax.set_ylabel('Precision', x=1.05)
    ax.set_title(f"Classification Precision of the DMs for {dataset}", y=1.05)
    ax.set_xticks(x)
    ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=True)
    ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=True)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.set_xticklabels(labels, rotation=45)
    legend = ax.legend(loc='lower left', shadow=True, fancybox=True, framealpha=1, borderpad=1)
    
    # Vertical lines between the DMs
    for i in range(len(labels)):
        ax.axvline(i - 0.5, color='gray', linestyle='--', linewidth=0.5, zorder=0)

    # Function to add labels to the datapoints
    def annotate_points(x, y, offset):
        for ii, (i, j) in enumerate(zip(x, y)):
            ax.annotate(f'{j:.3f}', (i, j), textcoords="offset points", xytext=offset, ha='center', va='bottom', fontsize=9)

    # Annotate the datapoints with different alignments
    annotate_points(np.array(range(len(labels))) + offsets[0], pr_sdnn, (-18, -4))
    annotate_points(np.array(range(len(labels))) + offsets[1], pr_ln, (-18, -4))
    annotate_points(np.array(range(len(labels))) + offsets[2], pr_pt, (-18, -4))
    annotate_points(np.array(range(len(labels))) + offsets[2], pr_hps, (-18, -4))

    legend = ax.legend(loc='lower left', shadow=True, fancybox=True, framealpha=1, borderpad=1)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{dataset}_dm_precision.pdf"), bbox_inches='tight', format='pdf')
    plt.close("all")

In [None]:
plot_decay_modes("DeepSet")
plot_decay_modes("LorentzNet")
plot_decay_modes("ParticleTransformer")

In [None]:
CM_plot("ZH")
CM_plot("Z")

In [None]:
plot_dm_prec("ZH")
plot_dm_prec("Z")