# Preliminiaries

In [1]:
import pyprob
import numpy as np
import ot
import torch
import cProfile

from pyprob.dis import ModelDIS
from showerSim import invMass_ginkgo
from torch.utils.data import DataLoader
from pyprob.nn.dataset import OnlineDataset
from pyprob.util import InferenceEngine
from pyprob.util import to_tensor
from pyprob import Model
from pyprob.model import Parallel_Generator
import math
from pyprob.distributions import Normal
from pyprob.distributions.delta import Delta


import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as mpl_cm
plt.ion()

import sklearn as skl
from sklearn.linear_model import LinearRegression

from geomloss import SamplesLoss
sinkhorn = SamplesLoss(loss="sinkhorn", p=1, blur=.05)
def sinkhorn_t(x,y):
    x = to_tensor(x)
    y = torch.stack(y)
    return sinkhorn(x,y)

def ot_dist(x,y):
    # x = to_tensor(x)
    # y = torch.stack(y)
    x = np.array(x)
    y = np.array(torch.stack(y))
    a = ot.unif(len(x))
    b = ot.unif(len(y))
    Mat = ot.dist(x, y, metric='euclidean')
    #Mat1 /= Mat1.max()
    distance = to_tensor(ot.emd2(a,b,Mat))
    return distance

def leaf_dist(x,y):
    x_count = len(x)
    y_count = len(y)
    return torch.abs(x_count-y_count)


device = "cpu"

from pyprob.util import set_device
set_device(device)

# Define Ginkgo Simulators

In [2]:
obs_leaves = to_tensor([[44.57652381, 26.16169856, 25.3945314 , 25.64598258],
                           [18.2146321 , 10.70465096, 10.43553391, 10.40449709],
                           [ 6.47106713,  4.0435395,  3.65545951,  3.48697568],
                           [ 8.43764314,  5.51040615,  4.60990593,  4.42270416],
                           [26.61664145, 16.55894826, 14.3357362 , 15.12215264],
                           [ 8.62925002,  3.37121204,  5.19699   ,  6.00480461],
                           [ 1.64291837,  0.74506775,  1.01003622,  1.05626017],
                           [ 0.75525072,  0.3051808 ,  0.45721085,  0.51760643],
                           [39.5749915 , 18.39638928, 24.24717939, 25.29349408],
                           [ 4.18355659,  2.11145474,  2.82071304,  2.25221316],
                           [ 0.82932922,  0.29842766,  0.5799056 ,  0.509021  ],
                           [ 3.00825023,  1.36339397,  1.99203677,  1.79428211],
                           [ 7.20024308,  4.03280868,  3.82379277,  4.57441754],
                           [ 2.09953618,  1.28473579,  1.03554351,  1.29769683],
                           [12.21401828,  6.76059035,  6.94920042,  7.42823701],
                           [ 6.91438054,  3.68417135,  3.83782514,  4.41656731],
                           [ 1.97218904,  1.01632927,  1.08008339,  1.27454585],
                           [ 8.58164301,  5.06157833,  4.79691164,  4.99553141],
                           [ 5.97809522,  3.26557958,  3.4253764 ,  3.64894791],
                           [ 5.22842301,  2.94437891,  3.10292633,  3.00551074],
                           [15.40023764,  9.10884407,  8.93836964,  8.61970667],
                           [ 1.96101346,  1.24996337,  1.06923988,  1.06743143],
                           [19.81054106, 11.90268453, 11.60989346, 10.76953856],
                           [18.79470876, 11.429855  , 10.8377334 , 10.25112761],
                           [25.74331932, 15.63430056, 14.83860792, 14.07189108],
                           [ 9.98357576,  6.10090721,  5.68664128,  5.48748692],
                           [12.34604239,  7.78770185,  6.76075998,  6.78498685],
                           [21.24998531, 12.95180254, 11.9511704 , 11.87319933],
                           [ 7.80693733,  4.83117128,  4.27443559,  4.39602348],
                           [16.28983576,  9.66683929,  9.24891886,  9.28970032],
                           [ 2.50706736,  1.53153206,  1.36060018,  1.43002765],
                           [ 3.73938645,  2.06006639,  2.31013974,  2.09378969],
                           [20.2174725 , 11.88622367, 12.05106468, 11.05325362],
                           [ 9.48660008,  5.53665456,  5.54171966,  5.34966654],
                           [ 2.65812987,  1.64102742,  1.67392209,  1.25083707]])


QCD_mass = to_tensor(30.)
#rate=to_tensor([QCD_rate,QCD_rate]) #Entries: [root node, every other node] decaying rates. Choose same values for a QCD jet
jetdir = to_tensor([1.,1.,1.])
jetP = to_tensor(400.)
jetvec = jetP * jetdir / torch.linalg.norm(jetdir) ## Jetvec is 3-momentum. JetP is relativistic p.


# Actual parameters
pt_min = to_tensor(0.3**2)
M2start = to_tensor(QCD_mass**2)
jetM = torch.sqrt(M2start) ## Mass of initial jet
jet4vec = torch.cat((torch.sqrt(jetP**2 + jetM**2).reshape(-1), jetvec))
minLeaves = 1
maxLeaves = 10000 # unachievable, to prevent rejections
maxNTry = 100



class SimulatorModelDIS(invMass_ginkgo.SimulatorModel, ModelDIS):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def dummy_bernoulli(self, jet):
        return True

    def forward(self, inputs=None):
        assert inputs is None # Modify code if this ever not met?
        # Sample parameter of interest from Unif(0,10) prior
        root_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                  name="decay_rate_parameter")
        decay_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                   name="decay_rate_parameter")
        # Simulator code needs two decay rates for (1) root note (2) all others
        # For now both are set to the same value
        inputs = [root_rate, decay_rate]
        jet = super().forward(inputs)
        delta_val = self.dummy_bernoulli(jet)
        bool_func_dist = pyprob.distributions.Bernoulli(delta_val)
        pyprob.observe(bool_func_dist, name = "dummy")
        return jet

# Make instance of the simulator
ginkgodis = SimulatorModelDIS(jet_p=jet4vec,  # parent particle 4-vector
                                    pt_cut=float(pt_min),  # minimum pT for resulting jet
                                    Delta_0= M2start,  # parent particle mass squared -> needs tensor
                                    M_hard=jetM,  # parent particle mass
                                    minLeaves=1,  # minimum number of jet constituents
                                    maxLeaves=10000,  # maximum number of jet constituents (a large value to stop expensive simulator runs)
                                    suppress_output=True,
                                    obs_leaves=obs_leaves,
                                    dist_fun=leaf_dist)

In [3]:
obs_leaves2=to_tensor([[ 25.8005,  15.8486,  13.7905,  14.9743],
        [ 64.6767,  39.8886,  34.8096,  37.1519],
        [112.1183,  68.9756,  59.4660,  65.3958],
        [ 17.2854,   9.8922,   9.2324,  10.7553],
        [  8.0760,   4.8046,   4.1940,   4.9540],
        [ 32.6218,  17.6533,  21.0402,  17.6020],
        [ 78.1670,  42.4337,  50.0361,  42.4953],
        [ 23.3093,  12.4575,  14.6925,  13.1246],
        [  8.5133,   4.5174,   5.0757,   5.1276],
        [ 17.5119,   8.7775,  10.6228,  10.8064],
        [ 11.6579,   5.0303,   7.2342,   7.6330],
        [  0.8138,   0.3167,   0.5029,   0.5491],
        [  0.5715,   0.3441,   0.2433,   0.3707]])

ginkgodis2 = SimulatorModelDIS(jet_p=jet4vec,  # parent particle 4-vector
                                    pt_cut=float(pt_min),  # minimum pT for resulting jet
                                    Delta_0= M2start,  # parent particle mass squared -> needs tensor
                                    M_hard=jetM,  # parent particle mass
                                    minLeaves=1,  # minimum number of jet constituents
                                    maxLeaves=10000,  # maximum number of jet constituents (a large value to stop expensive simulator runs)
                                    suppress_output=True,
                                    obs_leaves=obs_leaves2,
                                    dist_fun=leaf_dist)

In [4]:
class SimulatorModelIC(invMass_ginkgo.SimulatorModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, inputs=None):
        assert inputs is None # Modify code if this ever not met?
        # Sample parameter of interest from Unif(0,10) prior
        root_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                  name="decay_rate_parameter")
        decay_rate = pyprob.sample(pyprob.distributions.Uniform(0.01, 10.),
                                   name="decay_rate_parameter")
        # Simulator code needs two decay rates for (1) root note (2) all others
        # For now both are set to the same value
        inputs = [root_rate, decay_rate]
        jet = super().forward(inputs)
        leaf_delta = Delta(len(jet['leaves']))
        pyprob.observe(leaf_delta, name='leaf_count')
        return jet

model = SimulatorModelIC(jet_p=jet4vec,  # parent particle 4-vector
                                    pt_cut=float(pt_min),  # minimum pT for resulting jet
                                    Delta_0= M2start,  # parent particle mass squared -> needs tensor
                                    M_hard=jetM,  # parent particle mass
                                    minLeaves=1,  # minimum number of jet constituents
                                    maxLeaves=10000,  # maximum number of jet constituents (a large value to stop expensive simulator runs)
                                    suppress_output=True)

In [None]:
model.learn_inference_network(num_traces=81920,
                              observe_embeddings={'leaf_count' : {'dim' : 32}},
                              inference_network=pyprob.InferenceNetwork.LSTM)

Creating new inference network...
Observable leaf_count: reshape not specified, using shape torch.Size([]).
Observable leaf_count: using embedding dim torch.Size([32]).
Observable leaf_count: observe embedding not specified, using the default FEEDFORWARD.
Observable leaf_count: embedding depth not specified, using the default 2.
Observe embedding dimension: 32
Train. time | Epoch| Trace     | Init. loss| Min. loss | Curr. loss| T.since min | Learn.rate| Traces/sec
New layers, address: 32__forward__root_rate__Uniform__1, distribution: Uniform
New layers, address: 56__forward__decay_rate__Uniform__1, distribution: Uniform
New layers, address: 216__forward___traverse___traverse_rec__phi_CM__Un..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec__theta_CM_U..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec__draw_decay..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec__draw

New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traver

New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___trav

New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform1m[32m0d:00:00:00[0m | +1.00e-03 | 8.2 
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256_

New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___trav

New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential                            
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 216__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 256__forward___traverse___traverse_rec___traverse_..., distribution: Uniform
New layers, address: 460__forward___traverse___traverse_rec___traverse_..., distribution: TruncatedExponential
New layers, address: 510_

In [None]:
model.save_inference_network('inference_networks/Ginkgo_IC')

# Training