In [1]:
import os
import pickle
import networkx as nx
from tqdm import tqdm
import torch
import numpy as np

from torch_geometric.datasets import TUDataset
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, from_networkx, to_dense_adj
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_shortest_path_graph(num_nodes: int, topology: str = "complete") -> Data:
    assert num_nodes > 0
    assert topology in ["complete"], "Error: unknown topology"  # Extend this list for other topologies
    
    # Create a networkx graph with the desired topology
    if topology == "complete":
        raw_graph = create_complete_graph(num_nodes)
        
    # Randomly select two nodes to be relevant
    relevant_nodes = np.random.choice(raw_graph.nodes(), 2, replace=False)
    
    # Add features to nodes: 1 for relevant nodes, 0 for others
    for node in raw_graph.nodes():
        raw_graph.nodes[node]['feature'] = 1 if node in relevant_nodes else 0

    # Convert the NetworkX graph to PyTorch Geometric's Data format
    attributed_graph = from_networkx(raw_graph)
    
    # Calculate the shortest path distance between the two relevant nodes
    shortest_path_length = nx.shortest_path_length(raw_graph, source=relevant_nodes[0], target=relevant_nodes[1])
    
    # Add the distance as the graph label
    attributed_graph.y = shortest_path_length
    
    return attributed_graph

In [4]:
# shortest path task on complete graphs

random_integers = np.random.randint(5, 51, size=1000)
complete_graphs = [generate_shortest_path_graph(num_nodes=nodes) for nodes in random_integers]

In [5]:
file_path = "synthetic_data/shortest_path_task/complete_graphs.pkl"

with open(file_path, 'wb') as f:
    pickle.dump(complete_graphs, f)

In [3]:
# topologies

def create_complete_graph(num_nodes: int) -> nx.graph:
    complete_graph = nx.complete_graph(num_nodes).to_undirected()
    return complete_graph