In [4]:
import uproot as up
import pandas as pd
import os
import numpy as np
import awkward as ak
import mplhep as hep
import matplotlib.pyplot as plt
import yaml
from coffea.nanoevents import TreeMakerSchema, BaseSchema, NanoEventsFactory
import argparse
import sys

hep.style.use(hep.style.CMS)

In [5]:
class ParticleSelection:
    """
    A class to define and apply particle selection criteria.
    """
    def __init__(self, events):
        self.events = events
        self.cut_list = self.make_cut_list()

    def make_cut_list(self, elePt_low=5, elePt_high=5, eleEta=2.5, ymass_upper=10):
        """
        Define a list of cuts to be applied to the events.
        """

        cut_fake = (self.events['B_Z_mass'] > -1)
        cut_unOrdered_Z = self.events['B_Z_pt1'] > self.events['B_Z_pt2']
        cut_unOrdered_J = self.events['B_J_pt1'] > self.events['B_J_pt2']
        cut_unOrdered = cut_unOrdered_Z & cut_unOrdered_J

        cut_EleTrigger = self.events['B_Z_TriggerPath']
        cut_Jsoft = self.events['B_J_soft1'] & self.events['B_J_soft2']
        cut_EleTrigEnforce = self.events['B_Z_pt1'] > 27
        cut_dilepton_prob = (self.events['B_J_VtxProb'] > 0.01) & (self.events['B_Z_VtxProb'] > 0.01)
        cut_FourL_prob = self.events['FourL_VtxProb'] > 0.01

        cut_Pt = (self.events['B_Z_pt1'] > 27.0) & (self.events['B_Z_pt2'] > 5.0) & (self.events['B_J_pt1'] > 3.0) & (self.events['B_J_pt2'] > 3.0)
        cut_eta = (abs(self.events['B_Z_eta1']) < 2.4) & (abs(self.events['B_Z_eta2']) < 2.4) & (abs(self.events['B_J_eta1']) < 2.5) & (abs(self.events['B_J_eta2']) < 2.5)
        cut_detector = cut_Pt & cut_eta

        cut_Jmass = (self.events['B_J_mass'] > 3.0) & (self.events['B_J_mass'] < 3.2)
        cut_Zmass = (self.events['B_Z_mass'] > 70) & (self.events['B_Z_mass'] < 110)

        cut_ZmvaIsoHighpT = self.events['B_Z_mvaIsoWP90_1']
        cut_ZmvaIsoLowpT = self.events['B_Z_mvaIsoWP90_2']
        cut_ZmvaIsoBoth = (self.events['B_Z_mvaIsoWP90_1']) & (self.events['B_Z_mvaIsoWP90_2'])

        cut_FourLmass = (self.events['FourL_mass'] > 112) & (self.events['FourL_mass'] < 162)

        cut_list = {
            "0": {"name": "Preselection", "mask": cut_fake, "plot": False},
            "1": {"name": "UnOrdered pT", "mask": cut_unOrdered},
            "2": {"name": "Electron Trigger", "mask": cut_EleTrigger},
            "3": {"name": "Soft Muons", "mask": cut_Jsoft},
            "4": {"name": "Electron Trigger Enforce", "mask": cut_EleTrigEnforce},
            "5": {"name": "Dilepton Vtx > 1%", "mask": cut_dilepton_prob, "var": ["B_J_VtxProb", "B_Z_VtxProb"]},
            "6": {"name": "FourL Vtx > 1%", "mask": cut_FourL_prob, "var": ["FourL_VtxProb"]},
            "7": {"name": "Detector acceptance", "mask": cut_detector},
            "8": {"name": "J mass", "mask": cut_Jmass, "var": ["B_J_mass"]},
            "9": {"name": "Z mass", "mask": cut_Zmass, "var": ["B_Z_mass"]},
            "10": {"name": "eleID High pT", "mask": cut_ZmvaIsoHighpT, "var": ["B_J_pt1"]},
            "11": {"name": "eleID Low pT", "mask": cut_ZmvaIsoLowpT, "var": ["B_J_pt2"]},
            "12": {"name": "eleID either", "mask": cut_ZmvaIsoBoth, "var": ["B_J_pt1", "B_J_pt2"]},
            "13": {"name": "FourL mass", "mask": cut_FourLmass, "var": ["FourL_mass"]}
        }

        return cut_list

In [6]:
run_data = True
run_mc = False

def load_events(file_path, schema_class, entry_stop):
    """
    Load events from a ROOT file with error handling.
    """
    try:
        events = NanoEventsFactory.from_root(
            {file_path: "ntuple"}, 
            schemaclass=schema_class, 
            entry_stop=entry_stop
        ).events()
        return events
    except FileNotFoundError:
        print(f"Error: File '{file_path}' not found.")
        return None
    except Exception as e:
        print(f"Error loading events from '{file_path}': {e}")
        return None

def load_config(config_path):
    """
    Load configuration from a YAML file.
    """
    try:
        with open(config_path, 'r') as config_file:
            return yaml.safe_load(config_file)
    except FileNotFoundError:
        print(f"Error: Configuration file '{config_path}' not found.")
        sys.exit(1)
    except yaml.YAMLError as e:
        print(f"Error parsing YAML configuration file: {e}")
        sys.exit(1)

config = load_config('config.yaml')

events_data = None
events_mc = None

# Load events
if run_data:
    events_data = load_events(config['data_path'], BaseSchema, config['entry_stop'])
    if events_data is None:
        print("Failed to load data events. Skipping data analysis.")
        run_data = False

if run_mc:
    events_mc = load_events(config['mc_path'], BaseSchema, config['entry_stop'])
    if events_mc is None:
        print("Failed to load MC events. Skipping MC analysis.")
        run_mc = False

if not run_data and not run_mc:
    print("No valid data or MC events loaded. Exiting.")
    sys.exit(1)

# Select columns
if run_data:
    events_data = events_data[config['columns']].compute()
if run_mc:
    events_mc = events_mc[config['columns']].compute()

In [7]:
save_path = config['savepath_base']

In [8]:
save_path
os.makedirs(save_path, exist_ok=True)

In [9]:
dilepton_mass_cuts = [8, 9]
eleID_cuts = [10, 11, 12]

eleID_both = [0, 1, 3, 12, 7, 5, 8, 9, 6, 13]
cuts_to_show = [0, 3, 12, 7, 5, 8, 9, 6, 13]

In [47]:
class Cut:
    def __init__(self, name, long_name, mask, variables=None, plot=True, bins=50, x_range=None, labels=None, xlabel=None, cut_line=None):
        self.name = name
        self.long_name = long_name
        self.mask = mask
        self.variables = variables if variables else []
        self.plot = plot
        self.bins = bins
        self.x_range = x_range
        self.labels = labels if labels else []
        self.xlabel = xlabel
        self.cut_line = cut_line

    def __str__(self):
        return f"Cut: {self.name} ({self.long_name})"

In [48]:
class CutList:
    def __init__(self):
        self.cuts = {}

    def add_cut(self, cut_id, cut):
        self.cuts[cut_id] = cut

    def get_cut(self, cut_id):
        return self.cuts.get(cut_id)
    
    def __str__(self):
        return "\n".join([f"{cut_id}: {cut}" for cut_id, cut in self.cuts.items()])


In [135]:
cut_list = CutList()
cut_list.add_cut(0,
                 Cut(name="0",
                     long_name="Preselection",
                     mask=events_data['B_Z_mass'] > -1, 
                     variables=['B_Z_mass'], 
                     plot=False))

cut_list.add_cut(1,
                 Cut(name="1",
                     long_name="UnOrdered pT",
                     mask=(events_data['B_Z_pt1'] > events_data['B_Z_pt2']) & (events_data['B_J_pt1'] > events_data['B_J_pt2']), 
                     variables=['B_Z_pt1', 'B_Z_pt2', 'B_J_pt1', 'B_J_pt2'], 
                     plot=False))

cut_list.add_cut(2,
                    Cut(name="2",
                        long_name="Electron Trigger",
                        mask=events_data['B_Z_TriggerPath'] == 1, 
                        variables='B_Z_TriggerPath', 
                        plot=False))

cut_list.add_cut(3, 
                    Cut(name="3",
                        long_name="Soft Muons",
                        mask=(events_data['B_J_soft1'] == 1) & (events_data['B_J_soft2'] == 1),
                        variables=['B_J_soft1', 'B_J_soft2'],
                        plot=False))

cut_list.add_cut(4,
                    Cut(name="4",
                        long_name="Electron Trigger Enforce",
                        mask=(events_data['B_Z_pt1'] > 27),
                        variables=['B_Z_pt1'],
                        plot=True,
                        bins=100,
                        x_range=(0, 100),
                        labels=["Z_pt1"],
                        xlabel="Z pt1 (GeV)"))

cut_list.add_cut(5,
                    Cut(name="5",
                        long_name="Dilepton Vtx > 1%",
                        mask=(events_data['B_J_VtxProb'] > 0.01) & (events_data['B_Z_VtxProb'] > 0.01),
                        variables=['B_J_VtxProb', 'B_Z_VtxProb'],
                        plot=True,
                        bins=100,
                        x_range=(0, 1),
                        labels=["J_VtxProb", "Z_VtxProb"],
                        xlabel="Vtx Prob"))

cut_list.add_cut(6,
                    Cut(name="6",
                        long_name="FourL Vtx > 1%",
                        mask=(events_data['FourL_VtxProb'] > 0.01),
                        variables=['FourL_VtxProb'],
                        plot=True,
                        bins=100,
                        x_range=(0, 1),
                        labels=["FourL_VtxProb"],
                        xlabel="Vtx Prob"))

cut_list.add_cut(7,
                    Cut(name="7",
                        long_name="Detector acceptance",
                        mask=(abs(events_data['B_Z_eta1']) < 2.4) & (abs(events_data['B_Z_eta2']) < 2.4) & (abs(events_data['B_J_eta1']) < 2.5) & (abs(events_data['B_J_eta2']) < 2.5),
                        variables=['B_Z_eta1', 'B_Z_eta2', 'B_J_eta1', 'B_J_eta2'],
                        plot=True,
                        bins=100,
                        x_range=(-3, 3),
                        labels=["Z_eta1", "Z_eta2", "J_eta1", "J_eta2"],
                        xlabel="Eta"))

cut_list.add_cut(8,
                    Cut(name="8",
                        long_name="J mass",
                        mask=(events_data['B_J_mass'] > 3.0) & (events_data['B_J_mass'] < 3.2),
                        variables=['B_J_mass'],
                        plot=True,
                        bins=100,
                        x_range=(2.5, 3.5),
                        labels=["J_mass"],
                        xlabel="J mass (GeV)"))

cut_list.add_cut(9,
                    Cut(name="9",
                        long_name="Z mass",
                        mask=(events_data['B_Z_mass'] > 70) & (events_data['B_Z_mass'] < 110),
                        variables=['B_Z_mass'],
                        plot=True,
                        bins=100,
                        x_range=(60, 120),
                        labels=["Z_mass"],
                        xlabel="Z mass (GeV)"))

cut_list.add_cut(10,
                    Cut(name="10",
                        long_name="eleID High pT",
                        mask=(events_data['B_Z_mvaIsoWP90_1']),
                        variables=['B_J_pt1'],
                        plot=True,
                        bins=100,
                        x_range=(0, 100),
                        labels=["J_pt1"],
                        xlabel="J pt1 (GeV)"))

cut_list.add_cut(11,
                    Cut(name="11",
                        long_name="eleID Low pT",
                        mask=(events_data['B_Z_mvaIsoWP90_2']),
                        variables=['B_J_pt2'],
                        plot=True,
                        bins=100,
                        x_range=(0, 100),
                        labels=["J_pt2"],
                        xlabel="J pt2 (GeV)"))

cut_list.add_cut(12,
                    Cut(name="12",
                        long_name="eleID either",
                        mask=(events_data['B_Z_mvaIsoWP90_1']) & (events_data['B_Z_mvaIsoWP90_2']),
                        variables=['B_J_pt1', 'B_J_pt2'],
                        plot=True,
                        bins=100,
                        x_range=(0, 100),
                        labels=["J_pt1", "J_pt2"],
                        xlabel="J pt (GeV)"))

cut_list.add_cut(13,
                    Cut(name="13",
                        long_name="FourL mass",
                        mask=(events_data['FourL_mass'] > 112) & (events_data['FourL_mass'] < 162),
                        variables=['FourL_mass'],
                        plot=True,
                        bins=100,
                        x_range=(100, 200),
                        labels=["FourL_mass"],
                        xlabel="FourL mass (GeV)"))


In [136]:
for key, item in cut_list.cuts.items():
    print(key, item)

0 Cut: 0 (Preselection)
1 Cut: 1 (UnOrdered pT)
2 Cut: 2 (Electron Trigger)
3 Cut: 3 (Soft Muons)
4 Cut: 4 (Electron Trigger Enforce)
5 Cut: 5 (Dilepton Vtx > 1%)
6 Cut: 6 (FourL Vtx > 1%)
7 Cut: 7 (Detector acceptance)
8 Cut: 8 (J mass)
9 Cut: 9 (Z mass)
10 Cut: 10 (eleID High pT)
11 Cut: 11 (eleID Low pT)
12 Cut: 12 (eleID either)
13 Cut: 13 (FourL mass)


In [137]:
myorder = [0, 1, 3, 12, 7, 5, 8, 9, 6, 13]

In [161]:
class CutAnalysis:
    def __init__(self, events, cut_list, save_path):
        self.events = events
        self.cut_list = cut_list
        self.save_path = save_path

    def get_stats(self, data):
        array = data.B_J_mass
        nevents = len(array[ak.num(array, axis=1) > 0])
        ncandidates = ak.sum(ak.num(array, axis=1))
        return nevents, ncandidates
    
    def prepare_masks(self, myorder):
        masks_list = []
        cutobj_list = []
        ncandidates_list = []
        nevents_list = []

        aggregate_mask = self.cut_list.cuts[0].mask
        
        for cut_id in myorder:
            cut = self.cut_list.cuts[cut_id]
            aggregate_mask = aggregate_mask & cut.mask
            masks_list.append(aggregate_mask)
            cutobj_list.append(cut)
            ncandidates, nevents = self.get_stats(self.events[aggregate_mask])
            ncandidates_list.append(ncandidates)
            nevents_list.append(nevents)

            print(f"Cut {cut_id}: {cut.long_name} - {nevents} events, {ncandidates} candidates")


        summary_dict = {
            "cut_id": myorder,
            "cut_name": [self.cut_list.cuts[cut_id].long_name for cut_id in myorder],
            "nevents": nevents_list,
            "ncandidates": ncandidates_list,
            "mask": masks_list,
            "cutobj": cutobj_list
        }

        return summary_dict

    def plot_preselection(self):
        for cut_id in self.cut_list.cuts:
            cut = self.cut_list.get_cut(cut_id)
            if not cut.plot:
                print(f"Cut {cut_id} does not have a plot.")
                continue
            
            print(f"Plotting cut {cut_id}: {cut.long_name}")
            if len(cut.variables) == 0:
                print(f"Cut {cut_id} does not have any variables to plot.")
                continue

            nbins = cut.bins
            xrange = cut.x_range
            labels = cut.labels
            xlabel = cut.xlabel
            variables = cut.variables
            unit = 'GeV' if 'eta' not in xlabel else ''
            
            plt.figure()
            for i, variable in enumerate(variables):
                plt.hist(ak.flatten(self.events[variable][cut.mask]), bins=nbins, range=xrange, label=labels[i], alpha=0.5)
            plt.xlabel(cut.labels[0])
            plt.ylabel("Events")
            plt.title(f"{cut.long_name}")
            plt.legend()
            plt.savefig(f"{self.save_path}/pre_{cut.long_name}.png")
            plt.close()

    def plot_summary_at(self, cut_id, summary):
        
        idx_after_plot = summary["cut_id"].index(cut_id)
        name_after_plot = summary["cut_name"][idx_after_plot]
        n_eve_after = summary["nevents"][idx_after_plot]
        print(f"Next cut: {idx_after_plot} {name_after_plot} ({n_eve_after})")

        cut_obj = summary["cutobj"][idx_after_plot]

        if not cut_obj.plot:
            print(f"Cut {cut_id} does not have a plot.")
            return
        
        variables = cut_obj.variables
        nbins = cut_obj.bins
        xlow, xhigh = cut_obj.x_range
        labels = cut_obj.labels
        xlabel = cut_obj.xlabel
        unit = 'GeV' if 'eta' not in xlabel else ''
        fileName = f"{self.save_path}/cut_{name_after_plot}.png"

        idx_at_plot = idx_after_plot - 1
        name_at_plot = summary["cut_name"][idx_at_plot]
        n_eve_at_plot = summary["nevents"][idx_at_plot]
        mask = summary["mask"][idx_at_plot]
        print(f"Drawn at: {idx_at_plot} {name_at_plot} ({n_eve_at_plot})")

        for i, variable in enumerate(variables):
            plt.hist(ak.flatten(self.events[variable][mask]), bins=nbins, range=(xlow, xhigh), label=labels[i])

        plt.text(0.5, 0.5, f"Drawn at: {name_at_plot} ({n_eve_at_plot})", fontsize=12, transform=plt.gca().transAxes)
        plt.text(0.5, 0.45, f"Next cut: {name_after_plot} ({n_eve_after})", fontsize=12, transform=plt.gca().transAxes)

        plt.xlabel(xlabel)
        plt.ylabel(f"Counts / {(xhigh-xlow)/nbins:.3f} {unit}")
        plt.legend()
        plt.tight_layout()
        plt.savefig(fileName)
        plt.close()

    def plot_for_single_variable(self, variable, cuts, summary):
        # find the cut which has the variable
        cut_id = None
        for cut_id in self.cut_list.cuts:
            if variable in self.cut_list.cuts[cut_id].variables:
                break

        if cut_id is None:
            print(f"Variable {variable} not found in any cuts.")
            return
        
        print(f"Plotting variable {variable} for cuts: {cuts}")
        
        nbins = self.cut_list.cuts[cut_id].bins
        xlow, xhigh = self.cut_list.cuts[cut_id].x_range
        xlabel = self.cut_list.cuts[cut_id].xlabel
        unit = 'GeV' if 'eta' not in variable else ''

        cut_names = [self.cut_list.get_cut(cut_id).long_name for cut_id in cuts]
        print(f"Plotting variable {variable} for cuts: {cut_names}")

        plt.figure()
        for cut_id in cuts:
            local_idx = summary["cut_id"].index(cut_id)
            mask = summary["mask"][local_idx]
            nevents = summary["nevents"][local_idx]
            label = f"{summary['cut_name'][local_idx]}({nevents})"
            plt.hist(ak.flatten(self.events[variable][mask]), bins=50, label=label, alpha=0.5)

        plt.xlabel(xlabel)
        plt.ylabel(f"Counts / {(xhigh-xlow)/nbins:.3f} {unit}")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{self.save_path}/var_{variable}.png")
        plt.close()

In [165]:
cut_analysis = CutAnalysis(events_data, cut_list, save_path)
summary = cut_analysis.prepare_masks(myorder)

Cut 0: Preselection - 439851 events, 253773 candidates
Cut 1: UnOrdered pT - 439677 events, 253662 candidates
Cut 3: Soft Muons - 86821 events, 69235 candidates
Cut 12: eleID either - 11399 events, 10410 candidates
Cut 7: Detector acceptance - 10578 events, 9686 candidates
Cut 5: Dilepton Vtx > 1% - 9782 events, 8983 candidates
Cut 8: J mass - 283 events, 283 candidates
Cut 9: Z mass - 283 events, 283 candidates
Cut 6: FourL Vtx > 1% - 259 events, 259 candidates
Cut 13: FourL mass - 100 events, 100 candidates


In [166]:
summary_df = {k: v for k, v in summary.items() if k not in ['mask', 'cutobj']}
pd.DataFrame(summary_df)

Unnamed: 0,cut_id,cut_name,nevents,ncandidates
0,0,Preselection,439851,253773
1,1,UnOrdered pT,439677,253662
2,3,Soft Muons,86821,69235
3,12,eleID either,11399,10410
4,7,Detector acceptance,10578,9686
5,5,Dilepton Vtx > 1%,9782,8983
6,8,J mass,283,283
7,9,Z mass,283,283
8,6,FourL Vtx > 1%,259,259
9,13,FourL mass,100,100


In [168]:
cut_analysis.plot_preselection()

Cut 0 does not have a plot.
Cut 1 does not have a plot.
Cut 2 does not have a plot.
Cut 3 does not have a plot.
Plotting cut 4: Electron Trigger Enforce
Plotting cut 5: Dilepton Vtx > 1%
Plotting cut 6: FourL Vtx > 1%
Plotting cut 7: Detector acceptance
Plotting cut 8: J mass
Plotting cut 9: Z mass
Plotting cut 10: eleID High pT
Plotting cut 11: eleID Low pT
Plotting cut 12: eleID either
Plotting cut 13: FourL mass


In [169]:
cut_analysis.plot_summary_at(8, summary)

Next cut: 6 J mass (283)
Drawn at: 5 Dilepton Vtx > 1% (9782)


In [170]:
cuts_to_show = [0, 1, 3, 12, 7, 5, 8, 9, 6, 13]
cut_analysis.plot_for_single_variable("B_J_mass", cuts_to_show, summary)

Plotting variable B_J_mass for cuts: [0, 1, 3, 12, 7, 5, 8, 9, 6, 13]
Plotting variable B_J_mass for cuts: ['Preselection', 'UnOrdered pT', 'Soft Muons', 'eleID either', 'Detector acceptance', 'Dilepton Vtx > 1%', 'J mass', 'Z mass', 'FourL Vtx > 1%', 'FourL mass']


In [147]:
summary

{'cut_id': [0, 1, 3, 12, 7, 5, 8, 9, 6, 13],
 'cut_name': ['Preselection',
  'UnOrdered pT',
  'Soft Muons',
  'eleID either',
  'Detector acceptance',
  'Dilepton Vtx > 1%',
  'J mass',
  'Z mass',
  'FourL Vtx > 1%',
  'FourL mass'],
 'nevents': [439851, 439677, 86821, 11399, 10578, 9782, 283, 283, 259, 100],
 'ncandidates': [253773, 253662, 69235, 10410, 9686, 8983, 283, 283, 259, 100],
 'mask': [<Array [[True, True, True], ..., [True]] type='253773 * [var * bool[paramet...'>,
  <Array [[True, True, True], [True], ..., [True]] type='253773 * var * bool'>,
  <Array [[False, False, False], ..., [False]] type='253773 * var * bool'>,
  <Array [[False, False, False], ..., [False]] type='253773 * var * bool'>,
  <Array [[False, False, False], ..., [False]] type='253773 * var * bool'>,
  <Array [[False, False, False], ..., [False]] type='253773 * var * bool'>,
  <Array [[False, False, False], ..., [False]] type='253773 * var * bool'>,
  <Array [[False, False, False], ..., [False]] type='25

In [None]:
class Analysis:
    """
    Main class to perform the analysis.
    """
    def __init__(self, events, savepath):
        self.events = events
        self.particle_selection = ParticleSelection(events)
        # self.summary_dict, self.summary_table = self.get_summary_of_cuts(cuts_to_show)
        # self.plotter = Plotter(savepath)


    def get_stats(self, data):
        """
        Get the number of candidates and events in the given data.
        """
        array = data.B_Z_mass
        nevents = len(array[ak.num(array, axis=1) > 0])
        ncandidates = ak.sum(ak.num(array, axis=1))
        return ncandidates, nevents
    
    def get_summary_of_cuts(self, cut_order):
        """
        Generate a summary of cuts applied to the events.
        """
        key_list = []
        ncandidates_list = []
        nevents_list = []
        aggMask_list = []
        var_list = []

        # give a warning if there are duplicate cuts
        if len(cut_order) != len(set(cut_order)):
            print("Warning: Duplicate cuts detected. Please remove duplicates.")
            raise ValueError

        # start with the first cut
        agg_mask = self.particle_selection.cut_list["0"]["mask"]

        for i in cut_order:
            mask = self.particle_selection.cut_list[str(i)]["mask"]
            agg_mask = agg_mask & mask
            ncandidates, nevents = self.get_stats(self.events[agg_mask])

            key_list.append(self.particle_selection.cut_list[str(i)]["name"])
            ncandidates_list.append(ncandidates)
            nevents_list.append(nevents)
            aggMask_list.append(agg_mask)
            var_list.append(self.particle_selection.cut_list[str(i)].get("var", None))

        summary_dict = {
            "Cut": key_list, 
            "Candidates": ncandidates_list, 
            "Events": nevents_list, 
            "Aggregated mask": aggMask_list, 
            "Var": var_list
        }
        
        # summary_table = pd.DataFrame({key: summary_dict[key] for key in ["Cut", "Candidates", "Events"]})

        # return summary_dict, summary_table
        return summary_dict
    
    def get_view_at(self, mycut, summary_dict):
        """
        Get the view of the data at a specific cut.
        """
        mycut_name = self.particle_selection.cut_list[str(mycut)]['name']
        cut_index = summary_dict['Cut'].index(mycut_name)
        n_events_after_cut = summary_dict['Events'][cut_index]
        view_index = cut_index - 1
        view_index_name = summary_dict['Cut'][view_index]
        view_index_mask = summary_dict['Aggregated mask'][view_index]
        n_events_before_cut = summary_dict['Events'][view_index]
        cut_of_interest = self.events[view_index_mask]
        text_array = [view_index_name, mycut_name, n_events_before_cut, n_events_after_cut]
        return cut_of_interest, text_array

In [None]:
analysis_data = Analysis(events_data, save_path)
summary_dict = analysis_data.get_summary_of_cuts(eleID_both)
summary_table = pd.DataFrame({key: summary_dict[key] for key in ["Cut", "Candidates", "Events"]})
count = summary_table[summary_table['Cut'] == 'FourL mass']['Events'].values[0]

print(count)
print(summary_table)

In [None]:
cut_of_interest, text_array = analysis_data.get_view_at(5, summary_dict)

In [None]:
Plotter(save_path).plot_dilepton_vertexing(cut_of_interest, text_array)
    #.plot_dilepton_vertexing(cut_of_interest, text_array)

In [None]:
class Plotter:
    """
    A class to handle all plotting functions.
    """
    def __init__(self, savepath):
        self.savepath = savepath

    def make_hist(self, nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array):
        """
        Create and save a histogram.
        """
        cut_name_at_plot, cut_name_after_plot, n_eve_at_plot, n_eve_after = text_array
        unit = 'GeV' if 'eta' not in xlabel else ''
        
        plt.figure(figsize=(8, 8))
        for i, value in enumerate(values):
            plt.hist(value, bins=nbins, range=(xlow, xhigh), label=labels[i], alpha=0.5)
        for line in lines:
            plt.axvline(x=line, color='r')

        plt.text(0.5, 0.5, f"Drawn at: {cut_name_at_plot} ({n_eve_at_plot})", fontsize=12, transform=plt.gca().transAxes)
        plt.text(0.5, 0.45, f"Next cut: {cut_name_after_plot} ({n_eve_after})", fontsize=12, transform=plt.gca().transAxes)

        plt.xlabel(xlabel)
        plt.ylabel(f"Counts / {(xhigh-xlow)/nbins:.2f} {unit}")
        plt.legend(fontsize=13)
        plt.tight_layout()
        plt.savefig(f"{self.savepath}/{fileName}.png")
        plt.close()

    def make_hist2D(self, xvar, yvar, xlabel, ylabel, fileName, text_array):
        """
        Create and save a 2D histogram.
        """
        plt.figure(figsize=(8, 8))
        plt.hist2d(xvar, yvar, bins=(50, 50), range=((-3, 3), (0, 50)))
        plt.colorbar()
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.tight_layout()
        plt.savefig(f"{self.savepath}/{fileName}.png")
        plt.close()

    # Add other plotting methods here (plot_dilepton_vertexing, plot_dielectron_inv_mass, etc.)
    def plot_dilepton_vertexing(self, cut_of_interest, text_array):
        J_vtx_prob = ak.flatten(cut_of_interest['B_J_VtxProb'])
        Z_vtx_prob = ak.flatten(cut_of_interest['B_Z_VtxProb'])

        nbins, xlow, xhigh = 100, 0, 1
        fileName = "vtx_prob_dilepton"
        values = [J_vtx_prob, Z_vtx_prob]
        labels = ["J", "Z"]
        lines = [0.01]
        xlabel = "Dilepton Vtx Prob"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_dielectron_inv_mass(cut_of_interest, text_array, isFitData=False):
        dielectron_mass = ak.flatten(cut_of_interest['B_Z_mass'])

        if isFitData:
            nameExt = 'fitData'
            lines = []
        else:
            nameExt = ''
            lines = [8, 10]

        nbins, xlow, xhigh = 20, 0, 12
        fileName = f"Z_mass {nameExt}"
        values = [dielectron_mass]
        labels = ["Z_mass"]
        xlabel = "Dielectron inv mass [GeV]"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_dimuon_inv_mass(cut_of_interest, text_array, isFitData=False):
        dimuon_mass = ak.flatten(cut_of_interest['B_J_mass'])

        if isFitData:
            nameExt = 'fitData'
            lines = []
        else:
            nameExt = '' 
            lines = [70, 110]

        nbins, xlow, xhigh = 20, 60, 120
        fileName = f"J_mass {nameExt}"
        values = [dimuon_mass]
        labels = ["J_mass"]
        xlabel = "Dimuon inv mass [GeV]"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_fourL_vertexing(cut_of_interest, text_array):
        FourL_vtx_prob = ak.flatten(cut_of_interest['FourL_VtxProb'])

        nbins, xlow, xhigh = 100, 0, 1
        fileName = "vtx_prob_fourL"
        values = [FourL_vtx_prob]
        labels = ["FourL"]
        lines = [0.01]
        xlabel = "FourL Vtx Prob"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_muon_pt(cut_of_interest, text_array):
        muon_pt1 = ak.flatten(cut_of_interest['B_J_pt1'])
        muon_pt2 = ak.flatten(cut_of_interest['B_J_pt2'])

        nbins, xlow, xhigh = 75, 0, 20
        fileName = "Mu_pt"
        values = [muon_pt1, muon_pt2]
        labels = ["Mu_pt1", "Mu_pt2"]
        lines = [3.0]
        xlabel = "Mu pT [GeV]"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_electron_pt(cut_of_interest, text_array):
        ele_pt1 = ak.flatten(cut_of_interest['B_Z_pt1'])
        ele_pt2 = ak.flatten(cut_of_interest['B_Z_pt2'])

        nbins, xlow, xhigh = 50, 0, 50
        fileName = "Ele_pt"
        values = [ele_pt1, ele_pt2]
        labels = ["Ele_pt1", "Ele_pt2"]
        lines = [5.0]
        xlabel = "Ele pT [GeV]"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_muon_eta(cut_of_interest, text_array):
        muon_eta1 = ak.flatten(cut_of_interest['B_J_eta1'])
        muon_eta2 = ak.flatten(cut_of_interest['B_J_eta2'])

        nbins, xlow, xhigh = 30, -3, 3
        fileName = "Mu_eta"
        values = [muon_eta1, muon_eta2]
        labels = ["Mu_eta1", "Mu_eta2"]
        lines = [2.4]
        xlabel = "Mu eta"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_electron_eta(cut_of_interest, text_array):
        ele_eta1 = ak.flatten(cut_of_interest['B_Z_eta1'])
        ele_eta2 = ak.flatten(cut_of_interest['B_Z_eta2'])

        nbins, xlow, xhigh = 30, -3, 3
        fileName = "Ele_eta"
        values = [ele_eta1, ele_eta2]
        labels = ["Ele_eta1", "Ele_eta2"]
        lines = [2.5]
        xlabel = "Ele eta"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

    def plot_electron_pt_vs_eta(cut_of_interest, text_array):
        xvar = ak.flatten(cut_of_interest['B_Z_eta1']).tolist()
        yvar = ak.flatten(cut_of_interest['B_Z_pt1']).tolist()

        fileName = "Ele1_eta_pt"
        xlabel = "Ele eta1"
        ylabel = "Ele pT1"
        self.make_hist2D(xvar, yvar, xlabel, ylabel, fileName, text_array)

        xvar = ak.flatten(cut_of_interest['B_Z_eta2']).tolist()
        yvar = ak.flatten(cut_of_interest['B_Z_pt2']).tolist()

        fileName = "Ele2_eta_pt"
        xlabel = "Ele eta2"
        ylabel = "Ele pT2"
        self.make_hist2D(xvar, yvar, xlabel, ylabel, fileName, text_array)

    def plot_fourL_inv_mass(cut_of_interest, text_array):
        fourL_mass = ak.flatten(cut_of_interest['FourL_mass'])

        nameExt = text_array[0] if 'eleID' in text_array[0] else ''

        nbins, xlow, xhigh = 10, 70, 170
        fileName = f"FourL_mass_with {nameExt} cut"
        values = [fourL_mass]
        labels = ["FourL_mass"]
        lines = [112, 162]
        xlabel = "FourL inv mass [GeV]"

        self.make_hist(nbins, xlow, xhigh, values, labels, lines, fileName, xlabel, text_array)

In [None]:


class Analysis:
    """
    Main class to perform the analysis.
    """
    def __init__(self, events, savepath):
        self.events = events
        self.particle_selection = ParticleSelection(events)
        self.plotter = Plotter(savepath)


    def get_count(self, data):
        """
        Count the number of candidates and events in the given data.
        """
        array = data.B_J_mass
        nevents = len(array[ak.num(array, axis=1) > 0])
        ncandidates = ak.sum(ak.num(array, axis=1))
        return ncandidates, nevents
    

    def get_summary_of_cuts(self, cut_order):
        """
        Generate a summary of cuts applied to the events.
        """
        key_list = []
        ncandidates_list = []
        nevents_list = []
        aggMask_list = []
        var_list = []

        agg_mask = self.particle_selection.cut_list["0"]["mask"]

        for i in cut_order:
            mask = self.particle_selection.cut_list[str(i)]["mask"]
            agg_mask = agg_mask & mask
            ncandidates, nevents = self.get_count(self.events[agg_mask])

            key_list.append(self.particle_selection.cut_list[str(i)]["name"])
            ncandidates_list.append(ncandidates)
            nevents_list.append(nevents)   
            aggMask_list.append(agg_mask)
            var_list.append(self.particle_selection.cut_list[str(i)].get("var", None))

        summary_dict = {
            "Cut": key_list, 
            "Candidates": ncandidates_list, 
            "Events": nevents_list, 
            "Aggregated mask": aggMask_list, 
            "Var": var_list
        }
        
        summary_table = pd.DataFrame({key: summary_dict[key] for key in ["Cut", "Candidates", "Events"]})

        return summary_dict, summary_table

    def get_view_at(self, mycut, summary_dict):
        """
        Get the view of the data at a specific cut.
        """
        mycut_name = self.particle_selection.cut_list[str(mycut)]['name']
        cut_index = summary_dict['Cut'].index(mycut_name)
        n_events_after_cut = summary_dict['Events'][cut_index]
        view_index = cut_index - 1
        view_index_name = summary_dict['Cut'][view_index]
        view_index_mask = summary_dict['Aggregated mask'][view_index]
        n_events_before_cut = summary_dict['Events'][view_index]
        cut_of_interest = self.events[view_index_mask]
        text_array = [view_index_name, mycut_name, n_events_before_cut, n_events_after_cut]
        return cut_of_interest, text_array

    def apply_cut_progression(self, cut_progression):
        """
        Apply a series of cuts and create plots at each stage.
        """
        summary_dict, summary_table = self.get_summary_of_cuts(self.events, cut_progression)
        print("Summary of cuts")
        print(summary_table)
        self.show_plots_at_each_cut(summary_dict)
        return summary_dict, summary_table

    def show_plots_at_each_cut(self, summary_dict):
        """
        Create plots at each cut stage.
        """
        # Implement the logic to create plots at each cut stage
        pass

    def get_count(self, elePt_low, elePt_high, eleEta, ymass_upper):
        """
        Apply cuts and return the count of events passing all cuts.
        """
        self.particle_selection.cut_list = self.particle_selection.make_cut_list(elePt_low, elePt_high, eleEta, ymass_upper)
        eleID_both = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13]
        summary_dict, summary_table = self.apply_cut_progression(eleID_both)
        count = summary_table[summary_table['Cut'] == 'FourL mass']['Events'].values[0]
        return count

def get_FOM(events_data, events_mc, savepath_data, savepath_mc, elePt_low, elePt_high, eleEta, ymass_upper):
    """
    Calculate the Figure of Merit.
    """
    analysis_data = Analysis(events_data, savepath_data)
    analysis_mc = Analysis(events_mc, savepath_mc)
    
    nb = analysis_data.get_count(elePt_low, elePt_high, eleEta, ymass_upper)
    nEff = analysis_mc.get_count(elePt_low, elePt_high, eleEta, ymass_upper) / 100
    FOM = nEff / np.sqrt(nb)
    return nEff, nb, FOM

def load_config(config_path):
    """
    Load configuration from a YAML file.
    """
    try:
        with open(config_path, 'r') as config_file:
            return yaml.safe_load(config_file)
    except FileNotFoundError:
        print(f"Error: Configuration file '{config_path}' not found.")
        sys.exit(1)
    except yaml.YAMLError as e:
        print(f"Error parsing YAML configuration file: {e}")
        sys.exit(1)

def load_events(file_path, schema_class, entry_stop):
    """
    Load events from a ROOT file with error handling.
    """
    try:
        events = NanoEventsFactory.from_root(
            {file_path: "ntuple"}, 
            schemaclass=schema_class, 
            entry_stop=entry_stop
        ).events()
        return events
    except FileNotFoundError:
        print(f"Error: File '{file_path}' not found.")
        return None
    except Exception as e:
        print(f"Error loading events from '{file_path}': {e}")
        return None

def main(config_path='config.yaml', run_data=True, run_mc=True):
    # Load configuration
    config = load_config(config_path)

    events_data = None
    events_mc = None

    # Load events
    if run_data:
        events_data = load_events(config['data_path'], BaseSchema, config['entry_stop'])
        if events_data is None:
            print("Failed to load data events. Skipping data analysis.")
            run_data = False

    if run_mc:
        events_mc = load_events(config['mc_path'], BaseSchema, config['entry_stop'])
        if events_mc is None:
            print("Failed to load MC events. Skipping MC analysis.")
            run_mc = False

    if not run_data and not run_mc:
        print("No valid data or MC events loaded. Exiting.")
        sys.exit(1)

    # Select columns
    if run_data:
        events_data = events_data[config['columns']].compute()
    if run_mc:
        events_mc = events_mc[config['columns']].compute()

    # Create DataFrame to store results
    df = pd.DataFrame(columns=['elePt_low', 'elePt_high', 'eleEta', 'ymass_upper', 'Efficiency', 'Background', 'FOM'])

    # Loop over different cut values
    for elePt_low in config['cut_parameters']['elePt_low']:
        for elePt_high in config['cut_parameters']['elePt_high']:
            for eleEta in config['cut_parameters']['eleEta']:
                for ymass_upper in config['cut_parameters']['ymass_upper']:
                    savepath_data = f"{config['savepath_data_base']}el_{elePt_low}_eh_{elePt_high}_eE_{eleEta}_yup_{ymass_upper}"
                    savepath_mc = f"{config['savepath_mc_base']}el_{elePt_low}_eh_{elePt_high}_eE_{eleEta}_yup_{ymass_upper}"

                    os.makedirs(savepath_data, exist_ok=True)
                    os.makedirs(savepath_mc, exist_ok=True)

                    eff, nb, FOM = 0, 0, 0
                    
                    if run_data and run_mc:
                        eff, nb, FOM = get_FOM(events_data, events_mc, savepath_data, savepath_mc, elePt_low, elePt_high, eleEta, ymass_upper)
                    elif run_data:
                        analysis_data = Analysis(events_data, savepath_data)
                        nb = analysis_data.get_count(elePt_low, elePt_high, eleEta, ymass_upper)
                    elif run_mc:
                        analysis_mc = Analysis(events_mc, savepath_mc)
                        eff = analysis_mc.get_count(elePt_low, elePt_high, eleEta, ymass_upper) / 100

                    df.loc[len(df)] = {'elePt_low': elePt_low, 'elePt_high': elePt_high, 'eleEta': eleEta, 'ymass_upper': ymass_upper, 'Efficiency': eff, 'Background': nb, 'FOM': FOM}
                    print(f"elePt_low: {elePt_low}, elePt_high: {elePt_high}, eleEta: {eleEta}, ymass_upper: {ymass_upper}")
                    print(f"Efficiency: {eff}")
                    print(f"Background: {nb}")
                    print(f"FOM: {FOM}")
                    print(df)
                    print("\n")

    print(df)
    df.to_csv(config['output_file'], index=False)

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Run particle physics analysis on data and/or MC events.")
#     parser.add_argument("--config", default="config.yaml", help="Path to the configuration file")
#     parser.add_argument("--data", action="store_true", help="Run analysis on data events")
#     parser.add_argument("--mc", action="store_true", help="Run analysis on MC events")
#     args = parser.parse_args()

#     if not args.data and not args.mc:
#         print("Error: You must specify at least one of --data or --mc")
#         sys.exit(1)

    # main(config_path=args.config, run_data=args.data, run_mc=args.mc)

In [None]:
main(config_path="config.yaml", run_data=True, run_mc=False)