# Reference
[uproot documentation](https://uproot.readthedocs.io/en/latest/)

In [None]:
def make_sequential(obj):
    seq_events = np.zeros_like(obj.event)
    seq_events[0] = obj.event[0]
    seq_event = 0
    seq_files = np.zeros_like(obj.event)
    seq_files[0] = obj.event[0]
    seq_file = 0
    for i in range(1, len(obj.event)):
        if obj.event[i] != obj.event[i - 1]:
            seq_event += 1
        if obj.event[i] < obj.event[i - 1]:
            seq_file += 1
        seq_events[i] = seq_event
        seq_files[i] = seq_file
    obj.event = seq_events
    obj.file = seq_files

In [None]:
import uproot, numpy as np

class MCValidation:
    def __init__(self, filename, treename):
        file = uproot.open(filename)
        tree = file[treename]
        self.event = tree['event'].array(library="np") #linear list of events
        self.orig_event = tree["event"].array(library="np") #actual event number
        self.file = np.zeros_like(self.event)
        self.mc_id = tree["mcId"].array(library="np") #number of mc particles in event
        self.mc_pdg = tree["mcPDG"].array(library="np") #pdg code of particles in each event
        self.mc_tier = tree["mcTier"].array(library="np") #Which tier each event is folded back to
        self.mc_nhits = tree["mcNHits"].array(library="np") #number of mc particles hits in event
        self.mc_momentum = tree["mcMomentum"].array(library="np") #momentum of mc particles in event
        is_nu_int = tree["isNuInteration"].array(library="np") #not functioning
        is_cr_int = tree["isCosmicRay"].array(library="np") #not functioning
        is_tb_int = tree["isTestBeam"].array(library="np") #not functioning
        self.environment = np.full(is_nu_int.shape, "??") #not functioning
        self.environment[np.where(is_nu_int)] = "nu" #not functioning
        self.environment[np.where(is_cr_int)] = "tb" #not functioning
        self.environment[np.where(is_tb_int)] = "cr" #not functioning
        self.is_leading_lepton = tree["isLeadingLepton"].array(library="np")
        self.is_michel = tree["isMichel"].array(library="np")
        self.n_matches = tree["nMatches"].array(library="np")
        self.reco_id_list = tree["recoIdVector"].array(library="np")
        self.reco_nhits_list = tree["nRecoHitsVector"].array(library="np")
        self.shared_nhits_list = tree["nSharedHitsVector"].array(library="np")
        self.purity_adc_list = tree["purityAdcVector"].array(library="np")
        self.purity_list = tree["purityVector"].array(library="np")
        self.purity_list_u = tree["purityVectorU"].array(library="np")
        self.purity_list_v = tree["purityVectorV"].array(library="np")
        self.purity_list_w = tree["purityVectorW"].array(library="np")
        self.purity_adc_list_u = tree["purityAdcVectorU"].array(library="np")
        self.purity_adc_list_v = tree["purityAdcVectorV"].array(library="np")
        self.purity_adc_list_w = tree["purityAdcVectorW"].array(library="np")
        self.completeness_list = tree["completenessVector"].array(library="np")
        self.completeness_adc_list = tree["completenessAdcVector"].array(library="np")
        self.completeness_list_u = tree["completenessVectorU"].array(library="np")
        self.completeness_list_v = tree["completenessVectorV"].array(library="np")
        self.completeness_list_w = tree["completenessVectorW"].array(library="np")
        self.completeness_adc_list_u = tree["completenessAdcVectorU"].array(library="np")
        self.completeness_adc_list_v = tree["completenessAdcVectorV"].array(library="np")
        self.completeness_adc_list_w = tree["completenessAdcVectorW"].array(library="np")
        self.pc_metric = self.purity_adc_list * self.completeness_adc_list
        file.close()
        make_sequential(self)

import matplotlib.pyplot as plt

class PlotFormat:
    def __init__(self, title="", xlabel="x", ylabel="y", xlim=None, ylim=None, titlesize=24, labelsize=18, is_logx=False,
                is_semilogx=False, is_logy=False, is_semilogy=False):
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.xlim = xlim
        self.ylim = ylim
        self.titlesize = titlesize
        self.labelsize = labelsize
        self.is_logx = is_logx
        self.is_logy = is_logy
        self.is_semilogx = is_semilogx
        self.is_semilogy = is_semilogy

class Metric:
    def __init__(self, bin_var, n_bins):
        self.low = 0
        self.high = np.max(bin_var)
        self.n_bins = n_bins
        self.x = np.linspace(self.low, self.high, self.n_bins + 1)
        self.y = np.zeros_like(self.x)
        self.y_err_low = np.zeros_like(self.y)
        self.y_err_high = np.zeros_like(self.y)
        self.y_errs = np.array(list(zip(self.y_err_low, self.y_err_high))).T
        
    def rebin(self, bin_width):
        self.y = self.y[::bin_width]
        self.x = self.x[::bin_width]
        self.y_err_low = self.y_err_low[::bin_width]
        self.y_err_high = self.y_err_high[::bin_width]
        self.y_errs = np.array(list(zip(self.y_err_low, self.y_err_high))).T
    
    def suppress_empty(self, empty_val=-1):
        selection_idx = np.where(self.y != empty_val)
        self.y = self.y[selection_idx]
        self.x = self.x[selection_idx]
        self.y_err_low = self.y_err_low[selection_idx]
        self.y_err_high = self.y_err_high[selection_idx]
        self.y_errs = np.array(list(zip(self.y_err_high, self.y_err_low))).T

In [None]:
def get_efficiency(efficiency, bin_var, n_bins):
    metric = Metric(bin_var, n_bins)
    
    bins = np.linspace(metric.low, metric.high, n_bins + 1)
    for i in range(n_bins):
        if i > (n_bins - 1):
            idx = np.where((bin_var >= bins[i]) & (bin_var < (bins[i + 1])))
        else:
            idx = np.where((bin_var >= bins[i]) & (bin_var <= (bins[i + 1])))
        if len(idx[0]) > 0:
            selection = efficiency[idx]
            metric.y[i] = np.mean(selection)
            metric.y_err_low[i] = metric.y_err_high[i] = np.sqrt(metric.y[i] * (1 - metric.y[i]) / len(selection))
        else:
            metric.y[i] = -1

    #metric.rebin(bin_width)
    metric.suppress_empty()
    
    return metric

def get_purity_or_completeness(p_or_c, bin_var, n_bins):
    metric = Metric(bin_var, n_bins)
    
    bins = np.linspace(metric.low, metric.high, n_bins + 1)
    for i in range(n_bins):
        if i > (n_bins - 1):
            idx = np.where((bin_var >= bins[i]) & (bin_var < (bins[i + 1])))
        else:
            idx = np.where((bin_var >= bins[i]) & (bin_var <= (bins[i + 1])))
        if len(idx[0]) > 0:
            selection = p_or_c[idx]
            metric.y[i] = np.median(selection)
            metric.y_err_low[i], metric.y_err_high[i] = np.quantile(selection, [0.1, 0.9])
        else:
            metric.y[i] = -1

    #metric.rebin(bin_width)
    metric.suppress_empty()
    
    return metric

def get_purity(p, bin_var, n_bins):
    return get_purity_or_completeness(p, bin_var, n_bins)

def get_completeness(c, bin_var, n_bins):
    return get_purity_or_completeness(c, bin_var, n_bins)

def get_purity_or_completeness_fraction(p_or_c, metric):
    bins = np.linspace(0, 1, 21)
    weights = np.ones_like(p_or_c) / len(p_or_c)
    h, _ = np.histogram(p_or_c, bins=bins)
    idx = np.where(h > 0)
    e = np.array([ np.sqrt(val) / val if val > 0 else 0 for val in h ])
    y, _ = np.histogram(p_or_c, bins=bins, weights=weights)
    x = (bins[1:] + bins[:-1]) / 2
    
    y = y[idx]
    x = x[idx]
    e = e[idx]
    
    plt.figure(figsize=(20, 15))
    plt.errorbar(x, y, xerr=0.05 / 2, yerr = e, fmt='o', markersize = '2', color='r', elinewidth=0.5)
    
    plt.xlabel(metric.lower())
    plt.ylabel("fraction")
    plt.title(metric.title())
    plt.xlim(0,1.01)
    plt.ylim(0,1.01)
    plt.rcParams["figure.figsize"] = (13,8)
    plt.show()

def get_purity_fraction(purity):
    get_purity_or_completeness_fraction(purity, "purity")
    
def get_completeness_fraction(completeness):
    get_purity_or_completeness_fraction(completeness, "completeness")
    
def plot_metric(x, y, err, plot_format, metric="completeness"):
    plt.figure()
    plt.rcParams["figure.figsize"] = (20, 15)
    fig, ax = plt.subplots()
    
    ax.set_title(plot_format.title, fontsize=plot_format.titlesize)
    ax.tick_params(axis='x', labelsize=plot_format.labelsize)
    ax.tick_params(axis='y', labelsize=plot_format.labelsize)
    ax.set_xlabel(plot_format.xlabel, fontsize=plot_format.titlesize)
    ax.set_ylabel(plot_format.ylabel, fontsize=plot_format.titlesize)
    if plot_format.is_logx:
        plt.logx()
    elif plot_format.is_semilogx:
        plt.semilogx()
    if plot_format.is_logy:
        plt.logy()
    elif plot_format.is_semilogy:
        plt.semilogy()
    if not plot_format.xlim is None:
        plt.xlim(plot_format.xlim[0], plot_format.xlim[1])
    if not plot_format.ylim is None:
        plt.ylim(plot_format.ylim[0], plot_format.ylim[1])
    
    if metric.lower() in ['purity', 'completeness']:
        plt.plot(x, y, color='r', label = "median")
        plt.plot(x, err[0], color='g', label = "90th")
        plt.plot(x, err[1], color='b', label = "10th")
        plt.legend(fontsize=plot_format.titlesize)
    elif metric.lower() == 'efficiency':
        bin_width = (x[1] - x[0])
        plt.errorbar(x, y, xerr=bin_width / 2, yerr = err, fmt='o', markersize = '2', color='r', elinewidth=0.5)

    plt.show()

In [None]:
import os

def save_plot(fig, filename, subdir=None):
    if subdir is None:
        subdir = ""
    elif subdir.startswith("/"):
        subdir = subdir[1:]
        
    if not os.path.exists('images'):
        os.mkdir('images')
    for img_type in [ "png", "svg", "eps" ]:
        if not os.path.exists(f'images/{img_type}'):
            os.mkdir(f'images/{img_type}')
        if not os.path.exists(f'images/{img_type}/{subdir}'):
            os.mkdir(f'images/{img_type}/{subdir}')
        fig.savefig(f'images/{img_type}/{subdir}/{filename}.{img_type}', dpi=200)

# Reading data

In [None]:
validation = MCValidation("validation.root", "mc")

In [None]:
pdg = 11

In [None]:
idx = np.where(abs(validation.mc_pdg) == pdg)
# temp to test
#idx = np.where((abs(validation.mc_pdg) == pdg) & ((validation.mc_nhits >= 14) & (validation.mc_nhits <= 15)))

completeness = validation.completeness_list[idx]
purity = validation.purity_list[idx]
mc_nhits = validation.mc_nhits[idx]
n_matches = validation.n_matches[idx]
mc_mom = validation.mc_momentum[idx]

index_array = [ np.argmax(val) if len(val) > 0 else -1 for val in completeness ] 
completeness_unique_array = np.array([ np.max(val) if len(val) > 0 else 0 for val in completeness ])
purity_unique_array = np.array([ purity[i][index_array[i]] if index_array[i] != -1 else 0 for i in range(len(purity)) ])

In [None]:
plot_format = PlotFormat(title="", xlabel="purity", ylabel="fraction", xlim=None, ylim=None, titlesize=24, labelsize=18,
                         is_logx=False, is_semilogx=False, is_logy=False, is_semilogy=False)
get_purity_fraction(purity_unique_array)
get_completeness_fraction(completeness_unique_array)

In [None]:
n_bins = 3000
plot_format = PlotFormat(title="Completeness", xlabel="num true hits", ylabel="completeness", xlim=(1, 12000), ylim=(0, 1.01),
                         is_semilogx=True)
metric = get_completeness(completeness_unique_array, mc_nhits, n_bins)
plot_metric(metric.x, metric.y, metric.y_errs, plot_format, "completeness")

n_bins = 60
plot_format = PlotFormat(title="Completeness", xlabel="true momentum (GeV)", ylabel="completeness", xlim=(0, 30), ylim=(0, 1.01))
metric = get_completeness(completeness_unique_array, mc_mom, n_bins)
plot_metric(metric.x, metric.y, metric.y_errs, plot_format, "completeness")

In [None]:
n_bins = 3000
plot_format = PlotFormat(title="Purity", xlabel="num true hits", ylabel="purity", xlim=(1, 12000), ylim=(0, 1.01),
                         is_semilogx=True)
metric = get_purity(completeness_unique_array, mc_nhits, n_bins)
plot_metric(metric.x, metric.y, metric.y_errs, plot_format, "purity")

n_bins = 60
plot_format = PlotFormat(title="Purity", xlabel="true momentum (GeV)", ylabel="purity", xlim=(0, 30), ylim=(0, 1.01))
metric = get_purity(completeness_unique_array, mc_mom, n_bins)
plot_metric(metric.x, metric.y, metric.y_errs, plot_format, "purity")

In [None]:
efficiency_unique_array = np.array((completeness_unique_array >= 0.5) & (purity_unique_array >= 0.5))

In [None]:
n_bins = 1000
plot_format = PlotFormat(title="Efficiency", xlabel="num true hits", ylabel="purity", xlim=(1, 12000), ylim=(0, 1.01),
                         is_semilogx=True)
metric = get_efficiency(completeness_unique_array, mc_nhits, n_bins)
plot_metric(metric.x, metric.y, metric.y_errs, plot_format, "efficiency")

n_bins = 60
plot_format = PlotFormat(title="Efficiency", xlabel="true momentum (GeV)", ylabel="purity", xlim=(0, 30), ylim=(0, 1.01))
metric = get_efficiency(completeness_unique_array, mc_mom, n_bins)
plot_metric(metric.x, metric.y, metric.y_errs, plot_format, "efficiency")