In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import rcParams
import warnings
warnings.filterwarnings('ignore')

import re
import jax.numpy as jnp
from jax import grad, jit, partial
import ticktack
from ticktack import fitting3 as fitting
from tqdm import tqdm
import scipy

rcParams['figure.figsize'] = (8.0, 4.0)

In [2]:
cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
cf.load_data('inject_recovery_gp.csv')
cf.prepare_function(use_control_points=True, interp='gp')

In [3]:
@jit
def gauss(t):
    prod =  1.87 + 0.2 * 1.87 * jnp.exp(-1/2*((t-250)/10)**2.)
    return prod

In [4]:
cf.offset

DeviceArray(-0.44436304, dtype=float64)

In [5]:
cf.offset = 0
cf.offset

0

In [6]:
import inspect
print(inspect.getsource(cf.interp_gp))

    @partial(jit, static_argnums=(0,))
    def interp_gp(self, tval, *args):
        tval = tval.reshape(-1)
        params = jnp.squeeze(jnp.array(list(args)))
        control_points = params

        kernel = jax_terms.Matern32Term(sigma=2.0, rho=2.)
        gp = celerite2.jax.GaussianProcess(kernel, mean=1.87)
        gp.compute(self.control_points_time)
        mu = gp.predict(control_points, t=tval, return_var=False)
        mu = (tval > self.start) * mu + (tval <= self.start) * 1.87
        return mu



In [7]:
%%time
soln = cf.fit_cp(low_bound=0.)

KeyboardInterrupt: 

In [8]:
soln.message

NameError: name 'soln' is not defined

In [None]:
print("total likelihood: ", cf.gp_likelihood(soln.x))
print("gp likelihood: ", cf.gp_neg_log_likelihood(soln.x))
print("chi2: ", cf.loss_chi2(soln.x))

In [None]:
t = cf.control_points_time
true_cp = gauss(t)
plt.plot(t, gauss(t), ".")
plt.title('gauss production rate')

In [None]:
plt.plot(t, soln.x, ".", label="recovered")
plt.plot(t, true_cp, label='true')
plt.title("control points");
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True)

In [None]:
plt.plot(cf.time_data[:-1], cf.dc14(soln.x), ".k", label="recovered", markersize=10)
plt.plot(cf.time_data, cf.d14c_data, '--r')
plt.plot(cf.time_data, cf.d14c_data, 'or', fillstyle="none", label="true", alpha=0.5)
plt.title("d14c");
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True)

In [None]:
t = np.arange(cf.start-10, cf.start+10, 0.1)

In [None]:
mu = cf.production(t, (soln.x,))
plt.plot(t, mu, ".", markersize=2)
plt.plot(t, gauss(t), ".", markersize=2)

In [None]:
params = soln.x

In [None]:
gauss(200)

In [None]:
control_points

In [None]:
def diagnostic(t):
    cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
    cf = fitting.CarbonFitter(cbm)
    cf.load_data('inject_recovery_gp.csv')
    cf.prepare_function(use_control_points=True, interp='gp')
    
    burn_in = cf.run(cf.burn_in_time, cf.steady_state_y0, params=params)
    data, solution = cf.cbm.run(t, production=cf.production, args=params, y0=burn_in[-1,:])
    d_14_c = cf.cbm._to_d14c(data, cf.steady_state_y0)

    plt.plot(t, d_14_c, ".", label='fitted')
    plt.legend()

    cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
    cf2 = fitting.CarbonFitter(cbm)
    cf2.load_data('inject_recovery_gp.csv')
    cf2.prepare_function(custom_function=True, f=gauss)

    burn_in_true = cf2.run(cf2.burn_in_time, cf2.steady_state_y0)
    data_true, solution_true = cf2.cbm.run(t, production=cf2.production, y0=burn_in_true[-1,:])
    d_14_c_true = cf2.cbm._to_d14c(data_true, cf2.steady_state_y0)

    plt.plot(t, d_14_c_true, ".", label='true')
    plt.legend()

In [None]:
t = np.arange(cf.start-30, cf.start+5, 0.5)
t1 = np.arange(cf.start-20, cf.start+0.01, 0.1)

In [None]:
diagnostic(t)
plt.axvline(200, color='r')