# Rejecting High Efficiency Muons

In this notebook, I want to reject the high efficiency Muons and re-plot the ROC curve for e vs gamma for the 4-class models. 

Here is an outline for how to do this:
1. Determine the thr, tpr, and fpr for Electron/Gamma vs Muon case
2. Find the index in the tpr list which corresponds to 98% efficiency 
3. Use the index from the tpr list to find the threshold in the thr list 
4. Remove events which have a Electron/Gamma score below that threshold 
5. Recalculate the tpr, thr, and fpr for the Electron vs Gamma case 
6. Determine the adjustment constant 
7. Multiply the tpr by the adjustment constant 
8. Re-plot the Electron vs Gamma ROC curve 

In [None]:
%load_ext autoreload
%matplotlib inline
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

In [None]:
sys.path.append("..")

from WatChMaL.analysis.multi_plot_utils import multi_disp_learn_hist, multi_compute_roc, multi_plot_roc
from WatChMaL.analysis.comparison_utils import multi_get_masked_data, multi_collapse_test_output
from WatChMaL.analysis.plot_utils import plot_classifier_response, plot_reduced_classifier_response

In [None]:
label_dict         = {"$\gamma$":0, "$e$":1, "$\mu$":2, '$\pi^0$':3}
inverse_label_dict = {0:"$\gamma$", 1:"$e$", 2:"$\mu$", 3:'$\pi^0$'}

muon_softmax_index_dict = {"non-mu":0, "mu":1}

c = plt.rcParams['axes.prop_cycle'].by_key()['color']

##  Load 2-class models and the associated data

In [None]:
locs_2_class = ['/home/hlahiouel/WatChMaL/outputs/2021-03-10/12-07-03/outputs', '/home/hlahiouel/WatChMaL/outputs/2021-03-10/12-14-45/outputs'] 

titles_2_class = ['2 Class - Barrel Fix','2 Class - Extra Indices - Barrel Fix']

linecolor_2_class = [c[0] for _ in locs_2_class]
linestyle_2_class = ['-' for _ in locs_2_class]

raw_output_softmax_2_class = [np.load(loc + "/softmax.npy") for loc in locs_2_class]

raw_actual_labels_2_class  = [np.load(loc + "/labels.npy") for loc in locs_2_class]

raw_actual_indices_2_class = [np.load(loc + "/indices.npy") for loc in locs_2_class]

## Load 3-class model and the associated data

In [None]:
locs_3_class = ['/home/hlahiouel/WatChMaL/outputs/2021-03-10/12-23-38/outputs'] 

titles_3_class = ['3 Class - Barrel Fix']

linecolor_3_class = [c[1] for _ in locs_3_class]
linestyle_3_class = ['-' for _ in locs_3_class]

raw_output_softmax_3_class = [np.load(loc + "/softmax.npy") for loc in locs_3_class]

raw_actual_labels_3_class  = [np.load(loc + "/labels.npy") for loc in locs_3_class]

raw_actual_indices_3_class = [np.load(loc + "/indices.npy") for loc in locs_3_class]

## Load 4-class model and the associated data

In [None]:
locs_4_updated_class = ['/home/hlahiouel/WatChMaL/outputs/2021-03-01/12-52-40/outputs']

titles_4_updated_class = ['4 Class Run 1 - Barrel Fix ']

linecolor_4_updated_class = [c[2] for _ in locs_4_updated_class]
linestyle_4_updated_class = ['-' for _ in locs_4_updated_class]

raw_output_softmax_4_updated_class = [np.load(loc + "/softmax.npy") for loc in locs_4_updated_class]

raw_actual_labels_4_updated_class  = [np.load(loc + "/labels.npy") for loc in locs_4_updated_class]

print(raw_actual_labels_4_updated_class)

raw_actual_indices_4_updated_class = [np.load(loc + "/indices.npy") for loc in locs_4_updated_class]

In [None]:
titles = titles_2_class + titles_4_updated_class

linecolor = linecolor_2_class + linecolor_4_updated_class
linestyle = linestyle_2_class + linestyle_4_updated_class

softmaxes = raw_output_softmax_2_class + raw_output_softmax_4_updated_class
labels    = raw_actual_labels_2_class + raw_actual_labels_4_updated_class

In [None]:
for q in range(len(softmaxes)):
    softmax = softmaxes[q]
    print(softmax.shape)

## Find fpr, tpr, and thr for Electron/Gamma vs Muon 

In [None]:
four_class_softmax = [softmaxes[2]] # Softmaxes for the 4-class model

In [None]:
four_class_labels = [labels[2]] # Labels for the 4-class model 

In [None]:
collapsed_class_scores_list, collapsed_class_labels_list = multi_collapse_test_output(four_class_softmax, four_class_labels, label_dict, ignore_type='$\gamma$',threshold="$\mu$")

collapsed_class_labels_list = [collapsed_class_labels - 1 for collapsed_class_labels in collapsed_class_labels_list]

collapsed_class_scores_list = [collapsed_class_scores[:, 1:] for collapsed_class_scores in collapsed_class_scores_list]

In [None]:
print(collapsed_class_scores_list[0].shape)

In [None]:
# Computes the ROC curve for the e/gamma vs muon case 

fprs, tprs, thrs = multi_compute_roc(collapsed_class_scores_list, collapsed_class_labels_list, 
                                                         true_label=muon_softmax_index_dict["non-mu"], 
                                                         false_label=muon_softmax_index_dict["mu"],
                                                         normalize=True)

In [None]:
four_class_tpr = tprs[0] # Loads the true-positive rate (aka. efficiency) for the e/gamma vs muon 

In [None]:
four_class_thr = thrs[0] # Loads the thresholds for the e/gamma vs muon

## Find the index corresponding to the efficiency we want 

In [None]:
efficiency = 0.98 # The efficiency we want 
tolerance = 0.00001 # How much error we are willing to tolerate 

# This loop searches through the tpr to find the efficiency closest to the one we want

for i in range(len(four_class_tpr)): 
    if (four_class_tpr[i] >= (efficiency-tolerance)) and (four_class_tpr[i] <= efficiency+tolerance):
        print(four_class_tpr[i],i)

In [None]:
threshold_index = 16434 # Index corresponding to the efficiency we want 

In [None]:
threshold = four_class_thr[threshold_index] # Finds the threshold we need in the threshold data 

In [None]:
print(threshold)

## Remove events which are below the threshold 

In [None]:
four_class_softmax = softmaxes[2] # Softmaxes for the 4-class model

In [None]:
four_class_labels = labels[2] # Labels for the 4-class model 

In [None]:
print(four_class_labels)

In [None]:
print(np.where(four_class_labels == 1)[0])

In [None]:
print(np.where(four_class_labels == 0)[0])

In [None]:
idxs = [] # List of indices 

# This loop finds the indices of the softmax data which have a e+gamma score below the threshold  

for j in range(len(four_class_softmax)):
    norm_factor = four_class_softmax[j,0]+four_class_softmax[j,1]+four_class_softmax[j,2]
    normalized_probability = (four_class_softmax[j,0]+four_class_softmax[j,1])/norm_factor
    if normalized_probability <= threshold:
        idxs.append(j)
        
idxs = np.array(idxs)

In [None]:
print(len(idxs))

In [None]:
# This cell keeps track of the kind of events that are removed 

event_list = []

for w in idxs:
    event_list.append(four_class_labels[w])
    
event_list = np.array(event_list)

gamma_events = []

e_events = []

muon_events = []

pion_events= []

for r in event_list:
    if r == 0:
        gamma_events.append(r)
    if r == 1:
        e_events.append(r)
    if r == 2:
        muon_events.append(r)
    if r == 3:
        pion_events.append(r)

print("Removed gamma events =", len(gamma_events))
print("Removed electron events =", len(e_events))
print("Removed muon events =", len(muon_events))
print("Removed pion events =", len(pion_events))

## Filter out Electron and Gamma events from 2-class model 

In [None]:
last_gamma_event = np.where(four_class_labels == 0)[0][-1]

In [None]:
print(last_gamma_event)

In [None]:
two_class_softmax = softmaxes[0]

In [None]:
two_class_idxs = []

for e in range(len(idxs)):
    if idxs[e] <= len(two_class_softmax):
        two_class_idxs.append(idxs[e])
 
two_class_idxs = np.array(two_class_idxs)

In [None]:
print(len(two_class_idxs))

In [None]:
new_two_class_softmax = np.delete(two_class_softmax,two_class_idxs,axis=0)

In [None]:
two_class_labels = labels[0]

In [None]:
new_two_class_labels = np.delete(two_class_labels,two_class_idxs)

In [None]:
two_class_extra_softmax = softmaxes[1]

In [None]:
new_two_class_extra_softmax = np.delete(two_class_extra_softmax,two_class_idxs,axis=0)

In [None]:
two_class_extra_labels = labels[1]

In [None]:
new_two_class_extra_labels = np.delete(two_class_extra_labels,two_class_idxs)

## Create new softmax data and labels 

In [None]:
# This cell deletes events from the softmax and labels which have a e+gamma score below the threshold 

new_softmax = np.delete(four_class_softmax,idxs,axis=0)

new_labels = np.delete(four_class_labels,idxs)

In [None]:
print(new_softmax.shape)
print(new_labels.shape)

In [None]:
softmaxes += [new_softmax, new_two_class_softmax, new_two_class_extra_softmax]
labels += [new_labels, new_two_class_labels, new_two_class_extra_labels]
titles += ["4-Class - Rejected Muons with 98% efficiency threshold", "2-class - Removed Events", "2-class - Extra Data - Removed Events"]
linecolor += [c[3],c[4],c[5]]
linestyle += ['-','-','-']

In [None]:
softmaxes = [new_softmax, new_two_class_softmax, new_two_class_extra_softmax]
labels = [new_labels, new_two_class_labels, new_two_class_extra_labels]
titles = ["4-Class - Rejected Muons with 98% efficiency threshold", "2-class - Removed Events", "2-class - Extra Data - Removed Events"]
linecolor = [c[3],c[4],c[5]]
linestyle = ['-','-','-']

## Determine adjustment factor

In [None]:
num_electron_events_subset = np.where(new_labels == 1)

In [None]:
num_electron_events_subset = len(num_electron_events_subset[0]) # Number of electron events in the new subset 

In [None]:
print(num_electron_events_subset)

In [None]:
num_electron_events_overall = np.where(four_class_labels == 1)

In [None]:
num_electron_events_overall = len(num_electron_events_overall[0]) # Number of electron events before removal  

In [None]:
print(num_electron_events_overall)

In [None]:
num_gamma_events_subset = np.where(new_labels == 0)

In [None]:
num_gamma_events_subset = len(num_gamma_events_subset[0])

In [None]:
print(num_gamma_events_subset)

In [None]:
num_gamma_events_overall = np.where(four_class_labels == 0)

In [None]:
num_gamma_events_overall = len(num_gamma_events_overall[0])

In [None]:
print(num_gamma_events_overall)

In [None]:
# Used to adjust the tpr for the 4-class model with the removed events 
    
tpr_adjustment_constant = num_electron_events_subset / num_electron_events_overall

In [None]:
print(tpr_adjustment_constant)

In [None]:
# Used to adjust the fpr for the 4-class model with the removed events 

fpr_adjustment_constant = num_gamma_events_subset / num_gamma_events_overall

In [None]:
print(fpr_adjustment_constant)

## Modify the tpr and fpr 

In [None]:
# Calculates the ROC curve for e vs gamma based on the updated softmaxes and labels 

fprs, tprs, thrs = multi_compute_roc(softmaxes, labels, 
                                                     true_label=label_dict["$e$"], 
                                                     false_label=label_dict["$\gamma$"],
                                                     normalize=True)

In [None]:
for i,g in enumerate(softmaxes):
    if i >= 3:
        tprs[i] = tprs[i] * tpr_adjustment_constant
        fprs[i] = fprs[i] * fpr_adjustment_constant

In [None]:
for i,g in enumerate(softmaxes):
    tprs[i] = tprs[i] * tpr_adjustment_constant
    fprs[i] = fprs[i] * fpr_adjustment_constant

In [None]:
linecolor[1] = c[3]

In [None]:
linecolor[5] = c[6]

## Plot ROC e vs gamma with Normalization 

In [None]:
figs = multi_plot_roc(fprs, tprs, thrs, "$e$", "$\gamma$", 
                      fig_list=[1], ylims=[[0,3e6]], 
                      linestyles=linestyle,linecolors=linecolor, 
                      plot_labels=titles, show=False)

In [None]:
figs = multi_plot_roc(fprs, tprs, thrs, "$e$", "$\gamma$", 
                      fig_list=[1], 
                      xlims=[[0.8,1.0]],ylims=[[1e0,1.8e0]],
                      linestyles=linestyle,linecolors=linecolor, 
                      plot_labels=titles, show=False)