# Short Tank 4-Class Results Analysis - Hichem Lahiouel

In this notebook, I am analyzing the training results that have been done on 4 classes, namely e's, gammas, pions, and muons. 

In [None]:
%load_ext autoreload
%matplotlib inline
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

In [None]:
sys.path.append("..")

from WatChMaL.analysis.multi_plot_utils import multi_disp_learn_hist, multi_compute_roc, multi_plot_roc
from WatChMaL.analysis.comparison_utils import multi_get_masked_data, multi_collapse_test_output
from WatChMaL.analysis.plot_utils import plot_confusion_matrix

In [None]:
############# define plotting params #############

c = plt.rcParams['axes.prop_cycle'].by_key()['color']
l = ['solid','dashdot','dashed','dotted',':','-.']
label_dict = {"$\gamma$":0, "$e$":1, "$\mu$":2, "$\pi^0$":3}
inverse_label_dict = {0:"$\gamma$", 1:"$e$", 2:"$\mu$", 3:"$\pi^0$"}

In [None]:
############# define run locations #############

short_locs = [
              '/home/hlahiouel/WatChMaL/outputs/2021-02-11/08-36-41/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-04/12-47-05/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-05/07-22-59/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-07/20-01-02/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-07/20-04-38/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-10/14-54-54/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-10/15-07-17/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-10/15-20-39/outputs',
              '/home/hlahiouel/WatChMaL/outputs/2021-02-10/15-33-06/outputs',
             ]

short_titles = [
                'Short Tank - 3 Class - OD Veto',
                'Short Tank - 4 Class - No Veto',
                'Short Tank - 4 Class - No Veto',
                'Short Tank - 4 Class - No Veto',
                'Short Tank - 4 Class - No Veto',
                'Short Tank - 4 Class - OD Veto',
                'Short Tank - 4 Class - OD Veto',
                'Short Tank - 4 Class - OD Veto',
                'Short Tank - 4 Class - OD Veto',
               ]

print(len(short_titles))

short_linecolor = [c[3],c[4],c[4],c[4],c[4],c[5],c[5],c[5],c[5]]

print(len(short_linecolor))

short_linestyle = [l[3],l[4],l[4],l[4],l[4],l[5],l[5],l[5],l[5]]

print(len(short_linestyle))

In [None]:
############# load short tank run data #############

short_raw_output_softmax = [np.load(loc + "/softmax.npy") for loc in short_locs]

In [None]:
#print(short_raw_output_softmax)

In [None]:
#short_raw_output_softmax = np.array(short_raw_output_softmax)

In [None]:
#norm_sum = short_raw_output_softmax[:,0] + short_raw_output_softmax[:,1]

#new_softmax = short_raw_output_softmax[:,0:2] / norm_sum[:,None]

In [None]:
short_raw_actual_labels    = [np.load(loc + "/labels.npy") for loc in short_locs]
short_raw_predictions    = [np.load(loc + "/predictions.npy") for loc in short_locs]

In [None]:
############# compute short tank multi e/gamma ROC #############

short_fprs, short_tprs, short_thrs = multi_compute_roc(short_raw_output_softmax, short_raw_actual_labels, 
                                     true_label=label_dict["$e$"], 
                                     false_label=label_dict["$\gamma$"])

In [None]:
############# combine short results #############

fprs =  short_fprs 
tprs = short_tprs 
thrs = short_thrs

In [None]:
figs = multi_plot_roc(fprs, tprs, thrs, "$e$", "$\gamma$", 
                      fig_list=[1], ylims=[[0,3e6]], 
                      linestyles=short_linestyle, linecolors=short_linecolor, 
                      plot_labels=short_titles, show=False)

In [None]:
############# compute short multi e/mu ROC #############

short_fprs, short_tprs, short_thrs = multi_compute_roc(short_raw_output_softmax, short_raw_actual_labels, 
                                     true_label=label_dict["$e$"], 
                                     false_label=label_dict["$\mu$"])

In [None]:
############# combine short results #############

fprs =  short_fprs
tprs = short_tprs
thrs = short_thrs

In [None]:
figs = multi_plot_roc(fprs, tprs, thrs, "$e$", "$\mu$", fig_list=[1], 
                      linestyles=short_linestyle,linecolors=short_linecolor, plot_labels=short_titles,show=False)

In [None]:
muon_softmax_index_dict = {"non-mu":0, "mu":1}

In [None]:
############# compute short collapsed ROC #############

short_collapsed_class_scores_list, short_collapsed_class_labels_list = multi_collapse_test_output(short_raw_output_softmax, short_raw_actual_labels, label_dict, ignore_type='$\gamma$', ignore_type2='$\pi^0$')

short_collapsed_class_labels_list = [collapsed_class_labels - 1 for collapsed_class_labels in short_collapsed_class_labels_list]
short_collapsed_class_scores_list = [collapsed_class_scores[:,1:] for collapsed_class_scores in short_collapsed_class_scores_list]

In [None]:
short_fpr_list, short_tpr_list, short_thr_list = multi_compute_roc(short_collapsed_class_scores_list, short_collapsed_class_labels_list, 
                            true_label=muon_softmax_index_dict["non-mu"], 
                            false_label=muon_softmax_index_dict["mu"])

In [None]:
############# combine short results #############

fpr_list = short_fpr_list
tpr_list = short_tpr_list
thr_list = short_thr_list

In [None]:
# xlims=[[0.0,1.0]]

figs = multi_plot_roc(fpr_list, tpr_list, thr_list,"non-mu", "mu", 
                      fig_list=[1], ylims=[[0,3e6]], 
                      linestyles=short_linestyle,linecolors=short_linecolor, plot_labels=short_titles, show=False)