In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.pyplot import rcParams
import warnings
warnings.filterwarnings('ignore')

import re
import jax.numpy as jnp
from jax import grad, jit, partial
import jax
import ticktack
from ticktack import fitting
from tqdm import tqdm

import blackjax.hmc as hmc
import blackjax.nuts as nuts
import blackjax.stan_warmup as stan_warmup
import jax.scipy.stats as stats
import blackjax.nuts as nuts
rcParams['figure.figsize'] = (10.0, 5.0)

In [2]:
cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
cf.load_data('400_BCE_Data_processed.csv', time_oversample=50)
cf.prepare_function(use_control_points=True, gap_years=5)



In [3]:
params = jnp.array([1.89636017, 1.44580507, 1.8811576 , 2.1507674 , 1.80276473,
       2.23406508, 1.87885852, 1.92915198, 2.21036406, 2.06033825,
       2.44716087, 2.20314291, 2.14255201, 2.20198795, 2.06176215,
       2.12256311, 2.0582041 , 1.63628971])

In [4]:
%%time 
cf.grad_log_like(params=params)

ValueError: xp and fp must be one-dimensional arrays of equal size

In [None]:
potential = lambda x: cf.log_prob(**x)
kernel_generator = lambda step_size, inverse_mass_matrix: hmc.kernel(
    potential, step_size, inverse_mass_matrix, 30
)

In [None]:
initial_position = {"params": params}
state = nuts.new_state(initial_position, potential)

In [None]:
rng_key = jax.random.PRNGKey(0)
final_state, (step_size, inverse_mass_matrix), info = stan_warmup.run(
    rng_key,
    kernel_generator,
    state,
    1000,
)

In [None]:
kernel = nuts.kernel(potential, step_size, inverse_mass_matrix)
kernel = jit(kernel)

In [None]:
rng_key = jax.random.PRNGKey(0)
for _ in tqdm(range(1000)):
    _, rng_key = jax.random.split(rng_key)
    state, _ = kernel(rng_key, state)

In [None]:
state.position["params"]

In [None]:
t = jnp.append(jnp.array([cf.start-1]), cf.time_data)
plt.plot(t, state.position["params"], ".")
plt.plot(t, params)

In [None]:
cf.log_like(params=state.position["params"])

In [None]:
cf.log_like(params=params)

In [None]:
plt.plot(cf.time_data[:-1], cf.dc14(state.position["params"]), ".")
plt.plot(cf.time_data, cf.d14c_data, ".")

In [None]:
raise TypeError

In [None]:
second_position = {"params": state.position["params"]}
state = nuts.new_state(second_position, potential)

In [None]:
rng_key = jax.random.PRNGKey(0)
for _ in tqdm(range(1000)):
    _, rng_key = jax.random.split(rng_key)
    state, _ = kernel(rng_key, state)

In [None]:
t = jnp.append(jnp.array([cf.start-1]), cf.time_data)
plt.plot(t, state.position["params"], ".")
plt.plot(t, params)

In [None]:
cf.log_like(params=state.position["params"])