In [165]:
# Miscellaneous
from functools import partial 
from ticktack import load_presaved_model

# Hamiltonian monte-carlo
from numpyro.infer import NUTS, MCMC

import jax.numpy as np
import jax.scipy as sc
import jax.random as random
import matplotlib.pyplot as plt
from jax import jit, jacrev, jacfwd, vmap, grad

In [166]:
cbm = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
cbm.compile()

In [178]:
def construct_model(cbm=cbm):
    template = np.where(cbm._matrix != 0, np.ones(cbm._matrix.shape), np.zeros(cbm._matrix.shape))
    indexes = np.where(cbm._matrix != 0)
    indexes = np.array(indexes).T

    @jit
    def parse_parameters(parameters, template=template, indexes=indexes):
        index = 0
        for i, j in indexes:
            template = template.at[i, j].set(parameters[index])
            index += 1
        return template
    
    return parse_parameters

In [185]:
parse_parameters = construct_model(cbm)
parameters = random.uniform(random.PRNGKey(1), (35,))

In [186]:
parse_parameters(parameters)

DeviceArray([[0.85644567, 0.47139387, 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        ],
             [0.91312016, 0.22212906, 0.08435855, 0.        , 0.        ,
              0.00102618, 0.        , 0.61726359, 0.98502445, 0.78420864,
              0.58387788],
             [0.        , 0.14211087, 0.51747562, 0.18148646, 0.07293947,
              0.        , 0.        , 0.11105859, 0.        , 0.        ,
              0.        ],
             [0.        , 0.        , 0.81487791, 0.59780625, 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        ],
             [0.        , 0.        , 0.42537515, 0.00488481, 0.01466132,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        ],
             [0.        , 0.08158744, 0.        , 0.        , 0.        ,
              0.74505568, 0.        , 0.        , 0

In [170]:
@partial(jit, static_argnums=(0))
def load(filename: str):
    """
    A custom `JAX` file loading protocol designed to be very quick and return a value that is `JAX` transformable. 
    
    Parameters:
        filename: String -> The file address of the data
    Returns:
        DeviceArray -> The data in column major order
    """
    with open(filename) as data:    # Opening the data file
        header = next(data)
        data = np.array([row.strip().split(" ") for row in data], dtype=np.float64)
        return data.T


In [171]:
data = load("miyake12.csv")
data = data.at[1].add(-np.mean(data[1, 1:4]))

In [172]:
@partial(jit, static_argnums=(1, 2))
def loss(parameters, /, analytic_solution=analytic_solution, data=data):
    """
    Computes the log likelihood of a set of parameters in the parameter space
    """
    in_bounds = 0.0
    in_bounds = np.any((parameters[2:] < 0.0) | (parameters[2:] > 1.0)) * np.inf
    analytic_data = analytic_solution(parameters, data[0])
    chi_sq = np.sum((data[1] - analytic_data[:, 1]) ** 2 / data[2] ** 2)
    return in_bounds + chi_sq

In [173]:
parameters = np.concatenate([np.array([774.8, 40.0, 0.3], dtype=np.float64), parameters])

In [177]:
%%timeit
loss(parameters)

862 µs ± 108 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [58]:
gradient = jit(grad(loss))
hessian = jit(jacfwd(jacrev(loss)))

In [59]:
# Running the No U Turn sampling
nuts_kernel = NUTS(potential_fn=loss)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=500, progress_bar=True)
mcmc.run(random.PRNGKey(11), init_params=parameters)

sample: 100%|██████████| 600/600 [16:43<00:00,  1.67s/it, 982 steps of size 2.95e-13. acc. prob=1.00] 


In [60]:
test = mcmc.get_samples()

In [61]:
import seaborn as sns
sns.set()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 20))
for index, variable in enumerate(test.T):
    axis = axes[index // 4][index % 4]
    sns.kdeplot(test.T[index], ax=axis)