In [None]:
import torch
from torch_geometric.data import Data, Dataset
import networkx as nx
import random

def networkx_to_pyg_data(graph, num_node_features=1):
    edge_index = torch.tensor(list(graph.edges()), dtype=torch.long).t().contiguous()
    num_nodes = graph.number_of_nodes()
    x = torch.ones(num_nodes, num_node_features)
    return Data(x=x, edge_index=edge_index)

# Function to generate a random tree
def generate_random_tree(n):
    return nx.random_tree(n)

# Generating the trees
pyg_dataset = []
for n in range(3, 21):  # From 3 to 20 nodes
    for _ in range(5):  # Generate 5 examples for each size
        tree = generate_random_tree(n)
        pyg_data = networkx_to_pyg_data(tree)
        pyg_dataset.append(pyg_data)

class AcyclicGraphDataset(Dataset):
    def __init__(self, pyg_dataset):
        super(AcyclicGraphDataset, self).__init__()
        self.pyg_dataset = pyg_dataset

    def len(self):
        return len(self.pyg_dataset)

    def get(self, idx):
        return self.pyg_dataset[idx]

# Create the dataset
tree_dataset = AcyclicGraphDataset(pyg_dataset)

torch.save(tree_dataset, 'acyclic_dataset.pt')