Skip to content

Commit

Permalink
Multimethods for mathematical functions (#64)
Browse files Browse the repository at this point in the history
* Add multimethods for mathematical functions

* Add some default implementations

* Refactor _ediff1d_default

* Add default implementation for interp

* Add support for args left and right to _interp_default

* Add default implementation for unwrap

* Add default implementation for fix

* Refactor fix's default in terms of trunc
  • Loading branch information
joaosferreira committed Jul 23, 2020
1 parent 3a96c7a commit 6edccfd
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 1 deletion.
283 changes: 282 additions & 1 deletion unumpy/_multimethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def accumulate(self, a, axis=0, dtype=None, out=None):
logaddexp2 = ufunc("logaddexp2", 2, 1)
true_divide = ufunc("true_divide", 2, 1)
floor_divide = ufunc("floor_divide", 2, 1)
float_power = ufunc("float_power", 2, 1)
negative = ufunc("negative", 1, 1)
positive = ufunc("positive", 1, 1)
power = ufunc("power", 2, 1)
Expand All @@ -258,7 +259,8 @@ def accumulate(self, a, axis=0, dtype=None, out=None):
rint = ufunc("rint", 1, 1)
sign = ufunc("sign", 1, 1)
heaviside = ufunc("heaviside", 1, 1)
conj = ufunc("conj", 1, 1)
conjugate = ufunc("conjugate", 1, 1)
conj = conjugate
exp = ufunc("exp", 1, 1)
exp2 = ufunc("exp2", 1, 1)
log = ufunc("log", 1, 1)
Expand All @@ -282,6 +284,8 @@ def accumulate(self, a, axis=0, dtype=None, out=None):
arctan = ufunc("arctan", 1, 1)
arctan2 = ufunc("arctan2", 2, 1)
hypot = ufunc("hypot", 2, 1)
degrees = ufunc("degrees", 1, 1)
radians = ufunc("radians", 1, 1)
sinh = ufunc("sinh", 1, 1)
cosh = ufunc("cosh", 1, 1)
tanh = ufunc("tanh", 1, 1)
Expand Down Expand Up @@ -704,6 +708,50 @@ def nanprod(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_non_coercible(out))


def _self_dtype_out_argreplacer(args, kwargs, dispatchables):
def replacer(a, *args, dtype=None, out=None, **kwargs):
return (
(dispatchables[0],) + args,
dict(dtype=dispatchables[1], out=dispatchables[2], **kwargs),
)

return replacer(*args, **kwargs)


@create_numpy(_self_dtype_out_argreplacer)
@all_of_type(ndarray)
def cumprod(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))


@create_numpy(_self_dtype_out_argreplacer)
@all_of_type(ndarray)
def cumsum(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))


@create_numpy(
_self_dtype_out_argreplacer,
default=lambda a, axis=None, dtype=None, out=None: cumprod(
where(isnan(a), 1, a), axis=axis, dtype=dtype, out=out # type: ignore
),
)
@all_of_type(ndarray)
def nancumprod(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))


@create_numpy(
_self_dtype_out_argreplacer,
default=lambda a, axis=None, dtype=None, out=None: cumsum(
where(isnan(a), 0, a), axis=axis, dtype=dtype, out=out # type: ignore
),
)
@all_of_type(ndarray)
def nancumsum(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))


@create_numpy(_reduce_argreplacer)
@all_of_type(ndarray)
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
Expand Down Expand Up @@ -1334,12 +1382,64 @@ def diff(a, n=1, axis=-1):
return (a,)


def _ediff1d_argreplacer(args, kwargs, dispatchables):
return (
(dispatchables[0],),
dict(to_end=dispatchables[1], to_begin=dispatchables[2]),
)


def _ediff1d_default(ary, to_end=None, to_begin=None):
ary = ravel(ary)

diffs = ary[1:] - ary[:-1]

if to_end is None and to_begin is None:
return diffs

arrays = []
if to_begin is not None:
arrays.append(ravel(to_begin))

arrays.append(diff)

if to_end is not None:
arrays.append(ravel(to_end))

return concatenate(arrays)


@create_numpy(_ediff1d_argreplacer, default=_ediff1d_default)
@all_of_type(ndarray)
def ediff1d(ary, to_end=None, to_begin=None):
return (ary, to_end, to_begin)


@create_numpy(_args_argreplacer)
@all_of_type(ndarray)
def gradient(a, *varargs, edge_order=1, axis=None):
return (a,) + varargs


@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
return (a, b)


def _trapz_argreplacer(args, kwargs, dispatchables):
def replacer(y, x=None, dx=1.0, axis=-1):
return (dispatchables[0],), dict(x=dispatchables[1], dx=dx, axis=axis)

return replacer(*args, **kwargs)


@create_numpy(_trapz_argreplacer)
@all_of_type(ndarray)
def trapz(y, x=None, dx=1.0, axis=-1):
return (y, x)


class _Recurser(object):
def __init__(self, recurse_if):
self.recurse_if = recurse_if
Expand Down Expand Up @@ -1861,3 +1961,184 @@ def _piecewise_default(x, condlist, funclist, *args, **kw):
@all_of_type(ndarray)
def piecewise(x, condlist, funclist, *args, **kw):
return (x,)


def _unwrap_default(p, discont=3.141592653589793, axis=-1):
nd = ndim(p)

dd = diff(p, axis=axis)

slice0 = [slice(None)] * nd
slice0[axis] = slice(0, 1)
slice0 = tuple(slice0)

slice1 = [slice(None, None)] * nd
slice1[axis] = slice(1, None)
slice1 = tuple(slice1)

ddmod = mod(dd + pi, 2 * pi) - pi
ddmod = where((ddmod == -pi) & (dd > 0), pi, ddmod)

ph_correct = ddmod - dd
ph_correct = where(absolute(dd) < discont, 0, ph_correct)

up_slice0 = p[slice0]
up_slice1 = p[slice1] + cumsum(ph_correct, axis=axis)

up = concatenate([up_slice0, up_slice1], axis=axis)

return up


@create_numpy(_self_argreplacer, default=_unwrap_default)
@all_of_type(ndarray)
def unwrap(p, discont=3.141592653589793, axis=-1):
return (p,)


@create_numpy(_self_out_argreplacer)
@all_of_type(ndarray)
def around(a, decimals=0, out=None):
return (a, mark_non_coercible(out))


round_ = around


@create_numpy(_self_out_argreplacer, default=lambda x, out=None: trunc(x, out=out))
@all_of_type(ndarray)
def fix(x, out=None):
return (x, mark_non_coercible(out))


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def i0(x):
return (x,)


@create_numpy(
_self_argreplacer, default=lambda x: where(x != 0, sin(pi * x) / (pi * x), 1)
)
@all_of_type(ndarray)
def sinc(x):
return (x,)


def _angle_default(z, deg=False):
angles = arctan2(z.imag, z.real)
if deg:
angles *= 180 / pi

return angles


@create_numpy(_self_argreplacer, default=_angle_default)
@all_of_type(ndarray)
def angle(z, deg=False):
return (z,)


@create_numpy(_self_argreplacer, default=lambda val: val.real)
@all_of_type(ndarray)
def real(val):
return (val,)


@create_numpy(_self_argreplacer, default=lambda val: val.imag)
@all_of_type(ndarray)
def imag(val):
return (val,)


@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def convolve(a, v, mode="full"):
return (a, v)


def _clip_argreplacer(args, kwargs, dispatchables):
def replacer(a, a_min, a_max, out=None, **kwargs):
return dispatchables[:3], dict(out=dispatchables[3], **kwargs)

return replacer(*args, **kwargs)


def _clip_default(a, a_min, a_max, out=None, **kwargs):
if a_min is None and a_max is None:
raise ValueError("One of max or min must be given.")

if a_min is not None:
a = where(a < a_min, a_min, a)
if a_max is not None:
a = where(a > a_max, a_max, a)

if out is None:
return a
else:
copyto(out, a)


@create_numpy(_clip_argreplacer, default=_clip_default)
@all_of_type(ndarray)
def clip(a, a_min, a_max, out=None, **kwargs):
return (a, a_min, a_max, mark_non_coercible(out))


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
return (x,)


@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def real_if_close(a, tol=100):
return (a,)


def _interp_default(x, xp, fp, left=None, right=None, period=None):
if ndim(xp) != 1 or ndim(fp) != 1:
raise ValueError("Data points must be 1-D sequences.")
if len(xp) != len(fp):
raise ValueError("fp and xp are not of the same length.")

sorted_idxs = argsort(xp)
xp = xp[sorted_idxs]
fp = fp[sorted_idxs]

if period is not None:
if period == 0:
raise ValueError("period must be a non-zero value.")

period = abs(period)
x = x % period
xp = xp % period

xp = concatenate([xp[-1:] - period, xp, xp[0:1] + period])
fp = concatenate([fp[-1:], fp, fp[0:1]])

idxs = searchsorted(xp, x)
idxs = where(idxs >= len(xp), len(xp) - 1, idxs)

y0 = fp[idxs - 1]
y1 = fp[idxs]
x0_dist = x - xp[idxs - 1]
x1_dist = xp[idxs] - x
x_dist = xp[idxs] - xp[idxs - 1]

result = (y0 * x1_dist + y1 * x0_dist) / x_dist

left = fp[0] if left is None else left
right = fp[-1] if right is None else right

result = where(x < xp[0], left, result)
result = where(x > xp[-1], right, result)

return result


@create_numpy(_self_argreplacer, default=_interp_default)
@all_of_type(ndarray)
def interp(x, xp, fp, left=None, right=None, period=None):
return (x,)
34 changes: 34 additions & 0 deletions unumpy/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,20 @@ def replace_args_kwargs(method, backend, args, kwargs):
(np.count_nonzero, ([True, False, True, False],), {}),
(np.linspace, (0, 100, 200), {}),
(np.logspace, (0, 4, 200), {}),
(np.unwrap, ([0.0, 0.78539816, 1.57079633, 5.49778714, 6.28318531],), {}),
(np.around, ([0.5, 1.5, 2.5, 3.5, 4.5],), {}),
(np.round_, ([0.5, 1.5, 2.5, 3.5, 4.5],), {}),
(np.fix, ([2.1, 2.9, -2.1, -2.9],), {}),
(np.cumprod, ([1, 2, 3],), {}),
(np.cumsum, ([1, 2, 3],), {}),
(np.nancumprod, ([1, np.nan],), {"axis": 0}),
(np.nancumsum, ([1, np.nan],), {"axis": 0}),
(np.diff, ([1, 3, 2],), {}),
(np.ediff1d, ([1, 2, 4, 7, 0],), {}),
(np.cross, ([1, 2, 3], [4, 5, 6]), {}),
(np.trapz, ([1, 2, 3],), {}),
(np.i0, ([0.0, 1.0 + 2j],), {}),
(np.sinc, ([0, 1, 2],), {}),
(np.isclose, ([1, 3, 2], [3, 2, 1]), {}),
(np.allclose, ([1, 3, 2], [3, 2, 1]), {}),
(np.isposinf, ([np.NINF, 0.0, np.inf],), {}),
Expand Down Expand Up @@ -193,6 +206,13 @@ def replace_args_kwargs(method, backend, args, kwargs):
(np.flipud, ([[1, 2], [3, 4]],), {}),
(np.roll, ([1, 2, 3], 1), {}),
(np.rot90, ([[1, 2], [3, 4]],), {}),
(np.angle, ([1.0, 1.0j, 1 + 1j],), {}),
(np.real, ([1 + 2j, 3 + 4j, 5 + 6j],), {}),
(np.imag, ([1 + 2j, 3 + 4j, 5 + 6j],), {}),
(np.convolve, ([1, 2, 3], [0, 1, 0.5]), {}),
(np.nan_to_num, ([np.inf, np.NINF, np.nan],), {}),
(np.real_if_close, ([2.1 + 4e-14j, 5.2 + 3e-15j],), {}),
(np.interp, (2.5, [1, 2, 3], [3, 2, 0]), {}),
],
)
def test_functions_coerce(backend, method, args, kwargs):
Expand Down Expand Up @@ -391,8 +411,10 @@ def test_array_creation(backend, method, args, kwargs):
(np.divide, ([6, 1], [3, 2]), {}, [2.0, 0.5]),
(np.true_divide, ([6, 1], [3, 2]), {}, [2.0, 0.5]),
(np.power, ([2, 3], [3, 2]), {}, [8, 9]),
(np.float_power, ([2, 3], [3, 2]), {}, [8, 9]),
(np.positive, ([1, -2],), {}, [1, -2]),
(np.negative, ([-2, 3],), {}, [2, -3]),
(np.conjugate, ([1.0 + 2.0j, -1.0 - 1j],), {}, [1.0 - 2.0j, -1.0 + 1j]),
(np.conj, ([1.0 + 2.0j, -1.0 - 1j],), {}, [1.0 - 2.0j, -1.0 + 1j]),
(np.exp, ([0, 1, 2],), {}, [1.0, 2.718281828459045, 7.38905609893065]),
(np.exp2, ([3, 4],), {}, [8, 16]),
Expand All @@ -409,6 +431,18 @@ def test_array_creation(backend, method, args, kwargs):
{},
np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]),
),
(
np.degrees,
([0.0, 0.52359878, 1.04719755, 1.57079633],),
{},
[0.0, 30.0, 60.0, 90.0],
),
(
np.radians,
([0.0, 30.0, 60.0, 90.0],),
{},
[0.0, 0.52359878, 1.04719755, 1.57079633],
),
],
)
def test_ufuncs_results(backend, method, args, kwargs, res):
Expand Down

0 comments on commit 6edccfd

Please sign in to comment.