In [1]:
from jaxns.nested_sampling import NestedSampler
from jaxns.prior_transforms import PriorChain, UniformPrior
from jaxns.plotting import plot_cornerplot, plot_diagnostics
from jaxns.utils import summary

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import rcParams
import warnings
warnings.filterwarnings('ignore')

import re
import jax.numpy as jnp
from jax import grad, jit, partial, vmap, random
import jax
import ticktack
from ticktack import fitting
from tqdm import tqdm

import corner
rcParams['figure.figsize'] = (10.0, 5.0)

In [3]:
cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
cf.load_data('inject_recovery_sine.csv', time_oversample=50)
cf.prepare_function(use_control_points=True, interp="gp")

INFO[2021-08-23 16:18:33,912]: Starting the local TPU driver.
INFO[2021-08-23 16:18:33,913]: Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO[2021-08-23 16:18:33,913]: Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO[2021-08-23 16:18:33,913]: Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [4]:
params = jnp.array([1.73481592, 1.57313381, 1.51460236, 1.54205597, 1.69296061,
       1.92638044, 2.02390008, 2.02901486, 2.00823464, 1.96971271,
       1.94535871, 1.90460333, 1.75737994, 1.53488088, 1.49278517,
       1.6187574 , 1.80735385, 1.99521077, 2.1339268 , 2.18531979,
       2.13359721, 1.99344069, 1.80863156, 1.6380658 , 1.53642917,
       1.53598935, 1.63580367, 1.80392992, 1.98935698, 2.13788857,
       2.20402716, 2.14735618, 1.94022293, 1.69127506, 1.60565663,
       1.61456084, 1.68773984, 1.80642376, 1.85882893, 1.88320408,
       1.95647752, 2.06277574, 2.10620188, 2.08541465, 1.9720432 ,
       1.79474694, 1.71987311, 1.71031913, 1.70771597, 1.7151711 ,
       1.76149483, 1.93698132, 2.18883256, 2.19172478, 2.00511724,
       1.81703802, 1.6195444 , 1.53343735, 1.54617536, 1.64473931,
       1.81939341, 1.98520243, 2.13971682, 2.18423755, 2.13067923,
       1.99959428, 1.80333424, 1.65078193, 1.54131796, 1.54586649,
       1.6504936 , 1.80679469, 2.00354288, 2.1292661 , 2.1889818 ,
       2.13753896, 1.99097801, 1.82271049, 1.64132326, 1.55295448,
       1.54857706, 1.64330932, 1.82059429, 1.98683634, 2.13881441,
       2.18482024, 2.13284557, 2.00154459, 1.80802085, 1.65596087,
       1.54559262, 1.54976504, 1.65201871, 1.80601477, 2.00401875,
       2.13558226, 2.19640488, 2.11421172, 1.94221986, 1.87415851,
       1.88410201])

In [5]:
default_params = tuple(params)
bounds = {str(i):(0, 3) for i in np.arange(len(params))}

In [6]:
ndim = len(default_params)
low = jnp.array([bounds[key][0] for key in bounds.keys()])
high = jnp.array([bounds[key][1] for key in bounds.keys()])

In [7]:
prior_chain = PriorChain() \
    .push(UniformPrior('params', low=low, high=high))

In [8]:
%%time
params = vmap(lambda key: prior_chain(prior_chain.compactify_U(prior_chain.sample_U(key))))(
    random.split(random.PRNGKey(0), 100))

NameError: name 'random' is not defined

In [9]:
params["params"].shape

TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[tuple(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.

In [None]:
%%time
ns = NestedSampler(cf.gp_likelihood, prior_chain, num_live_points=3*prior_chain.U_ndims, sampler_name='multi_ellipsoid')

In [None]:
%%time
results = jit(ns)(key=random.PRNGKey(42), termination_frac=0.5)

In [None]:
summary(results)
plot_diagnostics(results)

In [None]:
results.samples['params'][:results.num_samples, :]

In [None]:
arr = results.samples['params'][:results.num_samples, :]

In [None]:
jnp.mean(arr, axis=0)