In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import nxviz as nv

In [None]:
import numpy as np
import seaborn as sns
from matplotlib.colors import Normalize
from vsup import VSUP


def grid_plot(
    G,
    attribute=None,
    uncertainty=None,
    *,
    label=False,
    cmap=None,
    vmin=None,
    vmax=None,
    umin=None,
    umax=None,
    quantization="linear",
    ax=None,
    figsize=(6, 4),
):
    G = G.copy()  # Ensure we don't modify the original graph

    attributes = nx.get_node_attributes(G, attribute)
    if vmin is None:
        vmin = min(attributes.values())
    if vmax is None:
        vmax = max(attributes.values())
    # norm = Normalize(vmin=vmin, vmax=vmax)

    if uncertainty is not None:
        uncertainties = nx.get_node_attributes(G, uncertainty)
        assert uncertainties.keys() == attributes.keys(), (
            "All nodes with either attribute and uncertainty values must have both."
        )
        if umin is None:
            umin = min(uncertainties.values())
        if umax is None:
            umax = max(uncertainties.values())
    else:
        uncertainties = {node: 0 for node in attributes.keys()}
        umin = 0
        umax = 1
        quantization = None
    labeled_nodes = attributes.keys() if attribute is not None else []
    node_vu = {node: (attributes[node], uncertainties[node]) for node in labeled_nodes}

    # Prepare color map
    if cmap is None:
        cmap = sns.diverging_palette(220, 20, center="dark", as_cmap=True)
    elif isinstance(cmap, str):
        cmap = sns.color_palette(cmap, as_cmap=True)
    # Prepare Value-Supressing Uncertainty Palette (VSUP)
    vsup = VSUP(
        palette=cmap,
        vmin=vmin,
        vmax=vmax,
        umin=umin,
        umax=umax,
        quantization=quantization,
    )

    node_edgecolors = ["black"] * len(G.nodes)
    # If node_attr is None, color all nodes white
    if attribute is None:
        node_colors = ["white"] * len(G.nodes)
    else:
        node_colors = []
        for node in G.nodes:
            if node in labeled_nodes:
                value, uncertainty = node_vu[node]
                color = vsup(value, uncertainty)
                node_colors.append(color)
                node_edgecolors.append("black")
            else:
                node_colors.append("white")
                node_edgecolors.append("black")

    # Get positions so that (i, j) is at (j, -i) for grid-like display
    pos = {n: (n[1], -n[0]) for n in G.nodes}

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure
    nx.draw(
        G,
        pos=pos,
        node_color=node_colors,
        edgecolors=node_edgecolors,
        with_labels=False,
        node_size=200,
        ax=ax,
    )
    # Optionally, mark the valued nodes with their value
    if label and attribute is not None:
        for node in G.nodes:
            if attribute in G.nodes[node]:
                val = G.nodes[node][attribute]
                ax.text(
                    pos[node][0],
                    pos[node][1],
                    str(val),
                    color="black",
                    ha="center",
                    va="center",
                    fontsize=10,
                    fontweight="bold",
                )
    return fig, ax

In [None]:
import botorch
import geometric_kernels as gk
import geometric_kernels.torch
import gpytorch
import numpy as np
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize as NormalizeInput
from botorch.models.transforms.outcome import Standardize as StandardizeOutput
from geometric_kernels.frontends.gpytorch import GPyTorchGeometricKernel
from geometric_kernels.kernels import MaternGeometricKernel
from geometric_kernels.spaces import Graph
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.priors.torch_priors import GammaPrior

In [None]:
# Helper function to ensure reproducibility
def set_seeds(seed=42):
    """Sets random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Set seeds for consistent results
set_seeds()


# Define the Gaussian Process model on the graph
class GraphGP:
    """
    A GPyTorch model for Graph Gaussian Processes.
    """

    def __init__(
        self,
        graph,
        attribute,
        normalize_laplacian=False,
        trainable_nu=False,
        nu=0.5,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    ):
        graph_torch = torch.tensor(nx.to_numpy_array(graph))
        graph_space = Graph(graph_torch, normalize_laplacian=normalize_laplacian)
        mean_module = gpytorch.means.ConstantMean()

        geo_kernel = MaternGeometricKernel(graph_space, nu=nu)
        params = geo_kernel.init_params()
        params["nu"] = torch.tensor([nu], dtype=torch.float64)
        params["lengthscale"] = torch.tensor(params["lengthscale"], dtype=torch.float64)
        base_kernel = GPyTorchGeometricKernel(
            geo_kernel,
            nu=params["nu"],
            trainable_nu=trainable_nu,
            lengthscale=params["lengthscale"],
            # lengthscale_prior=GammaPrior(3.0, 6.0),
        )
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel,
            outputscale_prior=GammaPrior(*scale_prior),
        )

        noise_prior = GammaPrior(*noise_prior)
        noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
        lik = GaussianLikelihood(
            noise_prior=noise_prior,
            noise_constraint=gpytorch.constraints.GreaterThan(noise_min),
            initial_value=noise_prior_mode,
        )

        node_index = {node: list(graph.nodes).index(node) for node in graph.nodes}
        attr_dict = nx.get_node_attributes(graph, attribute)
        # Prepare training data
        train_x, train_y = torch.tensor(
            [[node_index[node], val] for node, val in attr_dict.items()], dtype=float
        ).T

        self.model = SingleTaskGP(
            train_x[:, None],
            train_y[:, None],
            mean_module=mean_module,
            covar_module=covar_module,
            likelihood=lik,
            # input_transform=NormalizeInput(d=1),
            outcome_transform=StandardizeOutput(m=1),
        )
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(
            self.model.likelihood, self.model
        )

    def fit(self):
        return fit_gpytorch_mll(mll=self.mll)

    def predict(
        self,
        graph,
        attribute="prediction",
        mean_name="mean",
        var_name="var",
        std_name="std",
    ):
        """
        Predicts the values of the attribute for all nodes in the graph.
        """
        # Ensure the model is in evaluation mode
        self.model.eval()

        # Prepare input for prediction
        all_nodes = torch.arange(graph.number_of_nodes()).long()[:, None]
        idx2node = {idx: node for idx, node in enumerate(graph.nodes)}

        # Convert predictions to a dictionary
        posterior = self.model.posterior(all_nodes)
        node_posteriors = {
            idx2node[idx]: {
                f"{attribute}_{mean_name}": mn.item(),
                f"{attribute}_{var_name}": vr.item(),
                f"{attribute}_{std_name}": float(np.sqrt(vr.item())),
            }
            for idx, (mn, vr) in enumerate(zip(posterior.mean, posterior.variance))
        }

        nx.set_node_attributes(graph, node_posteriors)

        return node_posteriors


class SpatialGP:
    """
    A GPyTorch model for standard (Euclidean) Gaussian Processes using node coordinates as inputs.
    """

    def __init__(
        self,
        graph,
        attribute,
        nu=0.5,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
        normalize_laplacian=None,
    ):
        # Extract node coordinates (assume node names are (x1, x2) tuples)
        attr_dict = nx.get_node_attributes(graph, attribute)
        X = torch.tensor([list(node) for node in attr_dict.keys()], dtype=torch.float64)
        y = torch.tensor([val for val in attr_dict.values()], dtype=torch.float64)

        mean_module = gpytorch.means.ConstantMean()
        if nu == np.inf:
            base_kernel = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.RBFKernel(lengthscale_prior=GammaPrior(3.0, 6.0)),
                outputscale_prior=GammaPrior(*scale_prior),
            )
        else:
            base_kernel = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.MaternKernel(
                    nu=nu, lengthscale_prior=GammaPrior(3.0, 6.0)
                ),
                outputscale_prior=GammaPrior(*scale_prior),
            )

        noise_prior_obj = GammaPrior(*noise_prior)
        noise_prior_mode = (noise_prior_obj.concentration - 1) / noise_prior_obj.rate
        lik = GaussianLikelihood(
            noise_prior=noise_prior_obj,
            noise_constraint=gpytorch.constraints.GreaterThan(noise_min),
            initial_value=noise_prior_mode,
        )

        self.model = SingleTaskGP(
            X,
            y[:, None],
            mean_module=mean_module,
            covar_module=base_kernel,
            likelihood=lik,
            input_transform=NormalizeInput(d=X.shape[1]),
            outcome_transform=StandardizeOutput(m=1),
        )
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(
            self.model.likelihood, self.model
        )

    def fit(self):
        return fit_gpytorch_mll(mll=self.mll)

    def predict(
        self,
        graph,
        attribute="prediction",
        mean_name="mean",
        var_name="var",
        std_name="std",
    ):
        """
        Predicts the values of the attribute for all nodes in the graph using their coordinates.
        """
        self.model.eval()
        all_coords = torch.tensor(
            [list(node) for node in graph.nodes], dtype=torch.float
        )
        idx2node = {idx: node for idx, node in enumerate(graph.nodes)}

        posterior = self.model.posterior(all_coords)
        node_posteriors = {
            idx2node[idx]: {
                f"{attribute}_{mean_name}": mn.item(),
                f"{attribute}_{var_name}": vr.item(),
                f"{attribute}_{std_name}": np.sqrt(vr.item()),
            }
            for idx, (mn, vr) in enumerate(zip(posterior.mean, posterior.variance))
        }

        nx.set_node_attributes(graph, node_posteriors)
        return node_posteriors

# Trivial "1D" Graph

In [None]:
G1 = nx.grid_2d_graph(1, 11)

# Assign values to specific nodes
nx.set_node_attributes(
    G1,
    {
        (0, 1): 1,
        (0, 2): -1,
        (0, 6): 2,
        # (0, 9): -2,
    },
    "observation",
)

fig, ax = grid_plot(G1, "observation", label=True, figsize=(6, 1))

In [None]:
from vsup import VSUP

cmap = sns.diverging_palette(220, 20, center="dark", as_cmap=True)
vsup = VSUP(cmap, vmin=-1, vmax=1, umin=0, umax=1, quantization="tree")
vsup.create_arcmap_legend()

In [None]:
G1_model_graph = GraphGP(
    G1,
    "observation",
    normalize_laplacian=False,
    nu=np.inf,
    scale_prior=(2.0, 0.15),
    noise_prior=(1.1, 0.05),
    noise_min=1e-8,
)
G1_model_graph.fit()
G1_model_graph.predict(G1)

grid_plot(G1, "prediction_mean", "prediction_std", figsize=(6, 1))
# {n:v for n, v in G1_model_graph.model.named_parameters()}
# G1_model_graph.model.get_parameter('covar_module.base_kernel.raw_lengthscale').item()

In [None]:
G1_model_spatial = SpatialGP(
    G1,
    "observation",
    normalize_laplacian=False,
    nu=0.5,
    scale_prior=(2.0, 0.15),
    noise_prior=(1.1, 0.05),
    noise_min=1e-8,
)
G1_model_spatial.fit()

grid_plot(G1, "prediction_mean", "prediction_std", figsize=(6, 1))
# G1_model_spatial.model.get_parameter('covar_module.base_kernel.raw_lengthscale').item()

## Comparing different nu values

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(16, 4), sharex=True, height_ratios=[1, 1, 5])

# Create a new grid graph to illustrate "spatial" predictions
# This graph has nodes at (i, j/10) to simulate a continuous space
G1_ = nx.grid_2d_graph(1, 101)
old2new = {(i, j): (i, j / 10) for i, j in G1_.nodes}
G1_ = nx.relabel_nodes(G1_, old2new)
G1_model_spatial.predict(G1_)

for col, nu in zip(axs.T, [0.5, 1.5, 2.5]):
    graph_ax = col[0]
    graph_ax.set_title(f"nu={int(nu * 2)}/2")

    G1_model_graph = GraphGP(
        G1,
        "observation",
        normalize_laplacian=False,
        nu=nu,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    )
    G1_model_graph.fit()
    print(
        f"Fitted Graph GP with nu={nu}:",
        G1_model_graph.model.get_parameter(
            "covar_module.base_kernel.raw_lengthscale"
        ).item(),
    )

    grpah_preds = G1_model_graph.predict(G1)
    grid_plot(G1, "prediction_mean", "prediction_std", ax=graph_ax)

    spat_grid_ax = col[1]
    G1_model_spatial = SpatialGP(
        G1,
        "observation",
        normalize_laplacian=False,
        nu=nu,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    )
    G1_model_spatial.fit()
    print(
        f"Fitted Spatial GP with nu={nu}:",
        G1_model_spatial.model.get_parameter(
            "covar_module.base_kernel.raw_lengthscale"
        ).item(),
    )

    G1_model_spatial.predict(G1)
    grid_plot(G1, "prediction_mean", "prediction_std", ax=spat_grid_ax)

    spat_ax = col[2]
    spat_preds = G1_model_spatial.predict(G1_)
    spat_mean = [spat_preds[node]["prediction_mean"] for node in G1_.nodes]
    spat_std = [spat_preds[node]["prediction_std"] for node in G1_.nodes]
    spat_x = [node[1] for node in G1_.nodes]
    spat_ax.plot(spat_x, spat_mean, label="Mean")
    spat_ax.fill_between(
        spat_x,
        np.array(spat_mean) - 2 * np.array(spat_std),
        np.array(spat_mean) + 2 * np.array(spat_std),
        alpha=0.2,
        label="Std Dev",
    )

    # for node, preds in grpah_preds.items():
    #     node_x = node[1]
    #     node_mean = preds["prediction_mean"]
    #     node_std = preds["prediction_std"]

    node_x, node_mean, node_std = zip(
        *[
            (node[1], preds["prediction_mean"], preds["prediction_std"])
            for node, preds in grpah_preds.items()
        ]
    )
    spat_ax.errorbar(
        node_x,
        node_mean,
        yerr=2 * np.array(node_std),
        fmt="o",
        label="Graph GP\nμ ± 2σ",
    )

    obs_x, obs_y = zip(
        *[
            (node[1], val)
            for node, val in nx.get_node_attributes(G1, "observation").items()
        ]
    )
    spat_ax.scatter(obs_x, obs_y, color="black", label="Observations", zorder=5)

# spat_ax.legend()

# Gridded "2D" Graph

In [None]:
G2 = nx.grid_2d_graph(11, 11)

# Assign values to specific nodes
nx.set_node_attributes(
    G2,
    {
        (1, 2): 1,
        (7, 7): 1,
        (8, 1): -1,
        (3, 7): -1,
        (5, 4): 0,
    },
    "observation",
)

fig, ax = grid_plot(G2, "observation", label=True)
ax.set_aspect("equal")

In [None]:
G2_model_graph = GraphGP(
    G2,
    "observation",
    normalize_laplacian=False,
    nu=2.5,
    scale_prior=(2.0, 0.15),
    noise_prior=(1.1, 0.05),
    noise_min=1e-8,
)
G2_model_graph.fit()
G2_model_graph.predict(G2)

fig, ax = grid_plot(G2, "prediction_mean", "prediction_std")
ax.set_aspect("equal")

In [None]:
G2_model_spatial = SpatialGP(
    G2,
    "observation",
    normalize_laplacian=False,
    nu=2.5,
    scale_prior=(2.0, 0.15),
    noise_prior=(1.1, 0.05),
    noise_min=1e-8,
)
G2_model_spatial.fit()
G2_model_spatial.predict(G2)

fig, ax = grid_plot(G2, "prediction_mean", "prediction_std")
ax.set_aspect("equal")

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(12, 12), sharex=True, sharey="row")

# Prepare a "continuous" grid for the spatial GP (like G1_)
G2_dense = nx.grid_2d_graph(1, 101)
old2new_G2 = {(0, j): (8, j / 10) for (_, j) in G2_dense.nodes}
G2_dense = nx.relabel_nodes(G2_dense, old2new_G2)
G2_model_spatial.predict(G2_dense)

for col, nu in zip(axs.T, [0.5, 1.5, 2.5, np.inf]):
    # Graph GP
    graph_ax = col[0]
    graph_ax.set_title(f"Graph GP (nu={nu})")
    G2_model_graph = GraphGP(
        G2,
        "observation",
        normalize_laplacian=False,
        nu=nu,
        trainable_nu=False,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    )
    G2_model_graph.fit()
    graph_preds = G2_model_graph.predict(G2)
    grid_plot(G2, "prediction_mean", "prediction_std", ax=graph_ax, umin=0.3, umax=1.6)
    # print(f"Fitted Graph GP with nu={nu}:",
    #       G2_model_graph.model.get_parameter('covar_module.base_kernel.raw_nu').item())
    graph_ax.set_aspect("equal")

    # Spatial GP
    spat_ax = col[1]
    spat_ax.set_title(f"Spatial GP (nu={nu})")
    G2_model_spatial = SpatialGP(
        G2,
        "observation",
        normalize_laplacian=False,
        nu=nu,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    )
    G2_model_spatial.fit()
    G2_model_spatial.predict(G2)
    grid_plot(G2, "prediction_mean", "prediction_std", ax=spat_ax, umin=0.3, umax=1.6)
    spat_ax.set_aspect("equal")

    # 1D slice through row 9
    slice_ax = col[2]
    slice_ax.set_title(f"Row 9 slice (nu={nu})")
    # Graph GP predictions for row 9
    row9_nodes = [(8, j) for j in range(11)]
    row9_x = [j for j in range(11)]
    row9_mean = [graph_preds[node]["prediction_mean"] for node in row9_nodes]
    row9_std = [graph_preds[node]["prediction_std"] for node in row9_nodes]
    slice_ax.errorbar(
        row9_x,
        row9_mean,
        yerr=2 * np.array(row9_std),
        fmt="o",
        label="Graph GP μ±2σ",
        color="C1",
    )

    # Spatial GP predictions for dense slice
    G2_model_spatial.predict(G2_dense)
    dense_x = [node[1] for node in G2_dense.nodes]
    dense_mean = [G2_dense.nodes[node]["prediction_mean"] for node in G2_dense.nodes]
    dense_std = [G2_dense.nodes[node]["prediction_std"] for node in G2_dense.nodes]
    slice_ax.plot(dense_x, dense_mean, label="Spatial GP mean", color="C0")
    slice_ax.fill_between(
        dense_x,
        np.array(dense_mean) - 2 * np.array(dense_std),
        np.array(dense_mean) + 2 * np.array(dense_std),
        alpha=0.2,
        color="C0",
        label="Spatial GP ±2σ",
    )

    # Observations in row 9
    obs = nx.get_node_attributes(G2, "observation")
    obs_x = [j for (i, j) in obs if i == 8]
    obs_y = [obs[(8, j)] for j in obs_x]
    if obs_x:
        slice_ax.scatter(obs_x, obs_y, color="black", label="Observations", zorder=5)

    slice_ax.set_xlabel("Column (j)")
    slice_ax.set_ylabel("Prediction")
    # slice_ax.legend()

plt.tight_layout()
plt.show()