In [None]:
from kernelsearch import KernelSearch
from kernels import OrnsteinUhlenbeck
from jaxtyping import install_import_hook
import jax.numpy as jnp
import jax.random as jr

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    from gpjax.kernels import RBF, Constant, Linear, Periodic, PoweredExponential

# random seed
key = jr.PRNGKey(42)

# TESTING

In [None]:
# data
n = 100
noise = 0.3

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(xtest)

In [None]:
kernel_library = [
    Linear(),
    RBF(),
    OrnsteinUhlenbeck(),
    Periodic(),
    PoweredExponential(power=jnp.array(0.8)),
]  # default powered exponential has infinite parameter for some reason

In [None]:
tree = KernelSearch(
    kernel_library,
    X=x,
    y=y,
    obs_stddev=0.3,
    verbosity=1,
)

model = tree.search(depth=5, n_leafs=3)