### This experiment introduces a geometry-adaptive GCN that learns to select message-passing behavior locally based on edge curvature. Instead of using static curvature as a weight, the Curvature-Gated GCN interprets curvature as a structural signal to dynamically combine homophilic and heterophilic reasoning.

**Goal** : test whether a curvature-based gating mechanism can switch between homophilic and heterophilic message-passing behaviors during graph learning, improving node classification on both types of graphs.

We now assign a curvature-based gating value $g_{uv}$ to each edge $(u, v)$:

$
g_{uv} = \sigma \left( \frac{\mathcal{F}(u, v)}{k} \right)
$

where $k$ is a scaling constant (we use $k = 5$).

This value lies in $(0, 1)$ and determines how much the model trusts homophilic vs heterophilic information.

- $g_{uv} \approx 1 \rightarrow$ Edge is strongly homophilic $\rightarrow$ propagate via GCNConv path
- $g_{uv} \approx 0 \rightarrow$ Edge is heterophilic $\rightarrow$ propagate via HeteroConv path


We can define the Curvature-Gated GCN by:

$$
H^{(l+1)} = \sigma\left(\tilde{D}^{-\frac{1}{2}}\left[(G \odot A)H^{(l)}W^{(l)}_{\text{homo}} + ((1-G) \odot A)H^{(l)}W^{(l)}_{\text{hetero}}\right]\tilde{D}^{-\frac{1}{2}}\right)
$$

where:
- $G = \sigma(\mathcal{F}/k)$ is the curvature gate
- $\mathcal{F}$ is the Forman-Ricci curvature tensor over edges
- $A$ is the adjacency matrix
- $D$ is the degree matrix

In [14]:
!pip install torch_geometric



In [15]:
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from matplotlib.colors import Normalize
from matplotlib import cm
from torch_geometric.datasets import WikipediaNetwork

In [16]:
def forman_curvature(G):
    fc = {}
    for u, v in G.edges():
        triangles = len(list(nx.common_neighbors(G, u, v)))
        degree_sum = G.degree[u] + G.degree[v]
        curvature = max(0.1, 4 - degree_sum + 3 * triangles)
        fc[(u, v)] = curvature
    return fc

In [17]:
def get_edge_curvature_tensor(data):
    G = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    G.add_edges_from(zip(edge_index[0], edge_index[1]))
    forman = forman_curvature(G)

    curv_vals = []
    for i in range(data.edge_index.shape[1]):
        u, v = edge_index[0, i], edge_index[1, i]
        key = tuple(sorted((int(u), int(v))))
        val = forman.get(key, 0.0)
        curv_vals.append(val)

    return torch.tensor(curv_vals, dtype=torch.float, device=device)

In [19]:
from torch_geometric.nn import MessagePassing

In [20]:
class HeteroConv(MessagePassing):

    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, edge_weight=None):
        return self.propagate(edge_index, x=x, edge_weight=edge_weight)

    def message(self, x_i, x_j, edge_weight):

        diff = torch.abs(x_i - x_j)
        if edge_weight is not None:
            return edge_weight.view(-1, 1) * self.lin(diff)
        return self.lin(diff)

In [21]:
class CurvatureGatedGCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        # homophilic path
        self.homo_conv1 = GCNConv(in_dim, hidden_dim)
        self.homo_conv2 = GCNConv(hidden_dim, out_dim)

        # heterophilic path
        self.hetero_conv1 = HeteroConv(in_dim, hidden_dim)
        self.hetero_conv2 = HeteroConv(hidden_dim, out_dim)
        self.alpha = torch.nn.Parameter(torch.tensor(0.5))

    def forward(self, x, edge_index, edge_curvature):
        gate = torch.sigmoid(edge_curvature / 5.0).view(-1, 1)

        x_homo = self.homo_conv1(x, edge_index, edge_weight=gate.squeeze())
        x_homo = F.relu(x_homo)
        x_homo = F.dropout(x_homo, p=0.5, training=self.training)

        x_hetero = self.hetero_conv1(x, edge_index, edge_weight=(1 - gate).squeeze())
        x_hetero = F.relu(x_hetero)
        x_hetero = F.dropout(x_hetero, p=0.5, training=self.training)

        x = x_homo + x_hetero

        x_homo = self.homo_conv2(x, edge_index, edge_weight=gate.squeeze())
        x_hetero = self.hetero_conv2(x, edge_index, edge_weight=(1 - gate).squeeze())

        return x_homo + x_hetero


In [22]:
def train_gated_model(model, data, edge_curv, epochs=200):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for epoch in range(1, epochs + 1):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, edge_curv)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

    # Eval
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index, edge_curv)
        pred = out.argmax(dim=1)
        acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return acc



In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [24]:
from torch_geometric.transforms import NormalizeFeatures

In [25]:

def run_gated_experiment(dataset_name, dataset_class, name):
    print(f"Running Curvature-Gated GCN on {dataset_name}")
    dataset = dataset_class(root=f'/tmp/{name}', name=name, transform=NormalizeFeatures())
    data = dataset[0].to(device)

    if data.train_mask.dim() == 2:
        data.train_mask = data.train_mask[:, 0]
        data.val_mask = data.val_mask[:, 0]
        data.test_mask = data.test_mask[:, 0]

    edge_curv = get_edge_curvature_tensor(data)
    print(f"Curvature Range: [{edge_curv.min():.2f}, {edge_curv.max():.2f}]")

    model = CurvatureGatedGCN(data.num_features, 64, dataset.num_classes).to(device)
    test_acc = train_gated_model(model, data, edge_curv, epochs=200)

    print(f"{dataset_name} test accuracy: {test_acc:.4f}")
    return test_acc


In [26]:
acc_cora = run_gated_experiment("Cora", Planetoid, "Cora")

Running Curvature-Gated GCN on Cora
Curvature Range: [0.00, 6.00]
Cora test accuracy: 0.7040


In [27]:
acc_cham = run_gated_experiment("Chameleon", WikipediaNetwork, "chameleon")

Running Curvature-Gated GCN on Chameleon


Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/new_data/chameleon/out1_node_feature_label.txt
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/new_data/chameleon/out1_graph_edges.txt
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_0.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_1.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_2.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_3.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14

Curvature Range: [0.00, 652.00]
Chameleon test accuracy: 0.4605
