Skip to content

Commit

Permalink
Polish arg_min_max_op
Browse files Browse the repository at this point in the history
* Remove unused arg_max/min_op.h
* Remove reference parameter. Use pointer insteaded.
* undef macro
* Always set OutT as int64_t.
  • Loading branch information
reyoung committed Jun 11, 2018
1 parent 6d32e96 commit 9b43ede
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 95 deletions.
28 changes: 13 additions & 15 deletions paddle/fluid/operators/arg_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_max_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OPERATOR(arg_max, paddle::operators::ArgMaxOp,
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker);

REGISTER_OP_CPU_KERNEL(
arg_max, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
float, int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double,
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int64_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int32_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, int16_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, uint8_t,
int64_t>);
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
19 changes: 9 additions & 10 deletions paddle/fluid/operators/arg_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_max_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OP_CUDA_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, double,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t, int64_t>,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int32_t, int64_t>,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t, int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, size_t,
int64_t>,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t, int64_t>);
uint8_t>);
16 changes: 0 additions & 16 deletions paddle/fluid/operators/arg_max_op.h

This file was deleted.

29 changes: 16 additions & 13 deletions paddle/fluid/operators/arg_min_max_op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
Expand All @@ -37,9 +38,9 @@ struct ArgMinMaxFunctor {};
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor& out, int64_t axis) { \
framework::LoDTensor* out, int64_t axis) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(out); \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
Expand All @@ -62,7 +63,7 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, out, axis)
functor##rank(dev_ctx, x, &out, axis)

switch (x.dims().size()) {
case 1:
Expand All @@ -89,19 +90,20 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};

template <typename DeviceContext, typename T, typename Tout>
template <typename DeviceContext, typename T>
using ArgMinKernel =
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMin>;
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;

template <typename DeviceContext, typename T, typename Tout>
template <typename DeviceContext, typename T>
using ArgMaxKernel =
ArgMinMaxKernel<DeviceContext, T, Tout, ArgMinMaxType::kArgMax>;
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;

typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

Expand All @@ -121,7 +123,7 @@ typedef class BaseArgMinMaxOp : public framework::OperatorWithKernel {
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
ctx->SetOutputDim("Out", framework::make_ddim(vec));
}
} ArgMinOp, ArgMaxOp;
};

class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
protected:
Expand All @@ -133,12 +135,13 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddComment(::paddle::string::Sprintf(R"DOC(
%s Operator.
AddComment(string::Sprintf(R"DOC(
%s Operator.
Computes the indices of the %s elements of the input tensor's element along the provided axis.
Computes the indices of the %s elements of the input tensor's element
along the provided axis.
)DOC",
OpName(), Name()));
OpName(), Name()));
}
};

Expand Down
28 changes: 13 additions & 15 deletions paddle/fluid/operators/arg_min_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinOp,
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker);

REGISTER_OP_CPU_KERNEL(
arg_min, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
float, int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double,
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int64_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int32_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, int16_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, uint8_t,
int64_t>);
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
19 changes: 9 additions & 10 deletions paddle/fluid/operators/arg_min_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_op.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OP_CUDA_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, double,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int64_t, int64_t>,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int32_t, int64_t>,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t, int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, size_t,
int64_t>,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t, int64_t>);
uint8_t>);
16 changes: 0 additions & 16 deletions paddle/fluid/operators/arg_min_op.h

This file was deleted.

0 comments on commit 9b43ede

Please sign in to comment.