Skip to content

Commit

Permalink
Merge pull request #11313 from sneaxiy/argmin_argmax
Browse files Browse the repository at this point in the history
Add argmin and argmax op
  • Loading branch information
sneaxiy committed Jun 11, 2018
2 parents 69b5a62 + 9b43ede commit 831909c
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 0 deletions.
33 changes: 33 additions & 0 deletions paddle/fluid/operators/arg_max_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker);

REGISTER_OP_CPU_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
31 changes: 31 additions & 0 deletions paddle/fluid/operators/arg_max_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OP_CUDA_KERNEL(
arg_max,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
160 changes: 160 additions & 0 deletions paddle/fluid/operators/arg_min_max_op_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace operators {

enum ArgMinMaxType { kArgMin, kArgMax };

template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};

#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, int64_t axis) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
}

DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);

template <typename DeviceContext, typename T, typename Tout,
ArgMinMaxType EnumArgMinMaxValue>
class ArgMinMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
out.mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();

#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, axis)

switch (x.dims().size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
PADDLE_THROW(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};

template <typename DeviceContext, typename T>
using ArgMinKernel =
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;

template <typename DeviceContext, typename T>
using ArgMaxKernel =
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;

class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
const auto& x_dims = ctx->GetInputDim("X");
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(),
"'axis' must be inside [-Rank(X), Rank(X))");

auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;

std::vector<int64_t> vec;
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
ctx->SetOutputDim("Out", framework::make_ddim(vec));
}
};

class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
protected:
virtual const char* OpName() const = 0;
virtual const char* Name() const = 0;

public:
void Make() override {
AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddComment(string::Sprintf(R"DOC(
%s Operator.
Computes the indices of the %s elements of the input tensor's element
along the provided axis.
)DOC",
OpName(), Name()));
}
};

class ArgMinOpMaker : public BaseArgMinMaxOpMaker {
protected:
const char* OpName() const override { return "ArgMin"; }
const char* Name() const override { return "min"; }
};

class ArgMaxOpMaker : public BaseArgMinMaxOpMaker {
protected:
const char* OpName() const override { return "ArgMax"; }
const char* Name() const override { return "max"; }
};
} // namespace operators
} // namespace paddle
33 changes: 33 additions & 0 deletions paddle/fluid/operators/arg_min_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker);

REGISTER_OP_CPU_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
31 changes: 31 additions & 0 deletions paddle/fluid/operators/arg_min_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/arg_min_max_op_base.h"

REGISTER_OP_CUDA_KERNEL(
arg_min,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>);
64 changes: 64 additions & 0 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
'assign',
'fill_constant_batch_size_like',
'fill_constant',
'argmin',
'argmax',
'ones',
'zeros',
]
Expand Down Expand Up @@ -315,6 +317,68 @@ def fill_constant_batch_size_like(input,
return out


def argmin(x, axis=0):
"""
**argmin**
This function computes the indices of the min elements
of the input tensor's element along the provided axis.
Args:
x(Variable): The input to compute the indices of
the min elements.
axis(int): Axis to compute indices along.
Returns:
Variable: The tensor variable storing the output
Examples:
.. code-block:: python
out = fluid.layers.argmin(x=in, axis=0)
out = fluid.layers.argmin(x=in, axis=-1)
"""
helper = LayerHelper("arg_min", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
helper.append_op(
type='arg_min',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
return out


def argmax(x, axis=0):
"""
**argmax**
This function computes the indices of the max elements
of the input tensor's element along the provided axis.
Args:
x(Variable): The input to compute the indices of
the max elements.
axis(int): Axis to compute indices along.
Returns:
Variable: The tensor variable storing the output
Examples:
.. code-block:: python
out = fluid.layers.argmax(x=in, axis=0)
out = fluid.layers.argmax(x=in, axis=-1)
"""
helper = LayerHelper("arg_max", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
helper.append_op(
type='arg_max',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
return out


def ones(shape, dtype, force_cpu=False):
"""
**ones**
Expand Down
Loading

0 comments on commit 831909c

Please sign in to comment.