In [37]:
import torch
from torch_geometric.data import Data, DataLoader
import networkx as nx
import numpy as np
import random

In [38]:
def generate_random_dag(n_nodes, edge_prob):
    G = nx.DiGraph()
    nodes = list(range(n_nodes))
    random.shuffle(nodes)
    
    for i in range(1, n_nodes):
        parent = random.randint(0, i-1)
        G.add_edge(nodes[parent], nodes[i])
    
    additional_edges = int(edge_prob * n_nodes * (n_nodes - 1) / 2)
    
    while additional_edges > 0:
        u, v = random.sample(nodes, 2)
        if u != v and not G.has_edge(u, v):
            G.add_edge(u, v)
            if nx.is_directed_acyclic_graph(G):
                additional_edges -= 1
            else:
                G.remove_edge(u, v)
    
    edge_index = torch.tensor(list(G.edges), dtype=torch.int).t().contiguous()
    x = torch.zeros(n_nodes, 1)
    return Data(x=x, edge_index=edge_index)

In [39]:
# 图算法
def generate_paths_and_expansion(G, nodes):
    paths = set(nodes)
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            try:
                if nodes[i] in G and nodes[j] in G:
                    try:
                        path = nx.shortest_path(G, source=nodes[i], target=nodes[j])
                    except nx.NetworkXNoPath:
                        path = nx.shortest_path(G, source=nodes[j], target=nodes[i])
                    paths.update(path)
            except nx.NetworkXNoPath:
                paths.update([nodes[i], nodes[j]])

    # 去重并扩展
    expanded_nodes = set(paths)
    
    for node in paths:
        neighbors = nx.single_source_shortest_path_length(G, node, cutoff=2).keys()
        expanded_nodes.update(neighbors)

    subgraph = G.subgraph(expanded_nodes).copy()
    edge_index = torch.tensor(list(subgraph.edges), dtype=torch.long).t().contiguous()
    if edge_index.numel() == 0:  # 如果没有边，返回一个空的边列表
        edge_index = torch.empty((2, 0), dtype=torch.long)
    return edge_index

In [40]:
# 数据集生成
def generate_dataset(n_samples, min_nodes, max_nodes, edge_prob):
    dataset = []
    index = 0
    for _ in range(n_samples):
        print(index)
        n_nodes = random.randint(min_nodes, max_nodes)
        G_data = generate_random_dag(n_nodes, edge_prob)
        edgelist = [(row[0], row[1]) for row in G_data.edge_index.numpy().T]
        G = nx.DiGraph(edgelist)
        n_points = random.randint(1, 20)
        points = random.sample(range(n_nodes), n_points)
        G_data.x[points]=1
        subgraph_edge_index = generate_paths_and_expansion(G, points)
        
        y = torch.zeros(G_data.edge_index.size(1), dtype=torch.float)
        for edge in subgraph_edge_index.t().tolist():
            for i, orig_edge in enumerate(G_data.edge_index.t().tolist()):
                if edge == orig_edge:
                    y[i] = 1.0
        G_data.y = y
        print(G_data)
        dataset.append(G_data)
        index = index+1
    return dataset

In [43]:
n_samples=1000
min_nodes=400
max_nodes=500
edge_prob=0.05

In [44]:
G_data = generate_random_dag(5,0.3)

In [45]:
print(G_data['edge_index'])
edgelist = [(row[0], row[1]) for row in G_data.edge_index.numpy().T]
print(edgelist)

tensor([[4, 4, 2, 2, 2, 1, 0],
        [2, 1, 3, 1, 0, 0, 3]], dtype=torch.int32)
[(4, 2), (4, 1), (2, 3), (2, 1), (2, 0), (1, 0), (0, 3)]


In [None]:
# 生成数据集
dataset = generate_dataset(n_samples,min_nodes,max_nodes, edge_prob)
print("dataset:",dataset)

In [51]:
torch.save(dataset, './dataset1/train_dataset')