In [None]:
import os
import json
import hydra
import numpy as np
import enreg.tools.general as g
import mplhep as hep
import awkward as ak
import matplotlib.pyplot as plt
from omegaconf import DictConfig

In [None]:
# hep.style.use(hep.styles.CMS)

In [None]:
data_zh = g.load_all_data(["/scratch/persistent/joosep/ml-tau/20240402_full_stats_merged/zh_test.parquet"])
data_z = g.load_all_data(["/scratch/persistent/joosep/ml-tau/20240402_full_stats_merged/z_test.parquet"])


In [None]:
paths_zh_model = {
    "ParticleTransformer": "/local/joosep/ml-tau-en-reg/results/240517_fullstats/dm_multiclass/ParticleTransformer/zh_test.parquet",
    "LorentzNet": "/local/joosep/ml-tau-en-reg/results/240517_fullstats/dm_multiclass/LorentzNet/zh_test.parquet",
    "SimpleDNN": "/local/joosep/ml-tau-en-reg/results/240517_fullstats/dm_multiclass/SimpleDNN/zh_test.parquet",
}

data_zh_model = {k: g.load_all_data([v])["dm_multiclass"]["pred"] for (k, v) in paths_zh_model.items()}

In [None]:
paths_z_model = {
    "ParticleTransformer": "/local/joosep/ml-tau-en-reg/results/240517_fullstats/dm_multiclass/ParticleTransformer/z_test.parquet",
    "LorentzNet": "/local/joosep/ml-tau-en-reg/results/240517_fullstats/dm_multiclass/LorentzNet/z_test.parquet",
    "SimpleDNN": "/local/joosep/ml-tau-en-reg/results/240517_fullstats/dm_multiclass/SimpleDNN/z_test.parquet",
}

data_z_model = {k: g.load_all_data([v])["dm_multiclass"]["pred"] for (k, v) in paths_z_model.items()}

In [None]:
output_dir = os.path.join("../outputs/plots/")
os.makedirs(output_dir, exist_ok=True)

In [None]:
dms = np.arange(17)
plt.title("Actual decaymodes")
plt.hist(
    data_zh["gen_jet_tau_decaymode"], bins=dms,
    width=0.8
)
plt.yscale('log')
plt.xticks(dms+0.4, dms);
plt.savefig(os.path.join(output_dir, "DM_real"))

In [None]:
dms = np.arange(17)
plt.title("DeepSet Predicted decaymodes")
plt.hist(
    data_zh_model["SimpleDNN"], bins=dms,
    width=0.8
)
plt.yscale('log')
plt.xticks(dms+0.4, dms);
plt.savefig(os.path.join(output_dir, "DM_sdnn"))

In [None]:
actual_zh = data_zh['gen_jet_tau_decaymode']
predicted_zh = data_zh_model["SimpleDNN"]

actual_z = data_z['gen_jet_tau_decaymode']
predicted_z = data_z_model["SimpleDNN"]

DM_labels = ['OneProng0PiZero', 'OneProng1PiZero', 'OneProng2PiZero', 'OneProng3PiZero', 'OneProngNPiZero',
             'ThreeProng0PiZero', 'ThreeProng1PiZero', 'ThreeProng2PiZero', 'RareDecayMode']
# DM_labels = [0, 1, 2, 3, 4, 10, 11, 12, 15]

In [None]:
from sklearn import metrics
import seaborn as sns

In [None]:
def CM_plot(actual, predicted, title, output_path):
    cm = metrics.confusion_matrix(actual, predicted)
    cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    fig, ax = plt.subplots(figsize=(17,15))
    sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=DM_labels, yticklabels=DM_labels)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title(title)
    plt.savefig(output_path)
    plt.close("all")
    plt.show(block=False)

In [None]:
CM_plot(
    actual = actual_zh,
    predicted = predicted_zh,
    title = 'DeepSet DM Confusion matrix ZH',
    output_path = os.path.join(output_dir, "CM_zh.png")
    )

In [None]:
CM_plot(
    actual = actual_z,
    predicted = predicted_z,
    title = 'DeepSet DM Confusion matrix Z',
    output_path = os.path.join(output_dir, "CM_z.png")
    )