In [3]:
%%capture
!pip install torch
!pip install torch-geometric                      
!pip install ripser                               
!pip install networkx                             
!pip install scikit-learn                         
!pip install gudhi
!pip install giotto-tda
!pip install networkx

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GINConv, SAGEConv, GATConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import xgboost as xgb
from gudhi import SimplexTree
import networkx as nx
import gtda.homology as gph
from gtda.diagrams import BettiCurve
import scipy
import scipy.sparse.linalg


In [None]:
# ===== FILTRATION METHODS =====

# helper
def to_networkx(data):
    G = nx.Graph()
    edge_index = data.edge_index.numpy()
    edges = list(zip(edge_index[0], edge_index[1]))
    G.add_edges_from(edges)
    return G

def degree_filtration(data):
    deg = torch.bincount(data.edge_index[0], minlength=data.num_nodes)
    return deg.numpy()


def clustering_coeff_filtration(data):
    G = to_networkx(data)
    clust = nx.clustering(G)
    coeffs = np.array([clust.get(i, 0) for i in range(data.num_nodes)])
    return coeffs

def pagerank_filtration(data):
    G = to_networkx(data)
    pr = nx.pagerank(G)
    pr_vals = np.array([pr.get(i, 0) for i in range(data.num_nodes)])
    return pr_vals

def heat_kernel_filtration(data, t=1.0):
    G = to_networkx(data)
    L = nx.normalized_laplacian_matrix(G)
    try:
        eigvals, eigvecs = scipy.sparse.linalg.eigsh(L, k=min(100, data.num_nodes - 1), which='SM')
    except:
        L_dense = L.todense()
        eigvals, eigvecs = np.linalg.eigh(L_dense)
    diag = np.sum((eigvecs ** 2) * np.exp(-t * eigvals)[None, :], axis=1)
    return diag

In [None]:
def build_simplex_tree(data, filtration_values):
    st = SimplexTree()
    for i, fval in enumerate(filtration_values):
        st.insert([i], filtration=fval)
    edges = data.edge_index.t().numpy()
    for u, v in edges:
        fval = max(filtration_values[u], filtration_values[v])
        st.insert([u, v], filtration=fval)
    return st

def compute_betti_curve(st, dim=0, n_bins=100):
    st.compute_persistence()
    diag = np.array(st.persistence_intervals_in_dimension(dim))
    if diag.size > 0:
        max_finite = np.max(diag[np.isfinite(diag)])
        diag[np.isinf(diag)] = max_finite + 1
        dim_col = np.full((diag.shape[0], 1), fill_value=dim)
        diag_3d = np.hstack([diag, dim_col])
    else:
        # Handle empty diagrams gracefully (no intervals)
        diag_3d = np.empty((0, 3))
    bc = BettiCurve(n_bins=n_bins)
    betti_vector = bc.fit_transform([diag_3d])[0]
    return betti_vector



def compute_betti_vector(data, filtration_fn):
    filt_vals = filtration_fn(data)
    st = build_simplex_tree(data, filt_vals)
    return compute_betti_curve(st, dim=0)

def get_betti_vectors(dataset, filtration_fn):
    X, y = [], []
    for data in dataset:
        bv = compute_betti_vector(data, filtration_fn)
        bv = bv.flatten()  # ensure it's 1D
        X.append(bv)
        label = data.y.item() if data.y.dim() == 0 else data.y.numpy()
        y.append(label)
    X = np.vstack(X)  # ensures 2D shape for X
    y = np.array(y)
    return X, y

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        if num_layers > 1:
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                nn = torch.nn.Sequential(
                    torch.nn.Linear(in_channels, hidden_channels),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_channels, hidden_channels),
                )
            else:
                nn = torch.nn.Sequential(
                    torch.nn.Linear(hidden_channels, hidden_channels),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_channels, hidden_channels),
                )
            self.convs.append(GINConv(nn))
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        if num_layers > 1:
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, heads=4, dropout=0.6):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))
        if num_layers > 1:
            self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))
        self.lin = torch.nn.Linear(hidden_channels * heads, out_channels)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        x = F.elu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

In [None]:
def train_gnn(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

def test_gnn(model, loader, device):
    model.eval()
    correct = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())
    return correct / len(loader.dataset)

Methods for the following filtrations and gnns for use in testing:

In [None]:
def run_experiment(dataset_name, filtrations, gnn_models, device='cpu'):
    print(f"Running experiment on dataset: {dataset_name}")
    dataset = TUDataset(root='/tmp/'+dataset_name, name=dataset_name)
    # assign identity matrix for node features if they aren't provided (ex. for IMDB and Reddit datasets)
    if dataset.num_node_features==0:
        for data in dataset:
            data.x = torch.eye(data.num_nodes)

    # split train/test: 80/20
    dataset = dataset.shuffle()
    train_dataset = dataset[:int(0.8*len(dataset))]
    test_dataset = dataset[int(0.8*len(dataset)):]

    print(f"Dataset size: {len(dataset)}, train: {len(train_dataset)}, test: {len(test_dataset)}")

    # XGBoost on Betti vectors
    for filt_name, filt_fn in filtrations.items():
        print(f"\nComputing Betti vectors with filtration: {filt_name}")
        X_train, y_train = get_betti_vectors(train_dataset, filt_fn)
        X_test, y_test = get_betti_vectors(test_dataset, filt_fn)

        model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        print(f"XGBoost accuracy with {filt_name} filtration Betti vectors: {acc:.4f}")

    # GNN training
    for model_name,  model_fn in gnn_models.items():
        print(f"\nTraining GNN model: {model_name}")
        model = model_fn(dataset.num_node_features, 64, dataset.num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        criterion = torch.nn.CrossEntropyLoss()

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32)

        for epoch in range(30):
            loss = train_gnn(model, train_loader, optimizer, criterion, device)
            acc = test_gnn(model, test_loader, device)
            if epoch % 10 == 0 or epoch == 29:
                print(f"Epoch {epoch+1:02d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}")

In [None]:
filtrations = {
    'degree': degree_filtration,
    'clustering_coeff': clustering_coeff_filtration,
    'pagerank': pagerank_filtration,
    'heat_kernel': heat_kernel_filtration,
}

gnn_models = {
    'GCN': GCN,
    'GIN': GIN,
    'GraphSAGE': GraphSAGE,
    'GAT': GAT,
}

datasets = [
    'BZR',
    'COX2',
    'MUTAG',
    'PROTEINS',
    'IMDB-BINARY',
    'IMDB-MULTI'
    'REDDIT-BINARY',
    'REDDIT-MULTI-5K'
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for ds in datasets:
    run_experiment(ds, filtrations, gnn_models, device=device)