Skip to content

Commit

Permalink
Add overridden_class to CuPy backend (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
joaosferreira committed Aug 6, 2020
1 parent 8d1ca5c commit cc124c2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
19 changes: 18 additions & 1 deletion unumpy/cupy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,24 @@

__ua_domain__ = "numpy"

_implementations: Dict = {}
def overridden_class(self):
module = self.__module__.split(".")
module = ".".join(m for m in module if m != "_multimethods")
return _get_from_name_domain(self.__name__, module)

_implementations: Dict = {
unumpy.ClassOverrideMeta.overridden_class.fget: overridden_class
}

def _get_from_name_domain(name, domain):
module = cp
domain_hierarchy = domain.split(".")
for d in domain_hierarchy[1:]:
module = getattr(module, d)
if hasattr(module, name):
return getattr(module, name)
else:
return NotImplemented

def _implements(np_func):
def inner(func):
Expand Down
8 changes: 8 additions & 0 deletions unumpy/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,11 @@ def test_class_overriding():
assert np.dtype("float64") == onp.float64
assert isinstance(np.dtype("float64"), onp.dtype)
assert issubclass(onp.ufunc, np.ufunc)

if hasattr(CupyBackend, "__ua_function__"):
with ua.set_backend(CupyBackend, coerce=True):
assert isinstance(cp.add, np.ufunc)
assert isinstance(cp.dtype("float64"), np.dtype)
assert np.dtype("float64") == cp.float64
assert isinstance(np.dtype("float64"), cp.dtype)
assert issubclass(cp.ufunc, np.ufunc)

0 comments on commit cc124c2

Please sign in to comment.