In [8]:

import dataclasses

# import NEST
import tensorflow_probability.substrates.jax as tfp

from jaxns import Categorical
from jaxns import Prior, Model

tfpd = tfp.distributions

import numpy as np

# JAX imports
import jax
import jax.numpy as jnp
from jax.scipy.special import erf


@dataclasses.dataclass(eq=False)
class NEST:
    def __post_init__(self):
        # Global vars
        self.sqrt2 = np.sqrt(2)
        self.inv_sqrt2_PI = 1 / np.sqrt(2 * jnp.pi)

    def fano_ER(self, mean_N_q, E_field, density):
        """
        Calculates the a Fano-like factor to determine the variance of quanta generation.
    
        Parameters
        ----------
        mean_N_q : scalar or array
            Mean number of quanta produced per event.
        E_field : scalar or array
            Electric field strength in the detector.
        density : scalar or array
            Density of the detector material.
    
        Returns
        -------
        Fano factor : scalar or array
            The calculated Fano factor based on the input parameters.
        
        """

        return (0.13 - 0.030 * density - 0.0057 * density ** 2 + 0.0016 * density ** 3 + 0.0015 * jnp.sqrt(
            mean_N_q * E_field))

    def workfunction(self, density, molar_mass=131.293, avo=6.0221409e+23, atom_num=54.0, old_13eV=1.1716263232):
        alpha = 0.067366 + density * 0.039693
        eDensity = (density / molar_mass) * avo * atom_num
        return old_13eV * (18.7263 - 1.01e-23 * eDensity), alpha

    def calculate_yield_parameters(self, E_field):
        m01 = 30.66 + (6.1978 - 30.66) / (1. + (E_field / 73.855) ** 2.0318) ** 0.41883
        m02 = 77.2931084
        m03 = jnp.log10(E_field) * 0.13946236 + 0.52561312
        m04 = 1.82217496 + (2.82528809 - 1.82217496) / (1. + (E_field / 144.65029656) ** -2.80532006)
        # m05 = Nq / energy / (1. + alpha * erf(0.05 * E)) - m01
        m07 = 7.02921301 + (98.27936794 - 7.02921301) / (1. + (E_field / 256.48156448) ** 1.29119251)
        m08 = 4.285781736
        m09 = 0.3344049589
        m10 = 0.0508273937 + (0.1166087199 - 0.0508273937) / (1. + (E_field / 1.39260460e+02) ** -0.65763592)
        return m01, m02, m03, m04, m07, m08, m09, m10

    def get_yields_beta(self, E, density, exciton_ion_ratio, W, m01=7.096208, m02=77.2931084, m03=0.7155229,
                        m04=1.8279102,
                        m07=94.39740941928082, m08=4.285781736, m09=0.3344049589, m10=0.06623858):
        """
        Calculates the yields of quanta, electrons and photons for beta radiation in a detector.
    
        Parameters
        ----------
        E : scalar or array
            Energy of the incident beta radiation.
        density : scalar or array
            Density of the detector material.
        exciton_ion_ratio : scalar
            Ratio of excitons to ions produced in the detector.
        W : scalar
            Average energy required to produce an electron-ion pair.
        m01, m02, m03, m04, m07, m08, m09, m10 : scalars
            Parameter values used in the empirical formula for calculating yields.
    
        Returns
        -------
        mean_N_q : scalar or array
            Mean number of quanta produced per event.
        Ne : scalar or array
            Number of electrons produced.
        Nph : scalar or array
            Number of photons produced.
    
        Notes
        -----
        This function calculates the yields of quanta, electrons and photons produced by beta radiation 
        in LXe based on the energy of the incident radiation, the density of the detector material, 
        and the ratio of excitons to ions produced.
        """
        DENSITY = 2.9
        mean_N_q = E / W * 1e3
        m05 = mean_N_q / E / (1 + exciton_ion_ratio) - m01
        Qy = m01 + (m02 - m01) / ((1. + (E / m03) ** m04)) ** m09 + m05 + (0.0 - m05) / ((1. + (E / m07) ** m08)) ** m10
        coeff_TI = (1. / DENSITY) ** 0.3
        coeff_Ni = (1. / DENSITY) ** 1.4
        coeff_OL = (1. / DENSITY) ** -1.7 / jnp.log(1. + coeff_TI * coeff_Ni * (DENSITY ** 1.7))
        Qy *= coeff_OL * jnp.log(1. + coeff_TI * coeff_Ni * (density ** 1.7)) * (density ** -1.7)
        Ly = mean_N_q / E - Qy
        Ne = Qy * E
        Nph = Ly * E
        return mean_N_q, Ne, Nph

    def Nei_ratio(self, E, density, alpha):
        """
        Calculates the ratio of the number of electrons to ions for a given energy and material density.
    
        Parameters
        ----------
        E : scalar or array
            Energy of the incident radiation.
        density : scalar or array
            Density of the detector material.
    
        Returns
        -------
        Nei_ratio : scalar or array
            The ratio of the number of electrons to ions.
    
        Notes
        -----
        This function calculates the ratio of the number of electrons to ions produced by 
        ERs.
        """

        return alpha * erf(0.05 * E)

    def recom_omega_ER(self, E_field, elecFrac, width_param_7=0.046452, width_param_8=0.205, width_param_9=0.45,
                       width_param_10=-0.2):
        '''
        This function calculates the omega parameter for electron recoil (ER) events based on certain input parameters.
        
        Parameters
        ----------
        E_field : float 
            The electric field value.
        elecFrac : float 
            The fraction of energy deposited by electrons.
        width_param_7 : float, optional 
            Parameter A from Table VI of https://arxiv.org/pdf/2211.10726.pdf; default is 0.046452.
        width_param_8 : float, optional 
            Omega parameter from Table VI of https://arxiv.org/pdf/2211.10726.pdf; default is 0.205.
        width_param_9 : float, optional 
            Xi parameter from Table VI of https://arxiv.org/pdf/2211.10726.pdf; default is 0.45.
        width_param_10 : float, optional 
            Skewness parameter; default is -0.2.
    
        Returns
        -------
        omega : float 
            The calculated omega parameter for ER events.
        width_param_10 : float 
            The skewness parameter.
    
        Notes
        -----
        The function uses the provided parameters to calculate the value of omega for ER events. It also returns the skewness parameter.
        '''
        A = 0.086036 + (width_param_7 - 0.086036) / (1 + (E_field / 295.2) ** 251.6) ** 0.0069114
        wide = width_param_8
        cntr = width_param_9
        skew = width_param_10
        mode = cntr + 2 * (self.inv_sqrt2_PI) * skew * wide / jnp.sqrt(1. + skew * skew)
        norm = 1. / (jnp.exp(-0.5 * (mode - cntr) ** 2 / (wide * wide)) * (
                1. + erf(skew * (mode - cntr) / (wide * self.sqrt2))))
        omega = norm * A * jnp.exp(-0.5 * (elecFrac - cntr) ** 2 / (wide * wide)) * (
                1. + erf(skew * (elecFrac - cntr) / (wide * self.sqrt2)))
        return omega, width_param_10

    def ER_skew(self, E, E_field, alpha0=1.39, cc0=4.0, cc1=22.1, E0=7.7, E1=54., E2=26.7, E3=6.4, F0=225., F1=71.):
        """
        LUX ER Skewness model for computing the recombination skewness. Default parameters from NEST 2.3.12. 
        
        Parameters:
        E (float): Energy in keV.
        E_field (float): Electric field strength in V/cm.
        alpha0 (float): Baseline scale factor for the ER response. Default value is 1.39.
        cc0 (float): Scale factor for the first exponential term. Default value is 4.0.
        cc1 (float): Scale factor for the second exponential term. Default value is 22.1.
        E0 (float): Decay constant for the first exponential term. Default value is 7.7.
        E1 (float): Decay constant for the second exponential term. Default value is 54.
        E2 (float): Inflection point for the sigmoid function. Default value is 26.7.
        E3 (float): Scale factor for the sigmoid function. Default value is 6.4.
        F0 (float): Decay constant for the field dependence of the first exponential term. Default value is 225.
        F1 (float): Decay constant for the field dependence of the second exponential term. Default value is 71.
    
        Returns:
        float: The ER response given the input parameters. If E_field is less than 50 or greater than 1e4, it returns 0.0.
        """
        mask = jnp.bitwise_or(E_field < 50, E_field > 1e4)
        return jnp.where(mask,
                         0.,
                         1. / (1. + jnp.exp((E - E2) / E3)) * (
                                 alpha0 + cc0 * jnp.exp(-1. * E_field / F0) * (1. - jnp.exp(-1. * E / E0))) + 1. / (
                                 1. + jnp.exp(-1. * (E - E2) / E3)) * cc1 * jnp.exp(-1. * E / E1) * jnp.exp(
                             -1. * jnp.sqrt(E_field) / jnp.sqrt(F1)))


@dataclasses.dataclass(eq=False)
class XenonModel:
    data_S1: jax.Array
    data_S2: jax.Array
    e_survival_prob: jax.Array
    relative_LCE: jax.Array
    density: jax.Array

    E_field = 60.
    S2_sigma: float = 8.
    g1: float = 0.105
    g2: float = 17.3
    SE_gain: float = 31.8
    PMT_sigma: float = 0.35
    p_dpe: float = 0.2

    def prior_model(self):
        nest = NEST()
        W, alpha = nest.workfunction(self.density)
        m01, m02, m03, m04, m07, m08, m09, m10 = nest.calculate_yield_parameters(
            self.E_field)

        signal_rate = yield Prior(tfpd.Uniform(0.0, 2000.), name='signal_rate')
        bg_rate = yield Prior(tfpd.Uniform(0.0, 2000.), name='bg_rate')
        iodine_rate = yield Prior(tfpd.Uniform(0.0, 2000.), name='iodine_rate')
        total_rate = signal_rate + bg_rate + iodine_rate

        categorical_dist_for_E = yield Categorical(
            parametrisation='gumbel_max',
            probs=jnp.array([jnp.abs(bg_rate), jnp.abs(signal_rate), jnp.abs(iodine_rate)]) / (
                    jnp.abs(bg_rate) + jnp.abs(signal_rate) + jnp.abs(iodine_rate)),
            name='categorical_dist_for_E'
        )

        E_component_distributions = [
            (yield Prior(tfpd.Uniform(50., 100.), name='E_component_distributions_0')),
            (yield Prior(tfpd.TruncatedCauchy(64.3, 0.5, low=0., high=2000.), name='E_component_distributions_1')),
            (yield Prior(tfpd.TruncatedCauchy(67.3, 0.5, low=0., high=2000.), name='E_component_distributions_2'))
        ]

        E_component_distributions = jnp.stack(E_component_distributions, axis=0)

        E = E_component_distributions[categorical_dist_for_E]
        exciton_ion_ratio = nest.Nei_ratio(E, self.density, alpha)
        mean_N_q, mean_Ne, mean_Nph = nest.get_yields_beta(E,
                                                           self.density,
                                                           exciton_ion_ratio,
                                                           W,
                                                           m01=m01,
                                                           m02=m02,
                                                           m03=m03,
                                                           m04=m04,
                                                           m07=m07,
                                                           m08=m08,
                                                           m09=m09,
                                                           m10=m10)

        fano_factor = 1.
        Nq_scaled = yield Prior(tfpd.TruncatedCauchy(
            1.0, jnp.sqrt(fano_factor / mean_N_q), low=-1e5, high=10), name='Nq_scaled')
        Nq = Nq_scaled * mean_N_q
        alf = 1 / (1 + exciton_ion_ratio)
        # probs = alpha / total_count = alf
        alpha = Nq * alf + 1
        # beta + alpha = total_count
        beta = Nq - Nq * alf + 1
        Ni = Nq * (yield Prior(tfpd.Beta(concentration1=alpha, concentration0=beta), name='Ni'))
        e_frac = mean_Ne / mean_N_q
        recombProb = 1. - (exciton_ion_ratio + 1.) * e_frac
        recomb_omega, _ = nest.recom_omega_ER(E, self.E_field, e_frac)
        skewness = nest.ER_skew(E, self.E_field)

        recomb_variance = recombProb * (1 - recombProb) * Ni + (recomb_omega * Ni) ** 2
        widthCorrection = jnp.sqrt(1. - (2. / jnp.pi) * skewness * skewness / (1. + skewness * skewness))
        muCorrection = (jnp.sqrt(recomb_variance) / widthCorrection) * (
                skewness / jnp.sqrt(1. + skewness * skewness)) * 2. * nest.inv_sqrt2_PI
        Ne_mu = (1. - recombProb) * Ni - muCorrection
        Ne_sigma = jnp.sqrt(recomb_variance) / widthCorrection
        Ne = yield Prior(tfpd.TruncatedCauchy(Ne_mu, Ne_sigma, low=0., high=1e5), name='Ne')

        print(Nq, Ne)
        Nph = Nq - Ne
        probs = self.g1 * self.relative_LCE / (1 + self.p_dpe)
        # Nph_det = yield Prior(tfpd.Binomial(total_count=Nph, probs=self.g1 * self.relative_LCE / (1 + self.p_dpe)),
        #                       name='Nph_det')
        # Use Beta
        alpha = Nph * probs + 1
        beta = Nph - Nph * probs + 1
        Nph_det = Nph * (yield Prior(tfpd.Beta(concentration1=alpha, concentration0=beta), name='Nph_det'))

        # Nph_det_dpe = yield Prior(tfpd.Binomial(total_count=Nph_det, probs=self.p_dpe), name='Nph_det_dpe')
        # use Beta
        alpha = Nph_det * self.p_dpe + 1
        beta = Nph_det - Nph_det * self.p_dpe + 1
        Nph_det_dpe = Nph_det * (yield Prior(tfpd.Beta(concentration1=alpha, concentration0=beta), name='Nph_det_dpe'))

        # Ne_det = yield Prior(tfpd.Binomial(total_count=Ne, probs=self.e_survival_prob), name='Ne_det')
        # use Beta
        alpha = Ne * self.e_survival_prob + 1
        beta = Ne - Ne * self.e_survival_prob + 1
        Ne_det = Ne * (yield Prior(tfpd.Beta(concentration1=alpha, concentration0=beta), name='Ne_det'))

        return Nph_det, Nph_det_dpe, Ne_det, total_rate

    def log_likelihood(self, Nph_det, Nph_det_dpe, Ne_det, total_rate):
        count = self.data_S1.shape[0]
        log_prob_count = tfpd.Poisson(total_rate).log_prob(count).sum()
        log_prob_S1 = tfpd.Normal(jnp.abs(Nph_det + Nph_det_dpe),
                                  jnp.sqrt(Nph_det + Nph_det_dpe) * self.PMT_sigma).log_prob(self.data_S1).sum()

        log_prob_S2 = tfpd.Normal(jnp.abs(Ne_det) * self.SE_gain, jnp.sqrt(Ne_det) * self.S2_sigma).log_prob(
            self.data_S2).sum()

        return log_prob_count + log_prob_S1 + log_prob_S2

    def build_model(self) -> Model:
        return Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood)


model = XenonModel(
    data_S1=jnp.array([1., 2., 3.]),
    data_S2=jnp.array([1., 2., 3.]),
    e_survival_prob=jnp.asarray(0.5),
    relative_LCE=jnp.asarray(1.),
    density=jnp.asarray(2.8)
).build_model()

model.sanity_check(jax.random.PRNGKey(42), S=1000)


-364074180.0 nan
-364074180.0 nan
-364074180.0 nan
-364074180.0 nan
Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float32[1000])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0 Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float32[1000])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0


INFO[2024-03-05 14:45:54,932]: Sanity check...


5073.8296 14536.662


INFO[2024-03-05 14:45:55,980]: Found bad point: [0.00658643 0.615165   0.65633774 0.5147824  0.825768   0.98090184
 0.46996582 0.3895197  0.6139971  0.8687408  0.9155092  0.9991051
 0.1941179  0.05444241 0.93784034] -> {'signal_rate': Array(13.172865, dtype=float32), 'bg_rate': Array(1230.33, dtype=float32), 'iodine_rate': Array(1312.6755, dtype=float32), 'categorical_dist_for_E': Array(2, dtype=int32), 'E_component_distributions_0': Array(73.49829, dtype=float32), 'E_component_distributions_1': Array(64.121765, dtype=float32), 'E_component_distributions_2': Array(67.488686, dtype=float32), 'Nq_scaled': Array(1.0324882, dtype=float32), 'Ni': Array(0.85527396, dtype=float32), 'Ne': Array(14536.662, dtype=float32), 'Nph_det': Array(nan, dtype=float32), 'Nph_det_dpe': Array(nan, dtype=float32), 'Ne_det': Array(0.5063825, dtype=float32)}


6374.5996 8940.412


INFO[2024-03-05 14:45:56,999]: Found bad point: [0.9353961  0.7511554  0.5853338  0.94716334 0.85100925 0.39135242
 0.78641284 0.6188415  0.27138674 0.17762768 0.10887718 0.9982023
 0.4394921  0.9360069  0.7860557 ] -> {'signal_rate': Array(1870.7921, dtype=float32), 'bg_rate': Array(1502.3108, dtype=float32), 'iodine_rate': Array(1170.6676, dtype=float32), 'categorical_dist_for_E': Array(0, dtype=int32), 'E_component_distributions_0': Array(89.32064, dtype=float32), 'E_component_distributions_1': Array(64.49748, dtype=float32), 'E_component_distributions_2': Array(66.86774, dtype=float32), 'Nq_scaled': Array(0.98012453, dtype=float32), 'Ni': Array(0.8428392, dtype=float32), 'Ne': Array(8940.412, dtype=float32), 'Nph_det': Array(nan, dtype=float32), 'Nph_det_dpe': Array(nan, dtype=float32), 'Ne_det': Array(0.50418836, dtype=float32)}


4376.9453 7933.252


INFO[2024-03-05 14:45:57,995]: Found bad point: [0.73486686 0.01864588 0.14880204 0.3777262  0.6433649  0.6921903
 0.68583894 0.0933305  0.6851871  0.11016905 0.561115   0.9982666
 0.03954649 0.8049656  0.63736415] -> {'signal_rate': Array(1469.7338, dtype=float32), 'bg_rate': Array(37.291763, dtype=float32), 'iodine_rate': Array(297.6041, dtype=float32), 'categorical_dist_for_E': Array(1, dtype=int32), 'E_component_distributions_0': Array(84.29195, dtype=float32), 'E_component_distributions_1': Array(62.684967, dtype=float32), 'E_component_distributions_2': Array(67.63041, dtype=float32), 'Nq_scaled': Array(0.9589321, dtype=float32), 'Ni': Array(0.8492298, dtype=float32), 'Ne': Array(7933.252, dtype=float32), 'Nph_det': Array(nan, dtype=float32), 'Nph_det_dpe': Array(nan, dtype=float32), 'Ne_det': Array(0.5019986, dtype=float32)}


4896.732 6125.5044


INFO[2024-03-05 14:45:59,017]: Found bad point: [0.53153217 0.21612597 0.8912518  0.12068129 0.7840611  0.99708676
 0.5106522  0.56216013 0.4200827  0.5251063  0.43208838 0.997453
 0.47975826 0.46095204 0.43860602] -> {'signal_rate': Array(1063.0643, dtype=float32), 'bg_rate': Array(432.25192, dtype=float32), 'iodine_rate': Array(1782.5037, dtype=float32), 'categorical_dist_for_E': Array(2, dtype=int32), 'E_component_distributions_0': Array(75.53261, dtype=float32), 'E_component_distributions_1': Array(64.4006, dtype=float32), 'E_component_distributions_2': Array(67.173996, dtype=float32), 'Nq_scaled': Array(1.0011181, dtype=float32), 'Ni': Array(0.84756106, dtype=float32), 'Ne': Array(6125.5044, dtype=float32), 'Nph_det': Array(nan, dtype=float32), 'Nph_det_dpe': Array(nan, dtype=float32), 'Ne_det': Array(0.49897572, dtype=float32)}


-47406.883 nan


INFO[2024-03-05 14:46:00,147]: Found bad point: [7.0854235e-01 8.4835613e-01 6.0063398e-01 6.6411400e-01 3.2021821e-01
 4.7820663e-01 7.4464083e-01 5.2524757e-01 4.6370864e-01 4.7206879e-04
 4.2859972e-01 1.5435648e-01 6.6739655e-01 1.7577899e-01 5.0672293e-03] -> {'signal_rate': Array(1417.0847, dtype=float32), 'bg_rate': Array(1696.7123, dtype=float32), 'iodine_rate': Array(1201.268, dtype=float32), 'categorical_dist_for_E': Array(0, dtype=int32), 'E_component_distributions_0': Array(87.23204, dtype=float32), 'E_component_distributions_1': Array(64.34154, dtype=float32), 'E_component_distributions_2': Array(67.244705, dtype=float32), 'Nq_scaled': Array(-7.4635515, dtype=float32), 'Ni': Array(nan, dtype=float32), 'Ne': Array(nan, dtype=float32), 'Nph_det': Array(nan, dtype=float32), 'Nph_det_dpe': Array(nan, dtype=float32), 'Ne_det': Array(nan, dtype=float32)}


-182.88918 nan


INFO[2024-03-05 14:46:01,148]: Found bad point: [0.95632267 0.5956876  0.39256895 0.567359   0.8192885  0.08967352
 0.31124794 0.37546933 0.84953964 0.00448561 0.21437812 0.49151063
 0.0371505  0.75123656 0.11409247] -> {'signal_rate': Array(1912.6454, dtype=float32), 'bg_rate': Array(1191.3752, dtype=float32), 'iodine_rate': Array(785.1379, dtype=float32), 'categorical_dist_for_E': Array(1, dtype=int32), 'E_component_distributions_0': Array(65.56239, dtype=float32), 'E_component_distributions_1': Array(64.09654, dtype=float32), 'E_component_distributions_2': Array(68.279976, dtype=float32), 'Nq_scaled': Array(-0.03918624, dtype=float32), 'Ni': Array(nan, dtype=float32), 'Ne': Array(nan, dtype=float32), 'Nph_det': Array(nan, dtype=float32), 'Nph_det_dpe': Array(nan, dtype=float32), 'Ne_det': Array(nan, dtype=float32)}


AssertionError: 