In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.pyplot import rcParams
import warnings
warnings.filterwarnings('ignore')

import re
import jax.numpy as jnp
from jax import grad, jit, partial
import ticktack
# from ticktack import fitting
from ticktack import fitting_tinygp as fitting
from astropy.table import Table
from tqdm import tqdm
import scipy
import jax

rcParams['figure.figsize'] = (10.0, 5.0)

In [2]:
cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
cf.load_data('400_BCE_Data_processed.csv', time_oversample=50)
cf.prepare_function(use_control_points=True, interp="gp")

In [3]:
# params = np.array([1.6544328 , 1.90452021, 2.02977887, 1.99072169, 1.82493273,
#        1.61591751, 1.48993472, 1.470403  , 1.52646415, 1.61680469,
#        1.69502108, 1.807466  , 2.03711738, 2.26791059, 2.23266949,
#        2.05979814, 2.09392307, 2.17064044, 1.98973673, 1.91675929,
#        2.2641053 , 2.39127884, 2.10741525, 1.88661205, 1.90116188,
#        1.87844174, 1.87194416, 1.93044831, 2.13083373, 2.41960777,
#        2.26026764, 1.97124401, 2.0972239 , 2.35833611, 2.46362548,
#        2.41907182, 2.24627259, 2.02711906, 1.9076207 , 1.90912451,
#        1.98666302, 2.08556198, 2.13841073, 2.13013591, 2.07467497,
#        2.00286285, 1.95621838, 1.93731388, 1.9328552 , 1.94929622,
#        2.01912518, 2.17674413, 2.35208921, 2.4086832 , 2.37435805,
#        2.36878007, 2.39915424, 2.39977033, 2.30949079, 2.09497775,
#        1.91648161, 2.04587247, 2.36567808, 2.52043957, 2.47333743,
#        2.36461715, 2.22235458, 2.02314915, 1.83516595, 1.79931475,
#        1.91979754, 2.1183396 , 2.33886932, 2.51699745, 2.59543204,
#        2.50430744, 2.24856097, 1.91025668, 1.66154043, 1.73657921,
#        2.01041907, 2.14568526, 2.15795714, 2.2540714 , 2.38722364,
#        2.35090722, 2.15511825, 1.95806696, 1.85671087, 1.91504819,
#        2.04967867, 2.09645538, 2.10139398, 2.21140801, 2.30824818,
#        2.12406349, 1.77889989, 1.64043906, 1.7195221 , 1.81553967,
#        1.86562715, 1.88559886, 1.89248243])
steady_state = cf.steady_state_production * jnp.ones((len(cf.control_points_time)+1,))
params = steady_state
# bounds = tuple([(0, None)] * len(params))

In [4]:
%%time
cf.grad_gp_likelihood(params).block_until_ready()

CPU times: user 37.3 s, sys: 58.4 s, total: 1min 35s
Wall time: 19.1 s


DeviceArray([-9.77456012e+01, -5.38679694e+00, -7.13010645e+01,
             -3.05135033e+01, -5.67498014e+01, -4.12510658e+01,
             -5.16314260e+01, -4.68397206e+01, -5.15517314e+01,
             -5.09225007e+01, -5.35204949e+01, -5.41590674e+01,
             -5.63853320e+01, -5.84434909e+01, -6.05062064e+01,
             -6.16307311e+01, -6.30460794e+01, -6.41418131e+01,
             -6.45016294e+01, -6.39857597e+01, -6.54034347e+01,
             -6.57248284e+01, -6.54052294e+01, -6.47888033e+01,
             -6.42547376e+01, -6.24851178e+01, -6.22723142e+01,
             -6.21901630e+01, -6.16561032e+01, -6.11288120e+01,
             -5.83376140e+01, -5.65918399e+01, -5.56497535e+01,
             -5.52192782e+01, -5.63734392e+01, -5.77194648e+01,
             -5.78601735e+01, -5.62863353e+01, -5.73160246e+01,
             -5.87836371e+01, -5.93541029e+01, -5.87335278e+01,
             -5.99655544e+01, -6.15089579e+01, -6.20640957e+01,
             -6.11745742e+01, -6.2772038

In [None]:
from jax.config import config

config.update("jax_enable_x64", True)

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import celerite2.jax
from celerite2.jax import terms as jax_terms


nuts_kernel = NUTS(potential_fn=cf.gp_likelihood, dense_mass=True)
mcmc = MCMC(nuts_kernel, num_warmup=50, num_samples=100, num_chains=1)
rng_key = random.PRNGKey(12)
%time mcmc.run(rng_key, init_params=params)

warmup:   5%|█▎                          | 7/150 [11:19<7:02:18, 177.19s/it, 127 steps of size 2.98e-03. acc. prob=0.43]

In [None]:
mcmc.get_samples()

In [None]:
params2 = np.array([1.6544328 , 1.90452021, 2.02977887, 1.99072169, 1.82493273,
              1.61591751, 1.48993472, 1.470403  , 1.52646415, 1.61680469,
              1.69502108, 1.807466  , 2.03711738, 2.26791059, 2.23266949,
              2.05979814, 2.09392307, 2.17064044, 1.98973673, 1.91675929,
              2.2641053 , 2.39127884, 2.10741525, 1.88661205, 1.90116188,
              1.87844174, 1.87194416, 1.93044831, 2.13083373, 2.41960777,
              2.26026764, 1.97124401, 2.0972239 , 2.35833611, 2.46362548,
              2.41907182, 2.24627259, 2.02711906, 1.9076207 , 1.90912451,
              1.98666302, 2.08556198, 2.13841073, 2.13013591, 2.07467497,
              2.00286285, 1.95621838, 1.93731388, 1.9328552 , 1.94929622,
              2.01912518, 2.17674413, 2.35208921, 2.4086832 , 2.37435805,
              2.36878007, 2.39915424, 2.39977033, 2.30949079, 2.09497775,
              1.91648161, 2.04587247, 2.36567808, 2.52043957, 2.47333743,
              2.36461715, 2.22235458, 2.02314915, 1.83516595, 1.79931475,
              1.91979754, 2.1183396 , 2.33886932, 2.51699745, 2.59543204,
              2.50430744, 2.24856097, 1.91025668, 1.66154043, 1.73657921,
              2.01041907, 2.14568526, 2.15795714, 2.2540714 , 2.38722364,
              2.35090722, 2.15511825, 1.95806696, 1.85671087, 1.91504819,
              2.04967867, 2.09645538, 2.10139398, 2.21140801, 2.30824818,
              2.12406349, 1.77889989, 1.64043906, 1.7195221 , 1.81553967,
              1.86562715, 1.88559886, 1.89248243])

In [None]:
help(mcmc)