# Rejecting High Efficiency Muons 

In this notebook, I want to reject the high efficiency Muons and re-plot the ROC curve for e vs gamma for the 3-class case. 

In [None]:
%load_ext autoreload
%matplotlib inline
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

In [None]:
sys.path.append("..")

from WatChMaL.analysis.multi_plot_utils import multi_disp_learn_hist, multi_compute_roc, multi_plot_roc
from WatChMaL.analysis.comparison_utils import multi_get_masked_data, multi_collapse_test_output
from WatChMaL.analysis.plot_utils import plot_classifier_response, plot_reduced_classifier_response

In [None]:
label_dict         = {"$e$":0, "$\mu$":1}
inverse_label_dict = {0:"$e$", 1:"$\mu$"}

#muon_softmax_index_dict = {"non-mu":0, "mu":1}

c = plt.rcParams['axes.prop_cycle'].by_key()['color']

##  Load 2-class models and the associated data

In [None]:
locs_2_class = ['/home/hlahiouel/WatChMaL/outputs/2021-03-18/06-50-14/outputs'] 

titles_2_class = ['2 Class Model - Electrons and Muons']

linecolor_2_class = [c[0] for _ in locs_2_class]
linestyle_2_class = ['-' for _ in locs_2_class]

raw_output_softmax_2_class = [np.load(loc + "/softmax.npy") for loc in locs_2_class]

raw_actual_labels_2_class  = [np.load(loc + "/labels.npy") for loc in locs_2_class]

raw_actual_indices_2_class = [np.load(loc + "/indices.npy") for loc in locs_2_class]

In [None]:
two_fprs, two_tprs, two_thrs = multi_compute_roc(raw_output_softmax_2_class, raw_actual_labels_2_class, 
                                                     true_label=label_dict["$e$"], 
                                                     false_label=label_dict["$\mu$"],
                                                     normalize=True)

In [None]:
figs = multi_plot_roc(two_fprs, two_tprs, two_thrs, "$e$", "$\mu$", 
                      fig_list=[1], ylims=[[0,3e6]], 
                      linestyles=linestyle_2_class,linecolors=linecolor_2_class, 
                      plot_labels=titles_2_class, show=True)

In [None]:
figs = multi_plot_roc(two_fprs, two_tprs, two_thrs, "$e$", "$\mu$", 
                      fig_list=[1], 
                      xlims=[[0.6,1.0]],ylims=[[1e0,1e6]],
                      linestyles=linestyle_2_class,linecolors=linecolor_2_class, 
                      plot_labels=titles_2_class, show=True)

In [None]:
label_dict         = {"$\gamma$":0, "$e$":1, "$\mu$":2, '$\pi^0$':3}
inverse_label_dict = {0:"$\gamma$", 1:"$e$", 2:"$\mu$", 3:'$\pi^0$'}

muon_softmax_index_dict = {"non-mu":0, "mu":1}

c = plt.rcParams['axes.prop_cycle'].by_key()['color']

## Load 3-class model and the associated data

In [None]:
locs_3_class = ['/home/hlahiouel/WatChMaL/outputs/2021-03-10/12-23-38/outputs'] 

titles_3_class = ['3 Class - Barrel Fix']

linecolor_3_class = [c[1] for _ in locs_3_class]
linestyle_3_class = ['-' for _ in locs_3_class]

raw_output_softmax_3_class = [np.load(loc + "/softmax.npy") for loc in locs_3_class]

raw_actual_labels_3_class  = [np.load(loc + "/labels.npy") for loc in locs_3_class]

raw_actual_indices_3_class = [np.load(loc + "/indices.npy") for loc in locs_3_class]

## Load 4-class model and the associated data

In [None]:
locs_4_class = ['/home/hlahiouel/WatChMaL/outputs/2021-03-01/12-52-40/outputs']

titles_4_class = ['4 Class Run 1 - Barrel Fix ']

linecolor_4_class = [c[2] for _ in locs_4_class]
linestyle_4_class = ['-' for _ in locs_4_class]

raw_output_softmax_4_class = [np.load(loc + "/softmax.npy") for loc in locs_4_class]

raw_actual_labels_4_class  = [np.load(loc + "/labels.npy") for loc in locs_4_class]

raw_actual_indices_4_class = [np.load(loc + "/indices.npy") for loc in locs_4_class]

In [None]:
titles = titles_3_class + titles_4_class

linecolor = linecolor_3_class + linecolor_4_class
linestyle = linestyle_3_class + linestyle_4_class

softmaxes = raw_output_softmax_3_class + raw_output_softmax_4_class
labels    = raw_actual_labels_3_class + raw_actual_labels_4_class

In [None]:
for q in range(len(softmaxes)):
    softmax = softmaxes[q]
    print(softmax.shape)

## Plot ROC e vs gamma with Normalization 

In [None]:
fprs, tprs, thrs = multi_compute_roc(softmaxes, labels, 
                                                     true_label=label_dict["$e$"], 
                                                     false_label=label_dict["$\mu$"],
                                                     normalize=True)

In [None]:
fprs += two_fprs
tprs += two_tprs
thrs += two_thrs
linestyle += linestyle_2_class
linecolor += linecolor_2_class
titles += titles_2_class

In [None]:
figs = multi_plot_roc(fprs, tprs, thrs, "$e$", "$\mu$", 
                      fig_list=[1], ylims=[[0,3e6]], 
                      linestyles=linestyle,linecolors=linecolor, 
                      plot_labels=titles, show=True)

In [None]:
figs = multi_plot_roc(fprs, tprs, thrs, "$e$", "$\mu$", 
                      fig_list=[1], 
                      xlims=[[0.5,1.0]],ylims=[[1e0,1e6]],
                      linestyles=linestyle,linecolors=linecolor, 
                      plot_labels=titles, show=False)