Skip to content

Commit

Permalink
Fix normalisation.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Jun 11, 2019
1 parent 9ae6841 commit f21afc5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 45 deletions.
30 changes: 15 additions & 15 deletions uarray/backend.py
Expand Up @@ -93,6 +93,7 @@ def generate_multimethod(
uarray
See the module documentation for how to override the method by creating backends.
"""
defaults, opts = get_defaults(argument_extractor)

@functools.wraps(argument_extractor)
def inner(*args, **kwargs):
Expand Down Expand Up @@ -164,6 +165,12 @@ def replace_dispatchables(
)

args, kwargs = argument_replacer(args, kwargs, tuple(replaced_args))

kwargs = {k: kwargs[k] for k in opts if k in kwargs}
for k, v in defaults.items():
if k in kwargs and kwargs[k] is v:
del kwargs[k]

return args, kwargs, filtered_args

inner._coerce_args = replace_dispatchables # type: ignore
Expand Down Expand Up @@ -277,23 +284,16 @@ def skip_backend(backend):
skip.reset(token)


def _canonicalize(f, args, kwargs):
def get_defaults(f):
sig = inspect.signature(f)
bargs = sig.bind(*args, **kwargs)
# Pop out the named kwargs variable defaulting to {}
ret_kwargs = bargs.arguments.pop(inspect.getfullargspec(f).varkw, {})
# For all possible signature values
defaults = {}
opts = set()
for k, v in sig.parameters.items():
# If the name exists in the bound arguments and has a default value
if k in bargs.arguments and v.default is not v.empty:
# Remove from the bound arguments dict
val = bargs.arguments.pop(k)
# If the value isn't the same as the default value add it to ret_kwargs
if val is not v.default:
ret_kwargs[k] = val

# bargs.args here will be made up of what's left in bargs.arguments
return bargs.args, ret_kwargs
if v.default is not inspect.Parameter.empty:
defaults[k] = v.default
opts.add(v)

return defaults, opts


def set_global_backend(domain: str, backend):
Expand Down
68 changes: 38 additions & 30 deletions unumpy/multimethods.py
Expand Up @@ -9,32 +9,37 @@ def _identity_argreplacer(args, kwargs, arrays):


def _self_argreplacer(args, kwargs, dispatchables):
return dispatchables + args[1:], kwargs
def self_method(a, *args, **kwargs):
return dispatchables + args, kwargs

return self_method(*args, **kwargs)

def _ureduce_argreplacer(args, kwargs, arrays):
out_args = list(args)
out_args[:2] = arrays[:2]

out_kwargs = {**kwargs, "out": arrays[2]}
def _ureduce_argreplacer(args, kwargs, dispatchables):
def ureduce(self, a, axis=0, dtype=None, out=None, keepdims=False):
return (
(dispatchables[0], dispatchables[1]),
dict(axis=axis, dtype=dtype, out=dispatchables[2], keepdims=keepdims),
)

return tuple(out_args), out_kwargs
return ureduce(*args, **kwargs)


def _reduce_argreplacer(args, kwargs, arrays):
out_args = list(args)
out_args[0] = arrays[0]
def reduce(a, axis=None, dtype=None, out=None, keepdims=False):
return (
(arrays[0],),
dict(axis=axis, dtype=dtype, out=arrays[1], keepdims=keepdims),
)

if "out" in kwargs:
out_kwargs = {**kwargs, "out": arrays[1]}
else:
out_kwargs = kwargs

return tuple(out_args), out_kwargs
return reduce(*args, **kwargs)


def _first2argreplacer(args, kwargs, arrays):
return arrays + args[2:], kwargs
def func(a, b, **kwargs):
return arrays, kwargs

return func(*args, **kwargs)


class ndarray:
Expand Down Expand Up @@ -412,7 +417,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):


# set routines
@create_numpy(_reduce_argreplacer)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def unique(a, return_index=False, return_inverse=False, return_counts=False, axis=None):
return (a,)
Expand Down Expand Up @@ -469,7 +474,7 @@ def union1d(ar1, ar2):
return (ar1, ar2)


@create_numpy(_reduce_argreplacer)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def sort(a, axis=None, kind=None, order=None):
return (a,)
Expand Down Expand Up @@ -501,14 +506,17 @@ def broadcast_arrays(*args, subok=False):
return args


@create_numpy(_reduce_argreplacer)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def broadcast_to(array, shape, subok=False):
return (array,)


def _first_argreplacer(args, kwargs, arrays):
return (arrays,) + args[1:], kwargs
def _first_argreplacer(args, kwargs, arrays1):
def func(arrays, axis=0, out=None):
return (arrays1,), dict(axis=0, out=None)

return func(*args, **kwargs)


@create_numpy(_first_argreplacer)
Expand All @@ -523,34 +531,34 @@ def stack(arrays, axis=0, out=None):
return arrays


@create_numpy(_first_argreplacer)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def argsort(a, axis=-1, kind="quicksort", order=None):
return a
return (a,)


@create_numpy(_first_argreplacer, default=lambda a: sort(a, axis=0))
@create_numpy(_self_argreplacer, default=lambda a: sort(a, axis=0))
@all_of_type(ndarray)
def msort(a):
return a
return (a,)


@create_numpy(_first_argreplacer, default=lambda a: sort(a))
@create_numpy(_self_argreplacer, default=lambda a: sort(a))
@all_of_type(ndarray)
def sort_complex(a):
return a
return (a,)


@create_numpy(_first_argreplacer)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def partition(a, kth, axis=-1, kind="introselect", order=None):
return a
return (a,)


@create_numpy(_first_argreplacer)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def argpartition(a, kth, axis=-1, kind="introselect", order=None):
return a
return (a,)


del ufunc_name

0 comments on commit f21afc5

Please sign in to comment.