In [None]:
import geometric_kernels.torch  # noqa: F401
import gpytorch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import seaborn as sns
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from geometric_kernels.frontends.gpytorch import GPyTorchGeometricKernel
from geometric_kernels.kernels import MaternGeometricKernel
from geometric_kernels.spaces import Graph
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.priors.torch_priors import GammaPrior
from matplotlib.colors import Normalize

In [None]:
# Helper function to ensure reproducibility
def set_seeds(seed=42):
    """Sets random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Set seeds for consistent results
set_seeds()


# Define the Gaussian Process model on the graph
class GraphGP:
    """
    A GPyTorch model for Graph Gaussian Processes.
    """

    def __init__(
        self,
        graph,
        attribute,
        normalize_laplacian=False,
        trainable_nu=False,
        nu=np.inf,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    ):
        # Use a constant mean function
        mean_module = gpytorch.means.ConstantMean()

        # Initialize the geometric kernel
        graph_torch = torch.tensor(nx.to_numpy_array(graph))
        graph_space = Graph(graph_torch, normalize_laplacian=normalize_laplacian)
        geo_kernel = MaternGeometricKernel(graph_space, nu=nu)

        # Convert the geometric kernel to a GPyTorch kernel
        params = geo_kernel.init_params()
        params["nu"] = torch.tensor([nu], dtype=torch.float64)
        params["lengthscale"] = torch.tensor(params["lengthscale"], dtype=torch.float64)
        base_kernel = GPyTorchGeometricKernel(
            geo_kernel,
            nu=params["nu"],
            trainable_nu=trainable_nu,
            lengthscale=params["lengthscale"],
        )

        # Scale the kernel with a prior
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel,
            outputscale_prior=GammaPrior(*scale_prior),
        )

        # Set up the likelihood with a noise prior
        noise_prior = GammaPrior(*noise_prior)
        noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
        lik = GaussianLikelihood(
            noise_prior=noise_prior,
            noise_constraint=gpytorch.constraints.GreaterThan(noise_min),
            initial_value=noise_prior_mode,
        )

        # Prepare training data
        node_index = {node: list(graph.nodes).index(node) for node in graph.nodes}
        attr_dict = nx.get_node_attributes(graph, attribute)
        train_x, train_y = torch.tensor(
            [[node_index[node], val] for node, val in attr_dict.items()], dtype=float
        ).T
        # train_x and train_y must be reshaped to be 2D (N, 1) for GPyTorch compatibility
        train_x, train_y = train_x[:, None], train_y[:, None]

        # Build the model
        self.model = SingleTaskGP(
            train_x,
            train_y,
            mean_module=mean_module,
            covar_module=covar_module,
            likelihood=lik,
        )
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(
            self.model.likelihood, self.model
        )

    def fit(self):
        # Fit the model to the training data
        self.model.train()
        return fit_gpytorch_mll(mll=self.mll)

    def predict(
        self,
        graph,
        attribute="prediction",
        mean_name="mean",
        var_name="var",
        std_name="std",
    ):
        """
        Predicts the values of the attribute for all nodes in the graph.
        """
        # Ensure the model is in evaluation mode
        self.model.eval()

        # Prepare input for prediction
        all_nodes = torch.arange(graph.number_of_nodes()).long()[:, None]
        idx2node = {idx: node for idx, node in enumerate(graph.nodes)}

        # Convert predictions to a dictionary
        posterior = self.model.posterior(all_nodes)
        node_posteriors = {
            idx2node[idx]: {
                f"{attribute}_{mean_name}": mn.item(),
                f"{attribute}_{var_name}": vr.item(),
                f"{attribute}_{std_name}": float(np.sqrt(vr.item())),
            }
            for idx, (mn, vr) in enumerate(zip(posterior.mean, posterior.variance))
        }

        # Update the graph with the predictions
        nx.set_node_attributes(graph, node_posteriors)

        # Return the node posteriors
        return node_posteriors

In [None]:
def plot_graph(G, attribute, cmap, ax=None, cbar=True, vmin=None, vmax=None):
    """
    Plots the graph with nodes colored by the specified attribute.
    """
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure
    norm = Normalize(vmin=vmin, vmax=vmax)
    norm(list(nx.get_node_attributes(G, attribute).values()))
    if attribute is not None:
        node_colors = []
        for node in G.nodes:
            if attribute in G.nodes[node]:
                val = G.nodes[node][attribute]
                node_colors.append(cmap(norm(val)))
            else:
                node_colors.append("white")
    else:
        node_colors = "white"
    nx.draw_circular(
        G,
        node_color=node_colors,
        edge_color="black",
        edgecolors="black",
        ax=ax,
    )
    ax.set_aspect("equal")

    plt.colorbar(
        plt.cm.ScalarMappable(norm=norm, cmap=cmap),
        ax=ax,
        # label=attribute,
    )
    return fig, ax

In [None]:
G1 = nx.grid_2d_graph(1, 11)

# Assign values to specific nodes
nx.set_node_attributes(
    G1,
    {
        (0, 1): 1,
        (0, 6): -1,
    },
    "observation",
)

G1_model_graph = GraphGP(
    G1,
    "observation",
)
G1_model_graph.fit()
G1_model_graph.predict(G1)

fig, ax = plt.subplots(1, 3, figsize=(15, 4))

value_cmap = sns.diverging_palette(220, 20, as_cmap=True)
unc_cmap = sns.color_palette("cividis", as_cmap=True)

plot_graph(G1, "observation", value_cmap, ax=ax[0], vmin=-1, vmax=1)
ax[0].set_title("Observations")
plot_graph(G1, "prediction_mean", value_cmap, ax=ax[1], vmin=-1, vmax=1)
ax[1].set_title("Prediction Means")
plot_graph(G1, "prediction_std", unc_cmap, ax=ax[2])
ax[2].set_title("Prediction Stds")