In [84]:
# Miscellaneous
from functools import partial 
from ticktack import load_presaved_model

# Hamiltonian monte-carlo
from numpyro.infer import NUTS, MCMC

import jax.numpy as np
import jax.scipy as sc
import jax.random as random
import matplotlib.pyplot as plt
from jax import jit, jacrev, jacfwd, vmap, grad

In [85]:
cbm = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
cbm.compile()

In [86]:
def construct_model(cbm=cbm):
    template = np.where(cbm._matrix != 0, np.ones(cbm._matrix.shape), np.zeros(cbm._matrix.shape))

    def parse_parameters(parameters, transfer_matrix=template):
        index = 0
        for i, row in enumerate(transfer_matrix):
            for j, col in enumerate(row):
                if col != 0:
                    parameter = parameters[index]
                    transfer_matrix = transfer_matrix.at[i, j].set(*parameter)
                    index += 1
        return transfer_matrix
    
    return parse_parameters

In [112]:
parse_parameters = construct_model(cbm)
parameters = random.uniform(random.PRNGKey(0), (35,))

In [113]:
@jit
def analytic_solution(parameters, time_out, /, decay=cbm._decay_matrix, parser=parse_parameters):
    """
    This is the analytic solution itself.
    Parameters: 
        
    Returns:
    """
    start_date = parameters[0]
    event_area = parameters[1]
    production = parameters[2]
    production = np.array([1 - production, production, *np.zeros(9)], dtype=np.float64)
    flow_sizes = parameters[3:]

    initial_position = event_area * production
    transfer_matrix = parse_parameters(flow_sizes)

    @vmap
    def vmap_util(t, /, transfer_matrix=transfer_matrix, y0=initial_position, start=start_date):
        return sc.linalg.expm((t - start) * transfer_matrix) @ y0

    impulse_solution = vmap_util(time_out)
    steady_solution = np.zeros((impulse_solution.shape))
    condition = (time_out > start_date).reshape(-1, 1)
    
    return np.where(condition, impulse_solution, steady_solution)

In [114]:
@partial(jit, static_argnums=(0))
def load(filename: str):
    """
    A custom `JAX` file loading protocol designed to be very quick and return a value that is `JAX` transformable. 
    
    Parameters:
        filename: String -> The file address of the data
    Returns:
        DeviceArray -> The data in column major order
    """
    with open(filename) as data:    # Opening the data file
        header = next(data)
        data = np.array([row.strip().split(" ") for row in data], dtype=np.float64)
        return data.T


In [115]:
data = load("miyake12.csv")
data = data.at[1].add(-np.mean(data[1, 1:4]))

In [116]:
@partial(jit, static_argnums=(1, 2))
def loss(parameters, /, analytic_solution=analytic_solution, data=data):
    """
    Computes the log likelihood of a set of parameters in the parameter space
    """
    in_bounds = 0.0
    in_bounds = np.any((parameters[2:] < 0.0) | (parameters[2:] > 1.0)) * np.inf
    analytic_data = analytic_solution(parameters, data[0])
    chi_sq = np.sum((data[1] - analytic_data[:, 1]) ** 2 / data[2] ** 2)
    return in_bounds + chi_sq

In [118]:
parameters = np.concatenate([np.array([774.8, 40.0, 0.3], dtype=np.float64), parameters])

In [120]:
loss(parameters)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/2)>
The problem arose with the `bool` function. 
While tracing the function loss at /tmp/ipykernel_4072/1474160984.py:1 for jit, this value became a tracer due to JAX operations on these lines:

  operation a:bool[] = lt b c
    from line /tmp/ipykernel_4072/4259135557.py:9 (analytic_solution)

  operation a:i64[] = add b c
    from line /tmp/ipykernel_4072/4259135557.py:9 (analytic_solution)

  operation a:i64[] = select_n b c d
    from line /tmp/ipykernel_4072/4259135557.py:9 (analytic_solution)

  operation a:f64[11] b:f64[11] c:f64[11] d:f64[11] e:f64[11] f:f64[11] g:f64[11] h:f64[11] i:f64[11]
  j:f64[11] k:f64[11] = xla_call[
  call_jaxpr={ lambda ; l:f64[11,11]. let
      m:f64[1,11] = slice[
        limit_indices=(1, 11)
        start_indices=(0, 0)
        strides=(1, 1)
      ] l
      n:f64[11] = squeeze[dimensions=(0,)] m
      o:f64[1,11] = slice[
        limit_indices=(2, 11)
        start_indices=(1, 0)
        strides=(1, 1)
      ] l
      p:f64[11] = squeeze[dimensions=(0,)] o
      q:f64[1,11] = slice[
        limit_indices=(3, 11)
        start_indices=(2, 0)
        strides=(1, 1)
      ] l
      r:f64[11] = squeeze[dimensions=(0,)] q
      s:f64[1,11] = slice[
        limit_indices=(4, 11)
        start_indices=(3, 0)
        strides=(1, 1)
      ] l
      t:f64[11] = squeeze[dimensions=(0,)] s
      u:f64[1,11] = slice[
        limit_indices=(5, 11)
        start_indices=(4, 0)
        strides=(1, 1)
      ] l
      v:f64[11] = squeeze[dimensions=(0,)] u
      w:f64[1,11] = slice[
        limit_indices=(6, 11)
        start_indices=(5, 0)
        strides=(1, 1)
      ] l
      x:f64[11] = squeeze[dimensions=(0,)] w
      y:f64[1,11] = slice[
        limit_indices=(7, 11)
        start_indices=(6, 0)
        strides=(1, 1)
      ] l
      z:f64[11] = squeeze[dimensions=(0,)] y
      ba:f64[1,11] = slice[
        limit_indices=(8, 11)
        start_indices=(7, 0)
        strides=(1, 1)
      ] l
      bb:f64[11] = squeeze[dimensions=(0,)] ba
      bc:f64[1,11] = slice[
        limit_indices=(9, 11)
        start_indices=(8, 0)
        strides=(1, 1)
      ] l
      bd:f64[11] = squeeze[dimensions=(0,)] bc
      be:f64[1,11] = slice[
        limit_indices=(10, 11)
        start_indices=(9, 0)
        strides=(1, 1)
      ] l
      bf:f64[11] = squeeze[dimensions=(0,)] be
      bg:f64[1,11] = slice[
        limit_indices=(11, 11)
        start_indices=(10, 0)
        strides=(1, 1)
      ] l
      bh:f64[11] = squeeze[dimensions=(0,)] bg
    in (n, p, r, t, v, x, z, bb, bd, bf, bh) }
  name=_unstack
] bi
    from line /tmp/ipykernel_4072/929770895.py:6 (parse_parameters)

  operation a:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b
    from line /tmp/ipykernel_4072/929770895.py:8 (parse_parameters)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [58]:
gradient = jit(grad(loss))
hessian = jit(jacfwd(jacrev(loss)))

In [59]:
# Running the No U Turn sampling
nuts_kernel = NUTS(potential_fn=loss)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=500, progress_bar=True)
mcmc.run(random.PRNGKey(11), init_params=parameters)

sample: 100%|██████████| 600/600 [16:43<00:00,  1.67s/it, 982 steps of size 2.95e-13. acc. prob=1.00] 


In [60]:
test = mcmc.get_samples()

In [61]:
import seaborn as sns
sns.set()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 20))
for index, variable in enumerate(test.T):
    axis = axes[index // 4][index % 4]
    sns.kdeplot(test.T[index], ax=axis)