[numpy] array ufunc and array function protocols #16097
[numpy] array ufunc and array function protocols #16097
Conversation
return decorator | ||
|
||
|
||
_NUMPY_ARRAY_FUNCTION_LIST = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the list of operators where our operator definition complies with the official numpy behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is the operators in our codebase that is dispatchable through the array function protocol. Others will be dispatched through either fluent methods or array ufunc protocol, which will be implemented later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. We can use this list to track the full-compatibility with numpy operators once we start measuring it.
@szha This PR is ready for merge if there are no more comments. Thanks. |
LGTM. @leezu would you mind taking a look too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @reminisce ! I noticed some issues, consider the following example
[ins] In [1]: import mxnet as mx
impor
[ins] In [2]: import numpy as np
[ins] In [3]: a1 = mx.np.arange(10)
[ins] In [4]: np.dot(a1.asnumpy(),a1.asnumpy())
Out[4]: 285.0
[ins] In [5]: mx.np.dot(a1,a1)
Out[5]: array(285.)
[ins] In [6]: np.dot(a1,a1)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-6-be002a006ef6> in <module>
----> 1 np.dot(a1,a1)
ValueError: dot: too many dimensions in result
[ins] In [7]:
Would it make sense to fix it in this PR? My biggest concern is that this was not caught by the tests, so there may be more issues that we are not yet aware of. Could the tests be improved?
It's also OK to do in a follow-up PR, but if you can integrate it in this PR it's of course best.
if cur_np_ver >= np_1_17_ver: | ||
try: | ||
func(*args, **kwargs) | ||
except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it necessary to catch all possible exceptions here and to overwrite their error message with a generic 'Running ... failed'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I will append the generic message to the real exception.
@leezu Thanks for the review. I tried the commands you showed, but it works for me. Could you confirm that you are using NumPy with version >= 1.17? >>> from mxnet import npx, np
>>> a = np.arange(10)
>>> npx.set_np()
>>> a = np.arange(10)
>>> import numpy as onp
>>> print(onp.__version__)
1.17.1
>>> onp.dot(a, a)
array(285.)
>>> type(onp.dot(a, a))
<class 'mxnet.numpy.ndarray'> |
@leezu Re: Can test be improved? Yes, absolutely. This PR is for building the skeleton of testing suite. It enables |
You're right, I was using numpy 1.16 and it ran into the well-known issue (I guess there is some github issue tracking it, but not sure) that numpy doesn't correctly convert ndarrays to numpy arrays. Would it make sense to support Numpy 1.16 when |
@leezu I prefer to bump the required version to 1.17 to use this feature for two reasons. However, I'm not sure if we should bump the numpy version requirement in this or a separate PR. Maybe @szha can provide some insights.
|
The downside of upgrading is the potential conflict of the numpy dependency version between mxnet’s and other packages’. Shall we just produce proper error message for older numpy versions instead? |
I will add proper error message when it is used with numpy 1.16 versions. Thanks. |
@szha this assumes that people can't upgrade to numpy 1.17 if they are using 1.16, even though there are no breaking changes https://docs.scipy.org/doc/numpy/release.html If we properly declare the 1.17 dependency, users that install mxnet via pip will automatically get numpy 1.17 and everything will work. If we don't bump the dependency, users have to read through error messages to understand they need to upgrade.. |
Add test Fix Fix sanity Fix build failure Use array_function protocol only when np.version >= 1.17 Add unit tests for array ufunc protocol Fix pylint Reformat Fix build Fix Refactor test suite for numpy interoperability
32efa10
to
f3548b2
Compare
e72720d
to
39ba49d
Compare
@leezu I'm referring to the situation where another package on an end user's environment declares a dependency on numpy 1.16, in which case the control of upgrading numpy isn't in the hands of the end user. |
This seems quite hypothetical / uncommon to me. I'm not convinced if we need to take consideration of this, given the user experience impact of not declaring the correct numpy version dependency and our current strategy of throwing runtime errors or having numpy throw errors such as After all we have bumped numpy dependency for 1.5 release as well.. 1e6a1ab |
I'm OK with bumping the version and merging the PR. The error message during installation is sufficiently clear. |
Description
This PR implemented NumPy's array ufunc and array function protocols so that
mxnet.numpy.ndarray
can be fed into a subset of official NumPy operators and get dispatched to the corresponding MXNet operators for execution. See the following for examples.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments