Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Update fallback.py #19457

Merged
merged 7 commits into from Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/mxnet/numpy/fallback.py
Expand Up @@ -19,6 +19,7 @@
"""Operators that fallback to official NumPy implementation."""

import sys
from functools import wraps
import numpy as onp

fallbacks = [
Expand Down Expand Up @@ -114,11 +115,17 @@

fallback_mod = sys.modules[__name__]

def get_func(obj, doc):
"""Get new numpy function with object and doc"""
@wraps(obj)
def wrapper(*args, **kwargs):
return obj(*args, **kwargs)
wrapper.__doc__ = doc
leezu marked this conversation as resolved.
Show resolved Hide resolved
return wrapper

for obj_name in fallbacks:
onp_obj = getattr(onp, obj_name)
if callable(onp_obj):
def fn(*args, **kwargs):
return onp_obj(*args, **kwargs)
new_fn_doc = onp_obj.__doc__
if obj_name in {'divmod', 'float_power', 'frexp', 'heaviside', 'modf', 'signbit', 'spacing'}:
# remove reference of kwargs doc and the reference to ufuncs
Expand All @@ -128,8 +135,7 @@ def fn(*args, **kwargs):
# remove unused reference
new_fn_doc = new_fn_doc.replace(
'.. [1] Wikipedia page: https://en.wikipedia.org/wiki/Trapezoidal_rule', '')
fn.__doc__ = new_fn_doc
setattr(fallback_mod, obj_name, fn)
setattr(fallback_mod, obj_name, get_func(onp_obj, new_fn_doc))
else:
setattr(fallback_mod, obj_name, onp_obj)

Expand Down
21 changes: 18 additions & 3 deletions python/mxnet/numpy/multiarray.py
Expand Up @@ -11382,7 +11382,12 @@ def atleast_1d(*arys):
>>> np.atleast_1d(np.array(1), np.array([3, 4]))
[array([1.]), array([3., 4.])]
"""
return _mx_nd_np.atleast_1d(*arys)
res = []
for ary in arys:
if not isinstance(ary, NDArray):
ary = array(ary)
res.append(ary)
return _mx_nd_np.atleast_1d(*res)


@set_module('mxnet.numpy')
Expand Down Expand Up @@ -11414,7 +11419,12 @@ def atleast_2d(*arys):
>>> np.atleast_2d(np.array(1), np.array([1, 2]), np.array([[1, 2]]))
[array([[1.]]), array([[1., 2.]]), array([[1., 2.]])]
"""
return _mx_nd_np.atleast_2d(*arys)
res = []
for ary in arys:
if not isinstance(ary, NDArray):
ary = array(ary)
res.append(ary)
return _mx_nd_np.atleast_2d(*res)


@set_module('mxnet.numpy')
Expand Down Expand Up @@ -11457,7 +11467,12 @@ def atleast_3d(*arys):
[2.]]] (1, 2, 1)
[[[1. 2.]]] (1, 1, 2)
"""
return _mx_nd_np.atleast_3d(*arys)
res = []
for ary in arys:
if not isinstance(ary, NDArray):
ary = array(ary)
res.append(ary)
return _mx_nd_np.atleast_3d(*res)


@set_module('mxnet.numpy')
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_numpy_interoperability.py
Expand Up @@ -3255,13 +3255,15 @@ def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs):
strs = op_name.split('.')
if len(strs) == 1:
onp_op = getattr(_np, op_name)
mxnp_op = getattr(np, op_name)
elif len(strs) == 2:
onp_op = getattr(getattr(_np, strs[0]), strs[1])
mxnp_op = getattr(getattr(np, strs[0]), strs[1])
else:
assert False
if not is_op_runnable():
return
out = onp_op(*args, **kwargs)
out = mxnp_op(*args, **kwargs)
expected_out = _get_numpy_op_output(onp_op, *args, **kwargs)
if isinstance(out, (tuple, list)):
assert type(out) == type(expected_out)
Expand Down