From 9b43edeae05d0a9419c787f0166b95f6a70ba4f7 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Mon, 11 Jun 2018 12:26:16 +0800 Subject: [PATCH] Polish arg_min_max_op * Remove unused arg_max/min_op.h * Remove reference parameter. Use pointer insteaded. * undef macro * Always set OutT as int64_t. --- paddle/fluid/operators/arg_max_op.cc | 28 +++++++++---------- paddle/fluid/operators/arg_max_op.cu | 19 ++++++------- paddle/fluid/operators/arg_max_op.h | 16 ----------- paddle/fluid/operators/arg_min_max_op_base.h | 29 +++++++++++--------- paddle/fluid/operators/arg_min_op.cc | 28 +++++++++---------- paddle/fluid/operators/arg_min_op.cu | 19 ++++++------- paddle/fluid/operators/arg_min_op.h | 16 ----------- 7 files changed, 60 insertions(+), 95 deletions(-) delete mode 100644 paddle/fluid/operators/arg_max_op.h delete mode 100644 paddle/fluid/operators/arg_min_op.h diff --git a/paddle/fluid/operators/arg_max_op.cc b/paddle/fluid/operators/arg_max_op.cc index 859cccd1b2dfc..8174d3735859b 100644 --- a/paddle/fluid/operators/arg_max_op.cc +++ b/paddle/fluid/operators/arg_max_op.cc @@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_max_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" -REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp, +REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - arg_max, paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel); + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel); diff --git a/paddle/fluid/operators/arg_max_op.cu b/paddle/fluid/operators/arg_max_op.cu index c9c102bdccf32..a147d77a9e9c5 100644 --- a/paddle/fluid/operators/arg_max_op.cu +++ b/paddle/fluid/operators/arg_max_op.cu @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_max_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" REGISTER_OP_CUDA_KERNEL( arg_max, - paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, + paddle::operators::ArgMaxKernel, paddle::operators::ArgMaxKernel, + int32_t>, paddle::operators::ArgMaxKernel, + int16_t>, paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, + size_t>, paddle::operators::ArgMaxKernel); + uint8_t>); diff --git a/paddle/fluid/operators/arg_max_op.h b/paddle/fluid/operators/arg_max_op.h deleted file mode 100644 index d232a856992fb..0000000000000 --- a/paddle/fluid/operators/arg_max_op.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/operators/arg_min_max_op_base.h" diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 8c20461a345b8..6cbdaefeda099 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include "paddle/fluid/framework/ddim.h" @@ -37,9 +38,9 @@ struct ArgMinMaxFunctor {}; struct ArgMinMaxFunctor { \ void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \ - framework::LoDTensor& out, int64_t axis) { \ + framework::LoDTensor* out, int64_t axis) { \ auto in_eigen = framework::EigenTensor::From(in); \ - auto out_eigen = framework::EigenTensor::From(out); \ + auto out_eigen = framework::EigenTensor::From(*out); \ out_eigen.device(*(ctx.eigen_device())) = \ in_eigen.eigen_op_type(axis).template cast(); \ } \ @@ -62,7 +63,7 @@ class ArgMinMaxKernel : public framework::OpKernel { #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMinMaxFunctor \ functor##rank; \ - functor##rank(dev_ctx, x, out, axis) + functor##rank(dev_ctx, x, &out, axis) switch (x.dims().size()) { case 1: @@ -89,19 +90,20 @@ class ArgMinMaxKernel : public framework::OpKernel { "than 6.", (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax")); break; +#undef CALL_ARG_MINMAX_FUNCTOR } } }; -template +template using ArgMinKernel = - ArgMinMaxKernel; + ArgMinMaxKernel; -template +template using ArgMaxKernel = - ArgMinMaxKernel; + ArgMinMaxKernel; -typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel { +class ArgMinMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -121,7 +123,7 @@ typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel { for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]); ctx->SetOutputDim("Out", framework::make_ddim(vec)); } -} ArgMinOp, ArgMaxOp; +}; class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { protected: @@ -133,12 +135,13 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "Input tensor."); AddOutput("Out", "Output tensor."); AddAttr("axis", "The axis in which to compute the arg indics."); - AddComment(::paddle::string::Sprintf(R"DOC( - %s Operator. + AddComment(string::Sprintf(R"DOC( + %s Operator. - Computes the indices of the %s elements of the input tensor's element along the provided axis. + Computes the indices of the %s elements of the input tensor's element + along the provided axis. )DOC", - OpName(), Name())); + OpName(), Name())); } }; diff --git a/paddle/fluid/operators/arg_min_op.cc b/paddle/fluid/operators/arg_min_op.cc index 18c0884a04260..41f188029f17d 100644 --- a/paddle/fluid/operators/arg_min_op.cc +++ b/paddle/fluid/operators/arg_min_op.cc @@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_min_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" -REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp, +REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - arg_min, paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel); + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel); diff --git a/paddle/fluid/operators/arg_min_op.cu b/paddle/fluid/operators/arg_min_op.cu index 6d5aaa9596101..4d020508505a6 100644 --- a/paddle/fluid/operators/arg_min_op.cu +++ b/paddle/fluid/operators/arg_min_op.cu @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/arg_min_op.h" +#include "paddle/fluid/operators/arg_min_max_op_base.h" REGISTER_OP_CUDA_KERNEL( arg_min, - paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, + paddle::operators::ArgMinKernel, paddle::operators::ArgMinKernel, + int32_t>, paddle::operators::ArgMinKernel, + int16_t>, paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, + size_t>, paddle::operators::ArgMinKernel); + uint8_t>); diff --git a/paddle/fluid/operators/arg_min_op.h b/paddle/fluid/operators/arg_min_op.h deleted file mode 100644 index d232a856992fb..0000000000000 --- a/paddle/fluid/operators/arg_min_op.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/operators/arg_min_max_op_base.h"