In [1]:
import dgl
import torch

# 创建一个示例异构图 (Heterogeneous Graph)
data_dict = {
    ('protein', 'interacts_1', 'protein'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 0])),
    ('protein', '_interacts_1', 'protein'): (torch.tensor([1, 2, 0]), torch.tensor([0, 1, 2])),
    ('protein', 'interacts_0', 'go_annotation'): (torch.tensor([0, 1]), torch.tensor([0, 1])),
    ('go_annotation', '_interacts_0', 'protein'): (torch.tensor([0, 1]), torch.tensor([0, 1])),
    ('go_annotation', 'interacts_2', 'go_annotation'): (torch.tensor([0, 1]), torch.tensor([1, 0])),
    ('go_annotation', '_interacts_2', 'go_annotation'): (torch.tensor([1, 0]), torch.tensor([0, 1]))
}

g = dgl.heterograph(data_dict)

# 假设我们从 'protein' 节点集合中采样，定义 batch_nodes
batch_nodes = torch.tensor([0, 1])  # 假设从 'protein' 类型中的节点 [0, 1] 作为批次

# 定义 fanout 参数, 限制 'protein' 与 'protein' 之间的交互数量
fanout = {
    ('protein', 'interacts_0', 'go_annotation'): 0,    # 不采样
    ('go_annotation', '_interacts_0', 'protein'): 0,   # 不采样
    ('protein', 'interacts_1', 'protein'): 2,          # 每个节点最多采样2个 protein 作为邻居
    ('protein', '_interacts_1', 'protein'): 2,         # 每个节点最多采样2个 protein 作为邻居
    ('go_annotation', 'interacts_2', 'go_annotation'): 0,  # 不采样
    ('go_annotation', '_interacts_2', 'go_annotation'): 0   # 不采样
}

# 使用 DGL 的 sample_neighbors 函数进行邻居采样
protein_protein_subgraph = dgl.sampling.sample_neighbors(
    g, {'protein': batch_nodes}, fanout=fanout
)

# 输出采样结果
print("采样到的节点数:", protein_protein_subgraph.num_nodes())
print("采样到的边数:", protein_protein_subgraph.num_edges())

# 输出子图中的节点和边
print("采样到的 protein 节点:", protein_protein_subgraph.nodes('protein'))
print("采样到的边类型与索引:")
for etype in protein_protein_subgraph.canonical_etypes:
    print(f"边类型 {etype}: {protein_protein_subgraph.edges(etype=etype)}")


################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

  from .autonotebook import tqdm as notebook_tqdm


采样到的节点数: 5
采样到的边数: 4
采样到的 protein 节点: tensor([0, 1, 2])
采样到的边类型与索引:
边类型 ('go_annotation', '_interacts_0', 'protein'): (tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
边类型 ('go_annotation', '_interacts_2', 'go_annotation'): (tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
边类型 ('go_annotation', 'interacts_2', 'go_annotation'): (tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
边类型 ('protein', '_interacts_1', 'protein'): (tensor([1, 2]), tensor([0, 1]))
边类型 ('protein', 'interacts_0', 'go_annotation'): (tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
边类型 ('protein', 'interacts_1', 'protein'): (tensor([2, 0]), tensor([0, 1]))


In [2]:
print(g)
print(protein_protein_subgraph)

Graph(num_nodes={'go_annotation': 2, 'protein': 3},
      num_edges={('go_annotation', '_interacts_0', 'protein'): 2, ('go_annotation', '_interacts_2', 'go_annotation'): 2, ('go_annotation', 'interacts_2', 'go_annotation'): 2, ('protein', '_interacts_1', 'protein'): 3, ('protein', 'interacts_0', 'go_annotation'): 2, ('protein', 'interacts_1', 'protein'): 3},
      metagraph=[('go_annotation', 'protein', '_interacts_0'), ('go_annotation', 'go_annotation', '_interacts_2'), ('go_annotation', 'go_annotation', 'interacts_2'), ('protein', 'protein', '_interacts_1'), ('protein', 'protein', 'interacts_1'), ('protein', 'go_annotation', 'interacts_0')])
