Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] FFI: random.shuffle, equal, not_equal, less_equal, greater_equal, less, maximum and minimum #17896

Merged
merged 5 commits into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def prepare_workloads():
OpArgMngr.add_workload("nan_to_num", pool['2x2'])
OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1)))
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("random.shuffle", pool['3'])
OpArgMngr.add_workload("equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("not_equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("less", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("greater_equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("less_equal", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("maximum", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("minimum", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
Expand Down
32 changes: 23 additions & 9 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4363,7 +4363,9 @@ def maximum(x1, x2, out=None, **kwargs):
-------
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.maximum(x1, x2, out=out)
return _api_internal.maximum(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand All @@ -4382,7 +4384,9 @@ def minimum(x1, x2, out=None, **kwargs):
-------
out : mxnet.numpy.ndarray or scalar
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.minimum(x1, x2, out=out)
return _api_internal.minimum(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6143,7 +6147,9 @@ def equal(x1, x2, out=None):
>>> np.equal(1, np.ones(1))
array([ True])
"""
return _ufunc_helper(x1, x2, _npi.equal, _np.equal, _npi.equal_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.equal(x1, x2, out=out)
return _api_internal.equal(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6175,7 +6181,10 @@ def not_equal(x1, x2, out=None):
>>> np.not_equal(1, np.ones(1))
array([False])
"""
return _ufunc_helper(x1, x2, _npi.not_equal, _np.not_equal, _npi.not_equal_scalar, None, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.not_equal(x1, x2, out=out)
return _api_internal.not_equal(x1, x2, out)



@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6240,7 +6249,9 @@ def less(x1, x2, out=None):
>>> np.less(1, np.ones(1))
array([False])
"""
return _ufunc_helper(x1, x2, _npi.less, _np.less, _npi.less_scalar, _npi.greater_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.less(x1, x2, out=out)
return _api_internal.less(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6272,8 +6283,10 @@ def greater_equal(x1, x2, out=None):
>>> np.greater_equal(1, np.ones(1))
array([True])
"""
return _ufunc_helper(x1, x2, _npi.greater_equal, _np.greater_equal, _npi.greater_equal_scalar,
_npi.less_equal_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.greater_equal(x1, x2, out=out)
return _api_internal.greater_equal(x1, x2, out)



@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6305,8 +6318,9 @@ def less_equal(x1, x2, out=None):
>>> np.less_equal(1, np.ones(1))
array([True])
"""
return _ufunc_helper(x1, x2, _npi.less_equal, _np.less_equal, _npi.less_equal_scalar,
_npi.greater_equal_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.less_equal(x1, x2, out=out)
return _api_internal.less_equal(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def shuffle(x):
[3., 4., 5.],
[0., 1., 2.]])
"""
_npi.shuffle(x, out=x)
_api_internal.shuffle(x, x)


def laplace(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
Expand Down
71 changes: 71 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_elemwise_broadcast_logic_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_logic_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../ufunc_helper.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.not_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_not_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_not_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.less")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less");
const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
AntiZpvoh marked this conversation as resolved.
Show resolved Hide resolved
});

MXNET_REGISTER_API("_npi.greater_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_greater_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
AntiZpvoh marked this conversation as resolved.
Show resolved Hide resolved
});

MXNET_REGISTER_API("_npi.less_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
AntiZpvoh marked this conversation as resolved.
Show resolved Hide resolved
});

} // namespace mxnet
57 changes: 57 additions & 0 deletions src/api/operator/random/shuffle_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file shuffle_op.cc
* \brief Implementation of the API of functions in src/operator/random/shuffle_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/elemwise_op_common.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.shuffle")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_shuffle");
nnvm::NodeAttrs attrs;

NDArray* inputs[1];
int num_inputs = 1;

if (args[0].type_code() != kNull) {
inputs[0] = args[0].operator mxnet::NDArray *();
}

attrs.op = op;

NDArray* out = args[1].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(1);
} else {
*ret = ndoutputs[0];
}
});

} // namespace mxnet
47 changes: 47 additions & 0 deletions src/api/operator/tensor/elemwise_binary_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file elemwise_binary_broadcast_op_extended.cc
* \brief Implementation of the API of functions in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../ufunc_helper.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.maximum")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_maximum");
const nnvm::Op* op_scalar = Op::Get("_npi_maximum_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.minimum")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_minimum");
const nnvm::Op* op_scalar = Op::Get("_npi_minimum_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

} // namespace mxnet