In [23]:
# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

key = jr.PRNGKey(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

In [24]:
n = 100
noise = 0.3

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(xtest)

In [25]:
kernel = gpx.kernels.ProductKernel(kernels=[gpx.kernels.Constant(), gpx.kernels.RBF()])
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

In [26]:
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=jnp.array(noise))
likelihood = likelihood.replace_trainable(obs_stddev=False)

In [27]:
posterior = prior * likelihood

In [28]:
# negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
negative_mll = gpx.objectives.ConjugateMLL(negative=True)

In [36]:
opt_posterior, history = gpx.fit_scipy(
    model=posterior,
    objective=negative_mll,
    train_data=D,
    verbose=True,
)

Initial loss is 489.0740769876317
Optimization was successful
Final loss is 55.46842746389798 after 18 iterations


In [8]:
from jax.flatten_util import ravel_pytree
from gpjax.base import meta_leaves, meta_map

In [10]:
ravel_pytree(opt_posterior.prior.kernel)

(Array([1.01898955, 0.45282094, 1.23487555], dtype=float64),
 <jax._src.util.HashablePartial at 0x7f8d9997a4a0>)

In [54]:
opt_posterior.prior.kernel.kernels

[Constant(compute_engine=DenseKernelComputation(), active_dims=None, name='AbstractKernel', constant=Array(1.01898955, dtype=float64)),
 RBF(compute_engine=DenseKernelComputation(), active_dims=None, name='RBF', lengthscale=Array(0.45282094, dtype=float64), variance=Array(1.23487555, dtype=float64))]