Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2a503df
commit 37ac6a0
Showing
10 changed files
with
159 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,120 +1,81 @@ | ||
# import unumpy.multimethods as multimethods | ||
# from .multimethods import ufunc, ufunc_list, ndarray | ||
# import torch | ||
# from typing import Dict, Callable | ||
# import functools | ||
import unumpy.multimethods as multimethods | ||
from .multimethods import ufunc, ufunc_list, ndarray | ||
import torch | ||
from uarray import DispatchableInstance | ||
|
||
# from uarray.backend import ( | ||
# Backend, | ||
# register_backend, | ||
# register_implementation, | ||
# DispatchableInstance, | ||
# ) | ||
__ua_domain__ = "numpy" | ||
|
||
# TorchBackend = Backend() | ||
# register_backend(TorchBackend) | ||
|
||
def compat_check(args): | ||
args = [arg.value if isinstance(arg, DispatchableInstance) else arg for arg in args] | ||
return all( | ||
isinstance(arg, torch.Tensor) or callable(arg) | ||
for arg in args | ||
if arg is not None | ||
) | ||
|
||
# def compat_check(args): | ||
# args = [arg.value if isinstance(arg, DispatchableInstance) else arg for arg in args] | ||
# return all(isinstance(arg, torch.Tensor) for arg in args if arg is not None) | ||
|
||
def asarray(a, dtype=None, order=None): | ||
if torch.is_tensor(a): | ||
if dtype is None or a.dtype != dtype: | ||
ret = torch.tensor(a, dtype=dtype) | ||
if a.requires_grad: | ||
ret.requires_grad_() | ||
return ret | ||
|
||
# register_torch = functools.partial( | ||
# register_implementation, backend=TorchBackend, compat_check=compat_check | ||
# ) | ||
return a | ||
try: | ||
import numpy as np | ||
|
||
if isinstance(a, np.ndarray): | ||
return torch.from_numpy(a) | ||
except ImportError: | ||
pass | ||
|
||
# _reduce_mapping = { | ||
# multimethods.add: torch.sum, # type: ignore | ||
# multimethods.multiply: torch.prod, # type: ignore | ||
# multimethods.minimum: torch.min, # type: ignore | ||
# multimethods.maximum: torch.max, # type: ignore | ||
# } | ||
return torch.tensor(a, dtype=dtype) | ||
|
||
# _ufunc_mapping: Dict[ufunc, Callable] = {} | ||
|
||
_implementations = { | ||
multimethods.ufunc.__call__: lambda x, *a, **kw: x(*a, **kw), | ||
multimethods.asarray: asarray, | ||
multimethods.array: torch.Tensor, | ||
multimethods.arange: lambda start, stop, step, **kwargs: torch.arange( | ||
start, stop, step, **kwargs | ||
), | ||
} | ||
|
||
# @register_torch(ufunc.__call__) | ||
# def __call__(self, *args, out=None): | ||
# if self not in _ufunc_mapping: | ||
# return NotImplemented | ||
# return _ufunc_mapping[self](*args, out=out) | ||
|
||
def __ua_function__(method, args, kwargs, dispatchable_args): | ||
if not compat_check(dispatchable_args): | ||
return NotImplemented | ||
|
||
# @register_torch(ufunc.reduce) | ||
# def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, arg=False): | ||
# if self not in _reduce_mapping: | ||
# return NotImplemented | ||
if method in _implementations: | ||
return _implementations[method](*args, **kwargs) | ||
|
||
# if axis is None: | ||
# axis = tuple(range(a.dim())) | ||
# elif not isinstance(axis, tuple): | ||
# axis = (axis,) | ||
if not hasattr(torch, method.__name__): | ||
return NotImplemented | ||
|
||
# ret = a | ||
# for dim in tuple(reversed(sorted(axis))): | ||
# ret = _reduce_mapping[self](ret, dim=dim, keepdim=keepdims) | ||
return getattr(torch, method.__name__)(*args, **kwargs) | ||
|
||
# assert not arg or isinstance(ret, tuple) | ||
# if isinstance(ret, tuple): | ||
# ret = ret[int(arg)] | ||
|
||
# if out is not None: | ||
# out[...] = ret | ||
# ret = out | ||
def __ua_coerce__(value, dispatch_type): | ||
if dispatch_type is ndarray: | ||
return asarray(value) if value is not None else None | ||
|
||
# return ret | ||
if dispatch_type is ufunc and value in _ufunc_mapping: | ||
return _ufunc_mapping[value] | ||
|
||
return NotImplemented | ||
|
||
# for ufunc_name in ufunc_list: | ||
# if ufunc_name.startswith("arc"): | ||
# torch_name = ufunc_name.replace("arc", "a") | ||
# else: | ||
# torch_name = ufunc_name | ||
|
||
# if hasattr(torch, torch_name): | ||
# _ufunc_mapping[getattr(multimethods, ufunc_name)] = getattr(torch, torch_name) | ||
_ufunc_mapping = {} | ||
|
||
# register_torch(multimethods.arange)( | ||
# lambda start, stop, step, **kwargs: torch.arange(start, stop, step, **kwargs) | ||
# ) | ||
# register_torch(multimethods.array)(torch.tensor) | ||
|
||
for ufunc_name in ufunc_list: | ||
if ufunc_name.startswith("arc"): | ||
torch_name = ufunc_name.replace("arc", "a") | ||
else: | ||
torch_name = ufunc_name | ||
|
||
# @register_torch(multimethods.asarray) | ||
# def asarray(a, dtype=None, order=None): | ||
# if torch.is_tensor(a): | ||
# if dtype is None or a.dtype != dtype: | ||
# ret = torch.tensor(a, dtype=dtype) | ||
# if a.requires_grad: | ||
# ret.requires_grad_() | ||
# return ret | ||
|
||
# return a | ||
# try: | ||
# import numpy as np | ||
|
||
# if isinstance(a, np.ndarray): | ||
# return torch.from_numpy(a) | ||
# except ImportError: | ||
# pass | ||
|
||
# return torch.tensor(a, dtype=dtype) | ||
|
||
|
||
# register_torch(multimethods.zeros)(torch.zeros) | ||
# register_torch(multimethods.ones)(torch.ones) | ||
|
||
|
||
# @register_torch(multimethods.argmax) | ||
# def argmax(a, axis=None, out=None): | ||
# return reduce(getattr(multimethods, "max"), a, axis=axis, out=out, arg=True) | ||
|
||
|
||
# @register_torch(multimethods.argmin) | ||
# def argmin(a, axis=None, out=None): | ||
# return reduce(getattr(multimethods, "min"), a, axis=axis, out=out, arg=True) | ||
|
||
|
||
# ndarray.register_convertor(TorchBackend, asarray) | ||
if hasattr(torch, torch_name): | ||
_ufunc_mapping[getattr(multimethods, ufunc_name)] = getattr(torch, torch_name) |
Oops, something went wrong.