From 2b0bd9ae1ee8597900dc164fd419fc505760db65 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 11 Aug 2022 14:14:20 +0000 Subject: [PATCH] change api attribute name, move pool_type to reduce_op, move compute_type to message_op --- paddle/fluid/operators/graph_send_recv_op.cc | 12 ++-- .../fluid/operators/graph_send_ue_recv_op.cc | 16 ++--- paddle/phi/api/yaml/legacy_api.yaml | 4 +- paddle/phi/api/yaml/legacy_backward.yaml | 8 +-- paddle/phi/infermeta/multiary.cc | 6 +- paddle/phi/infermeta/multiary.h | 4 +- paddle/phi/infermeta/ternary.cc | 4 +- paddle/phi/infermeta/ternary.h | 2 +- .../cpu/graph_send_recv_grad_kernel.cc | 28 ++++---- .../phi/kernels/cpu/graph_send_recv_kernel.cc | 32 ++++----- .../cpu/graph_send_ue_recv_grad_kernel.cc | 72 +++++++++---------- .../kernels/cpu/graph_send_ue_recv_kernel.cc | 36 +++++----- .../gpu/graph_send_recv_grad_kernel.cu | 14 ++-- .../phi/kernels/gpu/graph_send_recv_kernel.cu | 22 +++--- .../gpu/graph_send_ue_recv_grad_kernel.cu | 72 +++++++++---------- .../kernels/gpu/graph_send_ue_recv_kernel.cu | 42 +++++------ .../phi/kernels/graph_send_recv_grad_kernel.h | 2 +- paddle/phi/kernels/graph_send_recv_kernel.h | 2 +- .../kernels/graph_send_ue_recv_grad_kernel.h | 4 +- .../phi/kernels/graph_send_ue_recv_kernel.h | 4 +- paddle/phi/ops/compat/graph_send_recv_sig.cc | 6 +- .../phi/ops/compat/graph_send_ue_recv_sig.cc | 6 +- .../unittests/test_graph_send_recv_op.py | 24 +++---- .../unittests/test_graph_send_ue_recv_op.py | 16 ++--- .../geometric/message_passing/send_recv.py | 18 ++--- .../incubate/operators/graph_send_recv.py | 6 +- 26 files changed, 231 insertions(+), 231 deletions(-) diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index e9ba861c3b88b..b954ecab704b4 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -64,9 +64,9 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("Out", "Output tensor of graph_send_recv op."); AddOutput("Dst_count", - "Count tensor of Dst_index, mainly for MEAN pool_type.") + "Count tensor of Dst_index, mainly for MEAN reduce_op.") .AsIntermediate(); - AddAttr("pool_type", + AddAttr("reduce_op", "(string, default 'SUM')" "Define different pool types to receive the result " "tensors of Dst_index.") @@ -81,7 +81,7 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Graph Learning Send_Recv combine operator. -$Out = Recv(Send(X, Src_index), Dst_index, pool_type)$ +$Out = Recv(Send(X, Src_index), Dst_index, reduce_op)$ This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. @@ -105,12 +105,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Dst_index", this->Input("Dst_index")); op->SetInput("X", this->Input("X")); - if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MEAN") { op->SetInput("Dst_count", this->Output("Dst_count")); } - if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || - PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MIN" || + PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MAX") { op->SetInput("Out", this->Output("Out")); } diff --git a/paddle/fluid/operators/graph_send_ue_recv_op.cc b/paddle/fluid/operators/graph_send_ue_recv_op.cc index 696b2656a7052..af16609df3ebd 100644 --- a/paddle/fluid/operators/graph_send_ue_recv_op.cc +++ b/paddle/fluid/operators/graph_send_ue_recv_op.cc @@ -68,14 +68,14 @@ class GraphSendUERecvOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("Out", "Output tensor of graph_send_ue_recv op."); AddOutput("Dst_count", - "Count tensor of Dst_index, mainly for MEAN pool_type.") + "Count tensor of Dst_index, mainly for MEAN reduce_op.") .AsIntermediate(); - AddAttr("compute_type", + AddAttr("message_op", "(string, default 'ADD')" "Define differenct computation types between X and E.") .SetDefault("ADD") .InEnum({"ADD", "MUL"}); - AddAttr("pool_type", + AddAttr("reduce_op", "(string, default 'SUM')" "Define different pool types to receive the result " "tensors of Dst_index.") @@ -90,13 +90,13 @@ class GraphSendUERecvOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Graph Learning Send_UE_Recv combine operator. -$Out = Recv(Compute(Send(X, Src_index), Y, compute_type), Dst_index, pool_type)$ +$Out = Recv(Compute(Send(X, Src_index), Y, message_op), Dst_index, reduce_op)$ This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `X` as the input tensor, we first use `src_index` to gather corresponding data. -Then the gather data should compute with `Y` in different compute_types, like add, sub, mul, and div, +Then the gather data should compute with `Y` in different message_ops, like add, sub, mul, and div, and get the computation result. Then, use `dst_index` to update the corresponding position of output tensor in different pooling types, like sum, mean, max, or min. @@ -117,12 +117,12 @@ class GraphSendUERecvGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Src_index", this->Input("Src_index")); op->SetInput("Dst_index", this->Input("Dst_index")); - if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MEAN") { op->SetInput("Dst_count", this->Output("Dst_count")); } - if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || - PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MIN" || + PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MAX") { op->SetInput("Out", this->Output("Out")); } diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index c7c2f6f0152f9..4f2ea4a6b6655 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1060,7 +1060,7 @@ func : generate_proposals_v2 - api : graph_send_recv - args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) + args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) output : Tensor(out), Tensor(dst_count) infer_meta : func : GraphSendRecvInferMeta @@ -1071,7 +1071,7 @@ backward : graph_send_recv_grad - api : graph_send_ue_recv - args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str compute_type, str pool_type, IntArray out_size) + args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) output : Tensor(out), Tensor(dst_count) infer_meta : func : GraphSendUERecvInferMeta diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 31ddba838f512..ba9f306faecc4 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -941,8 +941,8 @@ func : gelu_grad - backward_api : graph_send_recv_grad - forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) - args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM") + forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) + args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM") output : Tensor(x_grad) infer_meta : func : GeneralUnaryGradInferMeta @@ -953,8 +953,8 @@ optional: out, dst_count - backward_api : graph_send_ue_recv_grad - forward : graph_send_ue_recv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str compute_type, str pool_type, IntArray out_size) -> Tensor(out), Tensor(dst_count) - args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str compute_type, str pool_type) + forward : graph_send_ue_recv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) -> Tensor(out), Tensor(dst_count) + args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str message_op, str reduce_op) output : Tensor(x_grad), Tensor(y_grad) infer_meta : func : GeneralBinaryGradInferMeta diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 53076d1a5d127..7ccd52bb6ff39 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2603,8 +2603,8 @@ void GraphSendUERecvInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& src_index, const MetaTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count) { @@ -2658,7 +2658,7 @@ void GraphSendUERecvInferMeta(const MetaTensor& x, y_dims[0])); auto x_dims = x.dims(); - if (pool_type == "MEAN") { + if (reduce_op == "MEAN") { dst_count->set_dims({-1}); dst_count->set_dtype(DataType::INT32); } diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 66d8ad84a4378..660121b844d10 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -470,8 +470,8 @@ void GraphSendUERecvInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& src_index, const MetaTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a919a955a541a..342c9e4602309 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -411,7 +411,7 @@ void InstanceNormInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count) { @@ -460,7 +460,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(dims_)); out->set_dtype(x.dtype()); - if (pool_type == "MEAN") { + if (reduce_op == "MEAN") { dst_count->set_dims({-1}); dst_count->set_dtype(DataType::INT32); } diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 466bd3df5de2d..5314b8f45affe 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -75,7 +75,7 @@ void InstanceNormInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc index ad04bd258e141..d4131a1ffb5e3 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc @@ -29,10 +29,10 @@ void GraphSendRecvCpuGradLoop(const int& index_size, const DenseTensor& src, const DenseTensor& input, DenseTensor* dst, - const std::string& pool_type, + const std::string& reduce_op, const int* dst_count = nullptr, const DenseTensor* output = nullptr) { - if (pool_type == "SUM") { + if (reduce_op == "SUM") { Functor functor; for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; @@ -40,7 +40,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size, ElementwiseInnerOperation( src, dst, src_idx, dst_idx, false, functor); } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; const IndexT& dst_idx = d_index[i]; @@ -50,7 +50,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size, auto eigen_dst = phi::EigenVector::Flatten(dst_slice); eigen_dst += (eigen_src / static_cast(dst_count[src_idx])); } - } else if (pool_type == "MIN" || pool_type == "MAX") { + } else if (reduce_op == "MIN" || reduce_op == "MAX") { for (int i = 0; i < index_size; ++i) { const IndexT& forward_src_idx = d_index[i]; const IndexT& forward_dst_idx = s_index[i]; @@ -75,7 +75,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, const DenseTensor* out = nullptr) { @@ -94,15 +94,15 @@ void GraphSendRecvGradOpKernelLaunchHelper( const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvCpuGradLoop>( - index_size, d_index, s_index, out_grad, x, x_grad, pool_type); - } else if (pool_type == "MEAN") { + index_size, d_index, s_index, out_grad, x, x_grad, reduce_op); + } else if (reduce_op == "MEAN") { const int* s_count = dst_count->data(); // Functor not used here. GraphSendRecvCpuGradLoop>( - index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count); - } else if (pool_type == "MIN" || pool_type == "MAX") { + index_size, d_index, s_index, out_grad, x, x_grad, reduce_op, s_count); + } else if (reduce_op == "MIN" || reduce_op == "MAX") { // Functor not used here. GraphSendRecvCpuGradLoop>(index_size, d_index, @@ -110,7 +110,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( out_grad, x, x_grad, - pool_type, + reduce_op, nullptr, out); } @@ -124,7 +124,7 @@ void GraphSendRecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { @@ -134,7 +134,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); @@ -145,7 +145,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index d4b9c8c60e3f8..7985a65a20053 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -32,17 +32,17 @@ void GraphSendRecvCpuLoop(const int& input_size, const IndexT* d_index, const DenseTensor& src, DenseTensor* dst, - const std::string& pool_type, + const std::string& reduce_op, int* dst_count = nullptr) { Functor functor; - if (pool_type == "SUM") { + if (reduce_op == "SUM") { for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; const IndexT& dst_idx = d_index[i]; ElementwiseInnerOperation( src, dst, src_idx, dst_idx, false, functor); } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; const IndexT& dst_idx = d_index[i]; @@ -59,7 +59,7 @@ void GraphSendRecvCpuLoop(const int& input_size, auto eigen_dst = phi::EigenVector::Flatten(dst_slice); eigen_dst = eigen_dst / static_cast(*(dst_count + i)); } - } else if (pool_type == "MIN" || pool_type == "MAX") { + } else if (reduce_op == "MIN" || reduce_op == "MAX") { std::set existed_dst; for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; @@ -82,7 +82,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { @@ -117,16 +117,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvCpuLoop>( - src_dims[0], index_size, s_index, d_index, x, out, pool_type); - } else if (pool_type == "MIN") { + src_dims[0], index_size, s_index, d_index, x, out, reduce_op); + } else if (reduce_op == "MIN") { GraphSendRecvCpuLoop>( - src_dims[0], index_size, s_index, d_index, x, out, pool_type); - } else if (pool_type == "MAX") { + src_dims[0], index_size, s_index, d_index, x, out, reduce_op); + } else if (reduce_op == "MAX") { GraphSendRecvCpuLoop>( - src_dims[0], index_size, s_index, d_index, x, out, pool_type); - } else if (pool_type == "MEAN") { + src_dims[0], index_size, s_index, d_index, x, out, reduce_op); + } else if (reduce_op == "MEAN") { int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; dst_count->Resize({input_size}); ctx.template Alloc(dst_count); @@ -138,7 +138,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, d_index, x, out, - pool_type, + reduce_op, p_dst_count); } } @@ -148,7 +148,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count) { @@ -159,7 +159,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); @@ -168,7 +168,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc index c3ae8563370f8..95fdc6ff0a9cc 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc @@ -39,8 +39,8 @@ void CalculateXGrad(const Context& ctx, const phi::DDim& e_dims, const IndexT* s_index, const IndexT* d_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t index_size, T* x_grad, const DenseTensor& out_grad_tensor, @@ -50,8 +50,8 @@ void CalculateXGrad(const Context& ctx, std::vector reduce_idx; bool reduce = ReduceGrad(out_grad_dims, x_dims, reduce_idx); - if (pool_type == "SUM") { - if (compute_type == "ADD") { + if (reduce_op == "SUM") { + if (message_op == "ADD") { GraphSendRecvSumFunctor sum_functor; if (!reduce) { for (int64_t i = 0; i < index_size; i++) { @@ -78,7 +78,7 @@ void CalculateXGrad(const Context& ctx, true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); } - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { const auto& bcast = phi::CalcBCastInfo(out_grad_dims, e_dims); if (!reduce) { #ifdef PADDLE_WITH_MKLML @@ -137,9 +137,9 @@ void CalculateXGrad(const Context& ctx, memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); } } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { const int* s_count = dst_count->data(); - if (compute_type == "ADD") { + if (message_op == "ADD") { if (!reduce) { for (int64_t i = 0; i < index_size; i++) { IndexT src = s_index[i]; @@ -171,7 +171,7 @@ void CalculateXGrad(const Context& ctx, true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); } - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { const auto& bcast = phi::CalcBCastInfo(out_grad_dims, e_dims); if (!reduce) { #ifdef PADDLE_WITH_MKLML @@ -237,13 +237,13 @@ void CalculateEGrad(const T* out_grad_data, const phi::DDim& e_dims, const IndexT* s_index, const IndexT* d_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t index_size, T* e_grad, const DenseTensor* dst_count = nullptr) { const auto& bcast = phi::CalcBCastInfo(x_dims, e_dims); - if (pool_type == "SUM") { + if (reduce_op == "SUM") { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif @@ -256,12 +256,12 @@ void CalculateEGrad(const T* out_grad_data, for (int64_t j = 0; j < bcast.out_len; j++) { int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; - if (compute_type == "ADD") { + if (message_op == "ADD") { #ifdef PADDLE_WITH_MKLML #pragma omp atomic #endif e_grad_off[e_add] += out_grad_off[j]; - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { #ifdef PADDLE_WITH_MKLML #pragma omp atomic #endif @@ -269,7 +269,7 @@ void CalculateEGrad(const T* out_grad_data, } } } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { const int* s_count = dst_count->data(); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for @@ -283,12 +283,12 @@ void CalculateEGrad(const T* out_grad_data, for (int64_t j = 0; j < bcast.out_len; j++) { int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; - if (compute_type == "ADD") { + if (message_op == "ADD") { #ifdef PADDLE_WITH_MKLML #pragma omp atomic #endif e_grad_off[e_add] += (out_grad_off[j] / s_count[dst]); - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { #ifdef PADDLE_WITH_MKLML #pragma omp atomic #endif @@ -307,8 +307,8 @@ void CalculateXEGradForMinMax(const T* out_grad, const phi::DDim& e_dims, const IndexT* s_index, const IndexT* d_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t index_size, T* x_grad, T* e_grad, @@ -330,14 +330,14 @@ void CalculateXEGradForMinMax(const T* out_grad, for (int64_t j = 0; j < bcast.out_len; j++) { int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; - if (compute_type == "ADD") { + if (message_op == "ADD") { T val = x_off[x_add] + e_off[e_add]; #ifdef PADDLE_WITH_MKLML #pragma omp critical #endif x_grad_off[x_add] += (out_grad_off[j] * (val == out_off[j])); e_grad_off[e_add] += (out_grad_off[j] * (val == out_off[j])); - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { T val = x_off[x_add] * e_off[e_add]; #ifdef PADDLE_WITH_MKLML #pragma omp critical @@ -359,8 +359,8 @@ void GraphSendUERecvGradOpKernelLaunchHelper( const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, DenseTensor* x_grad, DenseTensor* y_grad, const DenseTensor* dst_count = nullptr, @@ -395,7 +395,7 @@ void GraphSendUERecvGradOpKernelLaunchHelper( const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM" || pool_type == "MEAN") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { CalculateXGrad(ctx, out_grad_data, x_data, @@ -405,8 +405,8 @@ void GraphSendUERecvGradOpKernelLaunchHelper( y_dims, d_index, s_index, - compute_type, - pool_type, + message_op, + reduce_op, index_size, x_grad_data, out_grad, @@ -420,12 +420,12 @@ void GraphSendUERecvGradOpKernelLaunchHelper( y_dims, s_index, d_index, - compute_type, - pool_type, + message_op, + reduce_op, index_size, y_grad_data, dst_count); - } else if (pool_type == "MIN" || pool_type == "MAX") { + } else if (reduce_op == "MIN" || reduce_op == "MAX") { CalculateXEGradForMinMax(out_grad_data, x_data, y_data, @@ -433,8 +433,8 @@ void GraphSendUERecvGradOpKernelLaunchHelper( y_dims, d_index, s_index, - compute_type, - pool_type, + message_op, + reduce_op, index_size, x_grad_data, y_grad_data, @@ -451,8 +451,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, DenseTensor* x_grad, DenseTensor* y_grad) { auto index_type = src_index.dtype(); @@ -464,8 +464,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, y, src_index, dst_index, - compute_type, - pool_type, + message_op, + reduce_op, x_grad, y_grad, dst_count.get_ptr(), @@ -478,8 +478,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, y, src_index, dst_index, - compute_type, - pool_type, + message_op, + reduce_op, x_grad, y_grad, dst_count.get_ptr(), diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc index 5c3760657be86..74fca002294db 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc @@ -110,8 +110,8 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { @@ -140,8 +140,8 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, const T* y_data = y.data(); const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM" || pool_type == "MEAN") { - if (compute_type == "ADD") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { + if (message_op == "ADD") { GraphAddFunctor add_functor; GraphSendUERecvSumCpuKernel>(bcast_info, x_data, @@ -151,7 +151,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, out_data, index_size, add_functor); - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { GraphMulFunctor mul_functor; GraphSendUERecvSumCpuKernel>(bcast_info, x_data, @@ -162,7 +162,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, index_size, mul_functor); } - if (pool_type == "MEAN") { + if (reduce_op == "MEAN") { int64_t input_size = out_size <= 0 ? x.dims()[0] : out_size; dst_count->Resize({input_size}); int* dst_count_data = ctx.template Alloc(dst_count); @@ -178,9 +178,9 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, eigen_out = eigen_out / static_cast(dst_count_data[i]); } } - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { GraphMinFunctor min_functor; - if (compute_type == "ADD") { + if (message_op == "ADD") { GraphAddFunctor add_functor; GraphSendUERecvMinMaxCpuKernel mul_functor; GraphSendUERecvMinMaxCpuKernel max_functor; - if (compute_type == "ADD") { + if (message_op == "ADD") { GraphAddFunctor add_functor; GraphSendUERecvMinMaxCpuKernel mul_functor; GraphSendUERecvMinMaxCpuKernel functor; GraphSendRecvCUDAKernel> <<>>( p_src, d_index, s_index, p_output, index_size, slice_size, functor); - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { const int32_t* s_count = dst_count->data(); ManipulateMeanGradCUDAKernel<<>>( p_src, d_index, s_index, p_output, index_size, slice_size, s_count); - } else if (pool_type == "MAX" || pool_type == "MIN") { + } else if (reduce_op == "MAX" || reduce_op == "MIN") { const T* ptr_input = x.data(); const T* ptr_output = out->data(); ManipulateMinMaxGradCUDAKernel @@ -105,7 +105,7 @@ void GraphSendRecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { @@ -115,7 +115,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); @@ -126,7 +126,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index e696960f800d0..055d4888e3f56 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -32,7 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { @@ -59,19 +59,19 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(out); T* p_output = out->data(); const size_t& memset_bytes = memset_size * sizeof(T); - if (pool_type == "SUM" || pool_type == "MEAN") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { #ifdef PADDLE_WITH_HIP hipMemset(p_output, 0, memset_bytes); #else cudaMemset(p_output, 0, memset_bytes); #endif - } else if (pool_type == "MAX") { + } else if (reduce_op == "MAX") { thrust::device_ptr p_output_ptr(p_output); thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, std::numeric_limits::lowest()); - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { thrust::device_ptr p_output_ptr(p_output); thrust::fill(thrust::device, p_output_ptr, @@ -99,12 +99,12 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, int64_t grid_tmp = (n + block - 1) / block; int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvSumCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - } else if (pool_type == "MAX") { + } else if (reduce_op == "MAX") { GraphSendRecvMaxCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( @@ -115,7 +115,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; InputResetMaxCUDAKernel<<>>( p_output, input_size, slice_size); - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { GraphSendRecvMinCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( @@ -126,7 +126,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; InputResetMinCUDAKernel<<>>( p_output, input_size, slice_size); - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { GraphSendRecvSumCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( @@ -158,7 +158,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count) { @@ -169,7 +169,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); @@ -178,7 +178,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu index 7d89a1bc7d82e..cb3d5591a7be6 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu @@ -35,8 +35,8 @@ void CalculateXEGradForMinMax(const Context& ctx, const phi::DDim& e_dims, const IndexT* s_index, const IndexT* d_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t index_size, T* x_grad, T* e_grad, @@ -56,7 +56,7 @@ void CalculateXEGradForMinMax(const Context& ctx, const dim3 grid(nbx, nby); const dim3 block(ntx, nty); - if (compute_type == "ADD") { + if (message_op == "ADD") { ManipulateMinMaxGradCUDAKernelForAdd <<>>( x_data, @@ -74,7 +74,7 @@ void CalculateXEGradForMinMax(const Context& ctx, bcast_info.r_len, out_len, bcast_info.use_bcast); - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { ManipulateMinMaxGradCUDAKernelForMul <<>>( x_data, @@ -105,8 +105,8 @@ void CalculateXGrad(const Context& ctx, const phi::DDim& e_dims, const IndexT* s_index, const IndexT* d_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t index_size, int64_t slice_size, T* x_grad, @@ -124,8 +124,8 @@ void CalculateXGrad(const Context& ctx, int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; std::vector reduce_idx; bool reduce = ReduceGrad(out_grad_dims, x_dims, reduce_idx); - if (pool_type == "SUM") { - if (compute_type == "ADD") { + if (reduce_op == "SUM") { + if (message_op == "ADD") { GraphSendRecvSumCUDAFunctor functor; if (!reduce) { GraphSendRecvCUDAKernel l_bcastoff, r_bcastoff; if (bcast_info.use_bcast) { @@ -251,9 +251,9 @@ void CalculateXGrad(const Context& ctx, #endif } } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { const int* s_count = dst_count->data(); - if (compute_type == "ADD") { + if (message_op == "ADD") { if (!reduce) { ManipulateMeanGradCUDAKernel <<>>(out_grad, @@ -296,7 +296,7 @@ void CalculateXGrad(const Context& ctx, cudaMemcpyDeviceToDevice); #endif } - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims); thrust::device_vector l_bcastoff, r_bcastoff; if (bcast_info.use_bcast) { @@ -378,8 +378,8 @@ void CalculateEGrad(const Context& ctx, const phi::DDim& e_dims, const IndexT* s_index, const IndexT* d_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t index_size, T* e_grad, const DenseTensor* dst_count = nullptr) { @@ -395,8 +395,8 @@ void CalculateEGrad(const Context& ctx, const int nby = (index_size + nty - 1) / nty; const dim3 grid(nbx, nby); const dim3 block(ntx, nty); - if (pool_type == "SUM") { - if (compute_type == "ADD") { + if (reduce_op == "SUM") { + if (message_op == "ADD") { ManipulateSumGradCUDAKernelForAddE <<>>( out_grad, @@ -407,7 +407,7 @@ void CalculateEGrad(const Context& ctx, bcast_info.r_len, out_len, bcast_info.use_bcast); - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { ManipulateSumGradCUDAKernelForMulE <<>>( x_data, @@ -423,9 +423,9 @@ void CalculateEGrad(const Context& ctx, out_len, bcast_info.use_bcast); } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { const int* s_count = dst_count->data(); - if (compute_type == "ADD") { + if (message_op == "ADD") { ManipulateMeanGradCUDAKernelForAddE <<>>( out_grad, @@ -437,7 +437,7 @@ void CalculateEGrad(const Context& ctx, bcast_info.r_len, out_len, bcast_info.use_bcast); - } else if (compute_type == "MUL") { + } else if (message_op == "MUL") { ManipulateMeanGradCUDAKernelForMulE <<>>( x_data, @@ -465,8 +465,8 @@ void GraphSendUERecvGradOpCUDAKernelLaunchHelper( const DenseTensor& e, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, DenseTensor* x_grad, DenseTensor* e_grad, const DenseTensor* dst_count = nullptr, @@ -506,7 +506,7 @@ void GraphSendUERecvGradOpCUDAKernelLaunchHelper( const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM" || pool_type == "MEAN") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { CalculateXGrad(ctx, out_grad_data, x_data, @@ -516,8 +516,8 @@ void GraphSendUERecvGradOpCUDAKernelLaunchHelper( e_dims, s_index, d_index, - compute_type, - pool_type, + message_op, + reduce_op, index_size, slice_size, x_grad_data, @@ -532,12 +532,12 @@ void GraphSendUERecvGradOpCUDAKernelLaunchHelper( e_dims, s_index, d_index, - compute_type, - pool_type, + message_op, + reduce_op, index_size, e_grad_data, dst_count); - } else if (pool_type == "MIN" || pool_type == "MAX") { + } else if (reduce_op == "MIN" || reduce_op == "MAX") { CalculateXEGradForMinMax(ctx, out_grad_data, x_data, @@ -546,8 +546,8 @@ void GraphSendUERecvGradOpCUDAKernelLaunchHelper( e_dims, s_index, d_index, - compute_type, - pool_type, + message_op, + reduce_op, index_size, x_grad_data, e_grad_data, @@ -564,8 +564,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, DenseTensor* x_grad, DenseTensor* y_grad) { auto index_type = src_index.dtype(); @@ -577,8 +577,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, y, src_index, dst_index, - compute_type, - pool_type, + message_op, + reduce_op, x_grad, y_grad, dst_count.get_ptr(), @@ -591,8 +591,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, y, src_index, dst_index, - compute_type, - pool_type, + message_op, + reduce_op, x_grad, y_grad, dst_count.get_ptr(), diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu index 28e304266dabd..f339387f0bbfc 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu @@ -35,8 +35,8 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, const DenseTensor& e, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { @@ -57,20 +57,20 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(out); T* out_data = out->data(); const size_t& memset_bytes = memset_size * sizeof(T); - if (pool_type == "SUM" || pool_type == "MEAN") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { #ifdef PADDLE_WITH_HIP hipMemset(out_data, 0, memset_bytes); #else cudaMemset(out_data, 0, memset_bytes); #endif - } else if (pool_type == "MAX") { + } else if (reduce_op == "MAX") { thrust::device_ptr out_data_ptr(out_data); thrust::fill(thrust::device, out_data_ptr, out_data_ptr + memset_size, std::numeric_limits::lowest()); - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { thrust::device_ptr out_data_ptr(out_data); thrust::fill(thrust::device, out_data_ptr, @@ -104,9 +104,9 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, #else int block_ = 1024; #endif - if (pool_type == "SUM" || pool_type == "MEAN") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { GraphSendUERecvSumCUDAFunctor sum_functor; - if (compute_type == "ADD") { + if (message_op == "ADD") { funcs::AddFunctor add_funtor; GraphSendUERecvCUDAKernel mul_functor; GraphSendUERecvCUDAKernelResize({input_size}); ctx.template Alloc(dst_count); @@ -171,9 +171,9 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, ManipulateMeanCUDAKernel<<>>( out_data, dst_count_data, input_size, out_len); } - } else if (pool_type == "MAX") { + } else if (reduce_op == "MAX") { GraphSendUERecvMaxCUDAFunctor max_functor; - if (compute_type == "ADD") { + if (message_op == "ADD") { funcs::AddFunctor add_funtor; GraphSendUERecvCUDAKernel mul_functor; GraphSendUERecvCUDAKernel <<>>(out_data, input_size, out_len); - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { GraphSendUERecvMinCUDAFunctor min_functor; - if (compute_type == "ADD") { + if (message_op == "ADD") { funcs::AddFunctor add_funtor; GraphSendUERecvCUDAKernel mul_functor; GraphSendUERecvCUDAKernel& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h index cd625c92b93ea..023e86064ff51 100644 --- a/paddle/phi/kernels/graph_send_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -26,7 +26,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count); diff --git a/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h b/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h index f5c7ce9a8937e..74050d126259d 100644 --- a/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h +++ b/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h @@ -29,8 +29,8 @@ void GraphSendUERecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, DenseTensor* x_grad, DenseTensor* y_grad); } // namespace phi diff --git a/paddle/phi/kernels/graph_send_ue_recv_kernel.h b/paddle/phi/kernels/graph_send_ue_recv_kernel.h index efb93ab47c93c..a308a78800f3a 100644 --- a/paddle/phi/kernels/graph_send_ue_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_ue_recv_kernel.h @@ -26,8 +26,8 @@ void GraphSendUERecvKernel(const Context& ctx, const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& compute_type, - const std::string& pool_type, + const std::string& message_op, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count); diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index c8c15619d5d39..0ca1a3fae0230 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -21,12 +21,12 @@ KernelSignature GraphSendRecvOpArgumentMapping( if (ctx.HasInput("Out_size")) { return KernelSignature("graph_send_recv", {"X", "Src_index", "Dst_index"}, - {"pool_type", "Out_size"}, + {"reduce_op", "Out_size"}, {"Out", "Dst_count"}); } else { return KernelSignature("graph_send_recv", {"X", "Src_index", "Dst_index"}, - {"pool_type", "out_size"}, + {"reduce_op", "out_size"}, {"Out", "Dst_count"}); } } @@ -36,7 +36,7 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( return KernelSignature( "graph_send_recv_grad", {"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, - {"pool_type"}, + {"reduce_op"}, {"X@GRAD"}); } diff --git a/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc b/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc index a4cd6f4a150b1..0b2ddcc07e1bb 100644 --- a/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc @@ -21,12 +21,12 @@ KernelSignature GraphSendUERecvOpArgumentMapping( if (ctx.HasInput("Out_size")) { return KernelSignature("graph_send_ue_recv", {"X", "Y", "Src_index", "Dst_index"}, - {"compute_type", "pool_type", "Out_size"}, + {"message_op", "reduce_op", "Out_size"}, {"Out", "Dst_count"}); } else { return KernelSignature("graph_send_ue_recv", {"X", "Y", "Src_index", "Dst_index"}, - {"compute_type", "pool_type", "out_size"}, + {"message_op", "reduce_op", "out_size"}, {"Out", "Dst_count"}); } } @@ -36,7 +36,7 @@ KernelSignature GraphSendUERecvGradOpArgumentMapping( return KernelSignature( "graph_send_ue_recv_grad", {"X", "Y", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, - {"compute_type", "pool_type"}, + {"message_op", "reduce_op"}, {"X@GRAD", "Y@GRAD"}); } diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index 1b7d8213e75ac..81fcf06167e13 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -46,7 +46,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'MAX'} + self.attrs = {'reduce_op': 'MAX'} out, self.gradient = compute_graph_send_recv_for_min_max( self.inputs, self.attrs) @@ -76,7 +76,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'MIN'} + self.attrs = {'reduce_op': 'MIN'} out, self.gradient = compute_graph_send_recv_for_min_max( self.inputs, self.attrs) @@ -107,7 +107,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'SUM'} + self.attrs = {'reduce_op': 'SUM'} out, _ = compute_graph_send_recv_for_sum_mean(self.inputs, self.attrs) @@ -134,7 +134,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'MEAN'} + self.attrs = {'reduce_op': 'MEAN'} out, dst_count = compute_graph_send_recv_for_sum_mean( self.inputs, self.attrs) @@ -153,15 +153,15 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - pool_type = attributes['pool_type'] + reduce_op = attributes['reduce_op'] gather_x = x[src_index] target_shape = list(x.shape) results = np.zeros(target_shape, dtype=x.dtype) - if pool_type == 'SUM': + if reduce_op == 'SUM': for index, s_id in enumerate(dst_index): results[s_id, :] += gather_x[index, :] - elif pool_type == 'MEAN': + elif reduce_op == 'MEAN': count = np.zeros(target_shape[0], dtype=np.int32) for index, s_id in enumerate(dst_index): results[s_id, :] += gather_x[index, :] @@ -169,7 +169,7 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): results = results / count.reshape([-1, 1]) results[np.isnan(results)] = 0 else: - raise ValueError("Invalid pool_type, only SUM, MEAN supported!") + raise ValueError("Invalid reduce_op, only SUM, MEAN supported!") count = np.zeros(target_shape[0], dtype=np.int32) for index, s_id in enumerate(dst_index): @@ -183,7 +183,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - pool_type = attributes['pool_type'] + reduce_op = attributes['reduce_op'] gather_x = x[src_index] target_shape = list(x.shape) @@ -191,7 +191,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): gradient = np.zeros_like(x) # Calculate forward output - if pool_type == "MAX": + if reduce_op == "MAX": first_set = set() for index, s_id in enumerate(dst_index): if s_id not in first_set: @@ -200,7 +200,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): else: results[s_id, :] = np.maximum(results[s_id, :], gather_x[index, :]) - elif pool_type == "MIN": + elif reduce_op == "MIN": first_set = set() for index, s_id in enumerate(dst_index): if s_id not in first_set: @@ -210,7 +210,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): results[s_id, :] = np.minimum(results[s_id, :], gather_x[index, :]) else: - raise ValueError("Invalid pool_type, only MAX, MIN supported!") + raise ValueError("Invalid reduce_op, only MAX, MIN supported!") # Calculate backward gradient index_size = len(src_index) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py index 25f4d3cb660f0..e8b5bdc7bb8f8 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py @@ -103,7 +103,7 @@ def compute_graph_send_ue_recv_for_sum(inputs, attributes): y = inputs['Y'] src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - message_op = attributes['compute_type'] + message_op = attributes['message_op'] gather_x = x[src_index] out_shp = [ @@ -126,7 +126,7 @@ def compute_graph_send_ue_recv_for_mean(inputs, attributes): y = inputs['Y'] src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - message_op = attributes['compute_type'] + message_op = attributes['message_op'] gather_x = x[src_index] out_shp = [ @@ -155,8 +155,8 @@ def compute_graph_send_ue_recv_for_max_min(inputs, attributes): y = inputs['Y'] src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - message_op = attributes['compute_type'] - reduce_op = attributes['pool_type'] + message_op = attributes['message_op'] + reduce_op = attributes['reduce_op'] gather_x = x[src_index] out_shp = [ @@ -277,7 +277,7 @@ def setUp(self): 'Src_index': self.src_index, 'Dst_index': self.dst_index } - self.attrs = {'compute_type': self.message_op, 'pool_type': 'SUM'} + self.attrs = {'message_op': self.message_op, 'reduce_op': 'SUM'} out = compute_graph_send_ue_recv_for_sum(self.inputs, self.attrs) @@ -389,7 +389,7 @@ def setUp(self): 'Src_index': self.src_index, 'Dst_index': self.dst_index } - self.attrs = {'compute_type': self.message_op, 'pool_type': 'MEAN'} + self.attrs = {'message_op': self.message_op, 'reduce_op': 'MEAN'} out, dst_count = compute_graph_send_ue_recv_for_mean( self.inputs, self.attrs) @@ -502,7 +502,7 @@ def setUp(self): 'Src_index': self.src_index, 'Dst_index': self.dst_index } - self.attrs = {'compute_type': self.message_op, 'pool_type': 'MAX'} + self.attrs = {'message_op': self.message_op, 'reduce_op': 'MAX'} out, self.gradients = compute_graph_send_ue_recv_for_max_min( self.inputs, self.attrs) @@ -618,7 +618,7 @@ def setUp(self): 'Src_index': self.src_index, 'Dst_index': self.dst_index } - self.attrs = {'compute_type': self.message_op, 'pool_type': 'MIN'} + self.attrs = {'message_op': self.message_op, 'reduce_op': 'MIN'} out, self.gradients = compute_graph_send_ue_recv_for_max_min( self.inputs, self.attrs) diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index cebd927566c97..bfe63f1f04d73 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -119,7 +119,7 @@ def send_u_recv(x, if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) out, tmp = _C_ops.graph_send_recv(x, src_index, - dst_index, None, 'pool_type', + dst_index, None, 'reduce_op', reduce_op.upper(), 'out_size', out_size) return out @@ -148,7 +148,7 @@ def send_u_recv(x, stop_gradient=True) inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} - attrs = {"pool_type": reduce_op.upper()} + attrs = {"reduce_op": reduce_op.upper()} get_out_size_tensor_inputs(inputs=inputs, attrs=attrs, out_size=out_size, @@ -178,8 +178,8 @@ def send_ue_recv(x, This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` - to gather the corresponding data, after computing with `y` in different compute types like add/sub/mul/div, then use `dst_index` to - update the corresponding position of output tensor in different pooling types, like sum, mean, max, or min. + to gather the corresponding data, after computing with `y` in different message ops like add/sub/mul/div, then use `dst_index` to + update the corresponding position of output tensor in different reduce ops, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. .. code-block:: text @@ -215,8 +215,8 @@ def send_ue_recv(x, src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. The available data type is int32, int64. - message_op (str): Different compute types for x and e, including `add`, `sub`, `mul`, `div`. - reduce_op (str): Different pooling types, including `sum`, `mean`, `max`, `min`. + message_op (str): Different message ops for x and e, including `add`, `sub`, `mul`, `div`. + reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`. Default value is `sum`. out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size is smaller or equal to 0, then this input will not be used. @@ -287,8 +287,8 @@ def send_ue_recv(x, if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) out, tmp = _C_ops.graph_send_ue_recv(x, y, src_index, dst_index, - None, 'compute_type', - message_op.upper(), 'pool_type', + None, 'message_op', + message_op.upper(), 'reduce_op', reduce_op.upper(), 'out_size', out_size) return out @@ -322,7 +322,7 @@ def send_ue_recv(x, stop_gradient=True) inputs = {"X": x, "Y": y, "Src_index": src_index, "Dst_index": dst_index} - attrs = {"compute_type": message_op.upper(), "pool_type": reduce_op.upper()} + attrs = {"message_op": message_op.upper(), "reduce_op": reduce_op.upper()} get_out_size_tensor_inputs(inputs=inputs, attrs=attrs, out_size=out_size, diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 132a6d4657ca1..4181885d419af 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -69,7 +69,7 @@ def graph_send_recv(x, src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. The available data type is int32, int64. - pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. + pool_type (str): The pooling types of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size is smaller or equal to 0, then this input will not be used. @@ -123,7 +123,7 @@ def graph_send_recv(x, if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) out, tmp = _C_ops.graph_send_recv(x, src_index, - dst_index, None, 'pool_type', + dst_index, None, 'reduce_op', pool_type.upper(), 'out_size', out_size) return out @@ -151,7 +151,7 @@ def graph_send_recv(x, stop_gradient=True) inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} - attrs = {"pool_type": pool_type.upper()} + attrs = {"reduce_op": pool_type.upper()} get_out_size_tensor_inputs(inputs=inputs, attrs=attrs, out_size=out_size,