-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
graph_reindex.py
127 lines (110 loc) · 5.52 KB
/
graph_reindex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid import core
from paddle import _C_ops
def graph_reindex(x,
neighbors,
count,
value_buffer=None,
index_buffer=None,
flag_buffer_hashtable=False,
name=None):
"""
Graph Reindex API.
This API is mainly used in Graph Learning domain, which should be used
in conjunction with `graph_sample_neighbors` API. And the main purpose
is to reindex the ids information of the input nodes, and return the
corresponding graph edges after reindex.
Take input nodes x = [0, 1, 2] as an example.
If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
then we know that the neighbors of 0 is [8, 9], the neighbors of 1
is [0, 4, 7], and the neighbors of 2 is [6, 7].
Args:
x (Tensor): The input nodes which we sample neighbors for. The available
data type is int32, int64.
neighbors (Tensor): The neighbors of the input nodes `x`. The data type
should be the same with `x`.
count (Tensor): The neighbor count of the input nodes `x`. And the
data type should be int32.
value_buffer (Tensor|None): Value buffer for hashtable. The data type should
be int32, and should be filled with -1.
index_buffer (Tensor|None): Index buffer for hashtable. The data type should
be int32, and should be filled with -1.
flag_buffer_hashtable (bool): Whether to use buffer for hashtable to speed up.
Default is False. Only useful for gpu version currently.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
reindex_src (Tensor): The source node index of graph edges after reindex.
reindex_dst (Tensor): The destination node index of graph edges after reindex.
out_nodes (Tensor): The index of unique input nodes and neighbors before reindex,
where we put the input nodes `x` in the front, and put neighbor
nodes in the back.
Examples:
.. code-block:: python
import paddle
x = [0, 1, 2]
neighbors = [8, 9, 0, 4, 7, 6, 7]
count = [2, 3, 2]
x = paddle.to_tensor(x, dtype="int64")
neighbors = paddle.to_tensor(neighbors, dtype="int64")
count = paddle.to_tensor(count, dtype="int32")
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count)
# reindex_src: [3, 4, 0, 5, 6, 7, 6]
# reindex_dst: [0, 0, 1, 1, 1, 2, 2]
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]
"""
if flag_buffer_hashtable:
if value_buffer is None or index_buffer is None:
raise ValueError(f"`value_buffer` and `index_buffer` should not"
"be None if `flag_buffer_hashtable` is True.")
if _non_static_mode():
reindex_src, reindex_dst, out_nodes = \
_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer,
"flag_buffer_hashtable", flag_buffer_hashtable)
return reindex_src, reindex_dst, out_nodes
check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex")
check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"),
"graph_reindex")
check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex")
if flag_buffer_hashtable:
check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"),
"graph_reindex")
check_variable_and_dtype(index_buffer, "HashTable_Value", ("int32"),
"graph_reindex")
helper = LayerHelper("graph_reindex", **locals())
reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="graph_reindex",
inputs={
"X": x,
"Neighbors": neighbors,
"Count": count,
"HashTable_Value": value_buffer if flag_buffer_hashtable else None,
"HashTable_Index": index_buffer if flag_buffer_hashtable else None,
},
outputs={
"Reindex_Src": reindex_src,
"Reindex_Dst": reindex_dst,
"Out_Nodes": out_nodes
},
attrs={"flag_buffer_hashtable": flag_buffer_hashtable})
return reindex_src, reindex_dst, out_nodes