### This experiment tests whether giving a GCN structural awareness of graph curvature, and especially the ability to learn it, improves node classification performance across different graph types.

We try to test the transitions from flat message passing (Vanilla GCN) → geometry-aware (Fixed Curv-GCN) → geometry-adaptive (Learnable Curv-GCN).

**Goal** : compare how graph curvature information, both fixed and learnable, affects a GCN’s ability to classify nodes in different types of graphs.

In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [2]:
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from matplotlib.colors import Normalize
from matplotlib import cm
from torch_geometric.datasets import WikipediaNetwork

In [3]:
def forman_curvature(G):
    fc = {}
    for u, v in G.edges():
        triangles = len(list(nx.common_neighbors(G, u, v)))
        degree_sum = G.degree[u] + G.degree[v]
        curvature = max(0.1, 4 - degree_sum + 3 * triangles)
        fc[(u, v)] = curvature
    return fc

In [4]:
def add_forman_curvature_as_feature(data):
    G = to_networkx(data, to_undirected=True)
    forman = forman_curvature(G)
    edges = data.edge_index.t().tolist()
    curvature_vals = []
    for e in edges:
        key = tuple(sorted(e))
        curvature_vals.append(forman.get(key, 0.0))
    curvature_tensor = torch.tensor(curvature_vals, dtype=torch.float).unsqueeze(1)  # [E, 1]

    c_min = curvature_tensor.min()
    c_max = curvature_tensor.max()
    if c_max > c_min:
        curvature_tensor = (curvature_tensor - c_min) / (c_max - c_min)
    else:
        curvature_tensor = torch.zeros_like(curvature_tensor)

    data.edge_curvature = curvature_tensor
    return data

In [5]:
class LearnableCurvGCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)
        self.curv_mlp = torch.nn.Sequential(
            torch.nn.Linear(1, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 1),
            torch.nn.Sigmoid()
        )
    def forward(self, x, edge_index, edge_curvature):

        edge_weight = self.curv_mlp(edge_curvature).squeeze(-1)
        edge_weight = 0.1 + 0.9 * edge_weight
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        return x


In [6]:
class VanillaGCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index, edge_curvature=None):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x


In [7]:

def train_and_evaluate(model, data, optimizer, criterion, epochs=200):
    model.train()
    for _ in range(epochs):
        optimizer.zero_grad()
        if isinstance(model, LearnableCurvGCN):
            out = model(data.x, data.edge_index, data.edge_curvature)
        else:
            out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        if isinstance(model, LearnableCurvGCN):
            out = model(data.x, data.edge_index, data.edge_curvature)
        else:
            out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        test_acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return test_acc


In [8]:

def run_experiment_2(dataset_name, data, num_classes, num_runs=5):
    results = {
        'vanilla_gcn': [],
        'fixed_curv_gcn': [],
        'learnable_curv_gcn': []
    }

    for run in range(num_runs):
        print(f"[{dataset_name}] Run {run+1}/{num_runs}")

        # vvanilla GCN
        model = VanillaGCN(data.num_features, 64, num_classes).to(device)
        data_d = data.to(device)
        opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        crit = torch.nn.CrossEntropyLoss()
        acc = train_and_evaluate(model, data_d, opt, crit)
        results['vanilla_gcn'].append(acc)

        # fixed Curvature GCN
        data_fixed = add_forman_curvature_as_feature(data.clone())
        curv = data_fixed.edge_curvature.squeeze().cpu()
        weights = (curv - curv.min()) / (curv.max() - curv.min() + 1e-9)
        weights = 0.1 + 0.9 * weights
        data_fixed.edge_weight = weights.to(device)
        data_fixed = data_fixed.to(device)

        model_fixed = VanillaGCN(data_fixed.num_features, 64, num_classes).to(device)
        opt_fixed = torch.optim.Adam(model_fixed.parameters(), lr=0.01, weight_decay=5e-4)
        def forward_fixed(x, edge_index, _=None):
            x = model_fixed.conv1(x, edge_index, data_fixed.edge_weight)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=model_fixed.training)
            x = model_fixed.conv2(x, edge_index, data_fixed.edge_weight)
            return x
        model_fixed.forward = forward_fixed

        model_fixed.train()
        for _ in range(200):
            opt_fixed.zero_grad()
            out = model_fixed(data_fixed.x, data_fixed.edge_index)
            loss = crit(out[data_fixed.train_mask], data_fixed.y[data_fixed.train_mask])
            loss.backward()
            opt_fixed.step()

        model_fixed.eval()
        with torch.no_grad():
            out = model_fixed(data_fixed.x, data_fixed.edge_index)
            pred = out.argmax(dim=1)
            acc_fixed = (pred[data_fixed.test_mask] == data_fixed.y[data_fixed.test_mask]).sum().item() / data_fixed.test_mask.sum().item()
        results['fixed_curv_gcn'].append(acc_fixed)

        #learnable Curvature GCN
        data_learn = add_forman_curvature_as_feature(data.clone())
        data_learn.edge_curvature = data_learn.edge_curvature.detach().to(device)
        data_learn = data_learn.to(device)

        model_learn = LearnableCurvGCN(data_learn.num_features, 64, num_classes).to(device)
        opt_learn = torch.optim.Adam(model_learn.parameters(), lr=0.01, weight_decay=5e-4)
        acc_learn = train_and_evaluate(model_learn, data_learn, opt_learn, crit)
        results['learnable_curv_gcn'].append(acc_learn)

    for key in results:
        mean, std = np.mean(results[key]), np.std(results[key])
        print(f"{key:20s}: {mean:.4f} ± {std:.4f}")
    print()
    return results


In [9]:
from torch_geometric.transforms import NormalizeFeatures

Training Fixed Curv-GCN and Learnable Curv-GCN using Cora dataset (homophilic graph)

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
print("test on Cora")
cora = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = cora[0]


num_classes = int(data.y.max().item()) + 1
print(f"Detected {num_classes} classes from labels.")


for key in ['train_mask', 'val_mask', 'test_mask']:
    if hasattr(data, key):
        delattr(data, key)

n = data.num_nodes
idx = torch.randperm(n)
train_idx = idx[:int(0.6 * n)]
val_idx = idx[int(0.6 * n):int(0.8 * n)]
test_idx = idx[int(0.8 * n):]

data.train_mask = torch.zeros(n, dtype=torch.bool)
data.val_mask = torch.zeros(n, dtype=torch.bool)
data.test_mask = torch.zeros(n, dtype=torch.bool)

data.train_mask[train_idx] = True
data.val_mask[val_idx] = True
data.test_mask[test_idx] = True

assert data.train_mask.dim() == 1, f"train_mask is {data.train_mask.shape}"
assert data.y.dim() == 1, f"y is {data.y.shape}"

cham_exp2 = run_experiment_2("Cora", data, num_classes)

test on Cora
Detected 7 classes from labels.
[Cora] Run 1/5
[Cora] Run 2/5
[Cora] Run 3/5
[Cora] Run 4/5
[Cora] Run 5/5
vanilla_gcn         : 0.8731 ± 0.0050
fixed_curv_gcn      : 0.8590 ± 0.0019
learnable_curv_gcn  : 0.8646 ± 0.0019



Training Fixed Curv-GCN and Learnable Curv-GCN using Chameleon dataset (Heterophilic graph)

In [14]:
print("test on Chameleon")
dataset = WikipediaNetwork(root='/tmp/Chameleon', name='chameleon', transform=NormalizeFeatures())
data = dataset[0]


num_classes = int(data.y.max().item()) + 1
print(f"Detected {num_classes} classes from labels.")


for key in ['train_mask', 'val_mask', 'test_mask']:
    if hasattr(data, key):
        delattr(data, key)

n = data.num_nodes
idx = torch.randperm(n)
train_idx = idx[:int(0.6 * n)]
val_idx = idx[int(0.6 * n):int(0.8 * n)]
test_idx = idx[int(0.8 * n):]

data.train_mask = torch.zeros(n, dtype=torch.bool)
data.val_mask = torch.zeros(n, dtype=torch.bool)
data.test_mask = torch.zeros(n, dtype=torch.bool)

data.train_mask[train_idx] = True
data.val_mask[val_idx] = True
data.test_mask[test_idx] = True

assert data.train_mask.dim() == 1, f"train_mask is {data.train_mask.shape}"
assert data.y.dim() == 1, f"y is {data.y.shape}"

cham_exp2 = run_experiment_2("Chameleon", data, num_classes)

test on Chameleon


Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/new_data/chameleon/out1_node_feature_label.txt
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/new_data/chameleon/out1_graph_edges.txt
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_0.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_1.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_2.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_3.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14

Detected 5 classes from labels.
[Chameleon] Run 1/5
[Chameleon] Run 2/5
[Chameleon] Run 3/5
[Chameleon] Run 4/5
[Chameleon] Run 5/5
vanilla_gcn         : 0.3465 ± 0.0059
fixed_curv_gcn      : 0.4193 ± 0.0038
learnable_curv_gcn  : 0.4202 ± 0.0045

