In [28]:
%%capture
!pip install torch
!pip install torch-geometric                      
!pip install ripser                               
!pip install networkx                             
!pip install scikit-learn                         
!pip install gudhi
!pip install giotto-tda
!pip install networkx

In [54]:
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GINConv, SAGEConv, GATConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import xgboost as xgb
from gudhi import SimplexTree
import networkx as nx
from gtda.diagrams import BettiCurve
import scipy
import scipy.sparse.linalg
import random
from torch_geometric.transforms import BaseTransform


In [30]:
# ===== FILTRATION METHODS =====

# helper
def to_networkx(data):
    G = nx.Graph()
    edge_index = data.edge_index.numpy()
    edges = list(zip(edge_index[0], edge_index[1]))
    G.add_edges_from(edges)
    # Add isolated nodes if any (nodes with no edges)
    for i in range(data.num_nodes):
        if i not in G:
            G.add_node(i)
    return G


def degree_filtration(data):
    # Use bincount with data.num_nodes for robustness
    deg = torch.bincount(data.edge_index.flatten(), minlength=data.num_nodes)
    return deg.numpy()


def clustering_coeff_filtration(data):
    G = to_networkx(data)
    clust = nx.clustering(G)
    coeffs = np.array([clust.get(i, 0) for i in range(data.num_nodes)])
    return coeffs

def pagerank_filtration(data):
    G = to_networkx(data)
    pr = nx.pagerank(G)
    pr_vals = np.array([pr.get(i, 0) for i in range(data.num_nodes)])
    return pr_vals

def heat_kernel_filtration(data, t=1.0):
    # skip if there's too many nodes
    if data.num_nodes > 500:
        print("Skipped HKS filtration due to size")
        return np.zeros(data.num_nodes)

    G = to_networkx(data)
    L = nx.normalized_laplacian_matrix(G)
    try:
        eigvals, eigvecs = scipy.sparse.linalg.eigsh(L, k=min(100, data.num_nodes - 1), which='SM')
    except:
        # Fallback for small graphs or graphs with disconnected components
        L_dense = L.todense()
        eigvals, eigvecs = np.linalg.eigh(L_dense)
    diag = np.sum((eigvecs ** 2) * np.exp(-t * eigvals)[None, :], axis=1)
    return diag

In [31]:
def build_simplex_tree(data, filtration_values):
    st = SimplexTree()
    for i, fval in enumerate(filtration_values):
        st.insert([i], filtration=fval)
    edges = data.edge_index.t().numpy()
    for u, v in edges:
        # Gudhi expects simplices to be sorted, so we sort the edge
        sorted_edge = sorted([u, v])
        fval = max(filtration_values[u], filtration_values[v])
        st.insert(sorted_edge, filtration=fval)
    return st

def compute_betti_curve(st, dim=0, n_bins=100):
    st.compute_persistence()
    diag = np.array(st.persistence_intervals_in_dimension(dim))

    if diag.size > 0:
        # Safely handle the case where all intervals are infinite
        finite_diag = diag[np.isfinite(diag)]
        if finite_diag.size > 0:
            max_finite = np.max(finite_diag)
            # Replace infinite values with a value slightly larger than the max finite value
            diag[np.isinf(diag)] = max_finite + 1
        else:
            # If all intervals are infinite, set them to a large number
            diag[np.isinf(diag)] = 1.0

        dim_col = np.full((diag.shape[0], 1), fill_value=dim)
        diag_3d = np.hstack([diag, dim_col])
    else:
        # Handle empty diagrams gracefully (no intervals)
        diag_3d = np.empty((0, 3))
    
    bc = BettiCurve(n_bins=n_bins)
    betti_vector = bc.fit_transform([diag_3d])[0]
    return betti_vector


def compute_betti_vector(data, filtration_fn):
    filt_vals = filtration_fn(data)
    st = build_simplex_tree(data, filt_vals)
    return compute_betti_curve(st, dim=0)

def get_betti_vectors(dataset, filtration_fn):
    X, y = [], []
    for data in dataset:
        bv = compute_betti_vector(data, filtration_fn)
        X.append(bv)
        # Fix: Ensure label is always a single number
        label = data.y.item() if data.y.dim() > 0 else data.y
        y.append(label)
    
    # Check for empty X before stacking
    if not X:
        return np.array([]).reshape(0, n_bins), np.array([])
        
    X = np.vstack(X)  # ensures 2D shape for X
    y = np.array(y)
    return X, y

In [32]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        # Input layer
        self.convs.append(GCNConv(in_channels, hidden_channels))
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        # Output layer
        if num_layers > 1:
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                nn = torch.nn.Sequential(
                    torch.nn.Linear(in_channels, hidden_channels),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_channels, hidden_channels),
                )
            else:
                nn = torch.nn.Sequential(
                    torch.nn.Linear(hidden_channels, hidden_channels),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_channels, hidden_channels),
                )
            self.convs.append(GINConv(nn))
        
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            
        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        # Input layer
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        # Output layer
        if num_layers > 1:
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))

        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            
        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, heads=4, dropout=0.6):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        # Input layer
        self.convs.append(GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout))
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))
        # Output layer
        if num_layers > 1:
            # A common practice is to use a single head for the last layer to get a clean output dimension
            self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=1, dropout=dropout))
            self.lin = torch.nn.Linear(hidden_channels, out_channels)
        else:
            self.lin = torch.nn.Linear(hidden_channels * heads, out_channels)
            
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            
        x = self.convs[-1](x, edge_index)
        x = F.elu(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x

In [33]:
def train_gnn(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

def test_gnn(model, loader, device):
    model.eval()
    correct = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())
    return correct / len(loader.dataset)

In [58]:
class AddIdentityFeatures(BaseTransform):
    def __call__(self, data):
        if data.x is None:
            data.x = torch.ones((data.num_nodes, 1))
        return data

Methods for the following filtrations and gnns for use in testing:

In [59]:
def run_experiment(dataset_name, filtrations, gnn_models, device='cpu'):
    print(f"Running experiment on dataset: {dataset_name}")
    dataset = TUDataset(root='/tmp/'+dataset_name, name=dataset_name, transform=AddIdentityFeatures())

    # split train/test: 80/20
    dataset = dataset.shuffle()
    train_dataset = dataset[:int(0.8*len(dataset))]
    test_dataset = dataset[int(0.8*len(dataset)):]

    print(f"Dataset size: {len(dataset)}, train: {len(train_dataset)}, test: {len(test_dataset)}")

    # XGBoost on Betti vectors
    for filt_name, filt_fn in filtrations.items():
        print(f"\nComputing Betti vectors with filtration: {filt_name}")
        X_train, y_train = get_betti_vectors(train_dataset, filt_fn)
        X_test, y_test = get_betti_vectors(test_dataset, filt_fn)

        if len(X_train) == 0:
            print(f"Skipping XGBoost for {filt_name} due to empty dataset.")
            continue
            
        model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        print(f"XGBoost accuracy with {filt_name} filtration Betti vectors: {acc:.4f}")

    # GNN training
    for model_name, model_fn in gnn_models.items():
        print(f"\nTraining GNN model: {model_name}")
        # Fix: Get the correct number of node features
        num_node_features = dataset[0].x.shape[1]
        model = model_fn(num_node_features, 64, dataset.num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        criterion = torch.nn.CrossEntropyLoss()

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32)

        for epoch in range(30):
            loss = train_gnn(model, train_loader, optimizer, criterion, device)
            acc = test_gnn(model, test_loader, device)
            if epoch % 10 == 0 or epoch == 29:
                print(f"Epoch {epoch+1:02d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}")

In [61]:
filtrations = {
    'degree': degree_filtration,
    'clustering_coeff': clustering_coeff_filtration,
    'pagerank': pagerank_filtration,
    # 'heat_kernel': heat_kernel_filtration,
}

gnn_models = {
    'GCN': GCN,
    'GIN': GIN,
    'GraphSAGE': GraphSAGE,
    'GAT': GAT,
}

datasets = [
    'BZR',
    'COX2',
    'MUTAG',
    'PROTEINS',
    'IMDB-BINARY',
    'IMDB-MULTI',
    'REDDIT-BINARY',
    'REDDIT-MULTI-5K'
]

# Use a smaller list of datasets for a quick run
#filtrations = {}
#datasets = ['REDDIT-BINARY']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for ds in datasets:
    run_experiment(ds, filtrations, gnn_models, device=device)

Running experiment on dataset: BZR
Dataset size: 405, train: 324, test: 81

Computing Betti vectors with filtration: degree
XGBoost accuracy with degree filtration Betti vectors: 0.8025

Computing Betti vectors with filtration: clustering_coeff
XGBoost accuracy with clustering_coeff filtration Betti vectors: 0.7531

Computing Betti vectors with filtration: pagerank
XGBoost accuracy with pagerank filtration Betti vectors: 0.8148

Training GNN model: GCN
Epoch 01, Loss: 0.5628, Test Acc: 0.7531
Epoch 11, Loss: 0.4372, Test Acc: 0.7531
Epoch 21, Loss: 0.4127, Test Acc: 0.7654
Epoch 30, Loss: 0.4000, Test Acc: 0.7654

Training GNN model: GIN
Epoch 01, Loss: 0.5780, Test Acc: 0.7531
Epoch 11, Loss: 0.4253, Test Acc: 0.7531
Epoch 21, Loss: 0.4197, Test Acc: 0.7531
Epoch 30, Loss: 0.4341, Test Acc: 0.7531

Training GNN model: GraphSAGE
Epoch 01, Loss: 0.5244, Test Acc: 0.7531
Epoch 11, Loss: 0.4368, Test Acc: 0.7531
Epoch 21, Loss: 0.4018, Test Acc: 0.7531
Epoch 30, Loss: 0.4023, Test Acc: 0.

Downloading https://www.chrsmrrs.com/graphkerneldatasets/IMDB-MULTI.zip
Processing...
Done!


Dataset size: 1500, train: 1200, test: 300

Computing Betti vectors with filtration: degree
XGBoost accuracy with degree filtration Betti vectors: 0.3800

Computing Betti vectors with filtration: clustering_coeff
XGBoost accuracy with clustering_coeff filtration Betti vectors: 0.3033

Computing Betti vectors with filtration: pagerank
XGBoost accuracy with pagerank filtration Betti vectors: 0.4233

Training GNN model: GCN
Epoch 01, Loss: 1.1039, Test Acc: 0.3500
Epoch 11, Loss: 1.0991, Test Acc: 0.3033
Epoch 21, Loss: 1.0990, Test Acc: 0.3500
Epoch 30, Loss: 1.0988, Test Acc: 0.3033

Training GNN model: GIN
Epoch 01, Loss: 1.3705, Test Acc: 0.3033
Epoch 11, Loss: 1.0996, Test Acc: 0.3500
Epoch 21, Loss: 1.1000, Test Acc: 0.3033
Epoch 30, Loss: 1.0995, Test Acc: 0.3033

Training GNN model: GraphSAGE
Epoch 01, Loss: 1.1240, Test Acc: 0.3500
Epoch 11, Loss: 1.0997, Test Acc: 0.3033
Epoch 21, Loss: 1.1002, Test Acc: 0.3033
Epoch 30, Loss: 1.0993, Test Acc: 0.3467

Training GNN model: GAT
Ep

Downloading https://www.chrsmrrs.com/graphkerneldatasets/REDDIT-MULTI-5K.zip
Processing...
Done!


Dataset size: 4999, train: 3999, test: 1000

Computing Betti vectors with filtration: degree
XGBoost accuracy with degree filtration Betti vectors: 0.4350

Computing Betti vectors with filtration: clustering_coeff
XGBoost accuracy with clustering_coeff filtration Betti vectors: 0.4520

Computing Betti vectors with filtration: pagerank
XGBoost accuracy with pagerank filtration Betti vectors: 0.4310

Training GNN model: GCN
Epoch 01, Loss: 1.6128, Test Acc: 0.1960
Epoch 11, Loss: 1.5516, Test Acc: 0.3750
Epoch 21, Loss: 1.5188, Test Acc: 0.4520
Epoch 30, Loss: 1.4943, Test Acc: 0.4060

Training GNN model: GIN
Epoch 01, Loss: 1.6523, Test Acc: 0.1780
Epoch 11, Loss: 1.6103, Test Acc: 0.2000
Epoch 21, Loss: 1.6104, Test Acc: 0.1780
Epoch 30, Loss: 1.6105, Test Acc: 0.1960

Training GNN model: GraphSAGE
Epoch 01, Loss: 1.6174, Test Acc: 0.1780
Epoch 11, Loss: 1.6106, Test Acc: 0.1780
Epoch 21, Loss: 1.6098, Test Acc: 0.1800
Epoch 30, Loss: 1.6099, Test Acc: 0.1810

Training GNN model: GAT
E