Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] FFI for argmax, argmin, indices #17843

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def prepare_workloads():
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1)
OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1)
OpArgMngr.add_workload("indices", dimensions=(1, 2, 3))
OpArgMngr.add_workload("subtract", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("multiply", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("mod", pool['2x2'], pool['2x2'])
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4529,7 +4529,7 @@ def argmax(a, axis=None, out=None):
>>> b
array([2., 2.])
"""
return _npi.argmax(a, axis=axis, keepdims=False, out=out)
return _api_internal.argmax(a, axis, False, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4597,7 +4597,7 @@ def argmin(a, axis=None, out=None):
>>> b
array([0., 0.])
"""
return _npi.argmin(a, axis=axis, keepdims=False, out=out)
return _api_internal.argmin(a, axis, False, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4945,8 +4945,10 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
"""
if isinstance(dimensions, (tuple, list)):
if ctx is None:
ctx = current_context()
return _npi.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
ctx = str(current_context())
else:
ctx = str(ctx)
return _api_internal.indices(dimensions, dtype, ctx)
else:
raise ValueError("The dimensions must be sequence of ints")
# pylint: enable=redefined-outer-name
Expand Down
98 changes: 98 additions & 0 deletions src/api/operator/numpy/np_broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_broadcast_reduce_op_index.cc
* \brief Implementation of the API of functions in
src/operator/numpy/np_broadcast_reduce_op_index.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/broadcast_reduce_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.argmax")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_argmax");
nnvm::NodeAttrs attrs;
op::ReduceAxisParam param;
// param.axis
if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
// param.keepdims
param.keepdims = args[2].operator bool();

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::ReduceAxisParam>(&attrs);
// inputs
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
// outputs
NDArray* out = args[3].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

MXNET_REGISTER_API("_npi.argmin")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_argmin");
nnvm::NodeAttrs attrs;
op::ReduceAxisParam param;
// param.axis
if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
// param.keepdims
param.keepdims = args[2].operator bool();

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::ReduceAxisParam>(&attrs);
// inputs
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
// outputs
NDArray* out = args[3].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
32 changes: 32 additions & 0 deletions src/api/operator/numpy/np_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/init_op.h"
#include "../../../operator/numpy/np_init_op.h"

namespace mxnet {

Expand Down Expand Up @@ -88,4 +89,35 @@ MXNET_REGISTER_API("_npi.full_like")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.indices")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_indices");
nnvm::NodeAttrs attrs;
op::IndicesOpParam param;
// param.dimensions
if (args[0].type_code() == kDLInt) {
param.dimensions = TShape(1, args[0].operator int64_t());
} else {
param.dimensions = TShape(args[0].operator ObjectRef());
}
// param.dtype
if (args[1].type_code() == kNull) {
param.dtype = mshadow::kInt32;
} else {
param.dtype = String2MXNetTypeWithBool(args[1].operator std::string());
}
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::IndicesOpParam>(&attrs);
// param.ctx
if (args[2].type_code() != kNull) {
attrs.dict["ctx"] = args[2].operator std::string();
}
int num_inputs = 0;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
8 changes: 8 additions & 0 deletions src/operator/numpy/np_init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <string>
#include "../tensor/init_op.h"
#include "../tensor/elemwise_unary_op.h"
#include "../../api/operator/op_utils.h"


namespace mxnet {
Expand Down Expand Up @@ -79,6 +80,13 @@ struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream dimensions_s, dtype_s;
dimensions_s << dimensions;
dtype_s << dtype;
(*dict)["dimensions"] = dimensions_s.str();
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype);
}
};

inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
Expand Down
7 changes: 7 additions & 0 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
.describe("If this is set to `True`, the reduced axis is left "
"in the result as dimension with size one.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, keepdims_s;
axis_s << axis;
keepdims_s << keepdims;
(*dict)["axis"] = axis_s.str();
(*dict)["keepdims"] = keepdims_s.str();
}
};

enum PickOpMode {kWrap, kClip};
Expand Down