Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 151 additions & 3 deletions csrc/fastdeploy/function/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "fastdeploy/function/reduce.h"

#include <limits>
#include <set>

#include "fastdeploy/function/eigen.h"
Expand Down Expand Up @@ -215,9 +216,139 @@ void Reduce(const FDTensor& x, FDTensor* out, const std::vector<int64_t>& dims,
}
reduce_all = (reduce_all || full_dim);

FD_VISIT_ALL_TYPES(x.dtype, "ReduceKernelImpl", ([&] {
ReduceKernelImpl<data_t, Functor>(x, out, dims, keep_dim,
reduce_all);
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ReduceKernelImpl", ([&] {
ReduceKernelImpl<data_t, Functor>(
x, out, dims, keep_dim, reduce_all);
}));
}

enum ArgMinMaxType { kArgMin, kArgMax };

template <typename T, typename Tout, int64_t Rank, ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};

#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<T, Tout, Rank, enum_argminmax_value> { \
void operator()(const FDTensor& in, FDTensor* out, \
const std::vector<int64_t>& x_dims, int64_t axis, \
bool keepdims, bool flatten) { \
const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
if (!flatten) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(dev) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenScalar<Tout>::From(*out); \
out_eigen.device(dev) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(dev) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}

DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);

template <typename T, typename Tout, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const FDTensor& x, FDTensor* out, int64_t axis,
bool keepdims, bool flatten) {
bool new_keepdims = keepdims | flatten;
// if flatten, will construct the new dims for the cacluate
std::vector<int64_t> x_dims;
int new_axis = axis;
if (flatten) {
x_dims = {x.Numel()};
// if flatten, the axis just as 0
new_axis = 0;
} else {
x_dims = x.shape;
if (axis < 0) new_axis = axis + x_dims.size();
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(x, out, x_dims, new_axis, new_keepdims, flatten)

switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
FDASSERT(x_dims.size() <= 6,
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}

template <typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype, bool keepdims, bool flatten) {
const auto& x_dims = x.shape;
int64_t x_rank = x_dims.size();
FDASSERT(axis >= -x_rank,
"'axis'(%d) must be greater than or equal to -Rank(X)(%d).", axis,
-x_rank);
FDASSERT(axis < x_rank,
"'axis'(%d) must be less than or equal to Rank(X)(%d).", axis,
x_rank);
FDASSERT(output_dtype == FDDataType::INT32 || FDDataType::INT64,
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s].",
Str(FDDataType::INT32), Str(FDDataType::INT64), Str(output_dtype));
if (axis < 0) axis += x_rank;
if (output_dtype == FDDataType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = x.Numel();

} else {
all_element_num = x_dims[axis];
}
FDASSERT(all_element_num <= std::numeric_limits<int>::max(),
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, std::numeric_limits<int>::max());
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->Allocate(vec, output_dtype);

FD_VISIT_INT_TYPES(output_dtype, "ArgMinMaxKernel", ([&] {
ArgMinMaxKernel<T, data_t, EnumArgMinMaxValue>(
x, out, axis, keepdims, flatten);
}));
}

Expand Down Expand Up @@ -255,6 +386,23 @@ void Prod(const FDTensor& x, FDTensor* out, const std::vector<int64_t>& dims,
bool keep_dim, bool reduce_all) {
Reduce<ProdFunctor>(x, out, dims, keep_dim, reduce_all);
}

void ArgMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype, bool keep_dim, bool flatten) {
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ArgMaxKernel", ([&] {
ArgMinMax<data_t, kArgMax>(
x, out, axis, output_dtype, keep_dim, flatten);
}));
}

void ArgMin(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype, bool keep_dim, bool flatten) {
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ArgMaxKernel", ([&] {
ArgMinMax<data_t, kArgMin>(
x, out, axis, output_dtype, keep_dim, flatten);
}));
}

#endif

} // namespace fastdeploy
28 changes: 28 additions & 0 deletions csrc/fastdeploy/function/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,33 @@ FASTDEPLOY_DECL void Prod(const FDTensor& x, FDTensor* out,
const std::vector<int64_t>& dims,
bool keep_dim = false, bool reduce_all = false);

/** Excute the argmax operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis which will be reduced.
@param output_dtype The data type of output FDTensor, INT64 or INT32,
default to INT64.
@param keep_dim Whether to keep the reduced dims, default false.
@param flatten Whether to flatten FDTensor to get the argmin index, default
false.
*/
FASTDEPLOY_DECL void ArgMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype = FDDataType::INT64,
bool keep_dim = false, bool flatten = false);

/** Excute the argmin operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis which will be reduced.
@param output_dtype The data type of output FDTensor, INT64 or INT32,
default to INT64.
@param keep_dim Whether to keep the reduced dims, default false.
@param flatten Whether to flatten FDTensor to get the argmin index, default
false.
*/
FASTDEPLOY_DECL void ArgMin(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype = FDDataType::INT64,
bool keep_dim = false, bool flatten = false);

#endif
} // namespace fastdeploy
20 changes: 20 additions & 0 deletions csrc/fastdeploy/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
} \
}()

#define FD_VISIT_INT_FLOAT_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \
__VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
__VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \
__VA_ARGS__) \
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
__VA_ARGS__) \
default: \
FDASSERT(false, \
"Invalid enum data type. Expect to accept data type INT32, " \
"INT64, FP32, FP64, but receive type %s.", \
Str(__dtype__)); \
} \
}()

#define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
Expand Down
67 changes: 67 additions & 0 deletions tests/function/test_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,72 @@ TEST(fastdeploy, reduce_any) {
check_data(reinterpret_cast<const bool*>(output.Data()),
expected_result_noaxis.data(), expected_result_noaxis.size());
}

TEST(fastdeploy, reduce_argmax) {
FDTensor input, output;
CheckShape check_shape;
CheckData check_data;

std::vector<int> inputs = {2, 4, 3, 7, 1, 5};
std::vector<int64_t> expected_result_axis0 = {1, 0, 1};
std::vector<int64_t> expected_result_axis1 = {1, 0};
std::vector<int64_t> expected_result_noaxis = {3};
input.SetExternalData({2, 3}, FDDataType::INT32, inputs.data());

// axis = 0, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// false
ArgMax(input, &output, 0);
check_shape(output.shape, {3});
check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_axis0.data(), expected_result_axis0.size());

// axis = -1, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// false
ArgMax(input, &output, -1);
check_shape(output.shape, {2});
check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_axis1.data(), expected_result_axis1.size());

// axis = -1, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// true
ArgMax(input, &output, -1, FDDataType::INT64, false, true);
check_shape(output.shape, {1});
check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_noaxis.data(), expected_result_noaxis.size());
}

TEST(fastdeploy, reduce_argmin) {
FDTensor input, output;
CheckShape check_shape;
CheckData check_data;

std::vector<int> inputs = {2, 4, 3, 7, 1, 5};
std::vector<int64_t> expected_result_axis0 = {0, 1, 0};
std::vector<int64_t> expected_result_axis1 = {0, 1};
std::vector<int64_t> expected_result_noaxis = {4};
input.SetExternalData({2, 3}, FDDataType::INT32, inputs.data());

// axis = 0, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// false
ArgMin(input, &output, 0);
check_shape(output.shape, {3});
check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_axis0.data(), expected_result_axis0.size());

// axis = -1, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// false
ArgMin(input, &output, -1);
check_shape(output.shape, {2});
check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_axis1.data(), expected_result_axis1.size());

// axis = -1, output_dtype = FDDataType::INT64, keep_dim = false, flatten =
// true
ArgMin(input, &output, -1, FDDataType::INT64, false, true);
check_shape(output.shape, {1});
check_data(reinterpret_cast<const int64_t*>(output.Data()),
expected_result_noaxis.data(), expected_result_noaxis.size());
}

#endif
} // namespace fastdeploy