Skip to content

Commit

Permalink
Multimethods for functional programming routines (#65)
Browse files Browse the repository at this point in the history
* Add multimethods on functional programming

* Add default implementation for piecewise

* Add default implementation for apply_over_axes

* Add coersion of condlist's items to arrays in piecewise's default
  • Loading branch information
joaosferreira committed Jul 23, 2020
1 parent 78f3323 commit 3a96c7a
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 2 deletions.
91 changes: 91 additions & 0 deletions unumpy/_multimethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,3 +1770,94 @@ def roll(a, shift, axis=None):
@all_of_type(ndarray)
def rot90(m, k=1, axes=(0, 1)):
return (m,)


def _apply_along_axis_argreplacer(args, kwargs, dispatchables):
def replacer(func1d, axis, arr, *args, **kwargs):
return (func1d, axis, dispatchables[0]) + args, kwargs

return replacer(*args, **kwargs)


@create_numpy(_apply_along_axis_argreplacer)
@all_of_type(ndarray)
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
return (arr,)


def _apply_over_axes_argreplacer(args, kwargs, dispatchables):
def replacer(func, a, axes):
return (func, dispatchables[0], axes), dict()

return replacer(*args, **kwargs)


def _apply_over_axes_default(func, a, axes):
nd = ndim(a)
axes = _normalize_axis(nd, axes)

res = a
for axis in axes:
res = func(res, axis)

res_nd = ndim(res)
if res_nd < nd - 1:
raise ValueError(
"Function is not returning an array with the correct number of dimensions."
)
if res_nd == nd - 1:
res = expand_dims(res, axis)

return res


@create_numpy(_apply_over_axes_argreplacer, default=_apply_over_axes_default)
@all_of_type(ndarray)
def apply_over_axes(func, a, axes):
return (a,)


@create_numpy(_identity_argreplacer)
def frompyfunc(func, nin, nout, identity=None):
return ()


def _piecewise_default(x, condlist, funclist, *args, **kw):
if not isinstance(condlist, list):
condlist = [condlist]

condlist = [asarray(cond) for cond in condlist]

n1 = len(condlist)
n2 = len(funclist)

if n1 != n2:
if n1 + 1 == n2:
condelse = ~any(condlist, axis=0, keepdims=True)
condlist = concatenate([condlist, condelse], axis=0)
else:
raise ValueError(
"With %d condition(s), either %d or %d functions are expected."
% (n, n, n + 1)
)

y = zeros(x.shape, dtype=x.dtype)

for i, (cond, func) in enumerate(zip(condlist, funclist)):
if cond.shape != x.shape and ndim(cond) != 0:
raise ValueError(
"Condition at index %d doesn't have the same shape as x." % i
)

if isinstance(func, collections.abc.Callable):
y = where(cond, func(x, *args, **kw), y)
else:
y = where(cond, func, y)

return y


@create_numpy(_self_argreplacer, default=_piecewise_default)
@all_of_type(ndarray)
def piecewise(x, condlist, funclist, *args, **kw):
return (x,)
41 changes: 39 additions & 2 deletions unumpy/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

dtypes = ["int8", "int16", "int32", "float32", "float64"]
LIST_BACKENDS = [
(NumpyBackend, (onp.ndarray, onp.generic)),
(DaskBackend(), (da.Array, onp.generic)),
(NumpyBackend, (onp.ndarray, onp.generic, onp.ufunc)),
(DaskBackend(), (da.Array, onp.generic, da.ufunc.ufunc)),
(SparseBackend, (sparse.SparseArray, onp.ndarray, onp.generic)),
pytest.param(
(TorchBackend, (torch.Tensor,)),
Expand Down Expand Up @@ -469,3 +469,40 @@ def test_linalg(backend, method, args, kwargs):

if isinstance(ret, da.Array):
ret.compute()


@pytest.mark.parametrize(
"method, args, kwargs",
[
(
np.apply_along_axis,
(lambda a: (a[0] + a[-1]) * 0.5, 1, [[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
{},
),
(np.apply_over_axes, (np.sum, [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [0, 1]), {}),
(np.frompyfunc, (bin, 1, 1), {}),
(
np.piecewise,
(
[0, 1, 2, 3],
[[True, False, True, False], [False, True, False, True]],
[0, 1],
),
{},
),
],
)
def test_functional(backend, method, args, kwargs):
backend, types = backend
try:
with ua.set_backend(backend, coerce=True):
ret = method(*args, **kwargs)
except ua.BackendNotImplementedError:
if backend in FULLY_TESTED_BACKENDS and (backend, method) not in EXCEPTIONS:
raise
pytest.xfail(reason="The backend has no implementation for this ufunc.")

assert isinstance(ret, types)

if isinstance(ret, da.Array):
ret.compute()

0 comments on commit 3a96c7a

Please sign in to comment.