In [None]:
import numpy as np
import ticktack
from ticktack import fitting
import jax.numpy as jnp
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (8.0, 6.0)
from tqdm import tqdm

In [None]:
model = 'Guttler14'
cbm = ticktack.load_presaved_model(model, production_rate_units = 'atoms/cm^2/s')
sf = fitting.SingleFitter(cbm, model, box="Troposphere", hemisphere="north")
sf.load_data("Sakurai20_CedarLw.csv", burnin_time=1000)
sf.compile_production_model(model="simple_sinusoid")
params = jnp.array([-660, 1./12, np.pi/2., 81./12])

In [None]:
chain = sf.MarkovChainSampler(params, sf.log_joint_simple_sinusoid, burnin=2000, production=1000, 
                           args=(jnp.array([-660-5, 1/365., -jnp.pi, 0., 0.]),
                                 jnp.array([-660+5, 5., jnp.pi, 15., 2.])
                                ))

In [None]:
labels = ["start date (yr)", "duration (yr)", "phi (yr)", "spike production (cm^2 yr/s)"]
fig, axs = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
axs = axs.flatten()
for i in range(chain.shape[1]):
    axs[i].plot(chain[:, i], 'b.', markersize=1, alpha=0.5)
    axs[i].set_title(labels[i])
    axs[i].get_xaxis().set_visible(False)

## Plot binned d14c

In [None]:
size = chain.shape[0]
d14cs_bin = np.zeros((size, sf.time_data.size))
for j in tqdm(range(size)):
    dc14 = sf.dc14(params=chain[j, :])
    d14cs_bin[j, :] = dc14

In [None]:
for d14c in tqdm(d14cs_bin):
    plt.plot(sf.time_data, d14c, 'g', alpha=0.05)

In [None]:
selected_chain = chain[(d14cs_bin[:, 9] < 5)] # select parameters that produce flat curves 
chain[(d14cs_bin[:, 19] > 9)]

In [None]:
selected_chain.shape 

## Check the likelihood 

In [None]:
i = 0
like = np.zeros((selected_chain.shape[0],))
for param in tqdm(selected_chain):
    like[i] = sf.log_likelihood(param)
    i +=1

In [None]:
like