In [261]:
# Miscellaneous
from functools import partial 
from ticktack import load_presaved_model

# Hamiltonian monte-carlo
from numpyro.infer import NUTS, MCMC

import jax.numpy as np
import jax.scipy as sc
import jax.random as random
import jax.lax as lax
import matplotlib.pyplot as plt
from jax import jit, jacrev, jacfwd, vmap, grad

In [262]:
cbm = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
cbm.compile()

In [263]:
def construct_model(cbm=cbm):
    template = np.where(cbm._matrix != 0, np.ones(cbm._matrix.shape), np.zeros(cbm._matrix.shape))
    indexes = np.where(cbm._matrix != 0)
    indexes = np.array(indexes).T

    @jit
    def scan_fun(template, parameter_and_index):
        param, [i, j] = parameter_and_index
        template = template.at[i, j].set(*param)
        return template, template

    @jit
    def parse_parameters(parameters, indexes=indexes, template=template, scan_fun=scan_fun):
        params_and_indexes = [parameters, indexes]
        model, _ = lax.scan(scan_fun, template, params_and_indexes)
        return model

    return parse_parameters

In [265]:
@partial(jit, static_argnums=(0))
def load(filename: str):
    """
    A custom `JAX` file loading protocol designed to be very quick and return a value that is `JAX` transformable. 
    
    Parameters:
        filename: String -> The file address of the data
    Returns:
        DeviceArray -> The data in column major order
    """
    with open(filename) as data:    # Opening the data file
        header = next(data)
        data = np.array([row.strip().split(" ") for row in data], dtype=np.float64)
        return data.T


In [None]:
@partial(jit, static_argnums=(1, 2))
def loss(parameters, /, analytic_solution=None, data=None):
    """
    Computes the log likelihood of a set of parameters in the parameter space
    """
    # in_bounds = 0.0
    # in_bounds = np.any((parameters[2:] < 0.0) | (parameters[2:] > 1.0)) * np.inf
    analytic_data = analytic_solution(parameters, data[0])
    chi_sq = np.sum((data[1] - analytic_data[:, 1]) ** 2 / data[2] ** 2)
    return chi_sq

In [264]:
parse_parameters = construct_model(cbm)
parameters = random.uniform(random.PRNGKey(1), (35, 1))

In [266]:
data = load("miyake12.csv")
data = data.at[1].add(-np.mean(data[1, 1:4]))

In [270]:
param_init = np.array([774.8, 40.0, 0.3], dtype=np.float64).reshape(3, 1)
parameters = np.concatenate([param_init, parameters])

In [272]:
loss(parameters)

ValueError: All input arrays must have the same shape.

In [58]:
gradient = jit(grad(loss))
hessian = jit(jacfwd(jacrev(loss)))

In [59]:
# Running the No U Turn sampling
nuts_kernel = NUTS(potential_fn=loss)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=500, progress_bar=True)
mcmc.run(random.PRNGKey(11), init_params=parameters)

sample: 100%|██████████| 600/600 [16:43<00:00,  1.67s/it, 982 steps of size 2.95e-13. acc. prob=1.00] 


In [60]:
test = mcmc.get_samples()

In [61]:
import seaborn as sns
sns.set()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 20))
for index, variable in enumerate(test.T):
    axis = axes[index // 4][index % 4]
    sns.kdeplot(test.T[index], ax=axis)