In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import pprint
import matplotlib.pyplot as plt
import sys
import pickle
import argparse
import logging
import os

# from ginkgo import invMass_ginkgo
# from invMass_ginkgo import *
from invMass_ginkgo_node import *
from ginkgo.utils import get_logger


logger = get_logger(level = logging.WARNING)
fh = logging.FileHandler('spam.log')
fh.setLevel(logging.WARNING)
logger.addHandler(fh)


"""
create node class
understand original output
simplify recursion logic further

define typology of a tree
access data on each node

call algo to generate tree
have a fully defined tree with 4D vectors for each node

generate likelihood for each tree 
compute eq (12)
eq(14) simplifies to eq(8), apply 8 repeatedly when matching l and r

"""

  from .autonotebook import tqdm as notebook_tqdm


'\ncreate node class\nunderstand original output\nsimplify recursion logic further\n\ndefine typology of a tree\naccess data on each node\n\ncall algo to generate tree\nhave a fully defined tree with 4D vectors for each node\n\ngenerate likelihood for each tree \ncompute eq (12)\neq(14) simplifies to eq(8), apply 8 repeatedly when matching l and r\n\n'

In [3]:
"""Parameters"""


rate2 = torch.tensor(8.)

# Parameters to get ~<10 constituents to test the trellis algorithm
pt_min = torch.tensor(4.**2)

### Physics inspired parameters to get ~ between 20 and 50 constituents
W_rate = 3.
QCD_rate = 1.5

QCD_mass = 30.
class ginkgo_simulator():
    def __init__(self,
                 rate,
                 pt_cut,
                 M2start,
                 Nsamples,
                 minLeaves,
                 maxLeaves,
                 maxNTry,
                 jetType, 
                 jetP,
                 root_rate= 1.5,
                ):
        
        self.root_rate = root_rate
        self.rate = rate
        self.pt_cut = pt_cut
        self.M2start = torch.tensor(M2start) # mass squared to start with
        self.Nsamples = Nsamples
        self.minLeaves = minLeaves
        self.maxLeaves = maxLeaves
        self.maxNTry = maxNTry
        self.jetType = jetType # W or QCD 
        self.jetM = np.sqrt(M2start) # mass to start with
        self.jetdir = np.array([1,1,1])
        self.jetP = jetP
        self.jetvec = self.jetP * self.jetdir / np.linalg.norm(self.jetdir)
        self.jet4vec = np.concatenate(([np.sqrt(self.jetP ** 2 + self.jetM ** 2)], self.jetvec))
        logger.debug(f"jet4vec = {self.jet4vec}")
        
        if jetType == "W":
            # defined in paper, W jets have a different root rate
            self.rate=torch.tensor([self.root_rate,self.rate])
        elif jetType == "QCD":
            # QCD jets maintain the same rate throughout
            self.rate=torch.tensor([self.rate,self.rate])
        else:
            raise ValueError("Choose a valid jet type between W or QCD")



    def simulator(self):

        simulator = Simulator(jet_p = self.jet4vec,
                                         pt_cut = float(self.pt_cut),
                                         Delta_0 = self.M2start,
                                         M_hard = self.jetM ,
                                         num_samples = int(self.Nsamples),
                                         minLeaves = int(self.minLeaves),
                                         maxLeaves = int(self.maxLeaves),
                                         maxNTry = int(self.maxNTry)
                                         )
        return simulator
       
    def generate(self):
        
        simulator = self.simulator()
        jet_list = simulator(self.rate)

        logger.debug(f"---"*10)
        logger.debug(f"jet_list = {jet_list}")
        
        return jet_list

In [4]:
# the range of leaves that you would consider as valid generations
minLeaves = 3
maxLeaves = 100

# number of jets you wish to generate
Nsamples = 1

# exponential rate parameter
rate = 1.5

# mass squared cut off to yield leaves
pt_cut =  torch.tensor(1.1**2)

# mass squared to start with
# M2start = 80.**2
M2start = 5.**2

# the maximum times you are willing to try to get Nsamples
maxNTry = 1

# jetP=400.
jetP=4.

In [5]:
jetType ="QCD"


ginkgo = ginkgo_simulator(
                 rate,
                 pt_cut ,
                 M2start,
                 Nsamples,
                 minLeaves,
                 maxLeaves,
                 maxNTry,
                 jetType, 
                 jetP)

QCD_jets = ginkgo.generate()

Node 0
 Vec4: [6.40312424 2.30940108 2.30940108 2.30940108]
 Decay Rate: 1
 Mass Squared: tensor(25.)
 Log Likelihood: -7.6007909530882225
 DIJ List: -7.6007909530882225 0.2078267172858753 2.166555308251764 3.4061457459831144

Node 1
 Vec4: [3.76697119 2.97761797 1.24844196 0.93091431]
 Decay Rate: tensor(0.2949)
 Mass Squared: tensor(2.8987)
 Log Likelihood: -3.4526740606213293
 DIJ List: -3.4526740606213293 0.01703905494276906 0.05902995735272885 0.11159882207418811

Node 2
 Vec4: [ 2.63615305 -0.66821689  1.06095912  1.37848676]
 Decay Rate: tensor(0.1391)
 Mass Squared: tensor(3.4769)
 Log Likelihood: -3.8926587398434362
 DIJ List: -3.8926587398434362 0.10216770637181595 0.0441424535803581 0.016247935756071325

Node 3
 Vec4: [1.93904408 1.75569092 0.61801368 0.35853349]
 Decay Rate: tensor(0.4041)
 Mass Squared: tensor(0.1670)
 Log Likelihood: 0
 DIJ List:

Node 4
 Vec4: [1.82792711 1.22192705 0.63042828 0.57238082]
 Decay Rate: tensor(0.3875)
 Mass Squared: tensor(1.1232)
 Log Lik

In [6]:
root_rate = 4.
jetType ="W"


ginkgo= ginkgo_simulator(
                 rate,
                 pt_cut ,
                 M2start,
                 Nsamples,
                 minLeaves,
                 maxLeaves,
                 maxNTry,
                 jetType, 
                 jetP,
                 root_rate)

W_jets = ginkgo.generate()

Node 0
 Vec4: [6.40312424 2.30940108 2.30940108 2.30940108]
 Decay Rate: 1
 Mass Squared: tensor(25.)
 Log Likelihood: -7.498405058290433
 DIJ List: -7.498405058290433 0.23126889432394307 3.061664245193524 3.2475055451929404

Node 1
 Vec4: [ 2.5820282   0.54604914 -0.87322951  2.19896425]
 Decay Rate: tensor(0.0523)
 Mass Squared: tensor(0.7707)
 Log Likelihood: 0
 DIJ List:

Node 2
 Vec4: [3.82109604 1.76335194 3.18263059 0.11043683]
 Decay Rate: tensor(0.0540)
 Mass Squared: tensor(1.3500)
 Log Likelihood: -2.6358070667172155
 DIJ List: -2.6358070667172155 0.05740117478321504 0.5949560628099919 0.13327289778082477

Node 3
 Vec4: [0.5606329  0.38726256 0.2720888  0.28376467]
 Decay Rate: tensor(0.0214)
 Mass Squared: tensor(0.0098)
 Log Likelihood: 0
 DIJ List:

Node 4
 Vec4: [ 3.26046324  1.37608943  2.91054187 -0.17332784]
 Decay Rate: tensor(0.1746)
 Mass Squared: tensor(0.2357)
 Log Likelihood: 0
 DIJ List:

  ┌1
 0┤
  │ ┌3
  └2┤
    └4


In [13]:
def llh(pL, pR, t_cut, lam):
    """
    Take two nodes and return the splitting log likelihood
    """
    tL = pL[0] ** 2 - np.linalg.norm(pL[1::]) ** 2
    tR = pR[0] ** 2 - np.linalg.norm(pR[1::]) ** 2


    pP = pR + pL ## eq (5)

    
    # Parent invariant mass squared
    tp = pP[0] ** 2 - np.linalg.norm(pP[1::]) ** 2
    # print(tL, tR, tp)
    if tp<=0 or tL<0 or tR<0:
        return - np.inf

    # We add a normalization factor -np.log(1 - np.exp(- lam))
    # because we need the mass squared to be strictly decreasing.
    # This way the likelihood integrates to 1 for 0<t<t_p.
    # All leaves should have t=0, this is a convention we are
    # taking (instead of keeping their value for t given that
    # it is below the threshold t_cut)
    def get_logp(tP_local, t, t_cut, lam):
        if t > t_cut:
            # Probability of the shower to stop F_s
            return -np.log(1 - np.exp(- (1. - 1e-3)*lam)) + np.log(lam) - np.log(tP_local) - lam * t / tP_local

        else: # For leaves we have t<t_cut
            t_upper = min(tP_local,t_cut) #There are cases where tp2 < t_cut
            log_F_s = -np.log(1 - np.exp(- (1. - 1e-3)*lam)) + np.log(1 - np.exp(-lam * t_upper / tP_local))
            return log_F_s


    if tp <= t_cut:
        #If the pairing is not allowed
        logLH = - np.inf

    elif tL >=(1 - 1e-3)* tp or tR >=(1 - 1e-3)* tp:
        # print("The pairing is not allowed because tL or tR are greater than tP")
        logLH = - np.inf

    elif np.sqrt(tL) + np.sqrt(tR) > np.sqrt(tp):
        print("Breaking invariant mass inequality condition")
        logLH = - np.inf


    else:
        # We sample a unit vector uniformly over the 2-sphere, so the angular likelihood is 1/(4*pi)

        tpLR = (np.sqrt(tp) - np.sqrt(tL)) ** 2
        tpRL = (np.sqrt(tp) - np.sqrt(tR)) ** 2

        logpLR = np.log(1/2)+ get_logp(tp, tL, t_cut, lam) + get_logp(tpLR, tR, t_cut, lam) #First sample tL
        logpRL = np.log(1/2)+ get_logp(tp, tR, t_cut, lam) + get_logp(tpRL, tL, t_cut, lam) #First sample tR

        logp_split = logsumexp(np.asarray([logpLR, logpRL]))

        logLH = (logp_split + np.log(1 / (4 * np.pi)) ) ## eq (8)

    return logLH, tp

In [8]:
import pickle

def leafToDict(leaf):
    leafDict = {}
    leafDict["vec4"] = leaf.vec4
    leafDict["delta"] = leaf.delta
    return leafDict
def nodeListToDictList(nodeList):
    dictList = []
    for i in nodeList:
        dictList.append(leafToDict(i))
    return dictList
def pickleDictList(dictList):
    with open("data.p", "wb") as f:
        pickle.dump(dictList, f)


In [9]:
pickleDictList(nodeListToDictList(QCD_jets))

with open('data.p', 'rb') as f:
    output = pickle.load(f)
    print(output)

[{'vec4': array([6.40312424, 2.30940108, 2.30940108, 2.30940108]), 'delta': tensor(25.)}, {'vec4': array([3.76697119, 2.97761797, 1.24844196, 0.93091431]), 'delta': tensor(2.8987)}, {'vec4': array([ 2.63615305, -0.66821689,  1.06095912,  1.37848676]), 'delta': tensor(3.4769)}, {'vec4': array([1.93904408, 1.75569092, 0.61801368, 0.35853349]), 'delta': tensor(0.1670)}, {'vec4': array([1.82792711, 1.22192705, 0.63042828, 0.57238082]), 'delta': tensor(1.1232)}, {'vec4': array([ 1.49480409, -0.41529857,  0.5094957 ,  0.63248879]), 'delta': tensor(1.4023)}, {'vec4': array([ 1.14134896, -0.25291833,  0.55146342,  0.74599797]), 'delta': tensor(0.3781)}, {'vec4': array([ 1.12558923, -0.36392265,  0.23320151,  0.6380379 ]), 'delta': tensor(0.6730)}, {'vec4': array([ 0.3692149 , -0.05137592,  0.27629421, -0.00554909]), 'delta': tensor(0.0573)}]


In [15]:
from node import *
import numpy as np

# finish function
# test function against actual values
# modify code to output 

"""


why is it left leaning???

3 cases

2 leaves
1 leave 1 internal
2 internal


tenserflow version (1.15 or something)



"""
cut_off = pt_cut # cut off rate defined previously

# Checking leaves
# leaves = [(index, i) for index, i in enumerate(QCD_jets) if i.left is None and i.right is None]
# print(*sorted([i[0] for i in leaves]))

# Getting Leaves
leaves = [i for i in QCD_jets if i.left is None and i.right is None]


print()

def testRecontruct(parent, left, right, cut_off):
    parentDecayRate = parent.decay_rate

    def reconstruct(left, right, cut_off):
        nonlocal parentDecayRate
        parVec4 = left.vec4 + right.vec4
        parLH, parDelta = llh(left.vec4, right.vec4, cut_off, parentDecayRate)
        parent = jetNode(vec4 = parVec4,
                         left = left,
                         right = right,
                         decay_rate = parentDecayRate, # template value
                         delta = parDelta,
                         logLH = parLH
                         )
        return parent
    res = reconstruct(left, right, cut_off)
    
    # print(parent.vec4, res.vec4)
    # print(parent.decay_rate, res.decay_rate)
    # print(parent.delta, res.delta)
    # print(parent.logLH, res.logLH)
    # print()
    return res

# testRecontruct(QCD_jets[0], QCD_jets[1], QCD_jets[2], cut_off)

def rec_test(node, cut_off):
    if node.right and node.left:
        testRecontruct(node, node.left, node.right, cut_off)
        rec_test(node.left, cut_off)
        rec_test(node.right, cut_off)


rec_test(QCD_jets[0], cut_off)




    


[6.40312424 2.30940108 2.30940108 2.30940108] [6.40312424 2.30940108 2.30940108 2.30940108]
1 1
tensor(25.) 24.999999999999993
-7.6007909530882225 -7.6007909530882225

[3.76697119 2.97761797 1.24844196 0.93091431] [3.76697119 2.97761797 1.24844196 0.93091431]
tensor(0.2949) tensor(0.2949)
tensor(2.8987) 2.8986543578935713
-3.4526740606213293 -3.4526740606213293

[ 2.63615305 -0.66821689  1.06095912  1.37848676] [ 2.63615305 -0.66821689  1.06095912  1.37848676]
tensor(0.1391) tensor(0.1391)
tensor(3.4769) 3.4769290847612386
-3.8926587398434362 -3.8926587398434362

[ 1.49480409 -0.41529857  0.5094957   0.63248879] [ 1.49480412 -0.41529858  0.50949572  0.63248881]
tensor(0.4033) tensor(0.4033)
tensor(1.4023) 1.4023384831878412
-2.650988727369376 -2.650988727369376

