In [None]:
import jax.numpy as jnp
import numpy as np
import os
from jax import jit
from functools import partial
import matplotlib.pyplot as plt

import appletree as apt
from appletree import flex
from appletree import ipm
from appletree import imm
from appletree.flex import randgen
from appletree.flex.plugin import *

In [None]:
apt.utils.set_gpu_memory_usage(0.2)
par_manager = ipm.ParManager('./appletree/ipm/par_config/par_config.json')
map_manager = imm.MapManager()

In [None]:
par_manager.sample_init()
par_manager.get_all_parameter()

In [None]:
map_manager.register_json_map(os.path.join(imm.MAPPATH, 's1_correction_map_regbin.json'), coord_type='regbin', map_name='s1_lce')
map_manager.register_json_map(os.path.join(imm.MAPPATH, 's2_correction_map_regbin.json'), coord_type='regbin', map_name='s2_lce')
map_manager.register_json_map(os.path.join(imm.MAPPATH, 'elife.json'), coord_type='point', map_name='elife')
map_manager.register_json_map(os.path.join(imm.MAPPATH, 's1_bias.json'), coord_type='point', map_name='s1_bias')
map_manager.register_json_map(os.path.join(imm.MAPPATH, 's1_smearing.json'), coord_type='point', map_name='s1_smear')
map_manager.register_json_map(os.path.join(imm.MAPPATH, 's2_bias.json'), coord_type='point', map_name='s2_bias')
map_manager.register_json_map(os.path.join(imm.MAPPATH, 's2_smearing.json'), coord_type='point', map_name='s2_smear')
map_manager.register_json_map(os.path.join(imm.MAPPATH, '3fold_recon_eff.json'), coord_type='point', map_name='s1_eff')

In [None]:
map_manager.registration

In [None]:
args = (par_manager, map_manager)

#### Microphysics plugins

In [None]:
plugin1 = EnergySpectra(*args)
plugin2 = Quenching(*args)
plugin3 = Ionization(*args)
plugin4 = mTI(*args)
plugin5 = RecombFluct(*args)
plugin6 = TrueRecomb(*args)
plugin7 = Recombination(*args)

#### Detector plugins

In [None]:
plugin8 = PositionSpectra(*args)
plugin9 = S1Correction(*args)
plugin10 = S2Correction(*args)
plugin11 = PhotonDetection(*args)
plugin12 = S1PE(*args)
plugin13 = DriftLoss(*args)
plugin14 = ElectronDrifted(*args)
plugin15 = S2PE(*args)

#### Reconstruction plugins

In [None]:
plugin16 = S1(*args)
plugin17 = S2(*args)
plugin18 = cS1(*args)
plugin19 = cS2(*args)

#### Efficiency plugins

In [None]:
plugin20 = S2Threshold(*args)
plugin21 = S1ReconEff(*args)

# Get binning

In [None]:
import pandas as pd
data = pd.read_csv('./appletree/bbf/data/data_XENONnT_Rn220_v8_strax_v1.2.2_straxen_v1.7.1_cutax_v1.9.0.csv')
x_bins, y_bins = apt.utils.get_equiprob_bins_2d(
    data[['cs1', 'cs2']].to_numpy(),
    [15, 15],
    order=[0, 1],
    x_clip=[0, 100],
    y_clip=[1e2, 1e4]
)

# Pipeline

In [None]:
@apt.utils.timeit
@partial(jit, static_argnums=(1, ))
def sim(key, n):
    key, energy = plugin1(key, n)
    key, n_q = plugin2(key, energy)
    key, n_i = plugin3(key, n_q)
    key, r_mean = plugin4(key, energy)
    key, r_std = plugin5(key, energy)
    key, r = plugin6(key, r_mean, r_std)
    key, n_ph, n_e = plugin7(key, n_q, n_i, r)
    
    key, x, y, z = plugin8(key, n)
    key, s1_correction = plugin9(key, x, y, z)
    key, s2_correction = plugin10(key, x, y)
    key, n_s1_phd = plugin11(key, n_ph, s1_correction)
    key, n_s1_pe = plugin12(key, n_s1_phd)
    key, surv_prob = plugin13(key, z)
    key, n_e_drifted = plugin14(key, n_e, surv_prob)
    key, n_s2_pe = plugin15(key, n_e_drifted, s2_correction)
    
    key, s1 = plugin16(key, n_s1_phd, n_s1_pe)
    key, s2 = plugin17(key, n_s2_pe)
    key, cs1 = plugin18(key, s1, s1_correction)
    key, cs2 = plugin19(key, s2, s2_correction, surv_prob)
    
    key, acc_s2_threshold = plugin20(key, s2)
    key, acc_s1_recon_eff = plugin21(key, n_s1_phd)
    
    eff = acc_s2_threshold*acc_s1_recon_eff
    
    hist = flex.hist.make_hist_irreg_bin_2d(
        jnp.asarray([cs1, cs2]).T,
        x_bins, y_bins, weights=eff
    )
    
    return key, hist

In [None]:
batch_size = int(1e6)
key = flex.randgen.get_key()

#### Build

In [None]:
key, hist = sim(key, batch_size)

In [None]:
apt.utils.plot_irreg_histogram_2d(x_bins, y_bins, hist, density=False)
plt.yscale('log')
plt.ylim(5e2, 1e4)

#### Speed test

In [64]:
@apt.utils.timeit
def benchmark():
    key = randgen.get_key()
    for itr in range(100):
        key, _ = sim(key, int(1e6))

In [65]:
benchmark()

 Function <benchmark> starts. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 1.127958 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 1.078367 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 9.681702 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 13.568640 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 13.511658 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 19.915819 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 19.883871 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 17.416954 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 13.598919 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 10.976315 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 17.527580 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 15.602827 msec. 
 Function <sim> starts. 
 Function <sim> ends! Time cost = 1