In [2]:
# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

from jax import jit, tree_map
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Array
from copy import deepcopy
from typing import Optional
import numpy as np
from gpjax.base import meta_leaves
from jax.flatten_util import ravel_pytree
from jax.stages import Wrapped
import warnings
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    from gpjax.kernels import Constant, Linear, RBF, Periodic, PoweredExponential

key = jr.PRNGKey(42)

In [3]:
# data
n = 100
noise = 0.3

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(xtest)

In [22]:
k = PoweredExponential().replace(power=jnp.array(0.8))
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=k)
lh = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.3).replace_trainable(
    obs_stddev=False
)
post = lh * prior

In [24]:
ravel_pytree(PoweredExponential(power=jnp.array(0.8)))

(Array([1. , 0.8, 1. ], dtype=float64, weak_type=True),
 <jax._src.util.HashablePartial at 0x7f8c8ae6b430>)

In [20]:
meta_leaves(PoweredExponential().replace(power=jnp.array(0.8)).unconstrain())

[({'bijector': <tfp.bijectors.Softplus 'softplus' batch_shape=[] forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=? dtype_y=?>,
   'trainable': True,
   'pytree_node': True},
  Array(0.54132485, dtype=float64)),
 ({'bijector': <tfp.bijectors.Sigmoid 'sigmoid' batch_shape=[] forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=? dtype_y=?>,
   'trainable': True,
   'pytree_node': True},
  Array(1.38629436, dtype=float64, weak_type=True)),
 ({'bijector': <tfp.bijectors.Softplus 'softplus' batch_shape=[] forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=? dtype_y=?>,
   'trainable': True,
   'pytree_node': True},
  Array(0.54132485, dtype=float64))]

In [23]:
static_tree = tree_map(lambda x: not (x), post)
optim = ox.chain(
    ox.adam(learning_rate=0.003),
    ox.masked(
        ox.set_to_zero(),
        static_tree,
    ),
)
optimized_posterior, history = gpx.fit_scipy(
    model=post,
    objective=gpx.objectives.ConjugateMLL(negative=True),
    train_data=D,
    verbose=True,
)

(Array([0.3, 1. , 0.8, 1. ], dtype=float64, weak_type=True), <jax._src.util.HashablePartial object at 0x7f8c8b0cbf70>)
(Array([-1.05022573,  0.54132485,  1.38629436,  0.54132485], dtype=float64), <jax._src.util.HashablePartial object at 0x7f8cde0e2ce0>)
[-1.05022573  0.54132485  1.38629436  0.54132485] <jax._src.util.HashablePartial object at 0x7f8cde0e2ce0>
[ 0.         -2.83754736 -5.24593602  3.2500574 ]
[ 0.         -2.83754736 -5.24593602  3.2500574 ]
[ 0.          3.32161922 -1.45554478 -5.79452998]
[-1.05022573  0.9632662   2.16636146  0.05804354]
[ 0.          1.07510472 -0.799189   -1.4263123 ]
[-1.05022573  0.87173439  3.07109454  0.44422486]
[ 0.         -0.9466767  -0.47250745  0.86478403]
[-1.05022573  0.48428155  3.83270158  0.44467597]
[ 0.         -0.42055976 -0.32765332  0.5127217 ]
[-1.05022573  0.68156774  4.15284947  0.56992843]
[ 0.         -0.19762626 -0.16887522  0.39642081]
[-1.05022573  0.81031629  4.7902646   0.67337745]
[ 0.         -0.10720213 -0.09621533  0

In [6]:
history

Array([583.29764153], dtype=float64)