In [1]:
from kernelsearch import KernelSearch
from kernels import OrnsteinUhlenbeck
from jaxtyping import install_import_hook
import jax.numpy as jnp
import jax.random as jr

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    from gpjax.kernels import RBF, Constant, Linear, Periodic, PoweredExponential

# random seed
key = jr.PRNGKey(42)

# TESTING

In [2]:
# data
n = 100
noise = 0.3

key, subkey = jr.split(key)
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(subkey, shape=signal.shape) * noise

xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
ytest = f(xtest)

In [3]:
kernel_library = [
    Linear(),
    RBF(),
    OrnsteinUhlenbeck(),
    Periodic(),
    PoweredExponential(power=jnp.array(0.8)),
]  # default powered exponential has infinite parameter for some reason

In [4]:
tree = KernelSearch(
    kernel_library,
    X=x,
    y=y,
    obs_stddev=0.3,
    verbosity=1,
)

model = tree.search(depth=5, n_leafs=3)

Fitting Layer 1:   0%|          | 0/5 [00:00<?, ?it/s]

Fitting Layer 1: 100%|██████████| 5/5 [00:06<00:00,  1.23s/it]


Layer 1 || Current BICs: [45.179920166196084, 73.87658262301538, 78.48175280948402, 292.263906637831, 1171.2001302012086]


Fitting Layer 2: 100%|██████████| 30/30 [00:42<00:00,  1.41s/it]

Layer 2 || Current BICs: [49.78471830601025, 49.785090352226824, 49.785090508079335, 53.49132175600381, 54.389885606479254, 54.38988769754202, 54.38989865663895, 54.39026061738552, 58.49776879647675, 58.995055792596034, 58.99507131500635, 66.39961456175199, 71.00083024796925, 78.4817528089967, 78.52125743663228, 78.63514731053223, 83.08692299525445, 83.08692301720465, 83.12642746887882, 83.12642758543072, 83.2396413587263, 87.65849638287614, 87.69209319138844, 87.73159765492437, 87.73159766392808, 92.2636665695964, 92.3367699927049, 113.93587662137641, 160.6658596270523, 165.27102988730678]
No more improvements found! Terminating early..

Terminated on layer: 2.
Final log likelihood: -17.984789897109952
Final number of model paramter: 2



