From d48172f22a5a60f3d66e0d90a0ccc5880c03f953 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 7 Jun 2018 20:48:34 +0800 Subject: [PATCH] split reduce op into multiple libraries, accelerate the compiling (#11029) * "split into multiple .ccl" * "refine file structure" * "refine files" * "remove the cmakelist" * "fix typo" * "fix typo" * fix ci --- paddle/fluid/framework/op_registry.h | 12 +- paddle/fluid/operators/CMakeLists.txt | 2 - paddle/fluid/operators/reduce_max_op.cc | 34 ++ paddle/fluid/operators/reduce_max_op.cu | 34 ++ paddle/fluid/operators/reduce_mean_op.cc | 35 ++ paddle/fluid/operators/reduce_mean_op.cu | 34 ++ paddle/fluid/operators/reduce_mean_op.h | 39 +++ paddle/fluid/operators/reduce_min_max_op.h | 50 +++ paddle/fluid/operators/reduce_min_op.cc | 34 ++ paddle/fluid/operators/reduce_min_op.cu | 34 ++ paddle/fluid/operators/reduce_op.cc | 186 ----------- paddle/fluid/operators/reduce_op.cu | 41 --- paddle/fluid/operators/reduce_op.h | 352 +++++++++----------- paddle/fluid/operators/reduce_op_function.h | 109 ++++++ paddle/fluid/operators/reduce_prod_op.cc | 35 ++ paddle/fluid/operators/reduce_prod_op.cu | 34 ++ paddle/fluid/operators/reduce_prod_op.h | 39 +++ paddle/fluid/operators/reduce_sum_op.cc | 34 ++ paddle/fluid/operators/reduce_sum_op.cu | 34 ++ paddle/fluid/operators/reduce_sum_op.h | 39 +++ 20 files changed, 789 insertions(+), 422 deletions(-) create mode 100644 paddle/fluid/operators/reduce_max_op.cc create mode 100644 paddle/fluid/operators/reduce_max_op.cu create mode 100644 paddle/fluid/operators/reduce_mean_op.cc create mode 100644 paddle/fluid/operators/reduce_mean_op.cu create mode 100644 paddle/fluid/operators/reduce_mean_op.h create mode 100644 paddle/fluid/operators/reduce_min_max_op.h create mode 100644 paddle/fluid/operators/reduce_min_op.cc create mode 100644 paddle/fluid/operators/reduce_min_op.cu delete mode 100644 paddle/fluid/operators/reduce_op.cc delete mode 100644 paddle/fluid/operators/reduce_op.cu create mode 100644 paddle/fluid/operators/reduce_op_function.h create mode 100644 paddle/fluid/operators/reduce_prod_op.cc create mode 100644 paddle/fluid/operators/reduce_prod_op.cu create mode 100644 paddle/fluid/operators/reduce_prod_op.h create mode 100644 paddle/fluid/operators/reduce_sum_op.cc create mode 100644 paddle/fluid/operators/reduce_sum_op.cu create mode 100644 paddle/fluid/operators/reduce_sum_op.h diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index e57c2ff3d0006..43ab227a94787 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -156,15 +156,15 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register OperatorKernel. */ -#define REGISTER_OP_KERNEL(op_type, LIBRARY_TYPE, place_class, ...) \ +#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op_kernel_##op_type##_##LIBRARY_TYPE##__, \ + __reg_op_kernel_##op_type##_##library_type##__, \ "REGISTER_OP_KERNEL must be called in global namespace"); \ static ::paddle::framework::OpKernelRegistrar \ - __op_kernel_registrar_##op_type##_##LIBRARY_TYPE##__(#op_type, \ - #LIBRARY_TYPE); \ - int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE() { \ - __op_kernel_registrar_##op_type##_##LIBRARY_TYPE##__.Touch(); \ + __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ + #library_type); \ + int TouchOpKernelRegistrar_##op_type##_##library_type() { \ + __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ return 0; \ } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index f75b7c70d60e7..5e86b16ba1ff6 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -166,8 +166,6 @@ function(op_library TARGET) # NOTE(*): activation use macro to regist the kernels, set use_op manually. if(${TARGET} STREQUAL "activation") file(APPEND ${pybind_file} "USE_OP(relu);\n") - elseif(${TARGET} STREQUAL "reduce") - file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") elseif(${TARGET} STREQUAL "fake_dequantize") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") else() diff --git a/paddle/fluid/operators/reduce_max_op.cc b/paddle/fluid/operators/reduce_max_op.cc new file mode 100644 index 0000000000000..95d3768e1fdf6 --- /dev/null +++ b/paddle/fluid/operators/reduce_max_op.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_min_max_op.h" + +REGISTER_REDUCE_OP(reduce_max); +REGISTER_OP_CPU_KERNEL( + reduce_max, ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL( + reduce_max_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_max_op.cu b/paddle/fluid/operators/reduce_max_op.cu new file mode 100644 index 0000000000000..0d86b3127e42f --- /dev/null +++ b/paddle/fluid/operators/reduce_max_op.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_min_max_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_max, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_max_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_mean_op.cc b/paddle/fluid/operators/reduce_mean_op.cc new file mode 100644 index 0000000000000..fc258c2496340 --- /dev/null +++ b/paddle/fluid/operators/reduce_mean_op.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_mean_op.h" + +REGISTER_REDUCE_OP(reduce_mean); +REGISTER_OP_CPU_KERNEL(reduce_mean, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_mean_grad, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_mean_op.cu b/paddle/fluid/operators/reduce_mean_op.cu new file mode 100644 index 0000000000000..960cb3235be7f --- /dev/null +++ b/paddle/fluid/operators/reduce_mean_op.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_mean_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_mean, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_mean_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_mean_op.h b/paddle/fluid/operators/reduce_mean_op.h new file mode 100644 index 0000000000000..1359679c4767d --- /dev/null +++ b/paddle/fluid/operators/reduce_mean_op.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +struct MeanFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->mean(dim); + } +}; + +struct MeanGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + dx->device(place) = dy->broadcast(dim) / dx->constant(size); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_min_max_op.h b/paddle/fluid/operators/reduce_min_max_op.h new file mode 100644 index 0000000000000..ec59f3e71c1c7 --- /dev/null +++ b/paddle/fluid/operators/reduce_min_max_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +struct MaxFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->maximum(dim); + } +}; + +struct MinFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->minimum(dim); + } +}; + +struct MaxOrMinGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); + // If there are multiple minimum or maximum elements, the subgradient of + // each is the set [0, 1], and we pass gradient to all of them here. + dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_min_op.cc b/paddle/fluid/operators/reduce_min_op.cc new file mode 100644 index 0000000000000..330a86d2e4237 --- /dev/null +++ b/paddle/fluid/operators/reduce_min_op.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_min_max_op.h" + +REGISTER_REDUCE_OP(reduce_min); +REGISTER_OP_CPU_KERNEL( + reduce_min, ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL( + reduce_min_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_min_op.cu b/paddle/fluid/operators/reduce_min_op.cu new file mode 100644 index 0000000000000..da466f805eff4 --- /dev/null +++ b/paddle/fluid/operators/reduce_min_op.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_min_max_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_min, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_min_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_op.cc b/paddle/fluid/operators/reduce_op.cc deleted file mode 100644 index e293fd5e410b2..0000000000000 --- a/paddle/fluid/operators/reduce_op.cc +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/reduce_op.h" - -#include -#include -#include - -namespace paddle { -namespace operators { - -using framework::Tensor; - -class ReduceOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ReduceOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ReduceOp should not be null."); - auto x_dims = ctx->GetInputDim("X"); - auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - auto dims = ctx->Attrs().Get>("dim"); - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) dims[i] = x_rank + dims[i]; - PADDLE_ENFORCE_LT( - dims[i], x_rank, - "The dim should be in the range [-rank(input), rank(input))."); - } - sort(dims.begin(), dims.end()); - bool reduce_all = ctx->Attrs().Get("reduce_all"); - bool keep_dim = ctx->Attrs().Get("keep_dim"); - if (reduce_all) { - if (keep_dim) - ctx->SetOutputDim( - "Out", framework::make_ddim(std::vector(x_rank, 1))); - else - ctx->SetOutputDim("Out", {1}); - } else { - auto dims_vector = vectorize(x_dims); - if (keep_dim) { - for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = 1; - } - } else { - const int kDelFlag = -2; - for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = kDelFlag; - } - dims_vector.erase( - remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); - } - auto out_dims = framework::make_ddim(dims_vector); - ctx->SetOutputDim("Out", out_dims); - if (dims[0] != 0) { - // Only pass LoD when not reducing on the first dim. - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - } -}; - -class ReduceGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); - auto x_dims = ctx->GetInputDim("X"); - auto x_rank = x_dims.size(); - PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - auto dims = ctx->Attrs().Get>("dim"); - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) dims[i] = x_rank + dims[i]; - PADDLE_ENFORCE_LT( - dims[i], x_rank, - "The dim should be in the range [-rank(input), rank(input))."); - } - sort(dims.begin(), dims.end()); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - ctx->ShareLoD("X", /*->*/ x_grad_name); - } - } -}; - -class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() final { - AddInput("X", - "(Tensor) The input tensor. Tensors with rank at most 6 are " - "supported."); - AddOutput("Out", "(Tensor) The result tensor."); - AddAttr>( - "dim", - "(list, default {0}) The dimensions to reduce. " - "Must be in the range [-rank(input), rank(input)). " - "If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. " - "Note that reducing on the first dim will make the LoD info lost.") - .SetDefault({0}); - AddAttr("keep_dim", - "(bool, default false) " - "If true, retain the reduced dimension with length 1.") - .SetDefault(false); - AddAttr("reduce_all", - "(bool, default false) " - "If true, output a scalar reduced along all dimensions.") - .SetDefault(false); - AddComment(string::Sprintf(R"DOC( -%s Operator. - -This operator computes the %s of input tensor along the given dimension. -The result tensor has 1 fewer dimension than the input unless keep_dim is true. -If reduce_all is true, just reduce along all dimensions and output a scalar. - -)DOC", - GetOpType(), GetName())); - } - - protected: - virtual std::string GetName() const = 0; - virtual std::string GetOpType() const = 0; -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -#define REGISTER_REDUCE_OP(op_name) \ - class __##op_name##Maker__ : public ops::ReduceOpMaker { \ - protected: \ - virtual std::string GetName() const { return #op_name; } \ - virtual std::string GetOpType() const { return "Reduce " #op_name; } \ - }; \ - REGISTER_OPERATOR(reduce_##op_name, ops::ReduceOp, __##op_name##Maker__, \ - paddle::framework::DefaultGradOpDescMaker); \ - REGISTER_OPERATOR(reduce_##op_name##_grad, ops::ReduceGradOp) - -REGISTER_REDUCE_OP(sum); -REGISTER_REDUCE_OP(mean); -REGISTER_REDUCE_OP(max); -REGISTER_REDUCE_OP(min); -REGISTER_REDUCE_OP(prod); - -#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \ - REGISTER_OP_CPU_KERNEL(reduce_type, \ - ops::ReduceKernel, \ - ops::ReduceKernel, \ - ops::ReduceKernel, \ - ops::ReduceKernel); \ - REGISTER_OP_CPU_KERNEL( \ - reduce_type##_grad, \ - ops::ReduceGradKernel, \ - ops::ReduceGradKernel, \ - ops::ReduceGradKernel, \ - ops::ReduceGradKernel); - -FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_CPU_KERNEL); diff --git a/paddle/fluid/operators/reduce_op.cu b/paddle/fluid/operators/reduce_op.cu deleted file mode 100644 index ae29587f55847..0000000000000 --- a/paddle/fluid/operators/reduce_op.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#define EIGEN_USE_GPU -#include "paddle/fluid/operators/reduce_op.h" - -namespace ops = paddle::operators; - -#define REGISTER_REDUCE_GPU_KERNEL(reduce_type, functor, grad_functor) \ - REGISTER_OP_CUDA_KERNEL( \ - reduce_type, ops::ReduceKernel, \ - ops::ReduceKernel, \ - ops::ReduceKernel, \ - ops::ReduceKernel); \ - REGISTER_OP_CUDA_KERNEL( \ - reduce_type##_grad, \ - ops::ReduceGradKernel, \ - ops::ReduceGradKernel, \ - ops::ReduceGradKernel, \ - ops::ReduceGradKernel); - -FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_GPU_KERNEL); diff --git a/paddle/fluid/operators/reduce_op.h b/paddle/fluid/operators/reduce_op.h index 7df47f316c30b..72b6cf1773d5b 100644 --- a/paddle/fluid/operators/reduce_op.h +++ b/paddle/fluid/operators/reduce_op.h @@ -14,105 +14,20 @@ limitations under the License. */ #pragma once +#include +#include #include -#include "glog/logging.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/fluid/operators/reduce_op_function.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using DDim = framework::DDim; -template -using EigenTensor = framework::EigenTensor; -template -using EigenScalar = framework::EigenScalar; -template -using EigenVector = framework::EigenVector; - -struct SumFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->sum(dim); - } -}; - -struct SumGradFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - dx->device(place) = dy->broadcast(dim); - } -}; - -struct MeanFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->mean(dim); - } -}; - -struct MeanGradFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - dx->device(place) = dy->broadcast(dim) / dx->constant(size); - } -}; - -struct MaxFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->maximum(dim); - } -}; - -struct MinFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->minimum(dim); - } -}; - -struct MaxOrMinGradFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - auto equals = (*x) == y->broadcast(dim); - auto ones = dx->constant(1); - auto zeros = dx->constant(0); - // If there are multiple minimum or maximum elements, the subgradient of - // each is the set [0, 1], and we pass gradient to all of them here. - dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros); - } -}; - -struct ProdFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->prod(dim); - } -}; - -struct ProdGradFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse(); - } -}; - -#define HANDLE_DIM(NDIM, RDIM) \ - if (ndim == NDIM && rdim == RDIM) { \ - ReduceCompute(context); \ +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + ReduceFunctor( \ + context.template device_context(), *input, output, \ + dims, keep_dim); \ } template @@ -120,11 +35,15 @@ class ReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { bool reduce_all = context.Attr("reduce_all"); + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + auto dims = context.Attr>("dim"); + bool keep_dim = context.Attr("keep_dim"); + if (reduce_all) { // Flatten and reduce 1-D tensor - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); auto x = EigenVector::Flatten(*input); auto out = EigenScalar::From(*output); auto& place = @@ -133,8 +52,8 @@ class ReduceKernel : public framework::OpKernel { Functor functor; functor(place, &x, &out, reduce_dim); } else { - int ndim = context.Input("X")->dims().size(); - int rdim = context.Attr>("dim").size(); + int ndim = input->dims().size(); + int rdim = dims.size(); // comments for accelerating compiling temporarily. // HANDLE_DIM(6, 5); // HANDLE_DIM(6, 4); @@ -154,48 +73,6 @@ class ReduceKernel : public framework::OpKernel { HANDLE_DIM(1, 1); } } - - private: - template - void ReduceCompute(const framework::ExecutionContext& context) const { - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - - auto x = EigenTensor::From(*input); - auto x_rank = static_cast(x.dimensions().size()); - auto dims = context.Attr>("dim"); - auto reduce_dim = Eigen::array(); - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) dims[i] = x_rank + dims[i]; - reduce_dim[i] = dims[i]; - } - // construct the squeezed output tensor - bool keep_dim = context.Attr("keep_dim"); - DDim out_dims = output->dims(); - if (keep_dim && x_rank > 1) { - const int kDelFlag = -2; - auto dims_vector = vectorize(out_dims); - for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = kDelFlag; - } - dims_vector.erase( - remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); - out_dims = framework::make_ddim(dims_vector); - } - auto& place = - *context.template device_context().eigen_device(); - Functor functor; - - if (D == 1) { - auto out = EigenScalar::From(*output); - functor(place, &x, &out, reduce_dim); - } else { - auto out = EigenTensor::From(*output, out_dims); - functor(place, &x, &out, reduce_dim); - } - } }; template @@ -203,12 +80,15 @@ class ReduceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { bool reduce_all = context.Attr("reduce_all"); + auto dims = context.Attr>("dim"); + + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Out"); + auto* input2 = context.Input(framework::GradVarName("Out")); + auto* output = context.Output(framework::GradVarName("X")); + output->mutable_data(context.GetPlace()); + if (reduce_all) { - auto* input0 = context.Input("X"); - auto* input1 = context.Input("Out"); - auto* input2 = context.Input(framework::GradVarName("Out")); - auto* output = context.Output(framework::GradVarName("X")); - output->mutable_data(context.GetPlace()); auto x = EigenVector::Flatten(*input0); auto x_reduce = EigenVector::From(*input1); auto x_reduce_grad = EigenVector::From(*input2); @@ -221,74 +101,172 @@ class ReduceGradKernel : public framework::OpKernel { functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, broadcast_dim[0]); } else { - int rank = context.Input("X")->dims().size(); + int rank = input0->dims().size(); switch (rank) { case 1: - ReduceGradCompute<1>(context); + ReduceGradFunctor( + context.template device_context(), *input0, + *input1, *input2, output, dims); break; case 2: - ReduceGradCompute<2>(context); + ReduceGradFunctor( + context.template device_context(), *input0, + *input1, *input2, output, dims); break; case 3: - ReduceGradCompute<3>(context); + ReduceGradFunctor( + context.template device_context(), *input0, + *input1, *input2, output, dims); break; case 4: - ReduceGradCompute<4>(context); + ReduceGradFunctor( + context.template device_context(), *input0, + *input1, *input2, output, dims); break; case 5: - ReduceGradCompute<5>(context); + ReduceGradFunctor( + context.template device_context(), *input0, + *input1, *input2, output, dims); break; case 6: - ReduceGradCompute<6>(context); + ReduceGradFunctor( + context.template device_context(), *input0, + *input1, *input2, output, dims); break; } } } +}; - private: - template - void ReduceGradCompute(const framework::ExecutionContext& context) const { - auto* input0 = context.Input("X"); - auto* input1 = context.Input("Out"); - auto* input2 = context.Input(framework::GradVarName("Out")); - auto* output = context.Output(framework::GradVarName("X")); +class ReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; - output->mutable_data(context.GetPlace()); - auto x = EigenTensor::From(*input0); - auto x_grad = EigenTensor::From(*output); - auto x_rank = static_cast(x.dimensions().size()); - auto dims = context.Attr>("dim"); - auto x_dims = input0->dims(); - auto reduced_dims_v = vectorize(x_dims); - Eigen::array broadcast_dim; - for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReduceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReduceOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); + auto dims = ctx->Attrs().Get>("dim"); + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) dims[i] = x_rank + dims[i]; + PADDLE_ENFORCE_LT( + dims[i], x_rank, + "The dim should be in the range [-rank(input), rank(input))."); + } + sort(dims.begin(), dims.end()); + bool reduce_all = ctx->Attrs().Get("reduce_all"); + bool keep_dim = ctx->Attrs().Get("keep_dim"); + if (reduce_all) { + if (keep_dim) + ctx->SetOutputDim( + "Out", framework::make_ddim(std::vector(x_rank, 1))); + else + ctx->SetOutputDim("Out", {1}); + } else { + auto dims_vector = vectorize(x_dims); + if (keep_dim) { + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = 1; + } + } else { + const int kDelFlag = -2; + for (size_t i = 0; i < dims.size(); ++i) { + dims_vector[dims[i]] = kDelFlag; + } + dims_vector.erase( + remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + } + auto out_dims = framework::make_ddim(dims_vector); + ctx->SetOutputDim("Out", out_dims); + if (dims[0] != 0) { + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } +}; + +class ReduceGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; - int broad_cats_times = 1; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = x_dims.size(); + PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); + auto dims = ctx->Attrs().Get>("dim"); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i] < 0) dims[i] = x_rank + dims[i]; - reduced_dims_v[dims[i]] = 1; - broadcast_dim[dims[i]] = x_dims[dims[i]]; - broad_cats_times *= x_dims[dims[i]]; + PADDLE_ENFORCE_LT( + dims[i], x_rank, + "The dim should be in the range [-rank(input), rank(input))."); + } + sort(dims.begin(), dims.end()); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + ctx->ShareLoD("X", /*->*/ x_grad_name); } - auto reduced_dims = framework::make_ddim(reduced_dims_v); - auto x_reduce = EigenTensor::From(*input1, reduced_dims); - auto x_reduce_grad = EigenTensor::From(*input2, reduced_dims); + } +}; + +class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() final { + AddInput("X", + "(Tensor) The input tensor. Tensors with rank at most 6 are " + "supported."); + AddOutput("Out", "(Tensor) The result tensor."); + AddAttr>( + "dim", + "(list, default {0}) The dimensions to reduce. " + "Must be in the range [-rank(input), rank(input)). " + "If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. " + "Note that reducing on the first dim will make the LoD info lost.") + .SetDefault({0}); + AddAttr("keep_dim", + "(bool, default false) " + "If true, retain the reduced dimension with length 1.") + .SetDefault(false); + AddAttr("reduce_all", + "(bool, default false) " + "If true, output a scalar reduced along all dimensions.") + .SetDefault(false); + AddComment(string::Sprintf(R"DOC( +%s Operator. - auto& place = - *context.template device_context().eigen_device(); +This operator computes the %s of input tensor along the given dimension. +The result tensor has 1 fewer dimension than the input unless keep_dim is true. +If reduce_all is true, just reduce along all dimensions and output a scalar. - Functor functor; - functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, - broad_cats_times); +)DOC", + GetOpType(), GetName())); } + + protected: + virtual std::string GetName() const = 0; + virtual std::string GetOpType() const = 0; }; } // namespace operators } // namespace paddle -#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ - __macro(reduce_sum, SumFunctor, SumGradFunctor); \ - __macro(reduce_mean, MeanFunctor, MeanGradFunctor); \ - __macro(reduce_max, MaxFunctor, MaxOrMinGradFunctor); \ - __macro(reduce_min, MinFunctor, MaxOrMinGradFunctor); \ - __macro(reduce_prod, ProdFunctor, ProdGradFunctor); +namespace ops = paddle::operators; + +#define REGISTER_REDUCE_OP(op_name) \ + class __##op_name##Maker__ : public ops::ReduceOpMaker { \ + protected: \ + virtual std::string GetName() const { return #op_name; } \ + virtual std::string GetOpType() const { return "Reduce " #op_name; } \ + }; \ + REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \ + paddle::framework::DefaultGradOpDescMaker); \ + REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp) diff --git a/paddle/fluid/operators/reduce_op_function.h b/paddle/fluid/operators/reduce_op_function.h new file mode 100644 index 0000000000000..3da27bc8ac8d4 --- /dev/null +++ b/paddle/fluid/operators/reduce_op_function.h @@ -0,0 +1,109 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; +template +using EigenTensor = framework::EigenTensor; +template +using EigenScalar = framework::EigenScalar; +template +using EigenVector = framework::EigenVector; + +template +void ReduceFunctor(const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* output, const std::vector& dims, + bool keep_dim) { + auto x = EigenTensor::From(input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + // construct the squeezed output tensor + DDim out_dims = output->dims(); + if (keep_dim && x_rank > 1) { + const int kDelFlag = -2; + auto dims_vector = framework::vectorize(out_dims); + for (size_t i = 0; i < dims_ref.size(); ++i) { + dims_vector[dims_ref[i]] = kDelFlag; + } + dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = framework::make_ddim(dims_vector); + } + auto& place = *context.eigen_device(); + Functor functor; + + if (D == 1) { + auto out = EigenScalar::From(*output); + functor(place, &x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(place, &x, &out, reduce_dim); + } +} + +template +void ReduceGradFunctor(const DeviceContext& context, + const framework::Tensor& input0, + const framework::Tensor& input1, + const framework::Tensor& input2, + framework::Tensor* output, + const std::vector& dims) { + auto x = EigenTensor::From(input0); + auto x_grad = EigenTensor::From(*output); + auto x_rank = static_cast(x.dimensions().size()); + auto x_dims = input0.dims(); + auto reduced_dims_v = framework::vectorize(x_dims); + std::vector dims_ref = dims; + Eigen::array broadcast_dim; + for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; + + int broad_cats_times = 1; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) { + dims_ref[i] = x_rank + dims_ref[i]; + } + reduced_dims_v[dims_ref[i]] = 1; + broadcast_dim[dims_ref[i]] = x_dims[dims_ref[i]]; + broad_cats_times *= x_dims[dims_ref[i]]; + } + auto reduced_dims = framework::make_ddim(reduced_dims_v); + auto x_reduce = EigenTensor::From(input1, reduced_dims); + auto x_reduce_grad = EigenTensor::From(input2, reduced_dims); + + auto& place = *context.eigen_device(); + + Functor functor; + functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, + broad_cats_times); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_prod_op.cc b/paddle/fluid/operators/reduce_prod_op.cc new file mode 100644 index 0000000000000..713728b99757a --- /dev/null +++ b/paddle/fluid/operators/reduce_prod_op.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_prod_op.h" + +REGISTER_REDUCE_OP(reduce_prod); +REGISTER_OP_CPU_KERNEL(reduce_prod, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_prod_grad, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_prod_op.cu b/paddle/fluid/operators/reduce_prod_op.cu new file mode 100644 index 0000000000000..d62e677d92cff --- /dev/null +++ b/paddle/fluid/operators/reduce_prod_op.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_prod_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_prod, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_prod_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_prod_op.h b/paddle/fluid/operators/reduce_prod_op.h new file mode 100644 index 0000000000000..97748113e0927 --- /dev/null +++ b/paddle/fluid/operators/reduce_prod_op.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +struct ProdFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->prod(dim); + } +}; + +struct ProdGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_sum_op.cc b/paddle/fluid/operators/reduce_sum_op.cc new file mode 100644 index 0000000000000..c5b5398787b44 --- /dev/null +++ b/paddle/fluid/operators/reduce_sum_op.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_sum_op.h" + +REGISTER_REDUCE_OP(reduce_sum); +REGISTER_OP_CPU_KERNEL( + reduce_sum, ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CPU_KERNEL(reduce_sum_grad, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_sum_op.cu b/paddle/fluid/operators/reduce_sum_op.cu new file mode 100644 index 0000000000000..f2e16955a50dc --- /dev/null +++ b/paddle/fluid/operators/reduce_sum_op.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_sum_op.h" + +REGISTER_OP_CUDA_KERNEL(reduce_sum, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel, + ops::ReduceKernel); +REGISTER_OP_CUDA_KERNEL( + reduce_sum_grad, ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel, + ops::ReduceGradKernel); diff --git a/paddle/fluid/operators/reduce_sum_op.h b/paddle/fluid/operators/reduce_sum_op.h new file mode 100644 index 0000000000000..e67d7e1da5f02 --- /dev/null +++ b/paddle/fluid/operators/reduce_sum_op.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/reduce_op.h" + +namespace paddle { +namespace operators { + +struct SumFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->sum(dim); + } +}; + +struct SumGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + dx->device(place) = dy->broadcast(dim); + } +}; + +} // namespace operators +} // namespace paddle