In [1]:
import uproot
import matplotlib.pyplot as plt
import awkward as ak
import numpy as np
import glob
import pandas as pd
from utils.readData import get2HDMaEvents, getIDMevents, getIDMnanoEvents
from utils.applyCuts import applyCuts, getLeptons, getDM, cutNleptons, cutDeltaPhiPTjetPTmiss, applyCuts, cutNjets, getJets
from collections import Counter

In [12]:
data = ak.from_parquet('/vols/cms/emc21/idmStudy/HiggsDNA/HDMa_Analysis/merged_nominal.parquet')
data


In [26]:
# Get electrons
electrons = data[ak.where(data.electron_1_pt != -999, True, False)]
muons = data[ak.where(data.muon_1_pt != -999, True, False)]
leptons = ak.concatenate([electrons, muons])
leptons.muon_1_pt

In [None]:
BP=8
process_name = 'h2h2lPlM_lem'
CMSSW_version='CMSSW_10_6_19'
run_name = f'{process_name}_BP{BP}'


files = f'/vols/cms/emc21/idmStudy/myFiles/gridpacks/{process_name}_{CMSSW_version}/{run_name}/fall17Data/nanoAOD*.root'
files = glob.glob(files)
#print(files)

filename = sorted(files, reverse=True)[0] + ':Events;1'
print(filename)

In [None]:
/vols/cms/emc21/idmStudy/myFiles/gridpacks/h2h2lPlM_lem_CMSSW_10_6_19/h2h2lPlM_lem_BP8/fall17Data/nanoAOD*.root
/vols/cms/emc21/idmStudy/myFiles/gridpacks/h2h2lPlM_lem_CMSSW_10_6_19/h2h2lPlM_lem_BP8/fall17Data/nanoAOD*.root

In [None]:
events_idm = getIDMevents(8, 'h2h2lPlM_lem')
events_hdma = get2HDMaEvents(500, 400)

In [None]:
jets_idm = getJets(events_idm)
mask = jets_idm.GenJet_partonFlavour == 21
glu_jets = jets_idm[mask]
glu_pt = ak.flatten(glu_jets.GenJet_pt)

_ = plt.hist(glu_pt, density=True, histtype='step', bins=100, range=(0,300))

In [None]:
jets_idm = getJets(events_idm)
partonFlavour = ak.flatten(jets_idm.GenJet_partonFlavour)
for pdgid, count in zip(Counter(partonFlavour).keys(), Counter(partonFlavour).values()):
    print(f"pdgId = {pdgid}, count = {count}")

In [None]:
evs = applyCuts(events_idm, 35, 1)

In [None]:
branches = ['GenJet_eta', 'GenJet_hadronFlavour', 'GenJet_mass', 'GenJet_partonFlavour', 'GenJet_phi', 'GenJet_pt']
jets = events_idm[branches]
# print(jets.GenJet_partonFlavour)
# #print(jets.GenJet_pt)
# print(jets.GenJet_pt[:,0])
mask = jets.GenJet_partonFlavour != 0
print(mask)
non0_jets = ak.sum(mask,axis=1)
print(ak.sum(mask,axis=1))
print(min(non0_jets))
for pdgid, count in zip(Counter(non0_jets).keys(), Counter(non0_jets).values()):
    print(f"pdgId = {pdgid}, count = {count}")

# jets = jets[mask]
# print(jets.GenJet_partonFlavour)
# #print(jets.GenJet_pt)
# print(jets.GenJet_pt[:,0])

In [None]:
jets_hdma = getJets(events_hdma)
print(jets_hdma.GenJet_partonFlavour[9])

In [None]:
jets_idm = getJets(events_idm)
for i in range(10):
    print(jets_idm.GenJet_pt[i])

In [None]:
def leadingJetPT(jets):
    pts = jets.GenJet_pt
    leading_pt = []
    for ev in pts:
        if len(ev) >= 1 :
            leading_pt.append(ev[0])
    return leading_pt

jets_idm = getJets(events_idm)
jet_pt_idm = leadingJetPT(jets_idm)

jets_hdma = getJets(events_hdma)
jet_pt_hdma = leadingJetPT(jets_hdma)

_ = plt.hist(jet_pt_idm, density=True, histtype='step', bins=100, label='IDM, BP8')
_ = plt.hist(jet_pt_hdma, density=True, histtype='step', bins=100, label='2HDMa, mH=500, mA=400')
plt.legend()
plt.xlabel('PT (GeV)')
plt.ylabel('Count')
plt.title('Leading Jet PT, for GenJet_partonFlavour != 0')

In [None]:
num_over_30 = ak.sum(jet_pt_idm > 30)
percentage = (num_over_30 / len(jet_pt_idm)) * 100
print(percentage)

In [None]:
jets = getJets(events_idm)
mask = jets.GenJet_partonFlavour != 0
jets_no0 = jets[mask]

partonFlavour = ak.flatten(jets_no0.GenJet_partonFlavour)
for pdgid, count in zip(Counter(partonFlavour).keys(), Counter(partonFlavour).values()):
    print(f"pdgId = {pdgid}, count = {count}")

jet_pt = ak.flatten(jets_no0.GenJet_pt)
_ = plt.hist(jet_pt, density=True, histtype='step', bins=100)

In [None]:
jets = getJets(events_idm)

partonFlavour = ak.flatten(jets.GenJet_partonFlavour)
for pdgid, count in zip(Counter(partonFlavour).keys(), Counter(partonFlavour).values()):
    print(f"pdgId = {pdgid}, count = {count}")

jet_pt = ak.flatten(jets.GenJet_pt)
_ = plt.hist(jet_pt, density=True, histtype='step', bins=100)

In [None]:
mask = jets.GenJet_partonFlavour == 0
jets_0 = jets[mask]
jet_pt = ak.flatten(jets_0.GenJet_pt)

leptons = getLeptons(events_idm)
lep_pt = ak.flatten(leptons.GenDressedLepton_pt)

_ = plt.hist(jet_pt, density=True, histtype='step', bins=100, label='jet, partonFlavour=0')
_ = plt.hist(lep_pt, density=True, histtype='step', bins=100, label='lepton')
plt.xlabel('PT (GeV)')
plt.ylabel("Count")
plt.legend()

In [None]:
from collections import Counter
jets = getJets(events_idm)
partonFlavour = ak.flatten(jets.GenJet_partonFlavour)
#words = ['a', 'b', 'c', 'a']

# print(Counter(partonFlavour).keys()) # equals to list(set(partonFlavour))
# print(Counter(partonFlavour).values()) # counts the elements' frequency

for pdgid, count in zip(Counter(partonFlavour).keys(), Counter(partonFlavour).values()):
    print(pdgid, count)


In [None]:
print(partonFlavour)

In [None]:
jets = getJets(events_idm)
partonFlavour = jets.GenJet_partonFlavour
_ = plt.hist(ak.flatten(partonFlavour))

In [None]:
events_hdma = get2HDMaEvents(500, 400)

In [None]:
evs = cutNjets(events_hdma, 1)

In [None]:
jets = getJets(events_hdma)
print(jets.GenJet_pt)
print(jets.GenJet_eta)
#_ = plt.hist(ak.flatten(jets.GenJet_eta), density=True, histtype='step', bins=100)
Y = ak.flatten(jets.GenJet_pt)
X = ak.flatten(abs(jets.GenJet_eta))
plt.scatter(X, Y)
plt.xlabel('eta')
plt.ylabel('pt')

In [None]:
evs = cutNjets(events_hdma, 1)

In [None]:
events_idm = getIDMevents(8, 'h2h2lPlM_lem')
events_idm = cutNleptons(events_idm)

In [None]:
jets = getJets(events_idm)
leading_jet = jets.GenJet_pt[:,0]


jets_hdma = getJets(events_hdma)
leading_jet_hdma = jets_hdma.GenJet_pt[:,0]
_ = plt.hist(leading_jet_hdma, density=True, histtype='step', bins=100, label='2HDMa')
_ = plt.hist(leading_jet, density=True, histtype='step', bins=100, label='IDM')
plt.legend()
plt.title("PT of leading jet, $|eta|<4.7$")
plt.ylabel('Count')
plt.xlabel('PT (GeV)')

In [None]:
jets = getJets(events_idm)
leading_jet = jets.GenJet_pt[:,1]


jets_hdma = getJets(events_hdma)
leading_jet_hdma = jets_hdma.GenJet_pt[:,1]
_ = plt.hist(leading_jet_hdma, density=True, histtype='step', bins=100, label='2HDMa')
_ = plt.hist(leading_jet, density=True, histtype='step', bins=100, label='IDM')
plt.legend()
plt.title("PT of subleading jet, $|eta|<4.7$")
plt.ylabel('Count')
plt.xlabel('PT (GeV)')

In [None]:
evs = applyCuts(events_idm, 35, 1)

In [None]:
print(len(evs))

In [None]:
print(len(events_idm))
evs = cutDeltaPhiPTjetPTmiss(events_idm, 35)
print(len(evs))

In [None]:
filename = "root://cmsxrootd.fnal.gov///store/test/xrootd/T1_ES_PIC/store/mc/RunIIFall17NanoAODv7/Pseudoscalar2HDM_MonoZLL_mScan_mH-1400_ma-500/NANOAODSIM/PU2017_12Apr2018_Nano02Apr2020_102X_mc2017_realistic_v8-v1/120000/804FA4F0-379A-2148-A1A3-7E412433CEB6.root:Events"

options = {'timeout' : 180}
events = uproot.open(filename, **options)
#events = uproot.open('testfile.root:Events;1')
filters = ['GenDressedLepton*', 'GenPart*', 'GenJet*']
events = events.arrays(filter_name=filters, library='ak')

In [None]:
print(events)

In [None]:
options = {'timeout' : 180}
events = uproot.open('/eos/user/e/ecurtis/idmStudy/myFiles/gridpacks/h2h2lPlM_lem/h2h2lPlM_lem_BP8/wmNANOAODGEN_new_conditions_6777251.0.root:Events;1', **options)

filters = ['GenDressedLepton*', 'GenPart*', 'GenJet*']
events = events.arrays(filter_name=filters, library='ak')

branches = ['GenDressedLepton_phi', 'GenDressedLepton_pt', 'GenDressedLepton_eta']
leptons = events[branches]
# First just get the events that have two leptons 
eta = leptons.GenDressedLepton_eta
count = ak.count(eta, axis=1, keepdims=True)
mask = count == 2
mask = ak.all(mask,axis=1)
leptons = leptons[mask]

events = events[mask]

branches = ['GenJet_eta', 'GenJet_hadronFlavour', 'GenJet_mass', 'GenJet_partonFlavour', 'GenJet_phi', 'GenJet_pt']
jets = events[branches]

def num_jet_over_30(jets):
    jet_pt = jets.GenJet_pt
    mask = jet_pt > 30
    num_over_30 = ak.sum(mask, axis=1)
    return num_over_30

def cutJet(events):
    branches = ['GenJet_eta', 'GenJet_hadronFlavour', 'GenJet_mass', 'GenJet_partonFlavour', 'GenJet_phi', 'GenJet_pt']
    jets = events[branches]
    # Want to split into two categories: 0 jet > 30GeV
    # and 1 jet > 30 GeV
    jet_pt = jets.GenJet_pt
    mask = jet_pt > 30
    # This finds how many jets in an event are over 30 GeV
    num_over_30 = ak.sum(mask, axis=1)
    # Get the indexes for 0 jet events
    jet_0_idxs = np.argwhere(num_over_30 == 0)
    # Get the indexes for 1 jet events
    jet_1_idxs = np.argwhere(num_over_30 == 1)
    return jet_0_idxs, jet_1_idxs

In [None]:
def PTmiss(dm, return_vec = False):
    dm_phi1 = dm['GenPart_phi'][:,0]
    dm_phi2 = dm['GenPart_phi'][:,1]
    dm_pt1 = dm['GenPart_pt'][:,0]
    dm_pt2 = dm['GenPart_pt'][:,1]
    MET_vec = np.array([dm_pt1*np.cos(dm_phi1) + dm_pt2*np.cos(dm_phi2), dm_pt1*np.sin(dm_phi1) + dm_pt2*np.sin(dm_phi2)])
    tot_MET = np.linalg.norm(MET_vec, axis=0)
    if return_vec:
        return MET_vec, tot_MET
    else:
        return tot_MET
    
def dileptonPT(leptons, return_vec = False):
    phi1 = leptons['GenDressedLepton_phi'][:,0]
    phi2 = leptons['GenDressedLepton_phi'][:,1]
    pt1 = leptons['GenDressedLepton_pt'][:,0]
    pt2 = leptons['GenDressedLepton_pt'][:,1]

    lep_PT_vec = np.array([pt1*np.cos(phi1) + pt2*np.cos(phi2), pt1*np.sin(phi1) + pt2*np.sin(phi2)])
    PT_abs = np.linalg.norm(lep_PT_vec, axis=0)
    if return_vec:
        return lep_PT_vec, PT_abs
    else:
        return PT_abs

def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    """ Returns the angle in radians between vectors 'v1' and 'v2'::

            >>> angle_between((1, 0, 0), (0, 1, 0))
            1.5707963267948966
            >>> angle_between((1, 0, 0), (1, 0, 0))
            0.0
            >>> angle_between((1, 0, 0), (-1, 0, 0))
            3.141592653589793
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))

def getDM(events, dm_pdgId):
    branches = ['GenPart_eta', 'GenPart_mass', 'GenPart_phi', 'GenPart_pt', 'GenPart_genPartIdxMother', 'GenPart_pdgId', 'GenPart_status', 'GenPart_statusFlags']
    gen = events[branches]
    # I think status of 1 means that it is a final state particle so let's look at them first 
    gen_final = gen[gen.GenPart_status == 1]
    dm = gen_final[abs(gen_final.GenPart_pdgId) == dm_pdgId]
    return dm

def getLeptons(events):
    branches = ['GenDressedLepton_phi', 'GenDressedLepton_pt', 'GenDressedLepton_eta']
    leptons = events[branches]
    return leptons




In [None]:
leptons = getLeptons(events)
dm = getDM(events, 35)
dilepton_PT, _ = dileptonPT(leptons, return_vec=True)
PT_miss, _ = PTmiss(dm, return_vec=True)
angles = []
for i in range(len(dilepton_PT[0])):
    vec1 = np.array([dilepton_PT[0,i], dilepton_PT[1,i]])
    vec2 = np.array([PT_miss[0,i], PT_miss[1,i]])
    angle = angle_between(vec1, vec2)
    angles.append(angle)
angles = np.array(angles)

print(angles)

In [None]:
def PTmiss(dm, return_vec = False):
    dm_phi1 = dm['GenPart_phi'][:,0]
    dm_phi2 = dm['GenPart_phi'][:,1]
    dm_pt1 = dm['GenPart_pt'][:,0]
    dm_pt2 = dm['GenPart_pt'][:,1]
    x = np.array(dm_pt1*np.cos(dm_phi1) + dm_pt2*np.cos(dm_phi2))
    y = np.array(dm_pt1*np.sin(dm_phi1) + dm_pt2*np.sin(dm_phi2))
    MET_vec = np.stack((x,y), axis=1)
    tot_MET = np.linalg.norm(MET_vec, axis=1)
    if return_vec:
        return MET_vec, tot_MET
    else:
        return tot_MET
    
def dileptonPT(leptons, return_vec = False):
    phi1 = leptons['GenDressedLepton_phi'][:,0]
    phi2 = leptons['GenDressedLepton_phi'][:,1]
    pt1 = leptons['GenDressedLepton_pt'][:,0]
    pt2 = leptons['GenDressedLepton_pt'][:,1]

    x = np.array(pt1*np.cos(phi1) + pt2*np.cos(phi2))
    y = np.array(pt1*np.sin(phi1) + pt2*np.sin(phi2))
    PT_vec = np.stack((x,y), axis=1)
    PT_abs = np.linalg.norm(PT_vec, axis=1)
    if return_vec:
        return PT_vec, PT_abs
    else:
        return PT_abs

def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    """ Returns the angle in radians between vectors 'v1' and 'v2'::

            >>> angle_between((1, 0, 0), (0, 1, 0))
            1.5707963267948966
            >>> angle_between((1, 0, 0), (1, 0, 0))
            0.0
            >>> angle_between((1, 0, 0), (-1, 0, 0))
            3.141592653589793
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))

def getDM(events, dm_pdgId):
    branches = ['GenPart_eta', 'GenPart_mass', 'GenPart_phi', 'GenPart_pt', 'GenPart_genPartIdxMother', 'GenPart_pdgId', 'GenPart_status', 'GenPart_statusFlags']
    gen = events[branches]
    # I think status of 1 means that it is a final state particle so let's look at them first 
    gen_final = gen[gen.GenPart_status == 1]
    dm = gen_final[abs(gen_final.GenPart_pdgId) == dm_pdgId]
    return dm

def getLeptons(events):
    branches = ['GenDressedLepton_phi', 'GenDressedLepton_pt', 'GenDressedLepton_eta']
    leptons = events[branches]
    return leptons

def getJets(events):
    branches = ['GenJet_eta', 'GenJet_hadronFlavour', 'GenJet_mass', 'GenJet_partonFlavour', 'GenJet_phi', 'GenJet_pt']
    jets = events[branches]
    return jets

def jetPT(jets, return_vec = False):
    phi1 = jets['GenJet_phi'][:,0]
    pt1 = jets['GenJet_pt'][:,0]
    x = np.array(pt1*np.cos(phi1))
    y = np.array(pt1*np.sin(phi1))
    jet_vec = np.stack((x,y), axis=1)
    jet_vec_abs = np.linalg.norm(jet_vec, axis=1)
    if return_vec:
        return jet_vec, jet_vec_abs
    else:
        return jet_vec_abs

def cutDeltaPhiPTjetPTmiss(events, dm_pdgId):
    print(f'cutDeltaPhiPTjetPTmiss:')
    dm = getDM(events, dm_pdgId)
    PT_miss, _ = PTmiss(dm, return_vec=True)
    jets = getJets(events)
    jet_PT, _ = jetPT(jets, return_vec=True)
    angles = []
    for jet_vec, PT_vec in zip(jet_PT, PT_miss):
        angle = angle_between(jet_vec, PT_vec)
        angles.append(angle)
    angles = np.array(angles)
    mask = angles > 0.5
    print(f'Percentage of events taken out = {((len(mask) - np.sum(mask)) / len(mask)) * 100:.2f}')
    events = events[mask]
    return events


In [None]:
jets = getJets(events)

phi1 = jets['GenJet_phi'][1][0]
pt1 = jets['GenJet_pt'][1][0]
print(phi1)
print(pt1)
jet_PT, jet_abs = jetPT(jets, return_vec=True)
print(jet_PT[1])

In [None]:
# Now I can find the number of jets over 30 GeV for each event
num_jets = num_jet_over_30(jets)
_ = plt.hist(num_jets, density=True, bins=100, histtype='step')

In [None]:
jet_pt = jets.GenJet_pt
mask = jet_pt > 30
print(mask)
print(ak.sum(mask,axis=1))
num = ak.sum(mask, axis=1)
# I only want the ones that have 0 or 1 jet over 30 GeV
jet_0 = np.argwhere(num == 0)
print(jet_0)

jet_1 = np.argwhere(num == 1)
print(jet_1)


In [None]:
events0jet = events[jet_0]

In [None]:
print(events0jet.GenDressedLepton_pt[3])

In [None]:
print(jet_pt[8020])

In [None]:
#filename = "root://cmsxrootd.fnal.gov///store/mc/RunIIFall17NanoAODv7/Pseudoscalar2HDM_MonoZLL_mScan_mH-800_ma-400/NANOAODSIM/PU2017_12Apr2018_Nano02Apr2020_102X_mc2017_realistic_v8-v1/260000/F14B27DB-21EF-B443-95C5-ECD3AC953083.root"
filename = "root://cmsxrootd.fnal.gov///store/test/xrootd/T1_ES_PIC/store/mc/RunIIFall17NanoAODv7/Pseudoscalar2HDM_MonoZLL_mScan_mH-1400_ma-500/NANOAODSIM/PU2017_12Apr2018_Nano02Apr2020_102X_mc2017_realistic_v8-v1/120000/804FA4F0-379A-2148-A1A3-7E412433CEB6.root:Events"

options = {'timeout' : 180}
events = uproot.open(filename, **options)
#events = uproot.open('testfile.root:Events;1')
filters = ['GenDressedLepton*', 'GenPart*', 'GenJet*']
events = events.arrays(filter_name=filters, library='ak')

branches = ['GenDressedLepton_phi', 'GenDressedLepton_pt', 'GenDressedLepton_eta']
leptons = events[branches]
# First just get the events that have two leptons 
eta = leptons.GenDressedLepton_eta
count = ak.count(eta, axis=1, keepdims=True)
mask = count == 2
mask = ak.all(mask,axis=1)
leptons = leptons[mask]
branches = ['GenPart_eta', 'GenPart_mass', 'GenPart_phi', 'GenPart_pt', 'GenPart_genPartIdxMother', 'GenPart_pdgId', 'GenPart_status', 'GenPart_statusFlags']
gen = events[branches]
# I think status of 1 means that it is a final state particle so let's look at them first 
gen_final = gen[gen.GenPart_status == 1]
dm = gen_final[abs(gen_final.GenPart_pdgId) == 52]
# But we want the same events as we have for the leptons (i.e. with two leptons)
# so use the same mask as we did for the leptons
dm = dm[mask]

lep_pt = ak.flatten(leptons.GenDressedLepton_pt)

branches = ['GenJet_eta', 'GenJet_hadronFlavour', 'GenJet_mass', 'GenJet_partonFlavour', 'GenJet_phi', 'GenJet_pt']
jets = events[branches]


In [None]:
# Now I can find the number of jets over 30 GeV for each event
num_jets = num_jet_over_30(jets)
_ = plt.hist(num_jets, density=True, bins=100, histtype='step')

In [None]:
#filename = "root://cmsxrootd.fnal.gov///store/mc/RunIIFall17NanoAODv7/Pseudoscalar2HDM_MonoZLL_mScan_mH-800_ma-400/NANOAODSIM/PU2017_12Apr2018_Nano02Apr2020_102X_mc2017_realistic_v8-v1/260000/F14B27DB-21EF-B443-95C5-ECD3AC953083.root"
filename = "root://cmsxrootd.fnal.gov///store/test/xrootd/T1_ES_PIC/store/mc/RunIIFall17NanoAODv7/Pseudoscalar2HDM_MonoZLL_mScan_mH-1400_ma-500/NANOAODSIM/PU2017_12Apr2018_Nano02Apr2020_102X_mc2017_realistic_v8-v1/120000/804FA4F0-379A-2148-A1A3-7E412433CEB6.root:Events"

options = {'timeout' : 180}
events = uproot.open(filename, **options)
#events = uproot.open('testfile.root:Events;1')
filters = ['GenDressedLepton*', 'GenPart*']
events = events.arrays(filter_name=filters, library='ak')

branches = ['GenDressedLepton_phi', 'GenDressedLepton_pt', 'GenDressedLepton_eta']
leptons = events[branches]
# First just get the events that have two leptons 
eta = leptons.GenDressedLepton_eta
count = ak.count(eta, axis=1, keepdims=True)
mask = count == 2
mask = ak.all(mask,axis=1)
leptons = leptons[mask]
branches = ['GenPart_eta', 'GenPart_mass', 'GenPart_phi', 'GenPart_pt', 'GenPart_genPartIdxMother', 'GenPart_pdgId', 'GenPart_status', 'GenPart_statusFlags']
gen = events[branches]
# I think status of 1 means that it is a final state particle so let's look at them first 
gen_final = gen[gen.GenPart_status == 1]
dm = gen_final[abs(gen_final.GenPart_pdgId) == 52]
# But we want the same events as we have for the leptons (i.e. with two leptons)
# so use the same mask as we did for the leptons
dm = dm[mask]


In [None]:
BPs_scan = [8, 10, 12, 13, 14, 18, 19, 20, 21, 24]
lam2_vals = np.linspace(-3, 3, 10)
for BP in BPs_scan:
    for i, lam2 in enumerate(lam2_vals):
        print(f"{BP}, {i}")

In [None]:
import awkward
import vector
import numpy as np

vector.register_awkward()

import logging
logger = logging.getLogger(__name__)

from higgs_dna.taggers.tagger import Tagger, NOMINAL_TAG
from higgs_dna.selections import object_selections, lepton_selections, jet_selections, tau_selections, physics_utils
from higgs_dna.utils import awkward_utils, misc_utils

DUMMY_VALUE = -999.
GEN_WEIGHT_BAD_VAL = -99999.
DEFAULT_OPTIONS = {
   "jet_category" : 0,
   "photons" : {
        "use_central_nano" : True,
        "pt" : 25.0,
        "eta" : [ 
            [0.0, 1.4442],
            [1.566, 2.5]
        ],
        "e_veto" : 0.5, 
        "e_veto_invert" : False,
        "hoe" : 0.08,
        "r9" : 0.8,
        "charged_iso" : 20.0,
        "charged_rel_iso" : 0.3,
        "hlt" : {
            "eta_rho_corr" : 1.5,
            "low_eta_rho_corr" : 0.16544,
            "high_eta_rho_corr" : 0.13212,
            "eb_high_r9" : {
                "r9" : 0.85
            },
            "eb_low_r9" : { 
                "r9" : 0.5, 
                "pho_iso" : 4.0, 
                "track_sum_pt" : 6.0,
                "sigma_ieie" : 0.015
            },
            "ee_high_r9" : {
                "r9" : 0.9
            },
            "ee_low_r9" : { 
                "r9" : 0.8, 
                "pho_iso" : 4.0, 
                "track_sum_pt" : 6.0,
                "sigma_ieie" : 0.035
            }
        }
    },
    "electrons" : {
        "pt" : 10.0,
        "eta" : 2.5,
        "dxy" : 0.045,
        "dz" : 0.2,
        "id" : "WP90",
        "dr_photons" : 0.2,
        "veto_transition" : True,
    },
    "muons" : {
        "pt" : 15.0,
        "eta" : 2.4,
        "dxy" : 0.045,
        "dz" : 0.2,
        "id" : "medium",
        "pfRelIso03_all" : 0.3,
        "dr_photons" : 0.2
    },
    "taus" : {
        "pt" : 20.0,
        "eta" : 2.3,
        "dz" : 0.2,
        "deep_tau_vs_ele" : 1,
        "deep_tau_vs_mu" : 0,
        "deep_tau_vs_jet" : 7,
        "dr_photons" : 0.2,
        "dr_electrons" : 0.2,
        "dr_muons" : 0.2
    },
    "jets" : {
        "pt" : 25.0,
        "eta" : 2.4,
        "looseID" : True,
        "dr_photons" : 0.4,
        "dr_electrons" : 0.4,
        "dr_muons" : 0.4,
        "dr_taus" : 0.4,
        "dr_iso_tracks" : 0.4,
        "bjet_thresh" : {
            "2016UL_postVFP" : 0.3093,
            "2016UL_preVFP": 0.3093,
            "2017" : 0.3033,
            "2018" : 0.2770
        }
    },
    "photon_mvaID" : -0.7
}

class HDMaFullTagger(Tagger):
    """
    Preselection Tagger for dummy analysis
    """
    def __init__(self, name = "HDMa_full_tagger", options = {}, is_data = None, year = None):
        super(HDMaFullTagger, self).__init__(name, options, is_data, year)

        if not options:
            self.options = DEFAULT_OPTIONS 
        else:
            self.options = misc_utils.update_dict(
                    original = DEFAULT_OPTIONS,
                    new = options
            )


    def calculate_selection(self, events):
        #################################
        ### Dummy Preselection ###
        #################################

        ### Presel step 1 : select objects ###

        # Photons
        if not self.options["photons"]["use_central_nano"]:
            if "fixedGridRhoAll" in events.fields:
                rho = events.fixedGridRhoAll
            elif "fixedGridRhoFastjetAll" in events.fields:
                rho = events.fixedGridRhoFastjetAll
            elif "Rho_fixedGridRhoAll" in events.fields:
                rho = events.Rho_fixedGridRhoAll
            else:
                logger.exception("[DiphotonTagger : calculate_selection] Did not find valid 'rho' field.")
                raise RuntimeError()
        else:
            rho = awkward.ones_like(events.Photon)

        photon_cut = object_selections.select_photons(
                photons = events.Photon,
                rho = rho,
                options = self.options["photons"]
        )

        photons = awkward_utils.add_field(
                events = events,
                name = "SelectedPhoton",
                data = events.Photon[photon_cut]
        )

        
        # Electrons
        electron_cut = lepton_selections.select_electrons(
                electrons = events.Electron,
                options = self.options["electrons"],
                clean = {
                    "photons" : {
                        "objects" : events.SelectedPhoton,
                        "min_dr" : self.options["electrons"]["dr_photons"]
                    }
                },
                name = "SelectedElectron",
                tagger = self
        )

        electrons = awkward_utils.add_field(
                events = events,
                name = "SelectedElectron",
                data = events.Electron[electron_cut]
        )
        # awkward.to_parquet(electrons, 'electrons.parquet')
        # print(f'The electrons object is == {electrons}')
        # print(f'Keys = {awkward.fields(electrons)}')
        # print(f'PT = {electrons.pt}')
        # print(f'PT[:,0]>41 = {electrons.pt[:,0] > 41}')
        # print(f'PT[:,1]>41 = {electrons.pt[:,1] > 41}')
        # Muons
        muon_cut = lepton_selections.select_muons(
                muons = events.Muon,
                options = self.options["muons"],
                clean = {
                    "photons" : {
                        "objects" : events.SelectedPhoton,
                        "min_dr" : self.options["muons"]["dr_photons"]
                    }
                },
                name = "SelectedMuon",
                tagger = self
        )

        muons = awkward_utils.add_field(
                events = events,
                name = "SelectedMuon",
                data = events.Muon[muon_cut]
        )

        # Taus
        tau_cut = tau_selections.select_taus(
                taus = events.Tau,
                options = self.options["taus"],
                clean = {
                    "photons" : {
                        "objects" : events.SelectedPhoton,
                        "min_dr" : self.options["taus"]["dr_photons"]
                    },
                    "electrons" : {
                        "objects" : events.SelectedElectron,
                        "min_dr" : self.options["taus"]["dr_electrons"]
                    },
                    "muons" : {
                        "objects" : events.SelectedMuon,
                        "min_dr" : self.options["taus"]["dr_muons"]
                    }
                },
                name = "AnalysisTau",
                tagger = self
        )

        taus = awkward_utils.add_field(
                events = events,
                name = "AnalysisTau",
                data = events.Tau[tau_cut]
        )


        # Jets
        jet_cut = jet_selections.select_jets(
                jets = events.Jet,
                options = self.options["jets"],
                clean = {
                    "photons" : {
                        "objects" : events.SelectedPhoton,
                        "min_dr" : self.options["jets"]["dr_photons"]
                    },
                    "electrons" : {
                        "objects" : events.SelectedElectron,
                        "min_dr" : self.options["jets"]["dr_electrons"]
                    },
                    "muons" : {
                        "objects" : events.SelectedMuon,
                        "min_dr" : self.options["jets"]["dr_muons"]
                    },
                    "taus" : {
                        "objects" : events.AnalysisTau,
                        "min_dr" : self.options["jets"]["dr_taus"]
                    }
                },
                name = "SelectedJet",
                tagger = self
        )

        jets = awkward_utils.add_field(
                events = events,
                name = "SelectedJet",
                data = events.Jet[jet_cut]
        )

        bjets = jets[awkward.argsort(jets.btagDeepFlavB, axis = 1, ascending = False)]
        awkward_utils.add_object_fields(
                events = events,
                name = "b_jet",
                objects = bjets,
                n_objects = 2,
                fields = ["btagDeepFlavB"],
                dummy_value = DUMMY_VALUE
        )


        # Add object fields to events array
        for objects, name in zip([electrons, muons, taus, jets], ["electron", "muon", "tau", "jet"]):
            awkward_utils.add_object_fields(
                    events = events,
                    name = name,
                    objects = objects,
                    n_objects = 2,
                    dummy_value = DUMMY_VALUE
            )

        n_photons = awkward.num(photons)
        awkward_utils.add_field(events, "n_photons", n_photons, overwrite=True)

        n_electrons = awkward.num(electrons)
        awkward_utils.add_field(events, "n_electrons", n_electrons, overwrite=True)
        
        n_muons = awkward.num(muons)
        awkward_utils.add_field(events, "n_muons", n_muons, overwrite=True)
        
        n_leptons = n_electrons + n_muons
        awkward_utils.add_field(events, "n_leptons", n_leptons, overwrite=True)
        
        n_taus = awkward.num(taus)
        awkward_utils.add_field(events, "n_taus", n_taus, overwrite=True)
        
        n_jets = awkward.num(jets)
        awkward_utils.add_field(events, "n_jets", n_jets, overwrite=True)

        n_bjets = awkward.num(bjets[bjets.btagDeepFlavB > self.options["jets"]["bjet_thresh"][self.year]]) 
        awkward_utils.add_field(events, "n_bjets", n_bjets, overwrite=True)

        ### Presel step 2: event level cuts ###

        ### Presel step 3: define event channels e.g. by number of leptons ###

        if "weight" not in events.fields:
            events["weight"] = awkward.ones_like(n_photons)
        gen_weight_cut = events.weight != GEN_WEIGHT_BAD_VAL

        n_leptons_cut = n_leptons >= 2
        
        # ------------------------------------------------N_l + P^l_T cut------------------------------------------------------
        # Want events with two electrons/muons, with leading pt>25
        # subleading pt> 20, and none with a subsubleading pt>10
        # First bracket for >25 condition, second for >10 condition
        # and last for the extra lepton veto
        electron_cut = (awkward.num(electrons.pt[electrons.pt>=25])>=1)&(awkward.num(electrons.pt[electrons.pt>=20])>=2)&((awkward.num(electrons.pt[electrons.pt>=10])<3) | ~(awkward.num(muons.pt[muons.pt>10])<1))
        muon_cut = (awkward.num(muons.pt[muons.pt>=25])>=1)&(awkward.num(muons.pt[muons.pt>=20])>=2)&((awkward.num(muons.pt[muons.pt>=10])<3)  | ~(awkward.num(electrons.pt[electrons.pt>10])<1))

        # lep_veto_OS_cut_elec = []
        # for i , val in enumerate(electron_cut):
        #     # If it's val==True then I need to check that they have opposite sign
        #     # and that there is no extra lepton with pt>10 
        #     if val == True:
        #         lead_charge = electrons[i].charge[0]
        #         sublead_charge = electrons[i].charge[1]
        #         if lead_charge * sublead_charge == -1:
        #             if len(electrons[i].pt) > 2:
        #                 if electrons[i].pt[3] > 10:
        #                     lep_veto_OS_cut_elec.append(False)
        #                     continue
        #             if len(muons[i].pt) > 0:
        #                 if muons[i].pt[0] > 10:
        #                     lep_veto_OS_cut_elec.append(False)
        #                     continue
        #             lep_veto_OS_cut_elec.append(True)
        #         else:
        #             lep_veto_OS_cut_elec.append(False)
        #     else:
        #         lep_veto_OS_cut_elec.append(False)
        # lep_veto_OS_cut_elec = awkward.Array(lep_veto_OS_cut_elec)

        # lep_veto_OS_cut_muon = []
        # for i , val in enumerate(muon_cut):
        #     # If it's val==True then I need to check that they have opposite sign
        #     # and that there is no extra lepton with pt>10 
        #     if val == True:
        #         lead_charge = muons[i].charge[0]
        #         sublead_charge = muons[i].charge[1]
        #         if lead_charge * sublead_charge == -1:
        #             if len(muons[i].pt) > 2:
        #                 if muons[i].pt[3] > 10:
        #                     lep_veto_OS_cut_muon.append(False)
        #                     continue
        #             if len(electrons[i].pt) > 0:
        #                 if electrons[i].pt[0] > 10:
        #                     lep_veto_OS_cut_muon.append(False)
        #                     continue
        #             lep_veto_OS_cut_muon.append(True)
        #         else:
        #             lep_veto_OS_cut_muon.append(False)
        #     else:
        #         lep_veto_OS_cut_muon.append(False)
        # lep_veto_OS_cut_muon = awkward.Array(lep_veto_OS_cut_muon)

        # electron_cut = electron_cut & lep_veto_OS_cut_elec
        # muon_cut = muon_cut & lep_veto_OS_cut_muon


        # ------------------------------------------------Dilepton Mass Cut------------------------------------------------------
        # Construct di-electron/di-muon pairs
        electrons_4V = awkward.Array(electrons, with_name = "Momentum4D")
        muons_4V = awkward.Array(muons, with_name = "Momentum4D")
        ee_pairs = awkward.combinations(
                electrons_4V, # objects to make combinations out of
                2, # how many objects go in a combination
                fields = ["LeadLepton", "SubleadLepton"] # can access these as e.g. ee_pairs.LeadLepton.pt
        )
        mm_pairs = awkward.combinations(muons_4V, 2, fields = ["LeadLepton", "SubleadLepton"])
        # Concatenate these together
        z_cands = awkward.concatenate([ee_pairs, mm_pairs], axis = 1)
        # Want only events where the dilepton deltaR<1.4
        deltaR_cut = z_cands.LeadLepton.deltaR(z_cands.SubleadLepton) < 1.4
        z_cands[deltaR_cut]
        z_cands["ZCand"] = z_cands.LeadLepton + z_cands.SubleadLepton # these add as 4-vectors since we registered them as "Momentum4D" objects
        # Make Z candidate-level cuts
        os_cut = z_cands.LeadLepton.charge * z_cands.SubleadLepton.charge == -1
        mass_cut = (z_cands.ZCand.mass > (91.1876 - 15)) & (z_cands.ZCand.mass < (91.1876 + 15))
        lead_pt_cut = z_cands.LeadLepton.pt > 25
        sublead_pt_cut = z_cands.SubleadLepton.pt > 20
        total_lepton_pt_cut = z_cands.ZCand.pt > 60

        # Also need cuts for the dilepton system vs the missing energy
        MET = events[['MET_covXY', 'MET_covXX', 'MET_covYY', 'MET_phi', 'MET_pt']]
        # Want to add fields that are just called phi and pt
        MET["pt"] = MET['MET_pt']
        MET['phi'] = MET['MET_phi']
        MET4V =  awkward.Array(MET, with_name="Momentum4D")
        deltaPhi_pTl_pTmiss_cut = MET4V.deltaphi(z_cands.ZCand) > 2.6
        # balance ratio cut 
        BR_cut = abs(MET.pt - z_cands.ZCand.pt) / z_cands.ZCand.pt < 0.4

        # Transverse mass cut
        MT_cut = np.sqrt(2 * z_cands.ZCand.pt * MET.pt * (1 - np.cos(MET4V.deltaphi(z_cands.ZCand)))) > 200

        z_cut = os_cut & mass_cut & lead_pt_cut & sublead_pt_cut & total_lepton_pt_cut & deltaPhi_pTl_pTmiss_cut & BR_cut & MT_cut
        z_cands = z_cands[z_cut] # OSSF lepton pairs with m_ll [M_Z - 15 , M_Z + 15]

        # Make event level cut
        has_z_cand = awkward.num(z_cands) >= 1 # veto any event that has no dilep system satisfying the above conditions
        ee_event = awkward.num(electrons_4V) >= 2
        mm_event = awkward.num(muons_4V) >= 2
        z_veto = (has_z_cand & (ee_event | mm_event))

        # ------------------------------------------------N Jet Cut------------------------------------------------------
        # Let's start with the 0 Jet category 
        # Find jets that have pt over 30GeV
        num_jets = self.options['jet_category']
        if num_jets == 0:
            n_jet_cut = jets.pt > 30
            # If any are True, then label that event as True, then take the opposite to get
            # events without any Jets with over 30 pt (opposite using the ~)
            n_jet_cut = ~(awkward.any(n_jet_cut, axis=1))
        else:
            # This is if it equals 1 jet instead
            jet4V = awkward.Array(jets, with_name="Momentum4D")
            print(jet4V.pt)
            # # Delta phi cut between the jet and the missing energy 
            # deltaPhi_pTjet_pTmiss_cut = jet4V.deltaphi(MET4V) > 0.5
            # print(deltaPhi_pTjet_pTmiss_cut)
            # jet_cuts = n_jet_cut & deltaPhi_pTjet_pTmiss_cut


        # ------------------------------------------------MET Cut------------------------------------------------------
        MET_cut = events.MET_pt > 80


        presel_cut =  n_leptons_cut & gen_weight_cut & z_veto & (electron_cut | muon_cut) & n_jet_cut & MET_cut
        self.register_cuts(
            names = ["n leptons cut", "gen weight cut", "all cuts"],
            results = [n_leptons_cut, gen_weight_cut, presel_cut]
        )

        return presel_cut, events 
