In [230]:
class Node:
    def __init__(self, value, parent=None):
        self.value = value
        self.parent = parent
        self.children = []

    def add_child(self, child_value):
        self.children.append(Node(child_value, self))


class SearchTree:
    def __init__(self, root_value, func, n_leafs=3):
        self.root = Node(root_value)
        self.current_layer = [self.root]
        self.func = func
        self.n_leafs = n_leafs

    def expand_node(self, node):
        # Implement the logic to expand a node
        pass

    def compute_layer(self):
        return [(node, self.func(node.value)) for node in self.current_layer]

    def select_top_nodes(self, layer_values):
        sorted_nodes = sorted(layer_values, key=lambda x: x[1], reverse=True)
        return [node for node, _ in sorted_nodes[: self.n_leafs]]

    def expand_layer(self):
        next_layer = []
        layer_values = self.compute_layer()
        top_nodes = self.select_top_nodes(layer_values)
        for node in top_nodes:
            self.expand_node(node)
            next_layer.extend(node.children)
        self.current_layer = next_layer

    def search(self, depth):
        for _ in range(depth):
            self.expand_layer()
            # You can add additional logic here, e.g., to process nodes at each layer


# Example usage
def example_function(value):
    # Define your function here
    return value


tree = SearchTree(root_value=0, func=example_function, n_leafs=3)
tree.search(depth=5)  # Adjust the depth as needed

In [21]:
# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
from copy import deepcopy
from typing import Optional
import numpy as np
from gpjax.base import meta_leaves

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx
    from gpjax.kernels import Constant, Linear, RBF, Periodic, PoweredExponential

key = jr.PRNGKey(123)

In [31]:
kernel_library = [Constant, Linear, RBF, Periodic, PoweredExponential]

In [30]:
class Node:
    def __init__(
        self,
        posterior: gpx.gps.AbstractPosterior,
        max_log_likelihood: Optional[float] = None,
        n_data: Optional[int] = None,
        parent=Optional["Node"],
    ):
        self.posterior = deepcopy(posterior)

        self.n_parameter = sum(
            leaf[0]["trainable"]
            for leaf in meta_leaves(posterior)
            if isinstance(leaf[0], dict)
        )  # number of trainable parameter

        if n_data is not None:
            self.n_data = np.log(n_data)

        if max_log_likelihood is not None:
            self.max_log_likelihood = max_log_likelihood
            if self.n_data is not None:
                self.bic = self.n_parameter * self.n_data - 2 * self.max_log_likelihood

        self.parent = parent
        self.children = []

    def add_child(
        self,
        node: "Node",
    ):
        self.children.append(node)


class KernelSearch:
    def __init__(
        self,
        root_prior: gpx.gps.Prior,
        likelihood: gpx.likelihoods.AbstractLikelihood,
        objective: gpx.objectives.AbstractObjective,
        data: gpx.Dataset,
        kernel_library: list[gpx.kernels.AbstractKernel],
        n_leafs: int = 3,
        n_layers: int = 10,
        max_iters: int = 500,
        verbose: bool = True,
    ):
        self.likelihood = likelihood
        self.objective = objective
        self.data = data

        self.kernel_library = kernel_library
        self.n_leafs = n_leafs
        self.n_layer = n_layers
        self.max_iters = max_iters
        self.verbose = verbose

        self.root = Node(likelihood * root_prior)

    def fit(self, posterior) -> tuple[gpx.gps.AbstractPosterior, float]:
        optimized_posterior, history = gpx.fit_scipy(
            model=posterior,
            objective=self.objective,
            train_data=self.data,
            max_iters=self.max_iters,
            verbose=self.verbose,
        )
        max_log_likelihood = float(history[-1])
        return optimized_posterior, max_log_likelihood

    def expand_node(self, node):
        for kernel_operation in [gpx.kernels.ProductKernel, gpx.kernels.SumKernel]:
            for ker in self.kernel_library:
                kernel = deepcopy(node.prior.kernel)
                new_kernel = kernel_operation(kernels=[kernel, ker()])  # type: ignore

                new_prior = gpx.gps.Prior(
                    mean_function=node.prior.mean_function, kernel=new_kernel
                )
                new_posterior = self.likelihood * new_prior
                node.add_child(Node(*self.fit(new_posterior), self.data.n, parent=node))

    def select_top_nodes(self, layer, bic_threshold):
        sorted_tuple = sorted((node.bic, node) for node in layer)
        # return first n_leafs nodes
        top_nodes = [node for _, node in sorted_tuple][: self.n_leafs]
        # filter for bic threshold
        top_nodes = [node for node in top_nodes if node.bic < bic_threshold]
        return top_nodes

    def expand_layer(self, layer):
        next_layer = []
        for node in layer:
            self.expand_node(node)
            next_layer.extend(node.children)
        return next_layer

    def search(self):
        self.root = Node(*self.fit(self.root.posterior), self.data.n)
        layer = [self.root]

        i = 0
        for i in range(self.n_layer):
            bic_threshold = layer[0].bic  # min bic of current layer
            layer = self.expand_layer(layer)
            layer = self.select_top_nodes(layer, bic_threshold)

            if len(layer) == 0:
                print("No more improvements found! Terminating early..\n")
                break
        print("Terminated on layer: ", i + 1)

        best_model = layer[0].posterior
        return best_model