In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import seaborn as sns
from vsup import VSUP

%config InlineBackend.figure_format = 'retina'

In [None]:
import botorch
import geometric_kernels as gk
import geometric_kernels.torch
import gpytorch
import numpy as np
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize as NormalizeInput
from botorch.models.transforms.outcome import Standardize as StandardizeOutput
from geometric_kernels.frontends.gpytorch import GPyTorchGeometricKernel
from geometric_kernels.kernels import MaternGeometricKernel
from geometric_kernels.spaces import Graph
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.priors.torch_priors import GammaPrior

In [None]:
def grid_plot(
    G,
    attribute=None,
    uncertainty=None,
    *,
    label=False,
    cmap=None,
    vmin=None,
    vmax=None,
    umin=None,
    umax=None,
    quantization="linear",
    ax=None,
    figsize=(6, 4),
    pos=None,
    node_size=200,
    **kwargs,
):
    G = G.copy()  # Ensure we don't modify the original graph

    attributes = nx.get_node_attributes(G, attribute)
    if vmin is None:
        vmin = min(attributes.values())
    if vmax is None:
        vmax = max(attributes.values())
    # norm = Normalize(vmin=vmin, vmax=vmax)

    if uncertainty is not None:
        uncertainties = nx.get_node_attributes(G, uncertainty)
        assert uncertainties.keys() == attributes.keys(), (
            "All nodes with either attribute and uncertainty values must have both."
        )
        if umin is None:
            umin = min(uncertainties.values())
        if umax is None:
            umax = max(uncertainties.values())
    else:
        uncertainties = {node: 0 for node in attributes.keys()}
        umin = 0
        umax = 1
        quantization = None
    labeled_nodes = attributes.keys() if attribute is not None else []
    node_vu = {node: (attributes[node], uncertainties[node]) for node in labeled_nodes}

    # Prepare color map
    if cmap is None:
        cmap = sns.diverging_palette(220, 20, center="dark", as_cmap=True)
    elif isinstance(cmap, str):
        cmap = sns.color_palette(cmap, as_cmap=True)
    # Prepare Value-Supressing Uncertainty Palette (VSUP)
    vsup = VSUP(
        palette=cmap,
        vmin=vmin,
        vmax=vmax,
        umin=umin,
        umax=umax,
        quantization=quantization,
    )

    node_edgecolors = ["black"] * len(G.nodes)
    # If node_attr is None, color all nodes white
    if attribute is None:
        node_colors = ["white"] * len(G.nodes)
    else:
        node_colors = []
        for node in G.nodes:
            if node in labeled_nodes:
                value, uncertainty = node_vu[node]
                color = vsup(value, uncertainty)
                node_colors.append(color)
                node_edgecolors.append("black")
            else:
                node_colors.append("white")
                node_edgecolors.append("black")

    if pos is None:
        # Get positions so that (i, j) is at (j, -i) for grid-like display
        pos = {n: (n[1], -n[0]) for n in G.nodes}

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure
    nx.draw(
        G,
        pos=pos,
        node_color=node_colors,
        edgecolors=node_edgecolors,
        with_labels=False,
        node_size=node_size,
        ax=ax,
        **kwargs,
    )
    # Optionally, mark the valued nodes with their value
    if label and attribute is not None:
        for node in G.nodes:
            if attribute in G.nodes[node]:
                val = G.nodes[node][attribute]
                ax.text(
                    pos[node][0],
                    pos[node][1],
                    str(val),
                    color="black",
                    ha="center",
                    va="center",
                    fontsize=10,
                    fontweight="bold",
                )
    return vsup

In [None]:
# Helper function to ensure reproducibility
def set_seeds(seed=42):
    """Sets random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Set seeds for consistent results
set_seeds()


# Define the Gaussian Process model on the graph
class GraphGP:
    """
    A GPyTorch model for Graph Gaussian Processes.
    """

    def __init__(
        self,
        graph,
        attribute,
        normalize_laplacian=False,
        trainable_nu=False,
        nu=0.5,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    ):
        graph_torch = torch.tensor(nx.to_numpy_array(graph))
        graph_space = Graph(graph_torch, normalize_laplacian=normalize_laplacian)
        mean_module = gpytorch.means.ConstantMean()

        geo_kernel = MaternGeometricKernel(graph_space, nu=nu)
        params = geo_kernel.init_params()
        params["nu"] = torch.tensor([nu], dtype=torch.float64)
        params["lengthscale"] = torch.tensor(params["lengthscale"], dtype=torch.float64)
        base_kernel = GPyTorchGeometricKernel(
            geo_kernel,
            nu=params["nu"],
            trainable_nu=trainable_nu,
            lengthscale=params["lengthscale"],
            # lengthscale_prior=GammaPrior(3.0, 6.0),
        )
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel,
            outputscale_prior=GammaPrior(*scale_prior),
        )

        noise_prior = GammaPrior(*noise_prior)
        noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
        lik = GaussianLikelihood(
            noise_prior=noise_prior,
            noise_constraint=gpytorch.constraints.GreaterThan(noise_min),
            initial_value=noise_prior_mode,
        )

        node_index = {node: list(graph.nodes).index(node) for node in graph.nodes}
        attr_dict = nx.get_node_attributes(graph, attribute)
        # Prepare training data
        train_x, train_y = torch.tensor(
            [[node_index[node], val] for node, val in attr_dict.items()], dtype=float
        ).T

        self.model = SingleTaskGP(
            train_x[:, None],
            train_y[:, None],
            mean_module=mean_module,
            covar_module=covar_module,
            likelihood=lik,
            # input_transform=NormalizeInput(d=1),
            # outcome_transform=StandardizeOutput(m=1),
        )
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(
            self.model.likelihood, self.model
        )

    def fit(self):
        return fit_gpytorch_mll(mll=self.mll)

    def predict(
        self,
        graph,
        store_pred_as="prediction",
        mean_name="mean",
        var_name="var",
        std_name="std",
    ):
        """
        Predicts the values of the attribute for all nodes in the graph.
        """
        # Ensure the model is in evaluation mode
        self.model.eval()

        # Prepare input for prediction
        all_nodes = torch.arange(graph.number_of_nodes()).long()[:, None]
        idx2node = {idx: node for idx, node in enumerate(graph.nodes)}

        store_pred = store_pred_as is not None
        if store_pred is False:
            store_pred_as = "prediction"
        # Convert predictions to a dictionary
        posterior = self.model.posterior(all_nodes)
        node_posteriors = {
            idx2node[idx]: {
                f"{store_pred_as}_{mean_name}": mn.item(),
                f"{store_pred_as}_{var_name}": vr.item(),
                f"{store_pred_as}_{std_name}": float(np.sqrt(vr.item())),
            }
            for idx, (mn, vr) in enumerate(zip(posterior.mean, posterior.variance))
        }

        if store_pred:
            # Store predictions as node attributes in the graph
            nx.set_node_attributes(graph, node_posteriors)

        return node_posteriors


class SpatialGP:
    """
    A GPyTorch model for standard (Euclidean) Gaussian Processes using node coordinates as inputs.
    """

    def __init__(
        self,
        graph,
        attribute,
        nu=0.5,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
        normalize_laplacian=None,
    ):
        # Extract node coordinates (assume node names are (x1, x2) tuples)
        attr_dict = nx.get_node_attributes(graph, attribute)
        X = torch.tensor([list(node) for node in attr_dict.keys()], dtype=torch.float64)
        y = torch.tensor([val for val in attr_dict.values()], dtype=torch.float64)

        mean_module = gpytorch.means.ConstantMean()
        if nu == np.inf:
            base_kernel = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.RBFKernel(lengthscale_prior=GammaPrior(3.0, 6.0)),
                outputscale_prior=GammaPrior(*scale_prior),
            )
        else:
            base_kernel = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.MaternKernel(
                    nu=nu, lengthscale_prior=GammaPrior(3.0, 6.0)
                ),
                outputscale_prior=GammaPrior(*scale_prior),
            )

        noise_prior_obj = GammaPrior(*noise_prior)
        noise_prior_mode = (noise_prior_obj.concentration - 1) / noise_prior_obj.rate
        lik = GaussianLikelihood(
            noise_prior=noise_prior_obj,
            noise_constraint=gpytorch.constraints.GreaterThan(noise_min),
            initial_value=noise_prior_mode,
        )

        self.model = SingleTaskGP(
            X,
            y[:, None],
            mean_module=mean_module,
            covar_module=base_kernel,
            likelihood=lik,
            input_transform=NormalizeInput(d=X.shape[1]),
            outcome_transform=StandardizeOutput(m=1),
        )
        self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(
            self.model.likelihood, self.model
        )

    def fit(self):
        return fit_gpytorch_mll(mll=self.mll)

    def predict(
        self,
        graph,
        store_pred_as="prediction",
        mean_name="mean",
        var_name="var",
        std_name="std",
    ):
        """
        Predicts the values of the attribute for all nodes in the graph using their coordinates.
        """
        self.model.eval()
        all_coords = torch.tensor(
            [list(node) for node in graph.nodes], dtype=torch.float
        )
        idx2node = {idx: node for idx, node in enumerate(graph.nodes)}

        store_pred = store_pred_as is not None
        if store_pred is False:
            store_pred_as = "prediction"
        # Convert predictions to a dictionary
        posterior = self.model.posterior(all_coords)
        node_posteriors = {
            idx2node[idx]: {
                f"{store_pred_as}_{mean_name}": mn.item(),
                f"{store_pred_as}_{var_name}": vr.item(),
                f"{store_pred_as}_{std_name}": float(np.sqrt(vr.item())),
            }
            for idx, (mn, vr) in enumerate(zip(posterior.mean, posterior.variance))
        }

        if store_pred:
            # Store predictions as node attributes in the graph
            nx.set_node_attributes(graph, node_posteriors)

        return node_posteriors

# Trivial "1D" Graph

In [None]:
# Set up the figure
fig, axs = plt.subplot_mosaic([["a", "b"], ["c", "b"], ["d", "b"]], figsize=(8, 4))


# Create a minimal graph with 11 nodes
G1 = nx.grid_2d_graph(1, 11)

# Assign values to specific nodes
nx.set_node_attributes(
    G1,
    {
        (0, 1): 1,
        (0, 2): -1,
        (0, 6): 2,
        # (0, 9): -2,
    },
    "observation",
)


def label_graph(ax, label):
    """Add a Label to the left of the plotted graph."""
    ax.set_yticks([0])
    ax.axis("on")
    ax.set_yticklabels([label], color="k", fontsize=15)
    ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
    ax.tick_params(axis="y", which="both", left=False, right=False)
    for spine in ax.spines.values():
        spine.set_visible(False)


# Define the parameters for the prediction plots
pred_plot_kwargs = {
    "vmin": -2,
    "vmax": +2,
    "umin": 0.8,
    "umax": 2.1,
    "hide_ticks": False,
}

# Plot the graph with observations
ax = axs["a"]
vsup = grid_plot(G1, "observation", label=True, ax=ax, **pred_plot_kwargs)

label_graph(ax, "Observation")

# Define and fit the GraphGP model
G1_model_graph = GraphGP(
    G1,
    "observation",
    nu=2.5,
)
G1_model_graph.fit()
# Make predictions. This will also store the predictions in the graph.
G1_model_graph.predict(G1)

# Plot the mean and uncertainty of the GraphGP predictions.
ax = axs["c"]
vsup = grid_plot(G1, "prediction_mean", "prediction_std", ax=ax, **pred_plot_kwargs)
label_graph(ax, "GraphGP Predictions")

# Define and fit the SpatialGP model
G1_model_spatial = SpatialGP(
    G1,
    "observation",
    normalize_laplacian=False,
    nu=2.5,
    scale_prior=(2.0, 0.15),
    noise_prior=(1.1, 0.05),
    noise_min=1e-8,
)
G1_model_spatial.fit()
# Make predictions. This will also store the predictions in the graph.
G1_model_spatial.predict(G1)

# Plot the mean and uncertainty of the SpatialGP predictions.
ax = axs["d"]
vsup = grid_plot(G1, "prediction_mean", "prediction_std", ax=ax, **pred_plot_kwargs)
label_graph(ax, "SpatialGP Predictions")

# Plot a Value-Supressing Uncertainty Palette (VSUP) arcmap legend. This is the
# palette used for the GraphGP and SpatialGP predictions. It visualizes the
# prediction mean via hue and uncertainty via saturation and lightness.
ax = axs["b"]
vsup.create_arcmap_legend(ax)
ax.set_title("VSUP Arcmap Legend", fontsize=15)

## Comparing different nu values

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(16, 4), sharex=True, height_ratios=[1, 1, 5])

# Create a new grid graph to illustrate "spatial" predictions
# This graph has nodes at (i, j/10) to simulate a continuous space
G1_ = nx.grid_2d_graph(1, 101)
old2new = {(i, j): (i, j / 10) for i, j in G1_.nodes}
G1_ = nx.relabel_nodes(G1_, old2new)
G1_model_spatial.predict(G1_)

for col, nu in zip(axs.T, [0.5, 1.5, 2.5, np.inf]):
    graph_ax = col[0]
    if nu == np.inf:
        graph_ax.set_title("nu=∞")
    else:
        graph_ax.set_title(f"nu={int(nu * 2)}/2")

    G1_model_graph = GraphGP(
        G1,
        "observation",
        nu=nu,
    )
    G1_model_graph.fit()

    graph_preds = G1_model_graph.predict(G1)
    vsup = grid_plot(
        G1, "prediction_mean", "prediction_std", ax=graph_ax, **pred_plot_kwargs
    )

    spat_grid_ax = col[1]
    G1_model_spatial = SpatialGP(
        G1,
        "observation",
        nu=nu,
    )
    G1_model_spatial.fit()

    G1_model_spatial.predict(G1)
    vsup = grid_plot(
        G1, "prediction_mean", "prediction_std", ax=spat_grid_ax, **pred_plot_kwargs
    )

    spat_ax = col[2]
    spat_preds = G1_model_spatial.predict(G1_)
    spat_mean = [spat_preds[node]["prediction_mean"] for node in G1_.nodes]
    spat_std = [spat_preds[node]["prediction_std"] for node in G1_.nodes]
    spat_x = [node[1] for node in G1_.nodes]
    spat_ax.plot(spat_x, spat_mean, label="Mean")
    spat_ax.fill_between(
        spat_x,
        np.array(spat_mean) - 2 * np.array(spat_std),
        np.array(spat_mean) + 2 * np.array(spat_std),
        alpha=0.2,
        label="Std Dev",
    )

    node_x, node_mean, node_std = zip(
        *[
            (node[1], preds["prediction_mean"], preds["prediction_std"])
            for node, preds in graph_preds.items()
        ]
    )
    spat_ax.errorbar(
        node_x,
        node_mean,
        yerr=2 * np.array(node_std),
        fmt="o",
        label="Graph GP\nμ ± 2σ",
    )

    obs_x, obs_y = zip(
        *[
            (node[1], val)
            for node, val in nx.get_node_attributes(G1, "observation").items()
        ]
    )
    spat_ax.scatter(obs_x, obs_y, color="black", label="Observations", zorder=5)

# spat_ax.legend()

# Gridded "2D" Graph

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

G2 = nx.grid_2d_graph(11, 11)

# Assign values to specific nodes
nx.set_node_attributes(
    G2,
    {
        (1, 2): 1,
        (7, 7): 1,
        (8, 1): -1,
        (3, 7): -1,
        (5, 4): 0,
    },
    "observation",
)

plot_kwargs_2D = {
    "vmin": -1,
    "vmax": 1,
    "umin": 0.4,
    "umax": 1.6,
}

ax = axs[0]
vsup = grid_plot(G2, "observation", label=True, ax=ax, **plot_kwargs_2D)
ax.set_aspect("equal")
ax.set_title("Observations")

G2_model_graph = GraphGP(
    G2,
    "observation",
    nu=2.5,
)
G2_model_graph.fit()
G2_model_graph.predict(G2)

ax = axs[1]
vsup = grid_plot(G2, "prediction_mean", "prediction_std", ax=ax, **plot_kwargs_2D)
ax.set_aspect("equal")
ax.set_title("Graph GP Predictions")

ax = axs[2]
vsup.create_arcmap_legend(ax)
ax.set_title("VSUP Arcmap Legend")

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

R = np.array(
    [
        [np.cos(-np.pi / 4), -np.sin(-np.pi / 4)],
        [np.sin(-np.pi / 4), np.cos(-np.pi / 4)],
    ]
)

diamond_layout = {n: np.array([n[0], -n[1]]) @ R for n in G2.nodes}

diamond_layout = {
    n: (pos / 10 / np.cos(-np.pi / 4) - np.array([1, 0]))
    for n, pos in diamond_layout.items()
}

G2_model_graph = GraphGP(
    G2,
    "observation",
    nu=2.5,
)
G2_model_graph.fit()
G2_model_graph.predict(G2)

ax = axs[0]
grid_plot(
    G2,
    "prediction_mean",
    "prediction_std",
    pos=diamond_layout,
    node_size=50,
    ax=ax,
    **plot_kwargs_2D,
)
ax.set_aspect("equal")
ax.set_title("Graph GP Predictions\nGrid Layout")

G2_model_spatial = SpatialGP(
    G2,
    "observation",
    nu=2.5,
)
G2_model_spatial.fit()
G2_model_spatial.predict(G2)

ax = axs[1]
grid_plot(
    G2,
    "prediction_mean",
    "prediction_std",
    pos=diamond_layout,
    node_size=50,
    ax=ax,
    **plot_kwargs_2D,
)
ax.set_aspect("equal")
ax.set_title("Spatial GP Predictions\nGrid Layout")

To help explain why we see different behavior near the edges of the graph when we consider it as a true graph rather than as an image, let's plot the predictions in a more natural space for each.

In [None]:
spectral_layout = nx.spectral_layout(G2)
spectral_layout = {
    n: (-pos[0], -pos[1]) for n, pos in spectral_layout.items()
}  # Flip to match orientation of diamond layout

fig, axs = plt.subplots(1, 2, figsize=(8, 4))

G2_model_graph.predict(G2)
ax = axs[0]
grid_plot(
    G2,
    "prediction_mean",
    vmin=-1,
    vmax=+1,
    pos=spectral_layout,
    ax=ax,
    node_size=50,
)
ax.set_aspect("equal")
ax.set_title("Graph GP Mean\nwith Spectral Layout")

G2_model_spatial.predict(G2)
ax = axs[1]
grid_plot(
    G2,
    "prediction_mean",
    # "prediction_std",
    vmin=-1,
    vmax=+1,
    pos=diamond_layout,
    ax=ax,
    node_size=50,
)
ax.set_aspect("equal")
ax.set_title("Spatial GP Mean\nwith Grid Layout")

Here, we've plotted the SpatialGP predictions with nodes positioned according to adjacency (that is, "image-like" Euclidean space, albeit rotated), but the GraphGP predictions have been plotted with the nodes positioned according to the first two (non-trivial) eigenvectors of the graph. These eigenvectors define the orientations which capture the most and second-most variation in *connectivity* (both adjacency and degree) between all nodes in the graph. In the middle of the graph, there isn't much variation, so the node positioning (and the "propagation" of the observations), looks similar to Euclidean space. Near the edges and especially near the corners, however, there is significant variation between neighbors in the connectivity of the nodes, causing the edge nodes to be "squished" towards the middle.

Now we can start to see why the GraphGP behaves a bit differently. The GraphGP operates in *eigenspace* rather than Euclidean space. That is, the GP kernel is defined on the eigenvalues and eigenvectors of the graph's Laplacian. Predictions for nodes on the periphery are more similar to adjacent nodes than are nodes towards the middle because peripheral nodes are actually *closer* to their neighbors *in eigenspace*!

Now let's look closer at the variance behavior, which also differs from the SpatialGP treatment.

In [None]:
from scipy.optimize import minimize

soc = nx.second_order_centrality(G2)

edge_values = {}
for u, v in G2.edges:
    G2.edges[u, v]["delta_soc"] = np.abs(soc[u] - soc[v])

# Get initial positions from diamond_layout
init_pos = {n: np.array(pos) for n, pos in diamond_layout.items()}

# Prepare edge list and target lengths
edges = list(G2.edges)
target_lengths = np.array(list(nx.get_edge_attributes(G2, "delta_soc").values()))

# Normalize target lengths for better scaling (optional)
if np.max(target_lengths) > 0:
    target_lengths = target_lengths / np.max(target_lengths)

# Flatten initial positions for optimization
nodes = list(G2.nodes)
node_idx = {n: i for i, n in enumerate(nodes)}
x0 = np.concatenate([init_pos[n] for n in nodes])


def stress(x):
    coords = x.reshape(-1, 2)
    s = 0.0
    for idx, (u, v) in enumerate(edges):
        i, j = node_idx[u], node_idx[v]
        dist = np.linalg.norm(coords[i] - coords[j])
        s += (dist - target_lengths[idx]) ** 2
    return s


res = minimize(stress, x0, method="L-BFGS-B")
soc_layout = {n: res.x[2 * i : 2 * i + 2] for i, n in enumerate(nodes)}

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

G2_model_graph.predict(G2)
ax = axs[0]
grid_plot(
    G2,
    # "prediction_mean",
    "prediction_std",
    cmap="binary_r",
    vmin=0.4,
    vmax=1.6,
    pos=soc_layout,
    ax=ax,
    node_size=50,
)
ax.set_aspect("equal")
ax.set_title("Graph GP Uncertainty\nwith SOC Layout")

G2_model_spatial.predict(G2)
ax = axs[1]
grid_plot(
    G2,
    # "prediction_mean",
    "prediction_std",
    cmap="binary_r",
    vmin=0.4,
    vmax=1.6,
    pos=diamond_layout,
    ax=ax,
    node_size=50,
)
ax.set_aspect("equal")
ax.set_title("Spatial GP Uncertainty\nwith Grid Layout")
plt.colorbar(
    plt.cm.ScalarMappable(
        cmap="binary_r",
        norm=plt.Normalize(vmin=0.4, vmax=1.6),
    ),
    ax=ax,
    label="Uncertainty",
    orientation="vertical",
    fraction=0.046,
    pad=0.04,
)

Unlike the SpatialGP, the predictive uncertainty behavior of GraphGP varies near the periphery for the GraphGP, increasing sharply for corner nodes. In the plots above, this is visualized as an increase in paleness near the corners. As noted in the [GeometricKernels documentation](https://geometric-kernels.github.io/GeometricKernels/examples/backends/PyTorch_Graph.html#A-Note-on-Prior-Variance) and their [tutorial notebook](https://github.com/spbu-math-cs/Graph-Gaussian-Processes/blob/main/examples/graph_variance.ipynb), this comes from the variation in *expected return time* of a random walk over the nodes of graph. The corners are more "difficult" to get to on a random walk because there are fewer possible paths to them. This leads to a wide range of return times (short, via the immediate neighbors, and long, via the opposite corners) and so the uncertainty of predictions on these nodes is higher.

Above, we visualize this discrepency by assigning the edge lengths on the Graph GP prediction to be the difference in *second-order centrality* (SOC) of the adjacent nodes. SOC provides the standard deviation in expected return times to the respective node. As we can see, nodes which are similar in SOC have similar predictive variances, with the corners having both the highest predictive variance and the biggest difference in SOC compared to their neighbors.

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(16, 12), sharex=True, sharey="row")

# Prepare a "continuous" grid for the spatial GP (like G1_)
G2_dense = nx.grid_2d_graph(1, 101)
old2new_G2 = {(0, j): (8, j / 10) for (_, j) in G2_dense.nodes}
G2_dense = nx.relabel_nodes(G2_dense, old2new_G2)
G2_model_spatial.predict(G2_dense)

for col, nu in zip(axs.T, [0.5, 1.5, 2.5, np.inf]):
    nu_str = "∞" if nu == np.inf else f"{int(nu * 2)}/2"

    # Graph GP
    graph_ax = col[0]
    graph_ax.set_title(f"Graph GP (nu={nu_str})")
    G2_model_graph = GraphGP(
        G2,
        "observation",
        normalize_laplacian=False,
        nu=nu,
        trainable_nu=False,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    )
    G2_model_graph.fit()
    graph_preds = G2_model_graph.predict(G2)
    grid_plot(G2, "prediction_mean", "prediction_std", ax=graph_ax, umin=0.4, umax=1.6)
    # print(f"Fitted Graph GP with nu={nu}:",
    #       G2_model_graph.model.get_parameter('covar_module.base_kernel.raw_nu').item())
    graph_ax.set_aspect("equal")

    # Spatial GP
    spat_ax = col[1]
    spat_ax.set_title(f"Spatial GP (nu={nu_str})")
    G2_model_spatial = SpatialGP(
        G2,
        "observation",
        normalize_laplacian=False,
        nu=nu,
        scale_prior=(2.0, 0.15),
        noise_prior=(1.1, 0.05),
        noise_min=1e-8,
    )
    G2_model_spatial.fit()
    G2_model_spatial.predict(G2)
    grid_plot(G2, "prediction_mean", "prediction_std", ax=spat_ax, umin=0.4, umax=1.6)
    spat_ax.set_aspect("equal")

    # 1D slice through row 9
    slice_ax = col[2]
    slice_ax.set_title(f"Row 9 slice (nu={nu_str})")
    # Graph GP predictions for row 9
    row9_nodes = [(8, j) for j in range(11)]
    row9_x = [j for j in range(11)]
    row9_mean = [graph_preds[node]["prediction_mean"] for node in row9_nodes]
    row9_std = [graph_preds[node]["prediction_std"] for node in row9_nodes]
    slice_ax.errorbar(
        row9_x,
        row9_mean,
        yerr=2 * np.array(row9_std),
        fmt="o",
        label="Graph GP μ±2σ",
        color="C1",
    )

    # Spatial GP predictions for dense slice
    G2_model_spatial.predict(G2_dense)
    dense_x = [node[1] for node in G2_dense.nodes]
    dense_mean = [G2_dense.nodes[node]["prediction_mean"] for node in G2_dense.nodes]
    dense_std = [G2_dense.nodes[node]["prediction_std"] for node in G2_dense.nodes]
    slice_ax.plot(dense_x, dense_mean, label="Spatial GP mean", color="C0")
    slice_ax.fill_between(
        dense_x,
        np.array(dense_mean) - 2 * np.array(dense_std),
        np.array(dense_mean) + 2 * np.array(dense_std),
        alpha=0.2,
        color="C0",
        label="Spatial GP ±2σ",
    )

    # Observations in row 9
    obs = nx.get_node_attributes(G2, "observation")
    obs_x = [j for (i, j) in obs if i == 8]
    obs_y = [obs[(8, j)] for j in obs_x]
    if obs_x:
        slice_ax.scatter(obs_x, obs_y, color="black", label="Observations", zorder=5)

    slice_ax.set_xlabel("Column (j)")
    slice_ax.set_ylabel("Prediction")
    # slice_ax.legend()

plt.tight_layout()
plt.show()