Skip to content

Commit

Permalink
Argmin + argmax + tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Apr 23, 2019
1 parent 75caa26 commit b80198f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
12 changes: 12 additions & 0 deletions unumpy/multimethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,16 @@ def all(a, axis=None, out=None, keepdims=False):
return (a, out)


@create_multimethod(_reduce_argreplacer)
@all_of_type(ndarray)
def argmin(a, axis=None, out=None):
return (a, out)


@create_multimethod(_reduce_argreplacer)
@all_of_type(ndarray)
def argmax(a, axis=None, out=None):
return (a, out)


del ufunc_name
4 changes: 4 additions & 0 deletions unumpy/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,8 @@ def inner(self, *args, **kwargs):
register_numpy(multimethods.zeros)(np.zeros)
register_numpy(multimethods.ones)(np.ones)
register_numpy(multimethods.asarray)(np.asarray)

register_numpy(multimethods.argmin)(np.argmin)
register_numpy(multimethods.argmax)(np.argmax)

NumpyBackend.register_convertor(ndarray, np.asarray)
2 changes: 2 additions & 0 deletions unumpy/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def replace_args_kwargs(method, backend, args, kwargs):
(np.all, ([True, False],), {}),
(np.min, ([1, 3, 2],), {}),
(np.max, ([1, 3, 2],), {}),
(np.argmin, ([1, 3, 2],), {}),
(np.argmax, ([1, 3, 2],), {}),
])
def test_ufunc_reductions(backend, method, args, kwargs):
backend, types = backend
Expand Down
36 changes: 24 additions & 12 deletions unumpy/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,26 @@ def __call__(self, *args, out=None):


@register_torch(ufunc.reduce)
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False):
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, arg=False):
if self not in _reduce_mapping:
return NotImplemented

if axis is None:
axis = tuple(range(a.dim()))
elif not isinstance(axis, tuple):
axis = (axis,)

if isinstance(axis, tuple):
ret = a
for dim in tuple(reversed(sorted(axis))):
ret = _reduce_mapping[self](ret, dim=dim, keepdim=keepdims)
ret = a
for dim in tuple(reversed(sorted(axis))):
ret = _reduce_mapping[self](ret, dim=dim, keepdim=keepdims)

if out is not None:
out[...] = ret
ret = out
assert not arg or isinstance(ret, tuple)
if isinstance(ret, tuple):
ret = ret[int(arg)]

ret = _reduce_mapping[self](a, dim=axis, keepdim=keepdims, out=out)

if isinstance(ret, tuple):
ret = ret[0]
if out is not None:
out[...] = ret
ret = out

return ret

Expand Down Expand Up @@ -96,4 +96,16 @@ def asarray(a, dtype=None, order=None):

register_torch(multimethods.zeros)(torch.zeros)
register_torch(multimethods.ones)(torch.ones)


@register_torch(multimethods.argmax)
def argmax(a, axis=None, out=None):
return reduce(getattr(multimethods, 'max'), a, axis=axis, out=out, arg=True)


@register_torch(multimethods.argmin)
def argmin(a, axis=None, out=None):
return reduce(getattr(multimethods, 'min'), a, axis=axis, out=out, arg=True)


TorchBackend.register_convertor(ndarray, asarray)

0 comments on commit b80198f

Please sign in to comment.