In [1]:
from coffea.nanoevents import NanoEventsFactory, NanoAODSchema
from coffea import hist, processor
import numpy as np
import awkward as ak
import matplotlib.pyplot as plt
from pprint import pprint
import numba

In [2]:
def show(a, ievt=0):
    pprint(a[ievt].tolist())

In [161]:
filename = "/pnfs/psi.ch/cms/trivcat/store/user/mmarcheg/RunIIFall17NanoAODv7/ttHTobb_M125_TuneCP5_13TeV-powheg-pythia8/587E2464-42CA-3A45-BD49-D23E49F658E6.root"
#events = NanoEventsFactory.from_root(filename, schemaclass=NanoAODSchema).events()
events = NanoEventsFactory.from_root(filename, schemaclass=NanoAODSchema, entry_stop=10000).events()

dataset = "ttHTobb"

In [4]:
ievt = 0

In [5]:
%load_ext snakeviz

In [163]:
isOutgoing = events.LHEPart.status == 1
isB = abs(events.LHEPart.pdgId) == 5
bquarks = events.LHEPart[isB & isOutgoing]

# Select b-quarks at Gen level, coming from H->bb decay
if dataset == 'ttHTobb':
    isHiggs = events.GenPart.pdgId == 25
    isHard = events.GenPart.hasFlags(['fromHardProcess'])
    hasTwoChildren = ak.num(events.GenPart.childrenIdxG, axis=2) == 2
    higgs = events.GenPart[isHiggs & isHard & hasTwoChildren]
    bquarks = ak.concatenate( (bquarks, ak.flatten(higgs.children, axis=2)), axis=1 )
    # Sort b-quarks by pt
    bquarks = ak.with_name(bquarks[ak.argsort(bquarks.pt, ascending=False)], name='PtEtaPhiMCandidate')


In [164]:
%%time
# Compute deltaR(b, jet) and save the nearest jet (deltaR matching)
dr_min = 0.4
deltaR = ak.flatten(bquarks.metric_table(events.Jet), axis=2)
# keeping only the pairs with a deltaR min
maskDR = deltaR < dr_min
deltaRcut = deltaR[maskDR] #--> the mask works since we are applying it also to the argcartesian
idx_pairs_sorted = ak.argsort(deltaRcut, axis=1)
pairs = ak.argcartesian([bquarks, events.Jet])[maskDR]
pairs_sorted = pairs[idx_pairs_sorted]
idx_quark, idx_jets = ak.unzip(pairs_sorted)
       

CPU times: user 348 ms, sys: 4.99 ms, total: 353 ms
Wall time: 351 ms


In [10]:
show(deltaRcut,1)

[0.057874202728271484, 0.24081091582775116, 0.06480082869529724]


In [11]:
show(pairs,1)

[(0, 0), (1, 1), (2, 5)]


In [12]:
show(idx_pairs_sorted,1)

[0, 2, 1]


In [96]:
pairs_sorted[0:10].tolist()

[[(1, 2), (2, 4), (3, 5), (0, 1)],
 [(0, 0), (2, 5), (1, 1)],
 [(1, 2), (0, 0), (3, 8), (2, 5)],
 [(2, 4), (0, 0), (3, 5), (1, 2)],
 [(0, 1), (1, 3), (2, 1)],
 [(2, 2), (1, 0), (3, 5), (0, 2), (0, 1), (2, 3)],
 [(1, 1), (0, 2), (2, 4), (3, 5)],
 [(1, 2), (0, 1), (2, 3), (3, 6)],
 [(0, 2), (2, 5), (1, 6)],
 [(2, 1), (0, 0), (3, 2), (1, 5)]]

In [165]:
def get_valid_pairs_masks(idx_bquarks, idx_Jet):
    hasMatch = ak.zeros_like(idx_Jet, dtype=bool)
    Npairmax = ak.max(ak.num(idx_bquarks))
    # Loop over the (parton, jet) pairs
    for idx_pair in range(Npairmax):
        idx_bquark = ak.pad_none(idx_bquarks, Npairmax)[:,idx_pair]
        idx_match_candidates = idx_Jet[ak.fill_none( (idx_bquarks == idx_bquark) & ~hasMatch, False)]
        idx_pair_candidates  = ak.local_index(idx_Jet)[ak.fill_none( (idx_bquarks == idx_bquark) & ~hasMatch, False)]
        #print(idx_bquark[ievt], idx_match_candidates[ievt])
        #print(idx_bquark[ievt], idx_pair_candidates[ievt])
        if idx_pair == 0:
            idx_matchedJet    = ak.unflatten( ak.firsts(idx_match_candidates), 1 )
            idx_matchedParton = ak.unflatten( idx_bquark, 1 )
            idx_matchedPair   = ak.unflatten( ak.firsts(idx_pair_candidates), 1 )
        else:
            # If the partons are matched in all events or the number of jets is smaller than the number of partons, stop iterating
            if ak.all( ( (ak.count(idx_matchedJet, axis=1) == ak.count(bquarks.pt, axis=1)) | (ak.count(events.Jet.pt, axis=1) < ak.count(bquarks.pt, axis=1) ) ) ): break
            idx_matchedJet    = ak.concatenate( (idx_matchedJet, ak.unflatten( ak.firsts(idx_match_candidates), 1 ) ), axis=1 )
            idx_matchedParton = ak.concatenate( (idx_matchedParton, ak.unflatten( idx_bquark, 1 )), axis=1 )
            idx_matchedPair   = ak.concatenate( (idx_matchedPair, ak.unflatten( ak.firsts(idx_pair_candidates), 1 ) ), axis=1 )
        # The mask `hasMatch` masks to False the 
        hasMatch = hasMatch | ak.fill_none(idx_Jet == ak.fill_none(ak.firsts(idx_match_candidates), -99), False) | ak.fill_none(idx_bquarks == idx_bquark, False)
        #print(idx_pair, hasMatch[ievt].tolist(), end='\n\n')

    idx_matchedParton = idx_matchedParton[~ak.is_none(idx_matchedJet, axis=1)]
    idx_matchedJet = idx_matchedJet[~ak.is_none(idx_matchedJet, axis=1)]
    return idx_matchedParton, idx_matchedJet

In [166]:
def get_valid_pairs_nonumba(idx_quark, idx_jets, builder):
    for ev_q, ev_j in zip(idx_quark, idx_jets):
        builder.begin_list()
        q_done = []
        j_done = []
        n = 0
        for q,j in zip(ev_q, ev_j):
            if q not in q_done and j not in j_done:
                builder.begin_record("pair")
                builder.field("quark").append(q)
                builder.field("jet").append(j)
                builder.end_record()
                q_done.append(q)
                j_done.append(j)
        builder.end_list()
    return builder

In [167]:
@numba.jit
def get_valid_pairs_numba(idx_quark, idx_jets, builder):
    for ev_q, ev_j in zip(idx_quark, idx_jets):
        builder.begin_list()
        q_done = []
        j_done = []
        n = 0
        for q,j in zip(ev_q, ev_j):
            if q not in q_done and j not in j_done:
                builder.begin_record("pair")
                builder.field("quark").append(q)
                builder.field("jet").append(j)
                builder.end_record()
                q_done.append(q)
                j_done.append(j)
        builder.end_list()
    return builder

In [176]:
%%time
results1 = get_valid_pairs_masks(idx_quark, idx_jets)

CPU times: user 488 ms, sys: 13.1 ms, total: 501 ms
Wall time: 498 ms


In [174]:
%%time
results2 = get_valid_pairs_nonumba(idx_quark, idx_jets, ak.ArrayBuilder())

CPU times: user 1.17 s, sys: 3.91 ms, total: 1.17 s
Wall time: 1.17 s


In [177]:
%%time
results3 = get_valid_pairs_numba(idx_quark, idx_jets, ak.ArrayBuilder())

CPU times: user 43.2 ms, sys: 62 µs, total: 43.3 ms
Wall time: 42.5 ms


In [159]:
qq, jj = results1
for i, (r, q) in enumerate(zip(results2,  results3)):
    for R,R2 in zip(r,q):
        if R.quark !=R2.quark or R.jet != R2.jet:
            print(i , "Mismatch")

In [181]:
qq, jj = results1
for i, (r, q, j) in enumerate(zip(results2, qq, jj)):
    mismatch = False
    for R,Q,J in zip(r,q,j):
        if R.quark !=Q or R.jet != J:
            mismatch = True
    if mismatch:
        print(f"ev:{i}, Nquark: {len(q)}, Njets: {len(j)}, match simple: {r}, match complex: {q},{j}")
        print(f"\t original sorted pairs: {pairs_sorted[i]}")

ev:2340, Nquark: 3, Njets: 3, match simple: [{quark: 0, jet: 1}, {quark: 1, jet: 0}, {quark: 3, jet: 4}], match complex: [0, 1, 2],[1, 0, 4]
	 original sorted pairs: [(0, 1), (1, 0), (2, 1), (3, 4), (2, 4)]
ev:2854, Nquark: 3, Njets: 3, match simple: [{quark: 0, jet: 0}, {quark: 1, jet: 2}, {quark: 2, jet: 6}], match complex: [0, 2, 1],[0, 6, 2]
	 original sorted pairs: [(0, 0), (2, 0), (1, 2), (2, 6)]
ev:4295, Nquark: 3, Njets: 3, match simple: [{quark: 0, jet: 1}, {quark: 2, jet: 5}, {quark: 3, jet: 4}], match complex: [0, 2, 1],[1, 5, 4]
	 original sorted pairs: [(0, 1), (2, 5), (1, 5), (3, 4), (1, 4)]
ev:4357, Nquark: 4, Njets: 4, match simple: [{quark: 1, jet: 0}, {quark: 2, jet: 7}, {quark: 3, jet: 6}, {quark: 0, jet: 11}], match complex: [1, 0, 2, 3],[0, 11, 7, 6]
	 original sorted pairs: [(1, 0), (0, 0), (2, 7), (3, 6), (0, 11), (1, 16)]
ev:5106, Nquark: 4, Njets: 4, match simple: [{quark: 0, jet: 1}, {quark: 2, jet: 0}, {quark: 3, jet: 2}, {quark: 1, jet: 5}], match complex: [

### Check profiling

In [183]:
%%snakeviz 
results1 = get_valid_pairs_masks(idx_quark, idx_jets)

 
*** Profile stats marshalled to file '/tmp/tmp1q8k9o1s'. 
Embedding SnakeViz in this document...


In [182]:
%%snakeviz 
results3 = get_valid_pairs_numba(idx_quark, idx_jets, ak.ArrayBuilder())

 
*** Profile stats marshalled to file '/tmp/tmpif7oopc_'. 
Embedding SnakeViz in this document...


In [16]:
    
dr_matchedJet = deltaR[idx_pairs_sorted][~ak.is_none(idx_matchedPair, axis=1)]
#print("idx_matchedPair", idx_matchedPair)
idx_matchedPair = idx_matchedPair[~ak.is_none(idx_matchedPair, axis=1)]
matchedJet    = evprint.Jet[idx_matchedJet]
matchedParton = bquarks[idx_matchedParton]
#print("matchedJet", matchedJet)
hasMatchedPartons = ak.count(idx_matchedParton, axis=1) == ak.count(bquarks.pt, axis=1)
#print(hasMatchedPartons)
#for cut in self._selections.keys():
#    print(events.metadata["dataset"], cut, "matched partons =", round(100*ak.sum(hasMatchedPartons[self._cuts.all(*self._selections[cut])])/ak.size(hasMatchedPartons[self._cuts.all(*self._selections[cut])]), 2), "%")
events["BQuark"] = bquarks
events["JetMatched"] = matchedJet
events["BQuarkMatched"] = matchedParton
events["BQuarkMatched"] = ak.with_field(events.BQuarkMatched, dr_matchedJet, "drMatchedJet")
#print("deltaR", deltaR)
return bquarks, idx_matchedJet, idx_matchedParton, idx_matchedPair, dr_matchedJet