In [1]:
import dgl

g_path = '/root/autodl-tmp/source/pprogo-flg/data/bp/graph.dgl'
g, _ = dgl.load_graphs(g_path)
g = g[0]

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from dgl.dataloading import Sampler, DataLoader
import time

def get_edge_subgraph(g, sample_subgraph):
    all_edges = {}
    for etype in sample_subgraph.canonical_etypes:
        src, dst = sample_subgraph.edges(etype=etype)
        all_edges[etype] = (src, dst)

    # 合并所有边的源节点和目标节点，并为每种类型去重
    subgraph_nodes = {}
    for etype in all_edges:
        src, dst = all_edges[etype]
        subgraph_nodes[etype] = torch.cat([src, dst]).unique()

    # 创建一个新的子图，只包含采样边及其连接的节点
    subgraph = dgl.edge_subgraph(g, subgraph_nodes)
    
    return subgraph
    
class ProteinGOSampler(Sampler):
    def __init__(self, g, num_protein_protein, num_protein_go, num_go_go):
        super().__init__()
        self.g = g
        self.num_protein_protein = num_protein_protein
        self.num_protein_go = num_protein_go
        self.num_go_go = num_go_go

    def sample(self, g, batch_nodes):
        time_0 = time.time()
        batch_nodes = batch_nodes['protein']
        # print('batch_nodes:{}'.format(batch_nodes))
        # 采样蛋白质-蛋白质关系
        protein_protein = dgl.sampling.sample_neighbors(
            self.g, {'protein':batch_nodes}, 
            fanout={('protein', 'interacts_0', 'go_annotation'): 0,
                ('go_annotation', '_interacts_0', 'protein'): 0,
                ('protein', 'interacts_1', 'protein'): self.num_protein_protein,
                ('protein', '_interacts_1', 'protein'): self.num_protein_protein,
                ('go_annotation', 'interacts_2', 'go_annotation'): 0,
                ('go_annotation', '_interacts_2', 'go_annotation'): 0},
        )
        protein_protein_subgraph = get_edge_subgraph(g, protein_protein)
        
        # 采样蛋白质-GO标签关系
        protein_go = dgl.sampling.sample_neighbors(
            self.g, {'protein':batch_nodes}, 
            fanout={('protein', 'interacts_0', 'go_annotation'): self.num_protein_go,
                ('go_annotation', '_interacts_0', 'protein'): 0,
                ('protein', 'interacts_1', 'protein'): 0,
                ('protein', '_interacts_1', 'protein'): 0,
                ('go_annotation', 'interacts_2', 'go_annotation'): 0,
                ('go_annotation', '_interacts_2', 'go_annotation'): 0},
            edge_dir='out'
        )
        protein_go_subgraph = get_edge_subgraph(g, protein_go)
        
        
        # 采样GO-GO关系
        # 首先找到所有与蛋白质节点相关的GO节点
        go_nodes = g.get_ntype_id('go_annotation')
        # print(go_nodes)
        go_go = dgl.sampling.sample_neighbors(
            self.g, {'go_annotation':go_nodes}, 
            fanout={('protein', 'interacts_0', 'go_annotation'): 0,
                ('go_annotation', '_interacts_0', 'protein'): 0,
                ('protein', 'interacts_1', 'protein'): 0,
                ('protein', '_interacts_1', 'protein'): 0,
                ('go_annotation', 'interacts_2', 'go_annotation'): self.num_go_go,
                ('go_annotation', '_interacts_2', 'go_annotation'): self.num_go_go}
        )
        go_go_subgraph = get_edge_subgraph(g, go_go)
        # print('protein_protein:{}'.format(protein_protein_subgraph))
        # print('protein_go:{}'.format(protein_go_subgraph))
        # print('go_go:{}'.format(go_go_subgraph))
        # print('sample_time:{}'.format(time.time()-time_0))
        # 合并所有采样的子图
        subg = dgl.merge([protein_protein_subgraph, protein_go_subgraph, go_go_subgraph])
        return batch_nodes, subg

In [3]:
# 定义每种边类型的采样量
protein_protein_sample_size = 5  # 采样5个蛋白质-蛋白质的邻居
protein_go_sample_size = 10       # 采样3个蛋白质-GO关系
go_go_sample_size = 2            # 采样2个GO-GO的邻居

# 创建自定义采样器
sampler = ProteinGOSampler(g, protein_protein_sample_size, protein_go_sample_size, go_go_sample_size)

# 创建 DataLoader，使用 Mini-Batch Sampling
train_loader = DataLoader(
    g,  # 输入异构图
    {'protein': torch.arange(g.num_nodes('protein'))},  # 从蛋白质节点开始采样
    sampler,  # 自定义的采样器
    batch_size=8,  # 每批采样8个蛋白质节点
    shuffle=True,  # 打乱采样顺序
    drop_last=False,  # 不丢弃最后不满 batch 的数据
)

In [4]:
for seeds, subgraph in train_loader:
    # 输出子图的节点和边的信息
    print("子图的节点类型:", subgraph.ntypes)
    print("子图的边类型:", subgraph.etypes)

    # 输出子图的节点数量
    print("子图中 protein 类型的节点数量:", subgraph.num_nodes('protein'))  # 输出: 1
    print("子图中 go_annotation 类型的节点数量:", subgraph.num_nodes('go_annotation'))



子图的节点类型: ['go_annotation', 'protein']
子图的边类型: ['_interacts_0', '_interacts_2', 'interacts_2', '_interacts_1', 'interacts_0', 'interacts_1']
子图中 protein 类型的节点数量: 163
子图中 go_annotation 类型的节点数量: 33
子图的节点类型: ['go_annotation', 'protein']
子图的边类型: ['_interacts_0', '_interacts_2', 'interacts_2', '_interacts_1', 'interacts_0', 'interacts_1']
子图中 protein 类型的节点数量: 160
子图中 go_annotation 类型的节点数量: 73
子图的节点类型: ['go_annotation', 'protein']
子图的边类型: ['_interacts_0', '_interacts_2', 'interacts_2', '_interacts_1', 'interacts_0', 'interacts_1']
子图中 protein 类型的节点数量: 166
子图中 go_annotation 类型的节点数量: 54
子图的节点类型: ['go_annotation', 'protein']
子图的边类型: ['_interacts_0', '_interacts_2', 'interacts_2', '_interacts_1', 'interacts_0', 'interacts_1']
子图中 protein 类型的节点数量: 161
子图中 go_annotation 类型的节点数量: 43
子图的节点类型: ['go_annotation', 'protein']
子图的边类型: ['_interacts_0', '_interacts_2', 'interacts_2', '_interacts_1', 'interacts_0', 'interacts_1']
子图中 protein 类型的节点数量: 164
子图中 go_annotation 类型的节点数量: 21
子图的节点类型: ['go_annotation'

KeyboardInterrupt: 

In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader

class RandomWalkSampler(dgl.dataloading.BlockSampler):
    def __init__(self, walk_length, num_traces, metapath, return_eids=False):
        super().__init__()
        self.walk_length = walk_length
        self.num_traces = num_traces
        self.metapath = metapath
        self.return_eids = return_eids

    def sample(self, g, seed_nodes):
        traces_dict = {}
        sampled_nodes = {ntype: set() for ntype in g.ntypes}
        
        # 对每种节点类型进行随机游走
        for ntype, seeds in seed_nodes.items():
            print('seeds:{}'.format(seeds))
            traces, _ = dgl.sampling.random_walk(g, seeds, metapath=self.metapath, length=self.walk_length)
            traces_dict[ntype] = traces
            
            # 将采样到的节点添加到 sampled_nodes 字典中
            for trace in traces:
                sampled_nodes[ntype].update(trace.numpy())
        
        # 将每种类型的节点 ID 列表转换为张量，确保类型为 torch.int64
        for ntype in sampled_nodes:
            sampled_nodes[ntype] = torch.tensor(list(sampled_nodes[ntype]), dtype=torch.int64)
        
        # 构建子图
        subgraph = g.subgraph(sampled_nodes)
        
        if self.return_eids:
            return subgraph, sampled_nodes, traces_dict
        return sampled_nodes, subgraph

In [5]:
metapath_1 = [('protein', 'interacts_1', 'protein'), ('protein', '_interacts_1', 'protein')]
metapath_2 = [('protein', 'interacts_0', 'go_annotation'), ('go_annotation', '_interacts_0', 'protein')]
metapath_3 = [('protein', 'interacts_0', 'go_annotation'), ('go_annotation', 'interacts_2', 'go_annotation'), ('go_annotation', '_interacts_0', 'protein')]

metapath_full = [
    ('protein', 'interacts_0', 'go_annotation'), 
    ('go_annotation', 'interacts_2', 'go_annotation'), 
    ('go_annotation', '_interacts_0', 'protein'), 
    ('protein', 'interacts_1', 'protein')
]

# 实例化 RandomWalkSampler，假设 walk_length=2, num_traces=3
sampler = RandomWalkSampler(walk_length=10, num_traces=1, metapath=metapath_3)


# 将 sampler 集成到 DataLoader 中
dataloader = dgl.dataloading.DataLoader(
    g, {'protein': torch.arange(g.number_of_nodes('protein'))}, sampler,
    batch_size=2, shuffle=True, drop_last=False
)

In [25]:
train_loader = dgl.dataloading.DataLoader(
                    g, {'protein': torch.arange(g.number_of_nodes('protein'))}, sampler,
                    batch_size=8, shuffle=True)

for subgraph, seeds in train_loader:
    # 输出子图的节点和边的信息
    print("子图的节点类型:", subgraph.ntypes)
    print("子图的边类型:", subgraph.etypes)

    # 输出子图的节点数量
    print("子图中 protein 类型的节点数量:", subgraph.num_nodes('protein'))  # 输出: 1
    print("子图中 go_annotation 类型的节点数量:", subgraph.num_nodes('go_annotation'))



RuntimeError: Could not infer dtype of dict

In [8]:
import dgl.dataloading as dataloading
import torch
import time
from dgl.sampling import RandomWalkNeighborSampler
# sampler = dataloading.MultiLayerNeighborSampler([10, 10])  # 每层采样 10 个邻居
meta_paths_dict = [('protein', 'interacts_0', 'go_annotation'), ('go_annotation', 'interacts_2', 'go_annotation'), ('go_annotation', '_interacts_0', 'protein')]
sampler = RandomWalkNeighborSampler(g,
                  num_traversals=3,  # 每次随机游走的步数
                  termination_prob=0.3,  # 每步随机游走的提前终止概率
                  num_random_walks=3,  # 每个节点的随机游走次数
                  num_neighbors=3,  # 每个节点采样的邻居数量
                  metapath=meta_paths_dict)
# dataloader = dataloading.DataLoader(
#     g, 
#     {'protein': torch.arange(g.number_of_nodes('protein'))}, 
#     sampler,
#     batch_size=16, 
#     shuffle=True, 
#     num_workers=4
# )

# for input_nodes, output_nodes, blocks in dataloader:
#     print(input_nodes)
#     print(output_nodes)
#     print(blocks)  # 每个 batch 返回的采样块(block)信息
#     time.sleep(100)

print(sampler)
    

<dgl.sampling.pinsage.RandomWalkNeighborSampler object at 0x7efd18672250>


In [None]:
num_protein = g.num_nodes('protein')
num_go = g.num_nodes('go_annotation')

node_dict = {
    'protein':torch.randint(0, num_protein, (48,)),
    'go_annotation':torch.randint(0, num_go, (16,))
}
start_time = time.time()
subgraph = dgl.node_subgraph(g, node_dict)

print('时间:', time.time() - start_time)
# 输出子图的节点和边的信息
print("子图的节点类型:", subgraph.ntypes)
print("子图的边类型:", subgraph.etypes)

# 输出子图的节点数量
print("子图中 protein 类型的节点数量:", subgraph.num_nodes('protein'))  # 输出: 1
print("子图中 go_annotation 类型的节点数量:", subgraph.num_nodes('go_annotation'))

时间: 0.010119438171386719
子图的节点类型: ['go_annotation', 'protein']
子图的边类型: ['_interacts_0', '_interacts_2', 'interacts_2', '_interacts_1', 'interacts_0', 'interacts_1']
子图中 protein 类型的节点数量: 48
子图中 go_annotation 类型的节点数量: 16


In [None]:
for subgraph_node , seeds , _ in dataloader:
    subgraph = dgl.node_subgraph(g, subgraph_node)

    # 输出子图的节点和边的信息
    print("子图的节点类型:", subgraph.ntypes)
    print("子图的边类型:", subgraph.etypes)

    # 输出子图的节点数量
    print("子图中 protein 类型的节点数量:", subgraph.num_nodes('protein'))  # 输出: 1
    print("子图中 go_annotation 类型的节点数量:", subgraph.num_nodes('go_annotation'))

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/root/miniconda3/envs/pytorch2.4/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
    fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/pytorch2.4/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/pytorch2.4/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
                        ^^^^^^^^^^^^^
  File "/root/miniconda3/envs/pytorch2.4/lib/python3.11/site-packages/dgl/dataloading/dataloader.py", line 226, in __iter__
    indices = _divide_by_worker(
              ^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/pytorch2.4/lib/python3.11/site-packages/dgl/dataloading/dataloader.py", line 177, in _divide_by_worker
    num_samples + (0 if drop_last else batch_size - 1)
                                       ~~~~~~~~~~~^~~
TypeError: unsupported operand type(s) for -: 'dict' and 'int'
