In [5]:
import json
import numpy as np
import torch
from torch_geometric.data import Data
import torch_geometric.datasets as datasets


In [6]:


def get_resdist(graph: Data) -> np.ndarray:
    """
    Compute the resistance distance matrix for a given graph using the formula in theorem E.3
    of 'rethinking expressiveness via biconnectivity'
    """
    n = graph.num_nodes

    if n <= 1:
        return np.zeros((n, n), dtype=float)

    A = np.zeros((n, n), dtype=float)
    edge_index = graph.edge_index
    # If tensor, move to cpu and convert to numpy
    if isinstance(edge_index, torch.Tensor):
        edge_index = edge_index.detach().cpu().numpy()

    for src, dst in edge_index.T:
        A[src, dst] = 1
        A[dst, src] = 1  # undirected

    deg = A.sum(axis=1)
    D = np.diag(deg)
    L = D - A

    ones_onesT = np.ones((n, n), dtype=float) / n
    mat_to_invert = L + ones_onesT
    M = np.linalg.inv(mat_to_invert)

    diag_M = np.diag(M)  # shape (n,)
    resdist = (diag_M[:, None] + diag_M[None, :]) - 2 * M
    return np.sort(np.round(resdist, 5), axis=1)


In [8]:
def convert_edge_index(edge_index_tensor):
    """Convert edge_index tensor to list of lists."""
    return edge_index_tensor.tolist()

def convert_features(x_tensor):
    """Convert feature tensor to list."""
    return x_tensor.tolist()

def convert_resdist(resdist_array):
    """Convert resdist numpy array to a dictionary with string keys."""
    resdist_dict = {}
    for vertex_idx, row in enumerate(resdist_array):
        resdist_dict[str(vertex_idx)] = row.tolist()
    return resdist_dict

def compute_resdist(data, output_file='_resdist.json'):
    data_dict = {}
    
    for graph_idx, graph in enumerate(data):
        graph_key = str(graph_idx)
        
        edge_index = convert_edge_index(graph.edge_index)
        features = convert_features(graph.x)
        
        resdist = get_resdist(graph)
        resdist_dict = convert_resdist(resdist)
        
        # Assemble graph data
        data_dict[graph_key] = {
            'edge_index': edge_index,
            'features': features,
            'resdist': resdist_dict
        }
        
        if (graph_idx + 1) % 500 == 0:
            print(f"Processed {graph_idx + 1} graphs")
    
    with open(output_file, 'w') as f:
        json.dump(data_dict, f)

    print("f"Data successfully written to {output_file}"")

    


In [2]:
zinc_train = datasets.ZINC(
    root="./data/",
    subset=True,
    split="train"
)
zinc_val = datasets.ZINC(
    root="./data/",
    subset=True,
    split="val"
)
zinc_test = datasets.ZINC(
    root="./data/",
    subset=True,
    split="test"
)
zinc_data = zinc_train + zinc_val + zinc_test
print(len(zinc_data))
i = 0
for graph in zinc_data:
    print(graph)
    if i >= 4:
        break
    i += 1

12000
Data(x=[29, 1], edge_index=[2, 64], edge_attr=[64], y=[1])
Data(x=[26, 1], edge_index=[2, 56], edge_attr=[56], y=[1])
Data(x=[16, 1], edge_index=[2, 34], edge_attr=[34], y=[1])
Data(x=[27, 1], edge_index=[2, 60], edge_attr=[60], y=[1])
Data(x=[21, 1], edge_index=[2, 44], edge_attr=[44], y=[1])


In [22]:
compute_resdist(zinc_data, 'zinc_resdist.json')

Processed 500 graphs
Processed 1000 graphs
Processed 1500 graphs
Processed 2000 graphs
Processed 2500 graphs
Processed 3000 graphs
Processed 3500 graphs
Processed 4000 graphs
Processed 4500 graphs
Processed 5000 graphs
Processed 5500 graphs
Processed 6000 graphs
Processed 6500 graphs
Processed 7000 graphs
Processed 7500 graphs
Processed 8000 graphs
Processed 8500 graphs
Processed 9000 graphs
Processed 9500 graphs
Processed 10000 graphs
Processed 10500 graphs
Processed 11000 graphs
Processed 11500 graphs
Processed 12000 graphs
Data successfully written to zinc_resdist.json


In [11]:
import pickle
with open('biconn_dataset.pkl', 'rb') as f:
    loaded_biconn_dataset = pickle.load(f)


In [12]:
compute_resdist(loaded_biconn_dataset, 'biconn_resdist.json')

Processed 500 graphs
Processed 1000 graphs
Data successfully written to biconn_resdist.json
