使用步骤1里处理好的节点的ID，来构建DGL的graph所需要的边列表。

In [1]:
import pandas as pd
import numpy as np
import os

import dgl

Using backend: pytorch


In [2]:
# path
base_path = './final_data'
publish_path = ''

link_p1_path = os.path.join(base_path, publish_path, 'link_phase1.csv')
link_p2_path = os.path.join(base_path, publish_path, 'link_phase2.csv')
nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')

### 读取节点列表

In [3]:
nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})
print(nodes_df.shape)
nodes_df.tail(4)

(5346177, 4)


Unnamed: 0,node_idx,paper_id,Label,Split_ID
5346173,5346173,1b8ab3d079dca59f31b846fd79e5ebb5,,1
5346174,5346174,38684c9ad0cbb959bbfd66c12938b227,,1
5346175,5346175,613fbc81d975a8d604ad71c48036b02e,,1
5346176,5346176,f58fbe42664299820e3b3b50b9a5983f,,1


### 读取边列表

In [4]:
# edge_df1 = pd.read_csv(link_p1_path)
# edge_df2 = pd.read_csv(link_p2_path)
edges_df = pd.read_csv(link_p2_path)
print(edges_df.shape)
edges_df.head()

(48159841, 3)


Unnamed: 0,paper_id,reference_paper_id,phase
0,65679d473736e72a8f12984d4c42e955,05f085e5ba961ce715912b6cb4836e96,phase2
1,c90b0164d72bd4f00b4c5efb7c216dc5,33c1a5ee63412996aed098f0a8f01c53,phase2
2,d5945e460d2d63118c6b601c7330e649,50603e852362bde38fa27ea2c9b7039b,phase2
3,f10da75ad1eaf16eb2ffe0d85b76b332,711ef25bdb2c2421c0131af77b3ede1d,phase1
4,9ac5a4327bd4f3dcb424c93ca9b84087,2d91c73304c5e8a94a0e5b4956093f71,phase1


## Join点列表和边列表以生成从0开始的边列表

DGL默认节点是从0开始，并以最大的ID为容量构建Graph，因此这里我们先构建从0开始的边列表。

In [5]:
# Merge paper_id列
edges = edges_df.merge(nodes_df, on='paper_id', how='left')
# Merge reference_paper_id列
edges = edges.merge(nodes_df, left_on='reference_paper_id', right_on='paper_id', how='left')
print(edges.shape)
edges.head(4)

(48159841, 10)


Unnamed: 0,paper_id_x,reference_paper_id,phase,node_idx_x,Label_x,Split_ID_x,node_idx_y,paper_id_y,Label_y,Split_ID_y
0,65679d473736e72a8f12984d4c42e955,05f085e5ba961ce715912b6cb4836e96,phase2,3926192,,1,985918,05f085e5ba961ce715912b6cb4836e96,,0
1,c90b0164d72bd4f00b4c5efb7c216dc5,33c1a5ee63412996aed098f0a8f01c53,phase2,3790465,,1,1188633,33c1a5ee63412996aed098f0a8f01c53,,0
2,d5945e460d2d63118c6b601c7330e649,50603e852362bde38fa27ea2c9b7039b,phase2,4780332,,1,3503589,50603e852362bde38fa27ea2c9b7039b,,0
3,f10da75ad1eaf16eb2ffe0d85b76b332,711ef25bdb2c2421c0131af77b3ede1d,phase1,529879,,0,2364950,711ef25bdb2c2421c0131af77b3ede1d,,0


#### 修改node_idx_* 列的名称作为新的node id，并只保留需要的列

In [6]:
edges.rename(columns={'paper_id_x': 'paper_id', 'node_idx_x':'src_nid', 'node_idx_y':'dst_nid'}, inplace=True)
edges = edges[['src_nid', 'dst_nid', 'paper_id', 'reference_paper_id']]
edges.head(4)

Unnamed: 0,src_nid,dst_nid,paper_id,reference_paper_id
0,3926192,985918,65679d473736e72a8f12984d4c42e955,05f085e5ba961ce715912b6cb4836e96
1,3790465,1188633,c90b0164d72bd4f00b4c5efb7c216dc5,33c1a5ee63412996aed098f0a8f01c53
2,4780332,3503589,d5945e460d2d63118c6b601c7330e649,50603e852362bde38fa27ea2c9b7039b
3,529879,2364950,f10da75ad1eaf16eb2ffe0d85b76b332,711ef25bdb2c2421c0131af77b3ede1d


## 构建DGL的Graph

In [7]:
# 讲源节点和目标节点转换成Numpy的NDArray
src_nid = edges.src_nid.to_numpy()
dst_nid = edges.dst_nid.to_numpy()

In [8]:
# 构建一个DGL的graph
graph = dgl.graph((src_nid, dst_nid), num_nodes=5346177)
print(graph)

Graph(num_nodes=5346177, num_edges=48159841,
      ndata_schemes={}
      edata_schemes={})


In [9]:
# 保存Graph为二进制格式方便后面建模时的快速读取
graph_path = os.path.join(base_path, publish_path, 'graph.bin')
dgl.data.utils.save_graphs(graph_path, [graph])



In [10]:
# 构建异构图
graph_data = {
   ('id', 'cite', 'id'): (src_nid, dst_nid),
    ('id', 'cite_by', 'id'): (dst_nid, src_nid),
    ('id', 'self_loop', 'id'): (nodes_df.node_idx, nodes_df.node_idx),
}
g = dgl.heterograph(graph_data)
graph_path = os.path.join(base_path, publish_path, 'heteo_graph.bin')
dgl.data.utils.save_graphs(graph_path, [g])