Skip to content

Commit

Permalink
Fix backends for protocol approach.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed May 13, 2019
1 parent 2a503df commit 37ac6a0
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 201 deletions.
4 changes: 1 addition & 3 deletions setup.py
Expand Up @@ -54,9 +54,7 @@ def parse_requires():
maintainer_email="habbasi@quansight.com",
license="BSD 3-Clause License (Revised)",
keywords="uarray,numpy,scipy,pytorch,cupy,tensorflow",
packages=find_packages(
include=["uarray", "uarray.*", "unumpy", "unumpy.*"]
),
packages=find_packages(include=["uarray", "uarray.*", "unumpy", "unumpy.*"]),
long_description=long_desc,
long_description_content_type="text/markdown",
install_requires=reqs,
Expand Down
2 changes: 1 addition & 1 deletion uarray/backend.py
Expand Up @@ -73,7 +73,7 @@ def replace_dispatchables(backend, args, kwargs, coerce: Optional[bool] = False)
filtered_args: List = []
for arg in dispatchable_args:
replaced_arg = (
backend.__ua_coerce__(arg)
backend.__ua_coerce__(arg.value, arg.dispatch_type)
if isinstance(arg, DispatchableInstance)
else NotImplemented
)
Expand Down
10 changes: 5 additions & 5 deletions unumpy/cupy_backend.py
Expand Up @@ -36,12 +36,12 @@ def __ua_function__(method, args, kwargs, dispatchable_args):

return getattr(cp, method.__name__)(*args, **kwargs)

def __ua_coerce__(arg):
if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ndarray:
return cp.asarray(arg.value) if arg.value is not None else None
def __ua_coerce__(value, dispatch_type):
if dispatch_type is ndarray:
return cp.asarray(value) if value is not None else None

if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ufunc:
return getattr(np, arg.value.name)
if dispatch_type is ufunc and hasattr(cp, value.name):
return getattr(cp, value.name)

return NotImplemented

Expand Down
12 changes: 6 additions & 6 deletions unumpy/dask_backend.py
@@ -1,6 +1,6 @@
import numpy as np
import dask.array as da
from uarray.backend import DispatchableInstance
from uarray import DispatchableInstance
from .multimethods import ufunc, ufunc_list, ndarray
import unumpy.multimethods as multimethods
import functools
Expand Down Expand Up @@ -42,12 +42,12 @@ def __ua_function__(method, args, kwargs, dispatchable_args):
return getattr(da, method.__name__)(*args, **kwargs)


def __ua_coerce__(arg):
if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ndarray:
return da.asarray(arg.value) if arg.value is not None else None
def __ua_coerce__(value, dispatch_type):
if dispatch_type is ndarray:
return da.asarray(value) if value is not None else None

if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ufunc:
return getattr(np, arg.value.name)
if dispatch_type is ufunc:
return getattr(np, value.name)

return NotImplemented

Expand Down
6 changes: 4 additions & 2 deletions unumpy/multimethods.py
Expand Up @@ -74,9 +74,11 @@ def _ufunc_argreplacer(args, kwargs, arrays):
out_arrays = arrays[self.nin + 1 :]
if self.nout == 1:
out_arrays = out_arrays[0]
out_kwargs = {**kwargs, "out": out_arrays}

return (arrays[0], *in_arrays), out_kwargs
if "out" in kwargs:
kwargs = {**kwargs, "out": out_arrays}

return (arrays[0], *in_arrays), kwargs

@create_numpy(_ufunc_argreplacer)
@all_of_type(ndarray)
Expand Down
10 changes: 5 additions & 5 deletions unumpy/numpy_backend.py
Expand Up @@ -39,12 +39,12 @@ def __ua_function__(method, args, kwargs, dispatchable_args):
return getattr(np, method.__name__)(*args, **kwargs)


def __ua_coerce__(arg):
if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ndarray:
return np.asarray(arg.value) if arg.value is not None else None
def __ua_coerce__(value, dispatch_type):
if dispatch_type is ndarray:
return np.asarray(value) if value is not None else None

if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ufunc:
return getattr(np, arg.value.name)
if dispatch_type is ufunc:
return getattr(np, value.name)

return NotImplemented

Expand Down
12 changes: 5 additions & 7 deletions unumpy/sparse_backend.py
Expand Up @@ -40,20 +40,18 @@ def __ua_function__(method, args, kwargs, dispatchable_args):
return getattr(sparse, method.__name__)(*args, **kwargs)


def __ua_coerce__(arg):
if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ndarray:
value = arg.value

if arg.value is None:
def __ua_coerce__(value, dispatch_type):
if dispatch_type is ndarray:
if value is None:
return None

if isinstance(value, sparse.SparseArray):
return value

return sparse.as_coo(np.asarray(value))

if isinstance(arg, DispatchableInstance) and arg.dispatch_type is ufunc:
return getattr(np, arg.value.name)
if dispatch_type is ufunc:
return getattr(np, value.name)

return NotImplemented

Expand Down
6 changes: 5 additions & 1 deletion unumpy/tests/test_numpy.py
Expand Up @@ -8,7 +8,7 @@
import sparse
import unumpy.numpy_backend as NumpyBackend

# from unumpy.torch_backend import TorchBackend
import unumpy.torch_backend as TorchBackend
import unumpy.xnd_backend as XndBackend
import unumpy.dask_backend as DaskBackend
import unumpy.sparse_backend as SparseBackend
Expand All @@ -20,6 +20,10 @@
),
(DaskBackend, (da.core.Array, onp.generic)),
(SparseBackend, (sparse.SparseArray, onp.generic)),
pytest.param(
(TorchBackend, torch.Tensor),
marks=pytest.mark.xfail(reason="PyTorch not fully NumPy compatible."),
),
]

try:
Expand Down
153 changes: 57 additions & 96 deletions unumpy/torch_backend.py
@@ -1,120 +1,81 @@
# import unumpy.multimethods as multimethods
# from .multimethods import ufunc, ufunc_list, ndarray
# import torch
# from typing import Dict, Callable
# import functools
import unumpy.multimethods as multimethods
from .multimethods import ufunc, ufunc_list, ndarray
import torch
from uarray import DispatchableInstance

# from uarray.backend import (
# Backend,
# register_backend,
# register_implementation,
# DispatchableInstance,
# )
__ua_domain__ = "numpy"

# TorchBackend = Backend()
# register_backend(TorchBackend)

def compat_check(args):
args = [arg.value if isinstance(arg, DispatchableInstance) else arg for arg in args]
return all(
isinstance(arg, torch.Tensor) or callable(arg)
for arg in args
if arg is not None
)

# def compat_check(args):
# args = [arg.value if isinstance(arg, DispatchableInstance) else arg for arg in args]
# return all(isinstance(arg, torch.Tensor) for arg in args if arg is not None)

def asarray(a, dtype=None, order=None):
if torch.is_tensor(a):
if dtype is None or a.dtype != dtype:
ret = torch.tensor(a, dtype=dtype)
if a.requires_grad:
ret.requires_grad_()
return ret

# register_torch = functools.partial(
# register_implementation, backend=TorchBackend, compat_check=compat_check
# )
return a
try:
import numpy as np

if isinstance(a, np.ndarray):
return torch.from_numpy(a)
except ImportError:
pass

# _reduce_mapping = {
# multimethods.add: torch.sum, # type: ignore
# multimethods.multiply: torch.prod, # type: ignore
# multimethods.minimum: torch.min, # type: ignore
# multimethods.maximum: torch.max, # type: ignore
# }
return torch.tensor(a, dtype=dtype)

# _ufunc_mapping: Dict[ufunc, Callable] = {}

_implementations = {
multimethods.ufunc.__call__: lambda x, *a, **kw: x(*a, **kw),
multimethods.asarray: asarray,
multimethods.array: torch.Tensor,
multimethods.arange: lambda start, stop, step, **kwargs: torch.arange(
start, stop, step, **kwargs
),
}

# @register_torch(ufunc.__call__)
# def __call__(self, *args, out=None):
# if self not in _ufunc_mapping:
# return NotImplemented
# return _ufunc_mapping[self](*args, out=out)

def __ua_function__(method, args, kwargs, dispatchable_args):
if not compat_check(dispatchable_args):
return NotImplemented

# @register_torch(ufunc.reduce)
# def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, arg=False):
# if self not in _reduce_mapping:
# return NotImplemented
if method in _implementations:
return _implementations[method](*args, **kwargs)

# if axis is None:
# axis = tuple(range(a.dim()))
# elif not isinstance(axis, tuple):
# axis = (axis,)
if not hasattr(torch, method.__name__):
return NotImplemented

# ret = a
# for dim in tuple(reversed(sorted(axis))):
# ret = _reduce_mapping[self](ret, dim=dim, keepdim=keepdims)
return getattr(torch, method.__name__)(*args, **kwargs)

# assert not arg or isinstance(ret, tuple)
# if isinstance(ret, tuple):
# ret = ret[int(arg)]

# if out is not None:
# out[...] = ret
# ret = out
def __ua_coerce__(value, dispatch_type):
if dispatch_type is ndarray:
return asarray(value) if value is not None else None

# return ret
if dispatch_type is ufunc and value in _ufunc_mapping:
return _ufunc_mapping[value]

return NotImplemented

# for ufunc_name in ufunc_list:
# if ufunc_name.startswith("arc"):
# torch_name = ufunc_name.replace("arc", "a")
# else:
# torch_name = ufunc_name

# if hasattr(torch, torch_name):
# _ufunc_mapping[getattr(multimethods, ufunc_name)] = getattr(torch, torch_name)
_ufunc_mapping = {}

# register_torch(multimethods.arange)(
# lambda start, stop, step, **kwargs: torch.arange(start, stop, step, **kwargs)
# )
# register_torch(multimethods.array)(torch.tensor)

for ufunc_name in ufunc_list:
if ufunc_name.startswith("arc"):
torch_name = ufunc_name.replace("arc", "a")
else:
torch_name = ufunc_name

# @register_torch(multimethods.asarray)
# def asarray(a, dtype=None, order=None):
# if torch.is_tensor(a):
# if dtype is None or a.dtype != dtype:
# ret = torch.tensor(a, dtype=dtype)
# if a.requires_grad:
# ret.requires_grad_()
# return ret

# return a
# try:
# import numpy as np

# if isinstance(a, np.ndarray):
# return torch.from_numpy(a)
# except ImportError:
# pass

# return torch.tensor(a, dtype=dtype)


# register_torch(multimethods.zeros)(torch.zeros)
# register_torch(multimethods.ones)(torch.ones)


# @register_torch(multimethods.argmax)
# def argmax(a, axis=None, out=None):
# return reduce(getattr(multimethods, "max"), a, axis=axis, out=out, arg=True)


# @register_torch(multimethods.argmin)
# def argmin(a, axis=None, out=None):
# return reduce(getattr(multimethods, "min"), a, axis=axis, out=out, arg=True)


# ndarray.register_convertor(TorchBackend, asarray)
if hasattr(torch, torch_name):
_ufunc_mapping[getattr(multimethods, ufunc_name)] = getattr(torch, torch_name)

0 comments on commit 37ac6a0

Please sign in to comment.