Skip to content

Commit

Permalink
Merge pull request #72 from joaosferreira/fix-backends
Browse files Browse the repository at this point in the history
Fix __ua_function__ in backends
  • Loading branch information
hameerabbasi committed Aug 14, 2020
2 parents 9a42b49 + 986b492 commit 8f1d962
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 24 deletions.
10 changes: 7 additions & 3 deletions unumpy/cupy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def _get_from_name_domain(name, domain):
module = cp
domain_hierarchy = domain.split(".")
for d in domain_hierarchy[1:]:
module = getattr(module, d)
if hasattr(module, d):
module = getattr(module, d)
else:
return NotImplemented
if hasattr(module, name):
return getattr(module, name)
else:
Expand All @@ -45,10 +48,11 @@ def __ua_function__(method, args, kwargs):
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented

if not hasattr(cp, method.__name__):
cupy_method = _get_from_name_domain(method.__name__, method.domain)
if cupy_method is NotImplemented:
return NotImplemented

return getattr(cp, method.__name__)(*args, **kwargs)
return cupy_method(*args, **kwargs)

@wrap_single_convertor
def __ua_convert__(value, dispatch_type, coerce):
Expand Down
22 changes: 13 additions & 9 deletions unumpy/dask_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
set_state,
set_backend,
)
from unumpy import ufunc, ufunc_list, ndarray
from unumpy import ufunc, ufunc_list, ndarray, dtype
import unumpy
import functools
import sys
Expand All @@ -18,22 +18,25 @@

from typing import Dict

_class_mapping = {ndarray: da.Array, dtype: np.dtype, ufunc: da.ufunc.ufunc}


def overridden_class(self):
if self is ndarray:
return da.Array
if self is ufunc:
return da.ufunc.ufunc
if self in _class_mapping:
return _class_mapping[self]
module = self.__module__.split(".")
module = ".".join(m for m in module if m != "_multimethods")
return _get_from_name_domain(self.__name__, module)


def _get_from_name_domain(name, domain):
module = np
module = da
domain_hierarchy = domain.split(".")
for d in domain_hierarchy[1:]:
module = getattr(module, d)
if hasattr(module, d):
module = getattr(module, d)
else:
return NotImplemented
if hasattr(module, name):
return getattr(module, name)
else:
Expand Down Expand Up @@ -141,10 +144,11 @@ def __ua_function__(self, method, args, kwargs):
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented

if not hasattr(da, method.__name__):
dask_method = _get_from_name_domain(method.__name__, method.domain)
if dask_method is NotImplemented:
return NotImplemented

return getattr(da, method.__name__)(*args, **kwargs)
return dask_method(*args, **kwargs)

@wrap_single_convertor_instance
def __ua_convert__(self, value, dispatch_type, coerce):
Expand Down
3 changes: 2 additions & 1 deletion unumpy/linalg/_multimethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
mark_dtype,
_first_argreplacer,
_first2argreplacer,
ndim,
)

__all__ = [
Expand Down Expand Up @@ -120,7 +121,7 @@ def cond_default(x, p=None):
raise ValueError("Array must be at least two-dimensional.")


@create_numpy(_self_argreplacer, default=cond_default)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def cond(x, p=None):
return (x,)
Expand Down
21 changes: 14 additions & 7 deletions unumpy/sparse_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import sparse
from uarray import Dispatchable, wrap_single_convertor
from unumpy import ufunc, ufunc_list, ndarray
from unumpy import ufunc, ufunc_list, ndarray, dtype
import unumpy
import functools

Expand All @@ -23,9 +23,12 @@ def array(x, *args, **kwargs):
return sparse.COO.from_numpy(np.asarray(x))


_class_mapping = {ndarray: sparse.SparseArray, dtype: np.dtype, ufunc: np.ufunc}


def overridden_class(self):
if self is ndarray:
return sparse.SparseArray
if self in _class_mapping:
return _class_mapping[self]
module = self.__module__.split(".")
module = ".".join(m for m in module if m != "_multimethods")
return _get_from_name_domain(self.__name__, module)
Expand All @@ -41,10 +44,13 @@ def overridden_class(self):


def _get_from_name_domain(name, domain):
module = np
module = sparse
domain_hierarchy = domain.split(".")
for d in domain_hierarchy[1:]:
module = getattr(module, d)
if hasattr(module, d):
module = getattr(module, d)
else:
return NotImplemented
if hasattr(module, name):
return getattr(module, name)
else:
Expand All @@ -58,10 +64,11 @@ def __ua_function__(method, args, kwargs):
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented

if not hasattr(sparse, method.__name__):
sparse_method = _get_from_name_domain(method.__name__, method.domain)
if sparse_method is NotImplemented:
return NotImplemented

return getattr(sparse, method.__name__)(*args, **kwargs)
return sparse_method(*args, **kwargs)


@wrap_single_convertor
Expand Down
24 changes: 20 additions & 4 deletions unumpy/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def test_ufuncs_results(backend, method, args, kwargs, res):
),
{},
),
(np.linalg.lstsq, ([[3, 1], [1, 2]], [9, 8]), {"rcond": None}),
(np.linalg.lstsq, ([[3, 1], [1, 2]], [9, 8]), {}),
(np.linalg.inv, ([[1.0, 2.0], [3.0, 4.0]],), {}),
(np.linalg.pinv, ([[1.0, 2.0], [3.0, 4.0]],), {}),
(np.linalg.tensorinv, (np.eye(4 * 6).reshape((4, 6, 8, 3)),), {}),
Expand All @@ -494,15 +494,31 @@ def test_ufuncs_results(backend, method, args, kwargs, res):
def test_linalg(backend, method, args, kwargs):
backend, types = backend
try:
with ua.set_backend(NumpyBackend, coerce=True):
with ua.set_backend(backend, coerce=True):
ret = method(*args, **kwargs)
except ua.BackendNotImplementedError:
if backend in FULLY_TESTED_BACKENDS and (backend, method) not in EXCEPTIONS:
raise
pytest.xfail(reason="The backend has no implementation for this ufunc.")

if isinstance(ret, da.Array):
ret.compute()
if method in {
np.linalg.qr,
np.linalg.svd,
np.linalg.eig,
np.linalg.eigh,
np.linalg.slogdet,
np.linalg.lstsq,
}:
assert all(isinstance(arr, types) for arr in ret)

for arr in ret:
if isinstance(arr, da.Array):
arr.compute()
else:
assert isinstance(ret, types)

if isinstance(ret, da.Array):
ret.compute()


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8f1d962

Please sign in to comment.