Skip to content

Commit

Permalink
Ufunc dispatch (#136)
Browse files Browse the repository at this point in the history
* Fix ufunc dispatch.
  • Loading branch information
hameerabbasi committed Apr 25, 2019
1 parent 22e096c commit c975172
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion uarray/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def outer(func):
@functools.wraps(func)
def inner(*args, **kwargs):
extracted_args = func(*args, **kwargs)
return tuple(arg_type(arg) for arg in extracted_args if not isinstance(arg, DispatchableInstance))
return tuple(arg_type(arg) if not isinstance(arg, DispatchableInstance) else arg for arg in extracted_args)

return inner

Expand Down
39 changes: 23 additions & 16 deletions unumpy/multimethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ def _identity_argreplacer(args, kwargs, arrays):
return args, kwargs


def _self_argreplacer(args, kwargs, dispatchables):
return dispatchables + args[1:], kwargs


def _ureduce_argreplacer(args, kwargs, arrays):
out_args = list(args)
out_args[1] = arrays[0]
out_args[:2] = arrays[:2]

out_kwargs = {**kwargs, 'out': arrays[1]}
out_kwargs = {**kwargs, 'out': arrays[2]}

return tuple(out_args), out_kwargs

Expand All @@ -28,11 +32,14 @@ class ndarray(DispatchableInstance):


class ufunc(DispatchableInstance):
def __init__(self, name, nin, nout):
def __init__(self, name, *args):
if isinstance(name, ufunc) and not len(args):
super().__init__(name)
return

self.name = name
self._nin = nin
self._nout = nout
super().__init__(self)
self._nin, self._nout = args
super().__init__(None)

def __str__(self):
return f"<ufunc '{self.name}'>"
Expand All @@ -46,14 +53,14 @@ def nout(self):
return self._nout

@property # type: ignore
@create_multimethod(_identity_argreplacer)
@create_multimethod(_self_argreplacer)
def types(self):
return ()
return (ufunc(self),)

@property # type: ignore
@create_multimethod(_identity_argreplacer)
@create_multimethod(_self_argreplacer)
def identity(self):
return ()
return (ufunc(self),)

@property
def nargs(self):
Expand All @@ -66,13 +73,13 @@ def ntypes(self):
def _ufunc_argreplacer(args, kwargs, arrays):
self = args[0]
args = args[1:]
in_arrays = arrays[:self.nin]
out_arrays = arrays[self.nin:]
in_arrays = arrays[1:self.nin+1]
out_arrays = arrays[self.nin+1:]
if self.nout == 1:
out_arrays = out_arrays[0]
out_kwargs = {**kwargs, 'out': out_arrays}

return (self, *in_arrays), out_kwargs
return (arrays[0], *in_arrays), out_kwargs

@create_multimethod(_ufunc_argreplacer)
@all_of_type(ndarray)
Expand All @@ -81,17 +88,17 @@ def __call__(self, *args, out=None):
if not isinstance(out, tuple):
out = (out,)

return in_args + out
return (ufunc(self),) + in_args + out

@create_multimethod(_ureduce_argreplacer)
@all_of_type(ndarray)
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False):
return (a, out)
return (ufunc(self), a, out)

@create_multimethod(_ureduce_argreplacer)
@all_of_type(ndarray)
def accumulate(self, a, axis=0, dtype=None, out=None):
return (a, out)
return (ufunc(self), a, out)


ufunc_list = [
Expand Down

0 comments on commit c975172

Please sign in to comment.