# GDL for main graph rewiring and metrics comparison

## Import libraries

In [19]:
import sys
import os
from pathlib import Path
print(sys.executable)
import torch
import torch_geometric
print(torch.__version__)
print(torch_geometric.__version__)

import matplotlib.pyplot as plt
from torch_geometric.datasets import TUDataset
# from GraphRicciCurvature.FormanRicci import FormanRicci

import networkx as nx
import numpy as np
from scipy.sparse.csgraph import laplacian
from scipy.linalg import pinv, eigvalsh

from utils.load_data import *
from evaluation.metrics import *
from evaluation.metrics_distance import *
from evaluation.curvature import *
from visualization.plots import *
from visualization.networkx_plot import *


/usr/local/bin/python3
2.2.2
2.7.0


# Rewiring our graphs

🚩 We need to respect the experimental details of hyperparameters describe on page 15

## Dynamic Graph Rewiring Method Code

In [None]:
from ipywidgets import widgets, Tab, VBox, Output

dataset_names = ["MUTAG"]  # Default dataset selection
REWIRING_METHOD = "BORF"   # Default rewiring method
ENTIRE_GRAPH = False       # Default entire graph selection
metrics = "Normal"         

output = Output()  



data = widgets.Dropdown(
    options=["REDDIT-BINARY", "IMDB-BINARY", "MUTAG", "ENZYMES", "PROTEINS"],
    value="MUTAG",
    description="Dataset:"
)

rewired_method = widgets.Dropdown(
    options=["BORF", "SDRF", "FOSR", "DES", "PPR", "LASER", "UNREWIRED"],
    value="BORF",
    description="Rewiring Method:"
)

entire_graph = widgets.Checkbox(
    value=False,
    description="Entire Graph"
)


metric = widgets.Dropdown(
    options=["Distance", "Normal"],
    value="Normal",
    description="Metrics:"
)


def update_dataset(change):
    global dataset_names
    dataset_names = [change["new"]]
    with output:
        output.clear_output()
        print(f"Dataset Selected: {dataset_names}")

def update_rewiring_method(change):
    global REWIRING_METHOD
    REWIRING_METHOD = change["new"]
    with output:
        output.clear_output()
        print(f"Rewiring Method Selected: {REWIRING_METHOD}")

def update_entire_graph(change):
    global ENTIRE_GRAPH
    ENTIRE_GRAPH = change["new"]
    with output:
        output.clear_output()
        print(f"Entire Graph Selected: {ENTIRE_GRAPH}")

def update_metrics(change):
    global metrics
    metrics = change["new"]
    with output:
        output.clear_output()
        print(f"Metrics Selected: {metrics}")


data.observe(update_dataset, names="value")
rewired_method.observe(update_rewiring_method, names="value")
entire_graph.observe(update_entire_graph, names="value")
metric.observe(update_metrics, names="value")


tab_contents = [
    VBox([data]),
    VBox([rewired_method]),
    VBox([entire_graph]),
    VBox([metric])
]

tab = Tab(children=tab_contents)

# Set tab titles
for i, title in enumerate(["Dataset", "Rewiring Method", "Entire Graph", "Metrics"]):
    tab.set_title(i, title)

display(tab, output)


Tab(children=(VBox(children=(Dropdown(description='Dataset:', index=2, options=('REDDIT-BINARY', 'IMDB-BINARY'…

Output()

In [25]:

print(f"Selected dataset: {dataset_names}")
print(f"Selected rewiring method: {REWIRING_METHOD}")
print(f"Selected entire graph: {ENTIRE_GRAPH}")
print(f"Selected metrics: {metrics}")

Selected dataset: ['MUTAG']
Selected rewiring method: SDRF
Selected entire graph: True
Selected metrics: Normal


In [24]:
from rewiring.rewiring_call import *
from tqdm import tqdm
import tkinter as tk
from tkinter import ttk, messagebox

# dataset_names = ["REDDIT-BINARY", "IMDB-BINARY", "MUTAG", "ENZYMES", "PROTEINS"]
#                   0              1            2         3           4
# dataset_names = dataset_names[2]

# REWIRING_METHOD = ["BORF","SDRF", "FOSR", "DES","PPR", "LASER", "UNREWIRED"]
# #                   0       1       2       3      4      5         6
# REWIRING_METHOD = REWIRING_METHOD[0]

# # flag all graph rewiring or not
# ENTIRE_GRAPH = True

dataset_loader = GraphDatasetLoader(dataset_names)
loaded_datasets = dataset_loader.get_loaded_dataset_names()

all_metrics_df = []

for dataset_name in loaded_datasets:
    for rewiring_name in [REWIRING_METHOD]:
        print("Rewiring methods being used:", rewiring_name)
        print(f"\n🚀 Processing dataset: {dataset_name}")
        
        # for testing purposes we can test on one graph
        if not ENTIRE_GRAPH:  
            # Get first graph
            graphs = dataset_loader.first_graphs[dataset_name]
            # print("first graph", type(graphs))
        else:
            graphs = dataset_loader.datasets[dataset_name]  # Load full dataset
            
        for graph in tqdm(graphs):
            #graph_copy = graphs.copy()
            
            if isinstance(graphs, torch_geometric.data.data.Data):
                rewiring_method = rewiring_call(graphs, dataset_name)
            else:
                rewiring_method = rewiring_call(graph, dataset_name)
            #rewiring_method = rewiring_call(actual_first_graph, dataset_name)
            
            if rewiring_name == "BORF":
                rewired_graph = rewiring_method.borf_rewiring()
            elif rewiring_name == "SDRF":
                rewired_graph = rewiring_method.sdrf_rewiring()
            elif rewiring_name == "FOSR":
                rewired_graph = rewiring_method.fosr_rewiring()
            elif rewiring_name == "LASER":
                rewired_graph = rewiring_method.laser_rewiring()
            elif rewiring_name == "DES":
                rewired_graph = rewiring_method.des_rewiring(dataset_loader)
            elif rewiring_name == "PPR":
                rewired_graph = rewiring_method.ppr_rewiring()
            elif rewiring_name == "UNREWIRED":
                
                G_nx = to_networkx(graph, to_undirected=True)  # Convert PyG graph to NetworkX
                print("G_nx", G_nx)
                rewired_graph = G_nx
            else:
                raise ValueError(f"Invalid rewiring method: {rewiring_name}")
            
            # Compute metrics for the rewired graph
            #DEBUG PRINT
            metrics_rewired = GraphMetrics(rewired_graph, dataset_name)
            df_metrics = metrics_rewired.get_all_metrics()
            df_metrics = pd.DataFrame([df_metrics]) 
            df_metrics["Rewiring Method"] = rewiring_name
            df_metrics["Dataset"] = dataset_name
            
            # Store and later save the metrics
            all_metrics_df.append(df_metrics)
    
# Convert results to DataFrame
final_df = pd.concat(all_metrics_df, ignore_index=True)

# Compute mean and standard deviation, excluding non-numeric columns
if ENTIRE_GRAPH:
    numeric_cols = final_df.select_dtypes(include=["number"])  
    avg_metrics = numeric_cols.mean().to_frame(name="Mean")  
    std_metrics = numeric_cols.std().to_frame(name="Std")  

    # Combine into a single DataFrame
    summary_df = pd.concat([avg_metrics, std_metrics], axis=1)

    # Format the output to display mean ± std in a single column
    summary_df["Formatted"] = summary_df.apply(lambda row: f"{row['Mean']:.6f} ± {row['Std']:.6f}", axis=1)

    # Save summary results correctly
    summary_output_csv = f"results/rewired_graph_avg_std_metrics_{dataset_name}_{REWIRING_METHOD}.csv"
    parent = Path(summary_output_csv).parent
    os.makedirs(parent, exist_ok=True)
    
    # If the path exists remove it
    if os.path.exists(summary_output_csv):
        os.remove(summary_output_csv)
    summary_df.to_csv(summary_output_csv, index=True)

    print(f"\n📂 Summary (Mean & Std) results saved to {summary_output_csv}.")
else:
    # Save individual rewiring results
    output_csv = f"results/rewired_graph_metrics_{dataset_name}.csv"

    #Make sure directory exists
    parent = Path(output_csv).parent
    os.makedirs(parent, exist_ok=True)
    
    final_df.to_csv(output_csv, index=False)
    print(f"\n📂 All rewiring results saved to {output_csv}.")

✅ Dataset MUTAG already exists. Loading from disk...
✅ Converted 188 graphs from MUTAG into NetworkX format.
Rewiring methods being used: SDRF

🚀 Processing dataset: MUTAG


  0%|          | 0/188 [00:00<?, ?it/s]

🔄 Applying SDRF on MUTAG...


  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge

✅ Rewiring complete! MUTAG now has 20 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 15 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 15 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 23 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 12 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 32 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 17 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 22 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 14 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 19 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 20 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 24 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 26 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 14 edges.
🔄 Appl

  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge_index)
  new_edge_index = torch.tensor(new_edge

✅ Rewiring complete! MUTAG now has 23 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 14 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 29 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 25 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 18 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 12 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 23 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 23 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 23 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 18 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 15 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 11 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 21 edges.
🔄 Applying SDRF on MUTAG...
✅ Rewiring complete! MUTAG now has 12 edges.
🔄 Appl


