Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix the implementation of binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
AntiZpvoh committed Mar 24, 2020
1 parent 8a51fd2 commit 757e6a4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
14 changes: 7 additions & 7 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4364,7 +4364,7 @@ def maximum(x1, x2, out=None, **kwargs):
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.maximum(x1, x2, out=out)
return _np.maximum(x1, x2, out=out)
return _api_internal.maximum(x1, x2, out)


Expand All @@ -4385,7 +4385,7 @@ def minimum(x1, x2, out=None, **kwargs):
out : mxnet.numpy.ndarray or scalar
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.minimum(x1, x2, out=out)
return _np.minimum(x1, x2, out=out)
return _api_internal.minimum(x1, x2, out)


Expand Down Expand Up @@ -6148,7 +6148,7 @@ def equal(x1, x2, out=None):
array([ True])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.equal(x1, x2, out=out)
return _np.equal(x1, x2, out=out)
return _api_internal.equal(x1, x2, out)


Expand Down Expand Up @@ -6182,7 +6182,7 @@ def not_equal(x1, x2, out=None):
array([False])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.not_equal(x1, x2, out=out)
return _np.not_equal(x1, x2, out=out)
return _api_internal.not_equal(x1, x2, out)


Expand Down Expand Up @@ -6250,7 +6250,7 @@ def less(x1, x2, out=None):
array([False])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.less(x1, x2, out=out)
return _np.less(x1, x2, out=out)
return _api_internal.less(x1, x2, out)


Expand Down Expand Up @@ -6284,7 +6284,7 @@ def greater_equal(x1, x2, out=None):
array([True])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.greater_equal(x1, x2, out=out)
return _np.greater_equal(x1, x2, out=out)
return _api_internal.greater_equal(x1, x2, out)


Expand Down Expand Up @@ -6319,7 +6319,7 @@ def less_equal(x1, x2, out=None):
array([True])
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
_np.less_equal(x1, x2, out=out)
return _np.less_equal(x1, x2, out=out)
return _api_internal.less_equal(x1, x2, out)


Expand Down
3 changes: 1 addition & 2 deletions src/api/operator/random/shuffle_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ MXNET_REGISTER_API("_npi.shuffle")
const nnvm::Op* op = Op::Get("_npi_shuffle");
nnvm::NodeAttrs attrs;

NDArray** inputs = new NDArray*[1]();
NDArray* inputs[1];
int num_inputs = 1;

if (args[0].type_code() != kNull) {
inputs[0] = args[0].operator mxnet::NDArray *();
}

attrs.op = op;
inputs = inputs == nullptr ? nullptr : inputs;

NDArray* out = args[1].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand Down

0 comments on commit 757e6a4

Please sign in to comment.