-
Notifications
You must be signed in to change notification settings - Fork 10
/
numpy_backend.py
87 lines (62 loc) · 2.24 KB
/
numpy_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
from uarray import Dispatchable, wrap_single_convertor
from unumpy import ufunc, ufunc_list, ndarray, dtype
import unumpy
import functools
from typing import Dict
_ufunc_mapping: Dict[ufunc, np.ufunc] = {}
__ua_domain__ = "numpy"
def overridden_class(self):
module = self.__module__.split(".")
module = ".".join(m for m in module if m != "_multimethods")
return _get_from_name_domain(self.__name__, module)
_implementations: Dict = {
unumpy.ufunc.__call__: np.ufunc.__call__,
unumpy.ufunc.reduce: np.ufunc.reduce,
unumpy.count_nonzero: lambda a, axis=None: np.asarray(np.count_nonzero(a, axis))[
()
],
unumpy.ClassOverrideMeta.overridden_class.fget: overridden_class,
}
def _get_from_name_domain(name, domain):
module = np
name_hierarchy = name.split(".")
domain_hierarchy = domain.split(".") + name_hierarchy[0:-1]
for d in domain_hierarchy[1:]:
module = getattr(module, d)
if hasattr(module, name_hierarchy[-1]):
return getattr(module, name_hierarchy[-1])
else:
return NotImplemented
def __ua_function__(method, args, kwargs):
if method in _implementations:
return _implementations[method](*args, **kwargs)
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented
method_numpy = _get_from_name_domain(method.__qualname__, method.domain)
if method_numpy is NotImplemented:
return NotImplemented
return method_numpy(*args, **kwargs)
@wrap_single_convertor
def __ua_convert__(value, dispatch_type, coerce):
if dispatch_type is ufunc:
return getattr(np, value.name)
if value is None:
return None
if dispatch_type is ndarray:
if not coerce and not isinstance(value, np.ndarray):
return NotImplemented
return np.asarray(value)
if dispatch_type is dtype:
try:
return np.dtype(str(value))
except TypeError:
return np.dtype(value)
return value
def replace_self(func):
@functools.wraps(func)
def inner(self, *args, **kwargs):
if self not in _ufunc_mapping:
return NotImplemented
return func(_ufunc_mapping[self], *args, **kwargs)
return inner