In [4]:
import os
import json
import hydra
import numpy as np
import enreg.tools.general as g
import mplhep as hep
import awkward as ak
import matplotlib.pyplot as plt
from omegaconf import DictConfig
from sklearn.preprocessing import label_binarize
from matplotlib import cm
from sklearn import metrics
from sklearn.metrics import confusion_matrix

In [5]:
hep.style.use("CMS")

In [6]:
data_zh = g.load_all_data(["/scratch/persistent/joosep/ml-tau/20240520_qq_zh_2m_merged/zh_test.parquet"])
data_z = g.load_all_data(["/scratch/persistent/joosep/ml-tau/20240520_qq_zh_2m_merged/z_test.parquet"])


[1/1] Loading from /scratch/persistent/joosep/ml-tau/20240520_qq_zh_2m_merged/zh_test.parquet
Input data loaded
[1/1] Loading from /scratch/persistent/joosep/ml-tau/20240520_qq_zh_2m_merged/z_test.parquet
Input data loaded


In [7]:
paths_zh_model = {
    "ParticleTransformer": "/local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/ParticleTransformer/zh_test.parquet",
    "LorentzNet": "/local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/LorentzNet/zh_test.parquet",
    "SimpleDNN": "/local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/SimpleDNN/zh_test.parquet",
}

data_zh_model = {k: g.load_all_data([v])["dm_multiclass"]["pred"] for (k, v) in paths_zh_model.items()}

[1/1] Loading from /local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/ParticleTransformer/zh_test.parquet
Input data loaded
[1/1] Loading from /local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/LorentzNet/zh_test.parquet
Input data loaded
[1/1] Loading from /local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/SimpleDNN/zh_test.parquet
Input data loaded


In [8]:
paths_z_model = {
    "ParticleTransformer": "/local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/ParticleTransformer/z_test.parquet",
    "LorentzNet": "/local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/LorentzNet/z_test.parquet",
    "SimpleDNN": "/local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/SimpleDNN/z_test.parquet",
}

data_z_model = {k: g.load_all_data([v])["dm_multiclass"]["pred"] for (k, v) in paths_z_model.items()}

[1/1] Loading from /local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/ParticleTransformer/z_test.parquet
Input data loaded
[1/1] Loading from /local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/LorentzNet/z_test.parquet
Input data loaded
[1/1] Loading from /local/joosep/ml-tau-en-reg/results/240524_cosinescheduler/dm_multiclass/SimpleDNN/z_test.parquet
Input data loaded


In [9]:
output_dir = os.path.join("../outputs/plots/")
os.makedirs(output_dir, exist_ok=True)

In [10]:
## Functions

#---------------For getting the distribution of the decay modes-----------------------------------------------------------------------------

# Function for plotting the distribution of the decay modes
def plot_decay_modes(model):
    dms = np.arange(17)
    x = data_zh["gen_jet_tau_decaymode"]
    y = data_zh_model[model]
    bins=dms
    if model == "SimpleDNN":
        plt.hist([x, y], bins, label=['Actual decaymode', 'DeepSet'])
    else:
        plt.hist([x, y], bins, label=['Actual decaymode', model])
    plt.legend(loc='upper right')
    plt.title("Actual vs predicted decaymodes")
    plt.yscale('log')
    plt.xticks(dms+0.4, dms);
    plt.savefig(os.path.join(output_dir, f"dm_{model}.pdf"), bbox_inches='tight', format='pdf')
    plt.close("all")

#---------------For getting the confusion matrices-----------------------------------------------------------------------------

# Function for aggregating specific classes of the confusion matrix
def aggregate_classes(cm, classes_to_merge):
    num_classes = cm.shape[0]
    new_num_classes = num_classes - len(classes_to_merge) + 1
    new_cm = np.zeros((new_num_classes, new_num_classes))
    
    if classes_to_merge == [2, 3, 4]:
        merge_index = 2  # Index for the merged class
    elif classes_to_merge == [4, 5]:
        merge_index = 4

    # Classes that are not being merged
    classes_to_keep = [i for i in range(num_classes) if i not in classes_to_merge]

    # Mapping old indices to new indices
    index_map = {}
    new_index = 0

    # Insert classes to keep before the merge_index
    for old_index in classes_to_keep:
        if new_index == merge_index:
            new_index += 1
        index_map[old_index] = new_index
        new_index += 1

    # Map the merged classes to the merge_index
    for cls in classes_to_merge:
        index_map[cls] = merge_index

    # Aggregate the confusion matrix
    for i in range(num_classes):
        for j in range(num_classes):
            new_i = index_map[i]
            new_j = index_map[j]
            new_cm[new_i, new_j] += cm[i, j]

    return new_cm

# Function for normalizing the confusion matrix
def normalize_confusion_matrix(cm):
    column_sums = cm.sum(axis=0, keepdims=True)
    normalized_cm = cm / column_sums
    return normalized_cm

# Function for plotting the DM confusion matrices
def CM_plot(dataset):

    models = ["SimpleDNN", "LorentzNet", "ParticleTransformer"]
    for model_name in models:
        if dataset == "ZH":
            y_true = np.array(data_zh["gen_jet_tau_decaymode"])
            y_pred = np.array(data_zh_model[model_name])
        elif dataset == "Z":
            y_true = np.array(data_z["gen_jet_tau_decaymode"])
            y_pred = np.array(data_z_model[model_name])    
        
        cm = confusion_matrix(y_true, y_pred)
        
        # Aggregate classes in the confusion matrix
        new_cm_ = aggregate_classes(cm, [2, 3, 4]) # These are 1h2pi0, 1h3pi0, 1hNpi0
        new_cm = aggregate_classes(new_cm_, [4, 5]) # These are 3h1pi0, 3h2pi0 in the new aggregated matrix

        # Normalize the matrix
        normalized_cm = normalize_confusion_matrix(new_cm)
    
        mirrored_cm = np.flipud(normalized_cm)
        
        new_labels = [r'$h^\pm$', r'$h^\pm+\pi^0$', r'$h^\pm+\geq2\pi^0$', r'$h^\pm h^\mp h^\pm$', r'$h^\pm h^\mp h^\pm$' '\n' r'$+\geq\pi^0$', 'Rare']
    
        fig, ax = plt.subplots()
        cax = ax.matshow(mirrored_cm, cmap='GnBu') # color scheme
        
        for (i, j), val in np.ndenumerate(mirrored_cm):
            ax.text(j, i, f'{val:.2f}', ha='center', va='center', color='black')
    
        ax.set_xticks(np.arange(len(new_labels)))
        ax.set_yticks(np.arange(len(new_labels)))
        ax.set_xticklabels(new_labels, rotation=45)
        ax.tick_params(axis='x', which='both', bottom=True, top=True, labeltop=False, labelbottom=True)
        ax.set_yticklabels(new_labels[::-1]) 
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(f"{dataset} Confusion Matrix", y=1.05)
        fig.colorbar(cax)
        plt.savefig(os.path.join(output_dir, f"{dataset}_cm_{model_name}.pdf"), bbox_inches='tight', format='pdf')
        plt.close("all")

#---------------For plotting the classification precision of the DMs-----------------------------------------------------------------------------

# Function to get precisions of the models
def get_precision(dataset):
    
    model_precisions = []
    models = ["SimpleDNN", "LorentzNet", "ParticleTransformer"]
    for model_name in models:
        if dataset == "ZH":
            y_true = np.array(data_zh["gen_jet_tau_decaymode"])
            y_pred = np.array(data_zh_model[model_name])
        elif dataset == "Z":
            y_true = np.array(data_z["gen_jet_tau_decaymode"])
            y_pred = np.array(data_z_model[model_name])    
        
        cm = confusion_matrix(y_true, y_pred)
        new_cm_ = aggregate_classes(cm, [2, 3, 4])
        new_cm = aggregate_classes(new_cm_, [4, 5])
        normalized_cm = normalize_confusion_matrix(new_cm)
        
        #compute neccesary values to get precision
        tp_and_fn = normalized_cm.sum(1)
        tp_and_fp = normalized_cm.sum(0)
        tp = normalized_cm.diagonal()
        
        precision = tp / tp_and_fp

        model_precisions.append(precision)
    return model_precisions

# Function to plot the classification precision of the DMs
def plot_dm_prec(dataset):
    
    if dataset == "ZH":
        pr = get_precision("ZH")
    elif dataset == "Z":
        pr = get_precision("Z")

    pr_sdnn, pr_ln, pr_pt = pr[0], pr[1], pr[2]
    
    labels = [r'$h^\pm$', r'$h^\pm+\pi^0$', r'$h^\pm+\geq2\pi^0$', r'$h^\pm h^\mp h^\pm$', r'$h^\pm h^\mp h^\pm$' '\n' r'$+\geq\pi^0$', 'Rare']
    
    # Create a mapping from labels to their positions on the x-axis
    x = range(len(labels))
    
    fig, ax = plt.subplots()
    
    # Plotting the data as points
    ax.scatter(x, pr_sdnn, label='DeepSet', color='#ff5b5b', marker='o', s=100)
    ax.scatter(x, pr_ln, label='LorentzNet', color='#ffc140', marker='o', s=100)
    ax.scatter(x, pr_pt, label='ParticleTransformer', color='#89cded', marker='o', s=100)
    
    ax.set_xlabel('Decay Modes')
    ax.set_ylabel('Precision', x=1.05)
    ax.set_title(f"Classification Precision of the DMs for {dataset}", y=1.05)
    ax.set_xticks(x)
    ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=True)
    ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=True)
    ax.set_xticklabels(labels, rotation=45)
    legend = ax.legend(loc='lower left', shadow=True, fancybox=True, framealpha=1, borderpad=1)
    
    # Vertical lines between the DMs
    for i in range(len(labels)):
        ax.axvline(i - 0.5, color='gray', linestyle='--', linewidth=0.5, zorder=0)

    # Function to add labels to the datapoints
    def annotate_points(x, y, offset):
        for (i, j) in zip(x, y):
            ax.annotate(f'{j:.2f}', (i, j), textcoords="offset points", xytext=offset, ha='center', va='bottom', fontsize=9)

    # Annotate the datapoints with different alignments
    annotate_points(x, pr_sdnn, (-16, -4))
    annotate_points(x, pr_ln, (16, -4))
    annotate_points(x, pr_pt, (-16, -4))

    legend = ax.legend(loc='lower left', shadow=True, fancybox=True, framealpha=1, borderpad=1)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{dataset}_dm_precision.pdf"), bbox_inches='tight', format='pdf')
    plt.close("all")

In [11]:
plot_decay_modes("SimpleDNN")
plot_decay_modes("LorentzNet")
plot_decay_modes("ParticleTransformer")

In [12]:
CM_plot("ZH")
CM_plot("Z")

In [13]:
plot_dm_prec("ZH")
plot_dm_prec("Z")