In [6]:
#Import required packages
import jax                                                #numpy on CPU/GPU/TPU
import sys                                                #allows for command line arguments to be run with the script
import arviz as az                                        #for saving Bayesian Data
from funcs_ub_globalfit_BL import *                          #physical recombination models (self written)
import jax.numpy as jnp                                   #jnp regularly called (ease of use)
from jax.random import PRNGKey                            #pseudo-random number generator (ease of use)

import numpyro                                            #Bayesian inference package
from numpyro.infer import MCMC, NUTS, Predictive          #MCMC with NUTS to make it Hamiltonian MC
from numpyro.distributions import TruncatedNormal, Normal #To define prior distributions
import pickle                                             #For loading data from .py files

In [7]:
#Load Data
with open('MAPI_230424_Sorted/iCCD_MAPI__150uW_FILM_iCCD.py', "rb") as file:
        data = pickle.load(file)

time = data[0]-data[0][0] 
signal = data[1]/data[1][0]
s_signal = standardise(signal)[0]
print(time.shape,s_signal.shape)

(65,) (65,)


In [16]:
num_chains = 10                       #number of chains to run (number of cores to use) - might need to run this on linux
numpyro.set_host_device_count(12)     #number of cores on computer

#Set all floats to 64 bit
jax.config.update("jax_enable_x64", True) #Needed for precision in jax - never touch
#warm up the JIT
TRPL_HERTZ(time,jnp.array([0, 1e-15, 1e-18, 1e-3, 0.0, 1e14, 0.0]), 10.0)              #These are from the funcs_ub_globalfit file - these numbers dont mean anything, 
standardise(TRPL_HERTZ(time,jnp.array([0.0, 1e-15, 1e-18, 1e-3, 0.0, 1e14, 0.0]), 10.0)) #it is literally to get the JIT warmed up with numbers

1


(Array([ 3.69266607,  3.3340325 ,  2.97503575,  2.62212288,  2.27560069,
         1.92731774,  1.59310518,  1.28355212,  0.98205252,  0.69760128,
         0.45006741,  0.24579386,  0.0678273 , -0.07882191, -0.18846719,
        -0.22061881, -0.2654415 , -0.32267499, -0.36578   , -0.39544284,
        -0.41454448, -0.41971455, -0.42674589, -0.43580708, -0.44243335,
        -0.44688669, -0.44965183, -0.45035571, -0.45414715, -0.45464385,
        -0.45468709, -0.45472939, -0.45471321, -0.45469231, -0.45469806,
        -0.45470042, -0.45470741, -0.45471102, -0.45471192, -0.45471122,
        -0.45471079, -0.45469979, -0.45470085, -0.45470437, -0.45470598,
        -0.45470609, -0.45470514, -0.45470355, -0.4547031 , -0.45470173,
        -0.4547001 , -0.45469978, -0.4546991 , -0.4546991 , -0.45469919,
        -0.45469975, -0.45470006, -0.45470017, -0.45470013, -0.45469997,
        -0.45469974, -0.45469949, -0.45469927, -0.45469911, -0.45469906],      dtype=float64),
 Array([-1.7845763], dtype=fl

In [10]:
#Bayesian model
def model(dev, ydata = None):
    """
    
    Bayesian model for the BTDP model.

    Parameters
    ----------
    dev: float
        stdev chosen to be less than the bounds of the truncated guassians of priors (fac and theta, defined below)
    
    y0: float
        Initial counts in the TRPL signal.

    ydata: array
        standardised log10 of the experimental TRPL signal.

    """

    std_dev = dev

    N0 = 1.75

    fac = numpyro.sample(                                 # Creates a truncated normal that we use to create a distribution for the N0s later
        "fac",
        TruncatedNormal(
            low   = jnp.array([0.950, 1.950, 2.950]),
            high  = jnp.array([1.050, 2.050, 3.150]),
            loc   = jnp.array([1.000, 2.000, 3.000]),
            scale = jnp.array([0.001, 0.001, 0.001]),
        ),
    ) 

    theta = numpyro.sample(                            #Truncated normal distributions of priors in special units. See NOTE for detailed conversion,
        "theta",                                       #but 1 units = 1e15 cm-3. Consider each parameters conversion with this in mind (Auger = cm6s-1)
        TruncatedNormal(
            low   = jnp.array([-7.00, -2.50, -3.95, -3.00, -4.00, N0 - fac[2] * jnp.log10(4.0)]),
            high  = jnp.array([-4.90,  0.00, -2.00, -1.00, -1.00, N0 - fac[0] * jnp.log10(4.0)]),
            loc   = jnp.array([-5.90, -1.20, -2.30, -2.00, -2.70,  1.00]),
            scale = jnp.array([std_dev, std_dev, std_dev, std_dev, std_dev, std_dev]),
        ),
    ) #in log form, not physically understandable - these units = 1e15 cm-3

    std_dev1 = 0.1
    noise = numpyro.sample(
        "noise",
        TruncatedNormal(
            low   = jnp.array([0.01, 0.01, 0.01, 0.01]),
            high  = jnp.array([0.50, 0.50, 0.50, 0.50]),
            loc   = jnp.array([0.10, 0.10, 0.10, 0.10]),
            scale = jnp.array([std_dev1, std_dev1, std_dev1, std_dev1]),
        ),
    )

    ka    = theta[0]
    kt    = theta[1]
    kb    = theta[2]
    kdt   = theta[3]
    kdp   = theta[4]
    NT    = theta[5]

    #Calculate the TRPL signal and standardise
    #signal0 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]),  10**(N0 - fac[5] * jnp.log10(4.0)))
    #signal1 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]),  10**(N0 - fac[4] * jnp.log10(4.0)))
    #signal2 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]),  10**(N0 - fac[3] * jnp.log10(4.0)))
    #signal3 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]), 10**(N0 - fac[2] * jnp.log10(4.0)))
    signal4 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]), 10**(N0 - fac[1] * jnp.log10(10.0)))
    signal5 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]), 10**(N0 - fac[0] * jnp.log10(10.0)))
    #signal4 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 0.0, 10**kb, 0.0, 0.0, 0.0, 0.0]), 10**(N0 - fac[1] * jnp.log10(4.0)))
    #signal5 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 0.0, 10**kb, 0.0, 0.0, 0.0, 0.0]), 10**(N0 - fac[0] * jnp.log10(4.0)))
    signal6 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 0.0, 10**kb, 0.0, 0.0, 0.0, 0.0]), 10**(N0))
    #signal6 = TRPL_HERTZ_HIGH(jnp.array([10**ka, 10**kt, 10**kb, 10**kdt, 10**kdp, 10**NT, 0.0]), 10**(N0))
    
    signal = jnp.stack([signal4, signal5, signal6])
    signal_s = (signal - means)/stds

    #Define the likelihood
    numpyro.sample("ydata", Normal(signal_s, noise[:, None]), obs=ydata)

In [12]:
#Setting up the MCMC
num_warmup, num_samples = 5000, 10000 #number of steps you want to do (2x samples than warmup)

rndint      = int(4)         #Seeds for picking of starting points in priors
rndint1     = int(3)         #Seeds picking from generated priors
rndint2     = int(6)         #Picking from posteriors
accept_prob = float(0.85)
std_dev     = float(0.1)

key1 = PRNGKey(rndint)            #Generating the random numbers from numbers above
key2 = PRNGKey(rndint1)
key3 = PRNGKey(rndint2)

In [13]:
#Define the MCMC - see Barney's notes for an explanation of each argument
mcmc = MCMC(
        NUTS(model, adapt_step_size=True, max_tree_depth=6, find_heuristic_step_size=False, dense_mass=True,target_accept_prob=accept_prob),
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        chain_method="parallel",
        progress_bar=True,
    )

  mcmc = MCMC(
