In [1]:
import numpy as np
import ticktack
from ticktack import fitting
import matplotlib.pyplot as plt

In [2]:
import jax.numpy as jnp
from jax import jit, grad, hessian, jacrev
import jax
from tinygp import kernels, GaussianProcess

In [3]:
sf2 = fitting.SingleFitter("Brehm21", "Brehm21")
sf2.load_data("miyake12.csv", oversample=12)
sf2.compile_production_model(model="control_points")

INFO[2022-03-04 13:24:34,297]: Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
INFO[2022-03-04 13:24:34,306]: Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host Interpreter
INFO[2022-03-04 13:24:34,311]: Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [4]:
def build_gp(*args):
    params = jnp.array(list(args)).reshape(-1)
    kernel = kernels.Matern32(1)
    return GaussianProcess(kernel, sf2.control_points_time, mean=params[0])

@jit
def interp_gp(tval, *args):
    tval = tval.reshape(-1)
    gp = build_gp(*args)
    params = jnp.array(list(args)).reshape(-1)
    return gp.condition(params, tval)[1].loc

@jit
def log_likelihood_gp(params):
    kernel = kernels.Matern32(1)
    gp = GaussianProcess(kernel, sf2.control_points_time, mean=params[0])
    return gp.log_probability(params)

@jit
def log_joint_likelihood_gp(params, low_bounds, up_bounds):
    lp = jnp.any((params < low_bounds) | (params > up_bounds)) * -jnp.inf
    return sf2.log_likelihood(params=params) + log_likelihood_gp(params) + lp

sf2.production = interp_gp
sf2.control_points_time_fine.shape

(324,)

In [10]:
params = np.ones(sf2.control_points_time.shape)
low = np.zeros(sf2.control_points_time.shape)
high = np.ones(sf2.control_points_time.shape) * 100

In [5]:
random = np.random.randn(sf2.control_points_time.size)

In [7]:
@jit
def grad_log_joint_likelihood_gp(params, low_bounds, up_bounds):
    return grad(log_joint_likelihood_gp)(params, low_bounds, up_bounds)

In [8]:
@jit
def hess_log_joint_likelihood_gp(params, low_bounds, up_bounds):
    return jacrev(jacrev(log_joint_likelihood_gp))(params, low_bounds, up_bounds)

In [11]:
%%timeit
grad_log_joint_likelihood_gp(params, low, high)

703 ms ± 55.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
hess_log_joint_likelihood_gp(params, low, high)

In [11]:
%%timeit
interp_gp(sf2.control_points_time_fine, params)
# plt.plot(sf2.control_points_time_fine, interp_gp(sf2.control_points_time_fine, params))

650 µs ± 65.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
# log_joint_likelihood_gp(params)
log_joint_likelihood_gp(params, low, high)

23.2 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
log_joint_likelihood_gp(params, low, high)

DeviceArray(-124980.6345283, dtype=float64)

In [14]:
log_likelihood_gp(params)

DeviceArray(-21.15613452, dtype=float64)