In [1]:
import os
import random
import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque
from typing import List, Optional, Tuple, Union
from itertools import combinations

#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Subset, Dataset, TensorDataset
from torch.utils.data import DataLoader as RegularDataLoader

# PyTorch Geometric
import torch_geometric.transforms as T
from torch_geometric.data import Data, HeteroData, Batch
from torch_geometric.loader import DataLoader 
from torch_geometric.datasets import TUDataset



# GNN Layers
from torch_geometric.nn import (
    GCNConv, GATConv, SAGEConv, GINConv, HeteroConv,
    global_mean_pool, global_max_pool, global_add_pool,
    DMoNPooling
)

# Graph Utilities
from torch_geometric.utils import (
    k_hop_subgraph, subgraph, from_networkx,
    to_networkx, to_dense_adj, to_dense_batch,
    degree
)

#sk-learn
from sklearn.linear_model import Ridge
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

# Dense Graph Utilities
from torch_geometric.nn.dense.mincut_pool import _rank3_trace

# === Visualization & Analysis ===
import networkx as nx
from sklearn.manifold import TSNE
from sklearn.model_selection import StratifiedKFold, KFold
import torchviz
from tqdm import tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
# dataset_name = 'ENZYMES'
# dataset_name = 'MUTAG'
# dataset_name = 'REDDIT-BINARY'
# dataset_name = 'IMDB-BINARY'
# dataset_name = 'MUTAG'
# dataset_name = 'PROTEINS'
# dataset_name = 'DD'
dataset_name = 'NCI1'

#downloads the dataset/loads it
dataset_path = f"data/TUDataset/{dataset_name}"
if not os.path.exists(os.path.join(dataset_path, "processed", "data.pt")):
    print("Dataset not found — downloading + processing.")
    dataset = TUDataset(root="data/TUDataset", name=dataset_name)
else:
    print("Loading cached dataset.")
    dataset = TUDataset(root="data/TUDataset", name=dataset_name)

num_features = max(dataset.num_features, 1)
num_classes = dataset.num_classes

Loading cached dataset.


# Experiment 1: Unsupervised Embedding Evaluation

This a very simple experiment - we simply run an untrained GIN, extract embeddings and use the linear evaluation protocols that many "GCL-style" papers use

In [87]:
# This is the architecture that https://github.com/sunfanyunn/InfoGraph/blob/master/unsupervised/gin.py uses (and 99% of graph-classification-based GCL methods)
class GINBlock(nn.Module):
    """
    GIN block: LazyLinear→ReLU→Linear inside GINConv, then BN+ReLU.
    """
    def __init__(self, in_features, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            # nn.LazyLinear(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.gin = GINConv(self.mlp)
        self.bn  = nn.BatchNorm1d(hidden_dim)

    def forward(self, x, edge_index):
        x = self.gin(x, edge_index)
        x = F.relu(x)
        x = self.bn(x)
        # return F.relu(x)
        return x

class Encoder(nn.Module):
    def __init__(self, num_features, dim, num_layers):
        super().__init__()
        self.blocks = nn.ModuleList()
        for i in range(num_layers):
            in_dim = num_features if i==0 else dim
            self.blocks.append(GINBlock(in_dim, dim))

    def forward(self, x, edge_index, batch):
        if x is None:
            x = torch.ones(batch.shape[0]).to(device)
            
        xs = []
        for block in self.blocks:
            x = block(x, edge_index)       # GIN → BN → ReLU
            xs.append(x)

        pooled = [global_add_pool(h, batch) for h in xs]
        return torch.cat(pooled, dim=1), x

    def get_embeddings(self, loader):

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        ret = []
        y = []
        with torch.no_grad():
            for data in loader:
                data.to(device)
                x, edge_index, batch = data.x, data.edge_index, data.batch
                if x is None:
                    x = torch.ones((batch.shape[0],1)).to(device)
                x, _ = self.forward(x, edge_index, batch)
                ret.append(x.cpu().numpy())
                y.append(data.y.cpu().numpy())
        ret = np.concatenate(ret, 0)
        y = np.concatenate(y, 0)
        return ret, y


### A barebones way of doing this: (initializes a new random GNN each iteration)

In [61]:
hidden_dim = 32
num_pooling_layers = 3
batch_size = 64

In [62]:
num_trials = 5
loader = DataLoader(dataset, batch_size=batch_size)
acc_vals = []
seeds = []
for _ in range(num_trials):

    seed = random.randint(100, 999)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    seeds.append(seed) 
    
    model = Encoder(num_features, hidden_dim, num_pooling_layers)
    model.to(device)
    model.eval()
    emb, y = model.get_embeddings(loader)
    acc, acc_val = evaluate_embedding(emb, y, device=device, search=True)
    acc_vals.append(acc_val)

In [63]:
print(f"Dataset: {dataset_name}")
print(f"Classification Accuracy: {np.mean(acc_vals) * 100:.2f} ±  {np.std(acc_vals) * 100:.2f}")
print(seeds)

Dataset: MUTAG
Classification Accuracy: 88.85 ±  1.31
[276, 481, 794, 415, 893]


## Experiment 1.1: MLP benchmark: we borrow code from Fair Evaluation from Graph classification: https://github.com/diningphil/gnn-comparison/blob/master/models/graph_classifiers/MolecularFingerprint.py

#### ONLY WORKS FOR MOLECULAR DATASET

In [93]:
class MolecularFingerprint(torch.nn.Module):

    def __init__(self, dim_features, hidden_dim):
        super(MolecularFingerprint, self).__init__()
        
        self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_features, hidden_dim), nn.ReLU(),
                                       torch.nn.Linear(hidden_dim, hidden_dim), nn.ReLU())

    def forward(self, data):
        x, batch = data.x, data.batch
        if x is None:
            x = torch.ones(batch.shape[0])
            
        return self.mlp(global_add_pool(x, batch))

In [None]:
def extract_embeddings_simple(model, loader):
    ret = []
    y = []
    
    for data in loader:
        y.append(data.y.detach().numpy())
        ret.append(model(data).detach().numpy())

    emb = np.concatenate(ret, 0)
    y = np.concatenate(y, 0)
    
    return emb, y 
    

In [97]:
batch_size = 64
num_trials = 5
loader = DataLoader(dataset, batch_size=batch_size)
acc_vals = []
seeds = []
for _ in range(num_trials):

    seed = random.randint(100, 999)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    seeds.append(seed) 
    
    model = MolecularFingerprint(num_features, hidden_dim=256)
    model.to(device)
    model.eval()
    emb, y = extract_embeddings_simple(model, loader)
    acc, acc_val = evaluate_embedding(emb, y, device=device, search=True)
    acc_vals.append(acc_val)

In [98]:
print(f"Dataset: {dataset_name}")
print(f"Classification Accuracy: {np.mean(acc_vals) * 100:.2f} ±  {np.std(acc_vals) * 100:.2f}")
print(seeds)

Dataset: NCI1
Classification Accuracy: 69.43 ±  0.10
[467, 358, 883, 669, 409]


# Experiment 2: Trivial Statistics

A very simple experiment - a powerful GNN can easily extract powerful features that correlate strongly with the feature statistics of graphs!

These features, when trained on, can achieve comparable graph classification accuracy to SOTA! 

In [11]:
stats = []
for data in dataset:
    n = data.num_nodes
    e = data.edge_index.size(1)
    avg_deg = 2 * e / n
    degs = degree(data.edge_index[0], num_nodes=n).cpu().numpy()
    # G = to_networkx(data)
    # n_comp = nx.average_clustering(G)
    hist, _ = np.histogram(degs, bins=[0,1,2,3,4,5, np.inf])

    stats.append([n, avg_deg] + hist.tolist())
    # stats.append([n, avg_deg])


In [32]:
stats1 = np.array(stats)[:, 0]

In [12]:
stats = np.array(stats)
labels_list = [data.y.item() for data in dataset]

In [13]:
def eval_trivial_baseline(stats_s, labs_s, seed):
    # the eval funciton
    random.seed(seed)
    np.random.seed(seed)
    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
    pipe = Pipeline([('scaler', StandardScaler()), ('svc', SVC())])
    param_grid = {'svc__C': [0.001,0.01,0.1,1,10,100,1000]}
    accs = []
    for tr, te in kf.split(stats_s, labs_s):
        gs = GridSearchCV(pipe, param_grid, cv=5, n_jobs=-1)
        gs.fit(stats_s[tr], [labs_s[i] for i in tr])
        accs.append(accuracy_score([labs_s[i] for i in te], gs.predict(stats_s[te])))
    return np.mean(accs)


In [17]:
stats = []
for data in dataset:
    G = to_networkx(data)  
    n = data.num_nodes
    e = data.edge_index.size(1)
    avg_deg = 2*e/n
    degs = degree(data.edge_index[0], num_nodes=n).cpu().numpy()
    hist, _ = np.histogram(degs, bins=[0,1,2,3,4,5,np.inf])

    # # extra stats:
    # clust = nx.average_clustering(G)
    # assort = nx.degree_assortativity_coefficient(G)
    stats.append([n, avg_deg] + hist.tolist())

stats = np.array(stats)
labels = np.array([d.y.item() for d in dataset])

feature_sets = {
  "node_count":        [0],
  "avg_degree":        [1],
  "deg_hist":          list(range(2, 2+len(hist))),
  # "clust+comps+assort":[2+len(hist), 2+len(hist)+1],
  "all_trivial":       list(range(stats.shape[1])),
  "random":            None  # handled specially
}

def eval_set(name, X, y, seed):
    if name=="random":
        X = np.random.RandomState(seed).randn(*X.shape)
    else:
        cols = feature_sets[name]
        X = X[:, cols]
    return eval_trivial_baseline(X, y, seed)

#ablations
results = {}
for name in feature_sets:
    accs = []
    for seed in seeds:
        random.seed(seed); np.random.seed(seed)
        idx = np.random.permutation(len(stats))
        accs.append(eval_set(name, stats[idx], labels[idx], seed))
    results[name] = (np.mean(accs), np.std(accs))

#tabulations
for name, (mean, std) in results.items():
    print(f"{name:15s} → {mean*100:.2f}% ± {std*100:.2f}%")


node_count      → 62.65% ± 0.18%
avg_degree      → 55.91% ± 0.20%
deg_hist        → 68.24% ± 0.36%
all_trivial     → 68.50% ± 0.47%
random          → 49.18% ± 1.06%


In [18]:
print(seeds)

[237, 206, 982, 445, 207]
