In [1]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import to_networkx
import numpy as np
import os
import random
import time
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, TopKPooling, GeneralConv
from torch_geometric.nn import MLP
from mag_edge_pool.src.make_splits import make_splits
import json
from torch.nn import PReLU
from torch_geometric.nn import global_add_pool
from mag_edge_pool.src.model import set_seed
from mag_edge_pool.mag_edge_pool import mag_edge_pool_transform

In [None]:
max_nodes = 10000
max_degree = 10000
dataset_name = "MUTAG"
metric = "diffusion_distance"
ratio = 0.5
model_name = "GeneralConv"
splits = "stratified"
early_stop_patience = 50
tolerance = 1e-6
runs=10
learning_rate = 5e-4
patience = early_stop_patience
batch_size = 32
ratio = 0.5
n_hidden = 64
seeds = [41] 
seed = 41

for method in ["MagEdgePool","SpreadEdgePool", "EdgePooling", "TopKPooling",  "SAGPooling", "NoPooling"]:
    print(f"Pooling method: {method}")
    print(f"Dataset: {dataset_name}")
    
    data_path = f"../data/{dataset_name}/{method}/"
    pooling_method = method

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    ### Load the dataset
    if dataset_name in ["DD", "COLLAB", "IMDB-BINARY", "IMDB-MULTI", "REDDIT-BINARY", "REDDIT-MULTI-5K"]:
        dataset_sparse = TUDataset(root=data_path, name=dataset_name, pre_filter=lambda data: data.num_nodes <= max_nodes, transform=T.Compose([T.OneHotDegree(max_degree)]), use_node_attr=True)
    else:
        dataset_sparse = TUDataset(root=data_path, name=dataset_name, pre_filter=lambda data: data.num_nodes <= max_nodes, use_node_attr=True)

    num_classes = dataset_sparse.num_classes
    in_channels = dataset_sparse.num_features
    num_features = dataset_sparse.num_features

    ### Determine the method for magnitude computation based on the pooling method
    if "Mag" in method:
        mag_method = "cholesky"
    elif "Spread" in method:
        mag_method = "spread"
    
    ### Define the GNN using PyTorch Geometric
    class MainModelTorch(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels, num_classes, pool):
            super(MainModelTorch, self).__init__()
            self.pre = MLP(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=2, out_channels=hidden_channels)
            self.conv1 = GeneralConv(hidden_channels, hidden_channels, aggr="add")
            self.bn = PReLU()
            self.bn2 = PReLU()
            self.conv2 = GeneralConv(hidden_channels*2, hidden_channels, aggr="add")
            
            self.post = MLP(in_channels=out_channels*3, hidden_channels=hidden_channels, out_channels=num_classes, num_layers=2)

            self.pool = pool

        def forward(self, data):
            if ("MagEdgePool" in method) or ("SpreadEdgePool" in method):
                cluster = data.cluster
            x, edge_index, batch = data.x, data.edge_index, data.batch
            x = self.pre(x)

            x = self.bn(x)

            gnn1_out = self.conv1(x, edge_index)
            x = torch.cat([gnn1_out, x], dim=-1)

            if ("MagEdgePool" in method) or ("SpreadEdgePool" in method):
                x, edge_index, batch, unpool = self.pool(x, edge_index, batch=batch, cluster=cluster)
            elif method == "TopKPooling" or method == "SAGPooling":
                x, edge_index, _, batch, _, _ = self.pool(x, edge_index, batch=batch)
            else:
                x, edge_index, batch, unpool = self.pool(x, edge_index, batch=batch)
            
            x = self.bn2(x)
            
            gnn2_out = self.conv2(x, edge_index)
            x = torch.cat([gnn2_out, x], dim=-1)
            
            x = global_add_pool(x, batch)
            
            x = self.post(x)
            
            return x

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ### Preprocess the dataset with the specified pooling method
    if ("MagEdgePool" in pooling_method) or ("SpreadEdgePool" in pooling_method):
        data_list = mag_edge_pool_transform(dataset_sparse, pooling_method, ratio, metric, mag_method)

    ### Set random seed for reproducibility and shuffle the dataset
    set_seed(seed)
    if ("MagEdgePool" in pooling_method) or ("SpreadEdgePool" in pooling_method):
        dataset_sparse, perm = dataset_sparse.shuffle(return_perm=True)
        data_list = [data_list[p] for p in perm]
    else:
        dataset_sparse = dataset_sparse.shuffle()
    

    ### Train and evaluate the model across multiple runs using different cross-validation splits
    best_val_accs = []
    best_test_accs = []

    for run in range(runs):
        num_total = len(dataset_sparse)

        labels = np.array([data.y.item() for data in dataset_sparse])

        split_list = make_splits(np.array([int(ni) for ni in range(num_total)]), labels, outer_k=10, inner_k=None, holdout_test_size=0.1, seed=seed)
        idx_tr, idx_va, idx_te = split_list[run][0], split_list[run][1], split_list[run][2]

        if ("MagEdgePool" in pooling_method) or ("SpreadEdgePool" in pooling_method):
            train_dataset = [data_list[i] for i in idx_tr]
            val_dataset = [data_list[i] for i in idx_va]
            test_dataset = [data_list[i] for i in idx_te]
        else:
            train_dataset = dataset_sparse[idx_tr]
            val_dataset = dataset_sparse[idx_va]
            test_dataset = dataset_sparse[idx_te]
            
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        if ("MagEdgePool" in pooling_method) or ("SpreadEdgePool" in pooling_method):
            from mag_edge_pool.mag_edge_pool import MagEdgePooling 
            pool = MagEdgePooling(n_hidden*2)
        elif pooling_method == "TopKPooling":
            pool = TopKPooling(n_hidden*2, ratio=0.5)
        elif pooling_method == "EdgePooling":
            from torch_geometric.nn.pool import EdgePooling
            pool = EdgePooling(n_hidden*2)
        elif pooling_method == "SAGPooling":
            from torch_geometric.nn.pool import SAGPooling
            pool = SAGPooling(n_hidden*2, ratio=0.5)
        elif pooling_method == "NoPooling":
            def NoPooling(x, edge_index, batch, cluster=None):
                return x, edge_index, batch, None
            pool = NoPooling
            ratio = 1
        else:
            raise ValueError(f"Not implemented yet: {pooling_method}")

        model = MainModelTorch(in_channels=dataset_sparse.num_features, hidden_channels=n_hidden, out_channels=n_hidden, num_classes=dataset_sparse.num_classes, pool=pool).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
        criterion = torch.nn.CrossEntropyLoss()

        def train():
            model.train()
            total_loss = 0
            for data in train_loader:
                data = data.to(device)
                optimizer.zero_grad()
                out = model(data)
                loss = F.nll_loss(out, data.y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item() * data.num_graphs
            return total_loss / len(train_loader.dataset)

        def test(loader):
            model.eval()
            correct = 0
            for data in loader:
                data = data.to(device)
                out = model(data)
                pred = out.argmax(dim=1)
                correct += (pred == data.y).sum().item()
            return correct / len(loader.dataset)

        model = MainModelTorch(in_channels=dataset_sparse.num_features, hidden_channels=n_hidden, out_channels=n_hidden, num_classes=dataset_sparse.num_classes, pool=pool).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
        
        best_val_acc = 0
        epochs_no_improve = 0
        for epoch in range(1, 201):
            loss = train()
            val_acc = test(valid_loader)
            test_acc = test(test_loader)
            if val_acc > best_val_acc + tolerance:
                best_val_acc = val_acc
                best_test_acc = test_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
            if epochs_no_improve >= early_stop_patience:
                print(f'Early stopping at epoch {epoch} for seed {seed}')
                break
        
        best_val_accs.append(best_val_acc)
        best_test_accs.append(best_test_acc)
        torch.cuda.empty_cache()
    
    print(f'Average Best Val Acc: {np.mean(best_val_accs):.4f}')
    print(f'Std Best Test Acc: {np.std(best_test_accs):.4f}')
    print(f'Average Test Acc: {np.mean(best_test_accs):.4f}')

    ### Save experiment data to a JSON file
    experiment_data = {
        "dataset": dataset_name,
        "method": method,
        "model": model_name,
        "experiment": "graph_classification",
        "runs": runs,
        "split_strategy": "stratified",
        "learning_rate": learning_rate,
        "es_patience": patience,
        "batch_size": batch_size,
        "seeds": seeds,
        "loss": "categorical_crossentropy",
        "results": {
            "accuracy": [float(r) for r in best_test_accs]
        },
        "ratio": ratio
    }

    log_dir = f"./results/{dataset_name}/"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    json_file_path = os.path.join(log_dir, f"{model_name}_{method}_{dataset_name}_{splits}_{seed}.json")

    with open(json_file_path, "w") as json_file:
        json.dump(experiment_data, json_file, indent=4)
    print(f"Experiment data saved to {json_file_path}")
    print("-----------------------------------------------------")

Pooling method: MagEdgePool
Dataset: MUTAG
Early stopping at epoch 52 for seed 41
Early stopping at epoch 89 for seed 41
Early stopping at epoch 53 for seed 41
Early stopping at epoch 53 for seed 41
Early stopping at epoch 123 for seed 41
Early stopping at epoch 54 for seed 41
Early stopping at epoch 89 for seed 41
Early stopping at epoch 52 for seed 41
Early stopping at epoch 54 for seed 41
Early stopping at epoch 53 for seed 41
Total time for 10 runs: 66.22 seconds
Average Best Val Acc: 0.9118
Std Best Test Acc: 0.0802
Average Test Acc: 0.8673
Experiment data saved to ./results/MUTAG/GeneralConv_MagEdgePool_MUTAG_stratified_41.json
-----------------------------------------------------
Pooling method: SpreadEdgePool
Dataset: MUTAG
Early stopping at epoch 52 for seed 41
Early stopping at epoch 89 for seed 41
Early stopping at epoch 53 for seed 41
Early stopping at epoch 53 for seed 41
Early stopping at epoch 123 for seed 41
Early stopping at epoch 54 for seed 41
Early stopping at epoch