In [5]:
import torch
from torch_geometric.data import Data, Dataset
import networkx as nx
import random

def networkx_to_pyg_data(graph, num_node_features=1):
    edge_index = torch.tensor(list(graph.edges()), dtype=torch.long).t().contiguous()
    num_nodes = graph.number_of_nodes()
    x = torch.ones(num_nodes, num_node_features)
    return Data(x=x, edge_index=edge_index)

# Function to generate a random tree
def generate_random_tree(n):
    return nx.random_tree(n)

# Generating the trees

graph_list = []
for n in range(3, 21):  # From 3 to 20 nodes
    for _ in range(1000):  # Generate 5 examples for each size
        tree = generate_random_tree(n)
        graph_list.append(tree)

In [4]:
def filter_isomorphic(graphs):
    unique_graphs = []
    for i, graph1 in enumerate(graphs):
        isomorphic = False
        for j, graph2 in enumerate(unique_graphs):
            if nx.is_isomorphic(graph1, graph2):
                isomorphic = True
                break
        if not isomorphic:
            unique_graphs.append(graph1)
    return unique_graphs

In [3]:
graph_list = filter_isomorphic(graph_list)

726.7222222222222

In [None]:
pyg_dataset = [networkx_to_pyg_data(tree) for tree in graph_list]

class AcyclicGraphDataset(Dataset):
    def __init__(self, pyg_dataset):
        super(AcyclicGraphDataset, self).__init__()
        self.pyg_dataset = pyg_dataset

    def len(self):
        return len(self.pyg_dataset)

    def get(self, idx):
        return self.pyg_dataset[idx]

# Create the dataset
tree_dataset = AcyclicGraphDataset(pyg_dataset)

torch.save(tree_dataset, 'acyclic_dataset.pt')