Skip to content

Commit

Permalink
Merge pull request #8305 from mhvk/v2-units-introduce-matmul
Browse files Browse the repository at this point in the history
2.0.x: Add support for matmul as a ufunc (new in numpy 1.16).
  • Loading branch information
bsipocz committed Dec 30, 2018
2 parents 809c0d4 + e97b42f commit 9dc315e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ astropy.time
astropy.units
^^^^^^^^^^^^^

- Add support for ``np.matmul`` as a ``ufunc`` (new in numpy 1.16). [#8305]

astropy.utils
^^^^^^^^^^^^^

Expand Down
21 changes: 11 additions & 10 deletions astropy/units/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .core import (Unit, dimensionless_unscaled, get_current_unit_registry,
UnitBase, UnitsError, UnitTypeError)
from .format.latex import Latex
from ..utils.compat import NUMPY_LT_1_13, NUMPY_LT_1_14
from ..utils.compat import NUMPY_LT_1_13, NUMPY_LT_1_14, NUMPY_LT_1_16
from ..utils.compat.misc import override__dir__
from ..utils.compat.numpy import matmul
from ..utils.misc import isiterable, InheritDocstrings
Expand Down Expand Up @@ -1105,15 +1105,16 @@ def __pow__(self, other):
return super(Quantity, self).__pow__(other)

# For Py>=3.5
def __matmul__(self, other, reverse=False):
result_unit = self.unit * getattr(other, 'unit', dimensionless_unscaled)
result_array = matmul(self.value, getattr(other, 'value', other))
return self._new_view(result_array, result_unit)

def __rmatmul__(self, other):
result_unit = self.unit * getattr(other, 'unit', dimensionless_unscaled)
result_array = matmul(getattr(other, 'value', other), self.value)
return self._new_view(result_array, result_unit)
if NUMPY_LT_1_16:
def __matmul__(self, other, reverse=False):
result_unit = self.unit * getattr(other, 'unit', dimensionless_unscaled)
result_array = matmul(self.value, getattr(other, 'value', other))
return self._new_view(result_array, result_unit)

def __rmatmul__(self, other):
result_unit = self.unit * getattr(other, 'unit', dimensionless_unscaled)
result_array = matmul(getattr(other, 'value', other), self.value)
return self._new_view(result_array, result_unit)

if NUMPY_LT_1_13:
# Pre-numpy 1.13, there was no np.positive ufunc and the copy done
Expand Down
2 changes: 2 additions & 0 deletions astropy/units/quantity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def helper_multiplication(f, unit1, unit2):


UFUNC_HELPERS[np.multiply] = helper_multiplication
if isinstance(getattr(np, 'matmul', None), np.ufunc):
UFUNC_HELPERS[np.matmul] = helper_multiplication


def helper_division(f, unit1, unit2):
Expand Down
20 changes: 20 additions & 0 deletions astropy/units/tests/test_quantity_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,26 @@ def test_multiply_array(self):
assert np.all(np.multiply(np.arange(3.) * u.m, 2. / u.s) ==
np.arange(0, 6., 2.) * u.m / u.s)

@pytest.mark.skipif(not isinstance(getattr(np, 'matmul', None), np.ufunc),
reason="np.matmul is not yet a gufunc")
def test_matmul(self):
q = np.arange(3.) * u.m
r = np.matmul(q, q)
assert r == 5. * u.m ** 2
# less trivial case.
q1 = np.eye(3) * u.m
q2 = np.array([[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]],
[[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]],
[[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.]]]) / u.s
r2 = np.matmul(q1, q2)
assert np.all(r2 == np.matmul(q1.value, q2.value) * q1.unit * q2.unit)

@pytest.mark.parametrize('function', (np.divide, np.true_divide))
def test_divide_scalar(self, function):
assert function(4. * u.m, 2. * u.s) == function(4., 2.) * u.m / u.s
Expand Down
3 changes: 2 additions & 1 deletion astropy/utils/compat/numpycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

__all__ = ['NUMPY_LT_1_9_1', 'NUMPY_LT_1_10', 'NUMPY_LT_1_10_4',
'NUMPY_LT_1_11', 'NUMPY_LT_1_12', 'NUMPY_LT_1_13', 'NUMPY_LT_1_14',
'NUMPY_LT_1_14_1']
'NUMPY_LT_1_14_1', 'NUMPY_LT_1_16']

# TODO: It might also be nice to have aliases to these named for specific
# features/bugs we're checking for (ex:
Expand All @@ -23,3 +23,4 @@
NUMPY_LT_1_13 = not minversion('numpy', '1.13')
NUMPY_LT_1_14 = not minversion('numpy', '1.14')
NUMPY_LT_1_14_1 = not minversion('numpy', '1.14.1')
NUMPY_LT_1_16 = not minversion('numpy', '1.16')

0 comments on commit 9dc315e

Please sign in to comment.