From 05a2c32e0deacfa4a9d428f518adfb6a4d85136e Mon Sep 17 00:00:00 2001 From: Yuxiang Yao Date: Thu, 8 Dec 2022 09:33:06 +0800 Subject: [PATCH 1/6] hetero data and inner message passing agrregation functions --- cogdl/data/__init__.py | 4 +- cogdl/data/data.py | 31 ++++ cogdl/utils/message_aggregate_utils.py | 166 ++++++++++++++++++ ..._data_inner_message_passing_aggregation.py | 19 ++ ..._data_inner_message_passing_aggregation.py | 31 ++++ 5 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 cogdl/utils/message_aggregate_utils.py create mode 100644 tests/test_data_inner_message_passing_aggregation.py create mode 100644 tests/test_hetero_data_inner_message_passing_aggregation.py diff --git a/cogdl/data/__init__.py b/cogdl/data/__init__.py index 550ca06d..52624b18 100644 --- a/cogdl/data/__init__.py +++ b/cogdl/data/__init__.py @@ -2,5 +2,7 @@ from .batch import Batch, batch_graphs from .dataset import Dataset, MultiGraphDataset from .dataloader import DataLoader +from .hetero_data import HeteroGraph -__all__ = ["Graph", "Adjacency", "Batch", "Dataset", "DataLoader", "MultiGraphDataset", "batch_graphs"] + +__all__ = ["Graph", "Adjacency", "Batch", "Dataset", "DataLoader", "MultiGraphDataset", "batch_graphs", "HeteroGraph"] diff --git a/cogdl/data/data.py b/cogdl/data/data.py index 1614a9a0..6f90f2fb 100644 --- a/cogdl/data/data.py +++ b/cogdl/data/data.py @@ -1,6 +1,8 @@ import re import copy from contextlib import contextmanager +from inspect import isfunction + import scipy.sparse as sp import networkx as nx @@ -16,6 +18,7 @@ ) from cogdl.utils import RandomWalker from cogdl.operators.sample import sample_adj_c, subgraph_c +from cogdl.utils.message_aggregate_utils import MessageBuiltinFunction, AggregateBuiltinFunction subgraph_c = None # noqa: F811 @@ -948,3 +951,31 @@ def set_grb_adj(self, adj): # @requires_grad.setter # def requires_grad(self, x): # print(f"Set `requires_grad` to {x}") + + def message_passing(self, msg_func, x, edge_weight=None): + src_node_id = self.edge_index[0] + dst_node_id = self.edge_index[1] + if not isfunction(msg_func) and not isinstance(msg_func, str): + raise RuntimeError('Only Support Message Functions and String Prompt') + if isinstance(msg_func, str): + if not hasattr(MessageBuiltinFunction, msg_func): + raise NotImplementedError + msg_func = getattr(MessageBuiltinFunction, msg_func) + if edge_weight is None: + edge_weight = torch.ones(len(src_node_id), x.shape[1]) + m = msg_func(x, src_node_id, dst_node_id, edge_weight) + return m + + def aggregate(self, agg_func, x, m): + dst_node_id = self.edge_index[1] + out = torch.zeros(x.shape[0], m.shape[1], dtype=x.dtype).to(x.device) + index = dst_node_id.unsqueeze(1).expand(-1, m.shape[1]) + src = m + if not isfunction(agg_func) and not isinstance(agg_func, str): + raise RuntimeError('Only Support Message Functions and String Prompt') + if isinstance(agg_func, str): + if not hasattr(AggregateBuiltinFunction, agg_func): + raise NotImplementedError + agg_func = getattr(AggregateBuiltinFunction, agg_func) + h = agg_func(src, index, out=out) + return h diff --git a/cogdl/utils/message_aggregate_utils.py b/cogdl/utils/message_aggregate_utils.py new file mode 100644 index 00000000..02fabd94 --- /dev/null +++ b/cogdl/utils/message_aggregate_utils.py @@ -0,0 +1,166 @@ +import torch +from torch_scatter import scatter_max, scatter_min, scatter_add, scatter_mean + + +class MessageBuiltinFunction: + # keep original source features + @staticmethod + def copy_u(x, src_node_id, dst_node_id, edge_weight): + return x[src_node_id] + + @staticmethod + def copy_e(x, src_node_id, dst_node_id, edge_weight): + return edge_weight + + @staticmethod + def copy_v(x, src_node_id, dst_node_id, edge_weight): + return x[dst_node_id] + + # source & target + @staticmethod + def u_add_v(x, src_node_id, dst_node_id, edge_weight): + return x[src_node_id] + x[dst_node_id] + + @staticmethod + def u_sub_v(x, src_node_id, dst_node_id, edge_weight): + return x[src_node_id] - x[dst_node_id] + + @staticmethod + def u_mul_v(x, src_node_id, dst_node_id, edge_weight): + return torch.mul(x[src_node_id], x[dst_node_id]) + + @staticmethod + def u_div_v(x, src_node_id, dst_node_id, edge_weight): + return torch.div(x[src_node_id], x[dst_node_id]) + + # source & edge weight + @staticmethod + def u_add_e(x, src_node_id, dst_node_id, edge_weight): + return x[src_node_id] + edge_weight + + @staticmethod + def u_sub_e(x, src_node_id, dst_node_id, edge_weight): + return x[src_node_id] - edge_weight + + @staticmethod + def u_mul_e(x, src_node_id, dst_node_id, edge_weight): + return torch.mul(x[src_node_id], edge_weight) + + @staticmethod + def u_div_e(x, src_node_id, dst_node_id, edge_weight): + return torch.div(x[src_node_id], edge_weight) + + # target & source + @staticmethod + def v_add_u(x, src_node_id, dst_node_id, edge_weight): + return x[dst_node_id] + x[src_node_id] + + @staticmethod + def v_sub_u(x, src_node_id, dst_node_id, edge_weight): + return x[dst_node_id] - x[src_node_id] + + @staticmethod + def v_mul_u(x, src_node_id, dst_node_id, edge_weight): + return torch.mul(x[dst_node_id], x[src_node_id]) + + @staticmethod + def v_div_u(x, src_node_id, dst_node_id, edge_weight): + return torch.div(x[dst_node_id], x[src_node_id]) + + # target & edge weight + @staticmethod + def v_add_e(x, src_node_id, dst_node_id, edge_weight): + return x[dst_node_id] + edge_weight + + @staticmethod + def v_sub_e(x, src_node_id, dst_node_id, edge_weight): + return x[dst_node_id] - edge_weight + + @staticmethod + def v_mul_e(x, src_node_id, dst_node_id, edge_weight): + return torch.mul(x[dst_node_id], edge_weight) + + @staticmethod + def v_div_e(x, src_node_id, dst_node_id, edge_weight): + return torch.div(x[dst_node_id], edge_weight) + + # edge weight & source + @staticmethod + def e_add_u(x, src_node_id, dst_node_id, edge_weight): + return edge_weight + x[src_node_id] + + @staticmethod + def e_sub_u(x, src_node_id, dst_node_id, edge_weight): + return edge_weight - x[src_node_id] + + @staticmethod + def e_mul_u(x, src_node_id, dst_node_id, edge_weight): + return torch.mul(edge_weight, x[src_node_id]) + + @staticmethod + def e_div_u(x, src_node_id, dst_node_id, edge_weight): + return torch.div(edge_weight, x[src_node_id]) + + # edge weight & target + @staticmethod + def e_add_v(x, src_node_id, dst_node_id, edge_weight): + return edge_weight + x[dst_node_id] + + @staticmethod + def e_sub_v(x, src_node_id, dst_node_id, edge_weight): + return edge_weight - x[dst_node_id] + + @staticmethod + def e_mul_v(x, src_node_id, dst_node_id, edge_weight): + return torch.mul(edge_weight, x[dst_node_id]) + + @staticmethod + def e_div_v(x, src_node_id, dst_node_id, edge_weight): + return torch.div(edge_weight, x[dst_node_id]) + + # dot manipulation + @staticmethod + def u_dot_v(x, src_node_id, dst_node_id, edge_weight): + return torch.mm(x[src_node_id], x[dst_node_id]) + + @staticmethod + def u_dot_e(x, src_node_id, dst_node_id, edge_weight): + return torch.mm(x[src_node_id], edge_weight) + + @staticmethod + def v_dot_e(x, src_node_id, dst_node_id, edge_weight): + return torch.mm(x[dst_node_id], edge_weight) + + @staticmethod + def v_fot_u(x, src_node_id, dst_node_id, edge_weight): + return torch.mm(x[dst_node_id], x[src_node_id]) + + @staticmethod + def e_dot_u(x, src_node_id, dst_node_id, edge_weight): + return torch.mm(edge_weight, x[src_node_id]) + + @staticmethod + def e_dot_v(x, src_node_id, dst_node_id, edge_weight): + return torch.mm(edge_weight, x[dst_node_id]) + + +class AggregateBuiltinFunction: + @staticmethod + def sum(src, index, out): + out = scatter_add(src, index, out=out, dim=0) + return out + + @staticmethod + def mean(src, index, out): + out = scatter_mean(src, index, out=out, dim=0) + return out + + @staticmethod + def max(src, index, out): + out = scatter_max(src, index, out=out, dim=0) + return out + + @staticmethod + def min(src, index, out): + out = scatter_min(src, index, out=out, dim=0) + return out \ No newline at end of file diff --git a/tests/test_data_inner_message_passing_aggregation.py b/tests/test_data_inner_message_passing_aggregation.py new file mode 100644 index 00000000..180c60f0 --- /dev/null +++ b/tests/test_data_inner_message_passing_aggregation.py @@ -0,0 +1,19 @@ +import torch +from cogdl.data import Graph + + +def test_data_inner_message_passing_aggregate(node_feats, node_num, edge_num): + x = torch.rand(node_num, node_feats) + edge_index = (torch.randint(0, node_num, (edge_num, )), torch.randint(0, node_num, (edge_num, ))) + graph = Graph(x=x, edge_index=edge_index) + # m = graph.message_passing('u_add_v', x) + m = graph.message_passing('u_mul_e', x) + x = graph.aggregate('sum', x, m) + print(x) + + +if __name__ == '__main__': + node_feats = 512 + edge_num = 1000 + node_num = 500 + test_data_inner_message_passing_aggregate(node_feats, node_num, edge_num) diff --git a/tests/test_hetero_data_inner_message_passing_aggregation.py b/tests/test_hetero_data_inner_message_passing_aggregation.py new file mode 100644 index 00000000..379f5855 --- /dev/null +++ b/tests/test_hetero_data_inner_message_passing_aggregation.py @@ -0,0 +1,31 @@ +import torch +from cogdl.data import HeteroGraph + + +def test_hetero_data_inner_message_passing_aggregate(node_feats, node_num, edge_num): + x = torch.rand(node_num, node_feats) + edge_index = (torch.randint(0, node_num, (edge_num,)), torch.randint(0, node_num, (edge_num,))) + edge_type = { + 'l': torch.randint(0, edge_num, (int(0.5 * edge_num), )), + 'r': torch.randint(0, edge_num, (int(0.5 * edge_num), )), + } + node_type = { + 'x': torch.randint(0, node_num, (int(0.5 * node_num), )), + 'y': torch.randint(0, node_num, (int(0.4 * node_num),)), + } + hetero_graph = HeteroGraph(x=x, edge_index=edge_index, edge_type=edge_type, node_type=node_type) + # 基于边的异构 + # m = hetero_graph.message_passing('u_mul_e', x, edge_type='l') + # x = hetero_graph.aggregate('sum', x, m, edge_type='l') + # 基于点的异构 + m = hetero_graph.message_passing('u_mul_e', x, src_node_type='x', dst_node_type='y') + x = hetero_graph.aggregate('sum', x, m, src_node_type='x', dst_node_type='y') + print(x) + + + +if __name__ == '__main__': + node_feats = 512 + edge_num = 1000 + node_num = 500 + test_hetero_data_inner_message_passing_aggregate(node_feats, node_num, edge_num) \ No newline at end of file From e0defb64cdc2cb4cc0c3b026a15e91cc536593b7 Mon Sep 17 00:00:00 2001 From: Yuxiang Yao Date: Thu, 8 Dec 2022 09:52:22 +0800 Subject: [PATCH 2/6] add ignore file --- cogdl/data/hetero_data.py | 157 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 cogdl/data/hetero_data.py diff --git a/cogdl/data/hetero_data.py b/cogdl/data/hetero_data.py new file mode 100644 index 00000000..070fcffb --- /dev/null +++ b/cogdl/data/hetero_data.py @@ -0,0 +1,157 @@ +import re +import torch +from inspect import isfunction +from cogdl.data.data import Graph, is_read_adj_key, Adjacency, is_adj_key_train +from cogdl.utils.message_aggregate_utils import MessageBuiltinFunction, AggregateBuiltinFunction + + +class HeteroGraph(Graph): + def __init__(self, x=None, y=None, **kwargs): + super(Graph, self).__init__() + if x is not None: + if not torch.is_tensor(x): + raise ValueError("Node features must be Tensor") + self.x = x + self.y = y + self.grb_adj = None + + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + elif key == "grb_adj": + self.grb_adj = item + elif not is_read_adj_key(key): + self[key] = item + + num_nodes = x.shape[0] if x is not None else None + if "edge_index_train" in kwargs: + self._adj_train = Adjacency(num_nodes=num_nodes) + for key, item in kwargs.items(): + if is_adj_key_train(key): + _key = re.search(r"(.*)_train", key).group(1) + if _key.startswith("edge_"): + _key = _key.split("edge_")[1] + if _key == "index": + self._adj_train.edge_index = item + else: + self._adj_train[_key] = item + else: + self._adj_train = None + + self._adj_full = Adjacency(num_nodes=num_nodes) + for key, item in kwargs.items(): + if is_read_adj_key(key) and not is_adj_key_train(key): + if key.startswith("edge_"): + key = key.split("edge_")[-1] + if key == "index": + self._adj_full.edge_index = item + else: + self._adj_full[key] = item + + self._adj = self._adj_full + self.__is_train__ = False + self.__temp_adj_stack__ = list() + self.__temp_storage__ = dict() + + # 异构图上对点的定义 + if 'node_type' in kwargs.keys(): + self.node_type = kwargs['node_type'] + assert isinstance(self.node_type, dict) + for k, v in self.node_type.items(): + if not isinstance(v, torch.Tensor) or v.dtype != torch.int64: + raise Exception('Each value of node type must be tensor type and in int data type') + else: + self.node_type = None + + # 异构图上对边的定义 + if 'edge_type' in kwargs.keys(): + self.edge_type = kwargs['edge_type'] + assert isinstance(self.edge_type, dict) + for k, v in self.edge_type.items(): + if not isinstance(v, torch.Tensor) or v.dtype != torch.int64: + raise Exception('Each value of node type must be tensor type and in int data type') + else: + self.edge_type = None + + + def build_hetero_mask(self, src_node_id, dst_node_id, src_node_type, dst_node_type): + r""" + src_node_id, dst_node_id: 需要进行消息传递的边上的起点和终点 + src_node type, dst_node_type: 消息传递的节点的 + """ + src_node_index = [i for i in range(len(src_node_id)) if src_node_id[i] in self.node_type[src_node_type]] + dst_node_index = [i for i in range(len(dst_node_id)) if dst_node_id[i] in self.node_type[dst_node_type]] + select_node_index = torch.Tensor(list(set(src_node_index).intersection(set(dst_node_index)))).to(torch.int64) + if len(select_node_index) == 0: + raise Warning('No nodes are selected according to the selection condition') + src_node_id = src_node_id.index_select(0, select_node_index) + dst_node_id = dst_node_id.index_select(0, select_node_index) + return src_node_id, dst_node_id + + + def message_passing(self, msg_func, x, edge_weight=None, **kwargs): + src_node_id = self.edge_index[0] + dst_node_id = self.edge_index[1] + + if 'edge_type' in kwargs.keys(): + # 按照预定义的边进行聚合 + edge_type = kwargs['edge_type'] + if self.edge_type is None or edge_type not in self.edge_type.keys(): + raise Exception('This heterograph does not has predefined edge types') + edge_type_index = self.edge_type[edge_type] + src_node_id = self.edge_index[0].index_select(0, edge_type_index) + dst_node_id = self.edge_index[1].index_select(0, edge_type_index) + else: + # 按照预定义的点进行聚合 + if 'src_node_type' in kwargs.keys() and 'dst_node_type' in kwargs.keys(): + src_node_type = kwargs['src_node_type'] + dst_node_type = kwargs['dst_node_type'] + else: + raise Exception('Lack of Arguments: "src_node_type" or "dst_node_type"') + if self.node_type is None or src_node_type not in self.node_type.keys() or dst_node_type not in self.node_type.keys(): + raise Exception('This heterograph does not has predefined node types') + src_node_id, dst_node_id = self.build_hetero_mask(src_node_id, dst_node_id, src_node_type, dst_node_type) + + if not isfunction(msg_func) and not isinstance(msg_func, str): + raise RuntimeError('Only Support Message Functions and String Prompt') + if isinstance(msg_func, str): + if not hasattr(MessageBuiltinFunction, msg_func): + raise NotImplementedError + msg_func = getattr(MessageBuiltinFunction, msg_func) + if edge_weight is None: + edge_weight = torch.ones(len(src_node_id), x.shape[1]) + m = msg_func(x, src_node_id, dst_node_id, edge_weight) + return m + + def aggregate(self, agg_func, x, m, **kwargs): + src_node_id = self.edge_index[0] + dst_node_id = self.edge_index[1] + if 'edge_type' in kwargs.keys(): + # 按照预定义的边进行聚合 + edge_type = kwargs['edge_type'] + if self.edge_type is None or edge_type not in self.edge_type.keys(): + raise Exception('This heterograph does not has predefined edge types') + edge_type_index = self.edge_type[edge_type] + src_node_id = self.edge_index[0].index_select(0, edge_type_index) + dst_node_id = self.edge_index[1].index_select(0, edge_type_index) + else: + # 按照预定义的点进行聚合 + if 'src_node_type' in kwargs.keys() and 'dst_node_type' in kwargs.keys(): + src_node_type = kwargs['src_node_type'] + dst_node_type = kwargs['dst_node_type'] + else: + raise Exception('Lack of Arguments: "src_node_type" or "dst_node_type"') + if self.node_type is None or src_node_type not in self.node_type.keys() or dst_node_type not in self.node_type.keys(): + raise Exception('This heterograph does not has predefined node types') + src_node_id, dst_node_id = self.build_hetero_mask(src_node_id, dst_node_id, src_node_type, dst_node_type) + out = torch.zeros(x.shape[0], m.shape[1], dtype=x.dtype).to(x.device) + index = dst_node_id.unsqueeze(1).expand(-1, m.shape[1]) + src = m + if not isfunction(agg_func) and not isinstance(agg_func, str): + raise RuntimeError('Only Support Message Functions and String Prompt') + if isinstance(agg_func, str): + if not hasattr(AggregateBuiltinFunction, agg_func): + raise NotImplementedError + agg_func = getattr(AggregateBuiltinFunction, agg_func) + h = agg_func(src, index, out=out) + return h From 525ae6428e7ce5e1dfb8101f4935b40ab4e93ac3 Mon Sep 17 00:00:00 2001 From: Yuxiang Yao Date: Fri, 9 Dec 2022 16:19:30 +0800 Subject: [PATCH 3/6] refine __init__ in hetero data file --- cogdl/data/hetero_data.py | 50 ++------------------------------------- 1 file changed, 2 insertions(+), 48 deletions(-) diff --git a/cogdl/data/hetero_data.py b/cogdl/data/hetero_data.py index 070fcffb..c2df2e34 100644 --- a/cogdl/data/hetero_data.py +++ b/cogdl/data/hetero_data.py @@ -1,59 +1,13 @@ import re import torch from inspect import isfunction -from cogdl.data.data import Graph, is_read_adj_key, Adjacency, is_adj_key_train +from cogdl.data.data import Graph from cogdl.utils.message_aggregate_utils import MessageBuiltinFunction, AggregateBuiltinFunction class HeteroGraph(Graph): def __init__(self, x=None, y=None, **kwargs): - super(Graph, self).__init__() - if x is not None: - if not torch.is_tensor(x): - raise ValueError("Node features must be Tensor") - self.x = x - self.y = y - self.grb_adj = None - - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - elif key == "grb_adj": - self.grb_adj = item - elif not is_read_adj_key(key): - self[key] = item - - num_nodes = x.shape[0] if x is not None else None - if "edge_index_train" in kwargs: - self._adj_train = Adjacency(num_nodes=num_nodes) - for key, item in kwargs.items(): - if is_adj_key_train(key): - _key = re.search(r"(.*)_train", key).group(1) - if _key.startswith("edge_"): - _key = _key.split("edge_")[1] - if _key == "index": - self._adj_train.edge_index = item - else: - self._adj_train[_key] = item - else: - self._adj_train = None - - self._adj_full = Adjacency(num_nodes=num_nodes) - for key, item in kwargs.items(): - if is_read_adj_key(key) and not is_adj_key_train(key): - if key.startswith("edge_"): - key = key.split("edge_")[-1] - if key == "index": - self._adj_full.edge_index = item - else: - self._adj_full[key] = item - - self._adj = self._adj_full - self.__is_train__ = False - self.__temp_adj_stack__ = list() - self.__temp_storage__ = dict() - - # 异构图上对点的定义 + super(HeteroGraph, self).__init__(x, y, **kwargs) if 'node_type' in kwargs.keys(): self.node_type = kwargs['node_type'] assert isinstance(self.node_type, dict) From 3a6ff2e1adc857ed09ecdc86fbf31dc5b1153cf0 Mon Sep 17 00:00:00 2001 From: Yuxiang Yao Date: Fri, 9 Dec 2022 19:17:22 +0800 Subject: [PATCH 4/6] modify test file --- cogdl/data/hetero_data.py | 28 +++++++++++++++++ ..._data_inner_message_passing_aggregation.py | 31 ------------------- ...py => test_message_passing_aggregation.py} | 0 3 files changed, 28 insertions(+), 31 deletions(-) delete mode 100644 tests/test_hetero_data_inner_message_passing_aggregation.py rename tests/{test_data_inner_message_passing_aggregation.py => test_message_passing_aggregation.py} (100%) diff --git a/cogdl/data/hetero_data.py b/cogdl/data/hetero_data.py index c2df2e34..d33434b1 100644 --- a/cogdl/data/hetero_data.py +++ b/cogdl/data/hetero_data.py @@ -109,3 +109,31 @@ def aggregate(self, agg_func, x, m, **kwargs): agg_func = getattr(AggregateBuiltinFunction, agg_func) h = agg_func(src, index, out=out) return h + +# def test_hetero_data_inner_message_passing_aggregate(node_feats, node_num, edge_num): +# x = torch.rand(node_num, node_feats) +# edge_index = (torch.randint(0, node_num, (edge_num,)), torch.randint(0, node_num, (edge_num,))) +# edge_type = { +# 'l': torch.randint(0, edge_num, (int(0.5 * edge_num), )), +# 'r': torch.randint(0, edge_num, (int(0.5 * edge_num), )), +# } +# node_type = { +# 'x': torch.randint(0, node_num, (int(0.5 * node_num), )), +# 'y': torch.randint(0, node_num, (int(0.4 * node_num),)), +# } +# hetero_graph = HeteroGraph(x=x, edge_index=edge_index, edge_type=edge_type, node_type=node_type) +# # 基于边的异构 +# # m = hetero_graph.message_passing('u_mul_e', x, edge_type='l') +# # x = hetero_graph.aggregate('sum', x, m, edge_type='l') +# # 基于点的异构 +# m = hetero_graph.message_passing('u_mul_e', x, src_node_type='x', dst_node_type='y') +# x = hetero_graph.aggregate('sum', x, m, src_node_type='x', dst_node_type='y') +# print(x) +# +# +# +# if __name__ == '__main__': +# node_feats = 512 +# edge_num = 1000 +# node_num = 500 +# test_hetero_data_inner_message_passing_aggregate(node_feats, node_num, edge_num) \ No newline at end of file diff --git a/tests/test_hetero_data_inner_message_passing_aggregation.py b/tests/test_hetero_data_inner_message_passing_aggregation.py deleted file mode 100644 index 379f5855..00000000 --- a/tests/test_hetero_data_inner_message_passing_aggregation.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from cogdl.data import HeteroGraph - - -def test_hetero_data_inner_message_passing_aggregate(node_feats, node_num, edge_num): - x = torch.rand(node_num, node_feats) - edge_index = (torch.randint(0, node_num, (edge_num,)), torch.randint(0, node_num, (edge_num,))) - edge_type = { - 'l': torch.randint(0, edge_num, (int(0.5 * edge_num), )), - 'r': torch.randint(0, edge_num, (int(0.5 * edge_num), )), - } - node_type = { - 'x': torch.randint(0, node_num, (int(0.5 * node_num), )), - 'y': torch.randint(0, node_num, (int(0.4 * node_num),)), - } - hetero_graph = HeteroGraph(x=x, edge_index=edge_index, edge_type=edge_type, node_type=node_type) - # 基于边的异构 - # m = hetero_graph.message_passing('u_mul_e', x, edge_type='l') - # x = hetero_graph.aggregate('sum', x, m, edge_type='l') - # 基于点的异构 - m = hetero_graph.message_passing('u_mul_e', x, src_node_type='x', dst_node_type='y') - x = hetero_graph.aggregate('sum', x, m, src_node_type='x', dst_node_type='y') - print(x) - - - -if __name__ == '__main__': - node_feats = 512 - edge_num = 1000 - node_num = 500 - test_hetero_data_inner_message_passing_aggregate(node_feats, node_num, edge_num) \ No newline at end of file diff --git a/tests/test_data_inner_message_passing_aggregation.py b/tests/test_message_passing_aggregation.py similarity index 100% rename from tests/test_data_inner_message_passing_aggregation.py rename to tests/test_message_passing_aggregation.py From 9a358fc597aee4e1118a726c1a1bdcf1060bdadb Mon Sep 17 00:00:00 2001 From: Yuxiang Yao Date: Fri, 9 Dec 2022 19:18:28 +0800 Subject: [PATCH 5/6] modify test file --- tests/test_message_passing_aggregation.py | 38 +++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test_message_passing_aggregation.py b/tests/test_message_passing_aggregation.py index 180c60f0..acb78271 100644 --- a/tests/test_message_passing_aggregation.py +++ b/tests/test_message_passing_aggregation.py @@ -1,19 +1,19 @@ -import torch -from cogdl.data import Graph - - -def test_data_inner_message_passing_aggregate(node_feats, node_num, edge_num): - x = torch.rand(node_num, node_feats) - edge_index = (torch.randint(0, node_num, (edge_num, )), torch.randint(0, node_num, (edge_num, ))) - graph = Graph(x=x, edge_index=edge_index) - # m = graph.message_passing('u_add_v', x) - m = graph.message_passing('u_mul_e', x) - x = graph.aggregate('sum', x, m) - print(x) - - -if __name__ == '__main__': - node_feats = 512 - edge_num = 1000 - node_num = 500 - test_data_inner_message_passing_aggregate(node_feats, node_num, edge_num) +# import torch +# from cogdl.data import Graph +# +# +# def test_data_inner_message_passing_aggregate(node_feats, node_num, edge_num): +# x = torch.rand(node_num, node_feats) +# edge_index = (torch.randint(0, node_num, (edge_num, )), torch.randint(0, node_num, (edge_num, ))) +# graph = Graph(x=x, edge_index=edge_index) +# # m = graph.message_passing('u_add_v', x) +# m = graph.message_passing('u_mul_e', x) +# x = graph.aggregate('sum', x, m) +# print(x) +# +# +# if __name__ == '__main__': +# node_feats = 512 +# edge_num = 1000 +# node_num = 500 +# test_data_inner_message_passing_aggregate(node_feats, node_num, edge_num) From 205fbb07e0bee832ad63f7d588ed72f59e3fb11c Mon Sep 17 00:00:00 2001 From: Yuxiang Yao Date: Fri, 9 Dec 2022 19:21:38 +0800 Subject: [PATCH 6/6] change test file --- tests/test_message_passing_aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_message_passing_aggregation.py b/tests/test_message_passing_aggregation.py index acb78271..310c8581 100644 --- a/tests/test_message_passing_aggregation.py +++ b/tests/test_message_passing_aggregation.py @@ -6,7 +6,7 @@ # x = torch.rand(node_num, node_feats) # edge_index = (torch.randint(0, node_num, (edge_num, )), torch.randint(0, node_num, (edge_num, ))) # graph = Graph(x=x, edge_index=edge_index) -# # m = graph.message_passing('u_add_v', x) +# # m = graph.messagegit _passing('u_add_v', x) # m = graph.message_passing('u_mul_e', x) # x = graph.aggregate('sum', x, m) # print(x)