In [2]:
# Miscellaneous
from functools import partial 
from ticktack import load_presaved_model

# Hamiltonian monte-carlo
from numpyro.infer import NUTS, MCMC

import jax.numpy as np
import jax.scipy as sc
import jax.random as random
import matplotlib.pyplot as plt
from jax import jit, jacrev, jacfwd, vmap, grad

In [3]:
cbm = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
cbm.compile()

In [17]:
def construct_model(cbm=cbm):
    template = np.where(cbm._matrix != 0, np.ones(cbm._matrix.shape), np.zeros(cbm._matrix.shape))
    
    def parse_parameters(parameters, transfer_matrix=template):
        parameters = iter(parameters)
        for i, row in enumerate(transfer_matrix):
            for j, col in enumerate(row):
                if col != 0:
                    transfer_matrix = transfer_matrix.at[i, j].set(next(parameters))
        return transfer_matrix
    
    return parse_parameters

In [12]:
parse_parameters = construct_model(cbm)

In [14]:
template = np.where(cbm._matrix != 0, np.ones(cbm._matrix.shape), np.zeros(cbm._matrix.shape))

In [19]:
parameters = random.uniform(random.PRNGKey(0), (35, 1))

In [25]:
parameters = iter(parameters)
for i, row in enumerate(template):
    for j, col in enumerate(row):
        if col != 0:
            parameter = next(parameters)
            template = template.at[i, j].set(*parameter)
            

[0.32422284] 0 0


AssertionError: length mismatch: [1, 0]

In [13]:
parse_parameters()

AssertionError: length mismatch: [1, 0]

In [15]:
@jit
def analytic_solution(parameters, time_out, /, fluxes=cbm._fluxes, 
    contents=cbm._reservoir_content, decay=cbm._decay_matrix):
    """
    This is the analytic solution itself.
    Parameters: 
        
    Returns:
    """


    @vmap
    def vmap_util(t, /, transfer_matrix=transfer_matrix, y0=initial_position, start=start_date):
        return sc.linalg.expm((t - start) * transfer_matrix) @ y0

    impulse_solution = vmap_util(time_out)
    steady_solution = np.zeros((impulse_solution.shape))
    condition = (time_out > start_date).reshape(-1, 1)
    
    return np.where(condition, impulse_solution, steady_solution)

In [16]:
@partial(jit, static_argnums=(0))
def load(filename: str):
    """
    A custom `JAX` file loading protocol designed to be very quick and return a value that is `JAX` transformable. 
    
    Parameters:
        filename: String -> The file address of the data
    Returns:
        DeviceArray -> The data in column major order
    """
    with open(filename) as data:    # Opening the data file
        header = next(data)
        data = np.array([row.strip().split(" ") for row in data], dtype=np.float64)
        return data.T


In [17]:
data = load("miyake12.csv")
data = data.at[1].add(-np.mean(data[1, 1:4]))

In [56]:
@partial(jit, static_argnums=(1, 2))
def loss(parameters, /, analytic_solution=analytic_solution, data=data):
    """
    Computes the log likelihood of a set of parameters in the parameter space
    """
    in_bounds = 0.0
    in_bounds = np.any((parameters[2:] < 0.0) | (parameters[2:] > 1.0)) * np.inf
    analytic_data = analytic_solution(parameters, data[0])
    chi_sq = np.sum((data[1] - analytic_data[:, 1]) ** 2 / data[2] ** 2)
    return in_bounds + chi_sq

In [57]:
parameters = np.array([774.86, 40.0, 0.3, *[1.0 for _ in range(10)]])

In [58]:
gradient = jit(grad(loss))
hessian = jit(jacfwd(jacrev(loss)))

In [59]:
# Running the No U Turn sampling
nuts_kernel = NUTS(potential_fn=loss)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=500, progress_bar=True)
mcmc.run(random.PRNGKey(11), init_params=parameters)

sample: 100%|██████████| 600/600 [16:43<00:00,  1.67s/it, 982 steps of size 2.95e-13. acc. prob=1.00] 


In [60]:
test = mcmc.get_samples()

In [61]:
import seaborn as sns
sns.set()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 20))
for index, variable in enumerate(test.T):
    axis = axes[index // 4][index % 4]
    sns.kdeplot(test.T[index], ax=axis)