In [None]:
import uproot
import numpy as np
import matplotlib.pyplot as plt
import awkward as ak
import selection
from helpers import *
%load_ext autoreload

In [None]:
branches = [
    "slc.vertex.x",
    "slc.vertex.y",
    "slc.vertex.z",
    "slc.truth.index",
    "slc.tmatch.pur",
    "slc.tmatch.eff",
    "slc.truth.iscc",
    "slc.truth.pdg",
    "slc.truth.E",
    "slc.truth.position.x",
    "slc.truth.position.y",
    "slc.truth.position.z",
    "slc.self",
    "slc.nu_score",
    "slc.fmatch.score",
    "slc.fmatch.time",
    "slc.fmatch_a.score",
    "slc.is_clear_cosmic",
    "reco.ntrk",
    "reco.nshw",
    
    "slc.reco.trk.len",
    "slc.reco.trk.costh",
    "slc.reco.trk.phi",
    "slc.reco.trk.start.x",
    "slc.reco.trk.start.y",
    "slc.reco.trk.start.z",
    "slc.reco.trk.end.x",
    "slc.reco.trk.end.y",
    "slc.reco.trk.end.z",
    "slc.reco.trk.ID",
    "slc.reco.trk.slcID",
    "slc.reco.trk.parent_is_primary",
    "slc.reco.trk.chi2pid0.chi2_muon",
    "slc.reco.trk.chi2pid0.chi2_proton",
    "slc.reco.trk.chi2pid1.chi2_muon",
    "slc.reco.trk.chi2pid1.chi2_proton",
    "slc.reco.trk.chi2pid2.chi2_muon",
    "slc.reco.trk.chi2pid2.chi2_proton",
    "slc.reco.trk.bestplane",
    "slc.reco.trk.mcsP.fwdP_muon",
    "slc.reco.trk.rangeP.p_muon",
    "slc.reco.trk.rangeP.p_proton",
    "slc.reco.trk.crthit.hit.time",
    "slc.reco.trk.crthit.distance",
    "slc.reco.trk.crttrack.angle",
    "slc.reco.trk.crttrack.time",
    "slc.reco.trk.truth.bestmatch.G4ID",
    "slc.reco.trk.truth.bestmatch.energy",
    "slc.reco.trk.truth.p.pdg",
    "slc.reco.trk.truth.p.planeVisE",
    "slc.reco.trk.truth.p.gen.x",
    "slc.reco.trk.truth.p.gen.y",
    "slc.reco.trk.truth.p.gen.z",
    "slc.reco.trk.truth.p.end.x",
    "slc.reco.trk.truth.p.end.y",
    "slc.reco.trk.truth.p.end.z",
    "slc.reco.trk.truth.p.length",
    "slc.reco.trk.truth.p.contained",
    "slc.reco.trk.truth.p.end_process",
    "slc.reco.trk.truth.total_deposited_energy",
    "mc.nnu",
    "mc.nu.E",
    "mc.nu.pdg",
    "mc.nu.iscc",
    "mc.nu.index",
    "mc.nu.position.x",
    "mc.nu.position.y",
    "mc.nu.position.z",
    "mc.nu.prim.genE",
    "mc.nu.prim.genp.x",
    "mc.nu.prim.genp.y",
    "mc.nu.prim.genp.z",
    "mc.nu.prim.pdg",
    "mc.nu.prim.contained",
    "mc.nu.prim.length",
    
    "true_particles.G4ID",
    "true_particles.pdg",
    
    "pass_flashtrig",
    "ncrt_hits",
    "crt_hits.time",
    
    "nslc",
    "slc.reco.ntrk",

    "hdr.pot",
    "hdr.evt",
    "hdr.fno",
    "hdr.subrun",
    "hdr.run",
    "hdr.ngenevt"
]

In [None]:
treenames = [
    "rec",
    "rec.slc",
    "rec.slc.reco.trk",
    "rec.true_particles",
    "rec.crt_hits",
    "rec.mc.nu",
    "rec.mc.nu.prim",
    "rec.reco.trk",
    "rec.reco.trk.chi2pid",
    "rec.reco.trk.truth.matches",
]

In [None]:
do_save = True
savedir = "./plots/"

In [None]:
fname = "sbnd-overlay.flat.root"
rootf = uproot.open(fname)

data = {}
for b in branches:
    keyname = "rec." + b # ".".join(b.split(".")[1:])
    for t in treenames:
        if not keyname.startswith(t):
            continue
        try:
            d = rootf["recTree"][t].array(keyname)
        except KeyError:
            continue
        data[b] = d
        break
    else:
        raise KeyError(keyname)
        


groupings = ["slc.reco.trk", "crt_hits"]
for k in data.keys():
    for g in groupings:
        if k.startswith(g):
            data[k] = group(data[k], data[g.replace(g.split(".")[-1], "n"+g.split(".")[-1])])

to_broadcast = ["pass_flashtrig"]
broadcast_over = "nslc"
for k in to_broadcast:
    data[k] = broadcast(data[k], data[broadcast_over])
    
            

In [None]:
nu_is_numu_cc = data["mc.nu.iscc"] & (np.abs(data["mc.nu.pdg"]) == 14)
nu_is_fid = selection.InFV(data["mc.nu.position.x"], data["mc.nu.position.y"], data["mc.nu.position.z"])

In [None]:
data["mc.nu.prim.genp"] = np.sqrt(data["mc.nu.prim.genp.x"]**2 + data["mc.nu.prim.genp.y"]**2 + data["mc.nu.prim.genp.z"]**2)

mc_groups = ["mc.nu.prim.genE", "mc.nu.prim.pdg", "mc.nu.prim.contained", "mc.nu.prim.genp", "mc.nu.prim.length"]

is_muon = np.abs(data["mc.nu.prim.pdg"]) == 13

for g in mc_groups:
    data[g.replace("prim", "lep")] = data["mc.nu.E"] * 0. # Clone
    data[g.replace("prim", "lep")][:] = np.NaN
    data[g.replace("prim", "lep")][nu_is_numu_cc] = data[g][is_muon]
    
data["mc.nu.lep.contained"] = data["mc.nu.lep.contained"] == 1

In [None]:
nu_has_contained_muon = data["mc.nu.lep.contained"]
nu_has_long_muon = (data["mc.nu.lep.contained"] & (data["mc.nu.lep.length"] > 50.)) |\
                (np.invert(data["mc.nu.lep.contained"]) & (data["mc.nu.lep.length"] > 100.))


In [None]:
ovrl_pot = NeutrinoPOT(data)
goal_pot = 6.6e20
pot_scale = goal_pot / ovrl_pot

In [None]:
is_nu = data["slc.truth.index"] >= 0
is_cosmic = np.invert(is_nu)
is_numu_cc = is_nu & data["slc.truth.iscc"] & (np.abs(data["slc.truth.pdg"]) == 14)

def dist(x1, y1, z1, x2, y2, z2):
    return np.sqrt((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2)

is_fid_true = selection.InFV(data["slc.truth.position.x"], data["slc.truth.position.y"], data["slc.truth.position.z"])
is_fid_reco = selection.InFV(data["slc.vertex.x"], data["slc.vertex.y"], data["slc.vertex.z"])

data["slc.reco.trk.truth.p.fiducial"] = selection.InFV(data["slc.reco.trk.truth.p.gen.x"], data["slc.reco.trk.truth.p.gen.y"], data["slc.reco.trk.truth.p.gen.z"]) &\
            selection.InFV(data["slc.reco.trk.truth.p.end.x"], data["slc.reco.trk.truth.p.end.y"], data["slc.reco.trk.truth.p.end.z"])
data["slc.reco.trk.contained"] = selection.InFV(data["slc.reco.trk.end.x"], data["slc.reco.trk.end.y"], data["slc.reco.trk.end.z"])
data["slc.reco.trk.atslc"] = dist(data["slc.reco.trk.start.x"], data["slc.reco.trk.start.y"], data["slc.reco.trk.start.z"],
                                 data["slc.vertex.x"], data["slc.vertex.y"], data["slc.vertex.z"]) < 10

data["slc.reco.trk.truth.p.is_stopping"] = (data["slc.reco.trk.truth.p.end_process"] == 1) |\
                                    (data["slc.reco.trk.truth.p.end_process"] == 2) |\
                                    (data["slc.reco.trk.truth.p.end_process"] == 3) |\
                                    (data["slc.reco.trk.truth.p.end_process"] == 41)

data["slc.reco.trk.recop"] = data["slc.reco.trk.rangeP.p_muon"] + 0. # clone
data["slc.reco.trk.recop"][np.invert(data["slc.reco.trk.contained"])] = data["slc.reco.trk.mcsP.fwdP_muon"][np.invert(data["slc.reco.trk.contained"])] 

for chi2 in ["chi2_muon", "chi2_proton"]:
    data["slc.reco.trk.bestplane." + chi2] = data["slc.reco.trk.chi2pid2." + chi2]
    data["slc.reco.trk.bestplane." + chi2][data["slc.reco.trk.bestplane"] == 0] = data["slc.reco.trk.chi2pid0." + chi2][data["slc.reco.trk.bestplane"] == 0]
    data["slc.reco.trk.bestplane." + chi2][data["slc.reco.trk.bestplane"] == 1] = data["slc.reco.trk.chi2pid1." + chi2][data["slc.reco.trk.bestplane"] == 1]

# For the cases where two slices match to the same neutrino, figure out which match is "primary"
# Use whichever slice gets more of the deposited energy
data["slc.truth.match_is_primary"] = data["slc.truth.iscc"] & False # clone
# Check the max match for each neutrino match 
# Group each slice by spill
index_spill = group(data["slc.truth.index"], data["nslc"])
eff_spill = group(data["slc.tmatch.eff"], data["nslc"])
primary_spill = group(data["slc.truth.match_is_primary"], data["nslc"])

for i in range(data["slc.truth.index"].max()+1): # consider up to the maximum nuetrinos
    primary_spill = primary_spill | (eff_spill[index_spill==i].max() == eff_spill) & (eff_spill[index_spill==i].max() > 0.)

data["slc.truth.match_is_primary"] = primary_spill.flatten()
# all cosmic matches are primary
data["slc.truth.match_is_primary"] = data["slc.truth.match_is_primary"] | (data["slc.truth.index"] < 0)

In [None]:
primary_track = selection.get_primary_tracks(data)
primary_track_ind = ak.JaggedArray.fromcounts((primary_track >= 0)*1, primary_track[primary_track >= 0])
true_primary_track = selection.get_true_primary_track(data)
true_primary_track_ind = ak.JaggedArray.fromcounts((true_primary_track >= 0)*1, true_primary_track[true_primary_track >= 0])

In [None]:
# Set all of the primary track slice variables
keys = list(data.keys())
for k in keys:
    if k.startswith("slc.reco.trk."):
        slc_key = k.replace(".reco.trk.", ".ptrk.")
        data[slc_key] = np.empty(data["slc.nu_score"].shape)
        data[slc_key][:] = np.NaN
        
        is_bool = data[k][0].dtype == "bool"
        d = data[k] + 0
        data[slc_key][primary_track >= 0] = d[primary_track_ind[primary_track_ind >= 0]].flatten()
        if is_bool:
            data[slc_key] = data[slc_key] == 1
data["slc.has_ptrk"] = primary_track >= 0

In [None]:
# Define the Cuts!!!!!
fid = selection.fid(data)
nu_score = selection.nu_score(data)
f_time = selection.f_time(data)
f_score = selection.f_score(data)
ptrk = selection.ptrk(data)
crttrack = selection.crttrack(data)
crthit = selection.crthit(data)
crtveto = selection.crtveto(data)

In [None]:
# now group the data by neutrino
# Figure out for each neutrino if there is a maching slice
# To do this, broadcast the list of slices over each neutrino
def broadcast_to_nu(d):
    return broadcast_ak(group(d, data["nslc"]), data["mc.nnu"].astype(np.int64))

nu_slc_match_index = broadcast_to_nu(data["slc.truth.index"])
nu_slc_is_primary = broadcast_to_nu(data["slc.truth.match_is_primary"])

# also broadcast the cuts
nu_fid = broadcast_to_nu(fid)
nu_nu_score = broadcast_to_nu(nu_score)
nu_f_time = broadcast_to_nu(f_time)
nu_f_score = broadcast_to_nu(f_score)
nu_ptrk = broadcast_to_nu(ptrk)
nu_crttrack = broadcast_to_nu(crttrack)
nu_crthit = broadcast_to_nu(crthit)
nu_crtveto = broadcast_to_nu(crtveto)

In [None]:
def effplot(var, cut, when_plot, bins, whenname="", xlabel="", text=""):
    plot = plt.subplot(111)

    var = data["mc.nu.E"]
    #var = data["mc.nu.E"] * 0 #copy
    #var[nu_is_numu_cc] = data["mc.nu.prim.genE"][np.abs(data["mc.nu.prim.pdg"]) == 13]
    when = nu_is_numu_cc & when_plot

    n,bins,_ = plot.hist([var[when & cut], var[when]],
                weights=[np.repeat(pot_scale, np.sum(when & cut)), np.repeat(pot_scale, np.sum(when))],
                label=["Reco %s $\\nu_\\mu$ CC" % whenname, "All %s $\\nu_\\mu$ CC" % whenname],
                histtype="step",
                bins=bins)

    eff_reco = n[0][n[-1]/pot_scale > 10.] / n[-1][n[-1]/pot_scale > 10.]

    bin_centers = (bins[:-1] + bins[1:]) / 2.
    eff_y = plt.twinx()
    eff_y.plot(bin_centers[n[-1]/pot_scale > 10.], eff_reco, label="Reco Efficiency")

    l1 = plot.legend(bbox_to_anchor=(1.06,0), loc="lower left")
    plot.set_xlabel(xlabel)
    plot.set_ylabel("Entries / 6.6e20 POT")
    l2 = eff_y.legend(bbox_to_anchor=(1.06,1), loc="upper left")
    eff_y.set_ylabel("Efficiency")
    eff_y.set_ylim([0, 1.05])

    eff = np.sum(cut & when) / np.sum(when)

    plot.text(0.75, 0.5, '%sIntegrated Eff.: %.1f%%' % (text, eff*100), verticalalignment="center", horizontalalignment='center',fontsize=14, transform=plot.transAxes)
    return (l1, l2)

In [None]:
nu_has_match = ((data["mc.nu.index"] == nu_slc_match_index) & nu_slc_is_primary)
cuts_ind = [
    nu_has_match,
    nu_fid & nu_nu_score,
    nu_f_time & nu_f_score,
    nu_ptrk,
    nu_crttrack & nu_crthit,
    nu_crtveto
]
cutnames_ind = ["Pandora Identification\n", "TPC Pre-Selection\n", "Flash Matching\n", "Muon Track Identification\n", "CRT Matching\n", "CRT Veto\n"]

cuts = []
cutnames = []
for i in range(len(cuts_ind)):
    if i == 0:
        cuts.append(cuts_ind[i])
        cutnames.append(cutnames_ind[i])
    else:
        cuts.append(cuts_ind[i] & cuts[i-1])
        cutnames.append(cutnames[i-1] + cutnames_ind[i])
cuts = [c.any() for c in cuts]
cutsave = ["pandora", "tpc_pre_sel", "fmatch", "prim_track", "crt_match", "crt_veto"]
cuts = list(zip(cuts, cutnames, cutsave))

xvars = [
    (data["mc.nu.E"], "Neutrino Energy [GeV]", np.linspace(0,3,21), "nuE"),
    (data["mc.nu.lep.genp"], "Muon Momentum [GeV]", np.linspace(0,3,21), "muP")
]

when_plots = [
    (nu_is_fid, "Fid", "fid"),
    (nu_is_fid & nu_has_contained_muon, "Contained", "contmu"),
    (nu_is_fid & nu_has_long_muon, "Long", "longmu"),
]

ifig = 0
for when_plot, whenname, whensave in when_plots:
    for var, xlabel, bins, varsave in xvars:
        for c, name, cutsave in cuts:
            plt.figure(ifig)
            artists = effplot(var, c, when_plot, bins, whenname=whenname, xlabel=xlabel, text=name)
            ifig += 1
            if do_save:
                plt.savefig(savedir + "eff/eff_%s_%s_%s.png" % (whensave, varsave, cutsave), bbox_extra_artists=artists, bbox_inches='tight')