Skip to content
Permalink
master
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
import functools
import itertools
import collections
import numbers
import operator
from uarray import create_multimethod, mark_as, all_of_type, Dispatchable
import builtins
create_numpy = functools.partial(create_multimethod, domain="numpy")
e = 2.718281828459045235360287471352662498
pi = 3.141592653589793238462643383279502884
euler_gamma = 0.577215664901532860606512090082402431
nan = float("nan")
inf = float("inf")
NINF = float("-inf")
PZERO = 0.0
NZERO = -0.0
newaxis = None
NaN = NAN = nan
Inf = Infinity = PINF = infty = inf
def _identity_argreplacer(args, kwargs, arrays):
return args, kwargs
def _dtype_argreplacer(args, kwargs, dispatchables):
def replacer(*a, dtype=None, **kw):
out_kw = kw.copy()
out_kw["dtype"] = dispatchables[0]
if "out" in out_kw:
out_kw["out"] = dispatchables[1]
return a, out_kw
return replacer(*args, **kwargs)
def _self_argreplacer(args, kwargs, dispatchables):
def self_method(a, *args, **kwargs):
kw_out = kwargs.copy()
if "out" in kw_out:
kw_out["out"] = dispatchables[1]
return (dispatchables[0],) + args, kw_out
return self_method(*args, **kwargs)
def _skip_self_argreplacer(args, kwargs, dispatchables):
def replacer(self, *args, **kwargs):
return (self,) + dispatchables, kwargs
return replacer(*args, **kwargs)
def _ureduce_argreplacer(args, kwargs, dispatchables):
def ureduce(self, a, axis=0, dtype=None, out=None, keepdims=False):
return (
(dispatchables[0], dispatchables[1]),
dict(
axis=axis,
dtype=dispatchables[2],
out=dispatchables[3],
keepdims=keepdims,
),
)
return ureduce(*args, **kwargs)
class ClassOverrideMeta(type):
def __new__(cls, name, bases, namespace):
bases_new = []
subclass = False
for b in bases:
if isinstance(b, cls):
subclass = True
bases_new.append(b._unwrapped)
else:
bases_new.append(b)
if subclass:
return type(name, tuple(bases_new), namespace)
return super().__new__(cls, name, bases, namespace)
def __init__(self, name, bases, namespace):
self._unwrapped = type(name, bases, namespace)
return super().__init__(name, bases, namespace)
@property # type: ignore
@create_numpy(_identity_argreplacer, default=lambda self: self._unwrapped)
def overridden_class(self):
return ()
@create_numpy(
_identity_argreplacer,
default=lambda self, value: isinstance(value, self.overridden_class),
)
def __instancecheck__(self, value):
return ()
@create_numpy(
_identity_argreplacer,
default=lambda self, value: issubclass(value, self.overridden_class),
)
def __subclasscheck__(self, value):
return ()
class ClassOverrideMetaWithConstructor(ClassOverrideMeta):
@create_numpy(
_identity_argreplacer,
default=lambda self, *a, **kw: self.overridden_class(*a, **kw),
)
def __call__(self, *args, **kwargs):
self._unwrapped = NotImplemented
return ()
class ClassOverrideMetaWithGetAttr(ClassOverrideMeta):
@create_numpy(
_identity_argreplacer,
default=lambda self, name: getattr(self.overridden_class, name),
)
def __getattr__(self, name):
return ()
class ClassOverrideMetaWithConstructorAndGetAttr(
ClassOverrideMetaWithConstructor, ClassOverrideMetaWithGetAttr
):
pass
def _call_first_argreplacer(args, kwargs, dispatchables):
def replacer(self, a, *args, **kwargs):
return (self, dispatchables[0]) + args, kwargs
return replacer(*args, **kwargs)
def _first2argreplacer(args, kwargs, arrays):
def func(a, b, *args, **kw):
kw_out = kw.copy()
if "out" in kw:
kw_out["out"] = arrays[2]
return arrays[:2] + args, kw_out
return func(*args, **kwargs)
def getattr_impl(attr):
def func(a):
if hasattr(a, attr):
return getattr(a, attr)
return NotImplemented
return func
def method_impl(method):
def func(self, *a, **kw):
if hasattr(a, method):
return getattr(a, method)(*a, **kw)
return NotImplemented
return func
def _ufunc_argreplacer(args, kwargs, arrays):
self = args[0]
args = args[1:]
in_arrays = arrays[1 : self.nin + 1]
out_arrays = arrays[self.nin + 1 : -1]
dtype = arrays[-1]
if self.nout == 1:
out_arrays = out_arrays[0]
if "out" in kwargs:
kwargs = {**kwargs, "out": out_arrays}
if "dtype" in kwargs:
kwargs["dtype"] = dtype
return (arrays[0], *in_arrays), kwargs
def _math_op(name, inplace=True, reverse=True):
def f(self, other):
return globals()[name](self, other)
def r(self, other):
return globals()[name](other, self)
def i(self, other):
return globals()[name](self, other, out=self)
out = [f]
if reverse:
out.append(r)
if inplace:
out.append(i)
return out if len(out) != 1 else out[0]
def _unary_op(name):
def f(self):
return globals()[name](self)
return f
class ndarray(metaclass=ClassOverrideMetaWithConstructorAndGetAttr):
__add__, __radd__, __iadd__ = _math_op("add")
__sub__, __rsub__, __isub__ = _math_op("subtract")
__mul__, __rmul__, __imul__ = _math_op("multiply")
__truediv__, __rtruediv__, __itruediv__ = _math_op("true_divide")
__floordiv__, __rfloordiv__, __ifloordiv__ = _math_op("floor_divide")
__matmul__, __rmatmul__, __imatmul__ = _math_op("matmul")
__mod__, __rmod__, __imod__ = _math_op("mod")
__divmod__, __rdivmod__ = _math_op("divmod", reverse=False)
__lshift__, __rlshift__, __ilshift__ = _math_op("left_shift")
__rshift__, __rrshift__, __irshift__ = _math_op("right_shift")
__pow__, __rpow__, __ipow__ = _math_op("power")
__and__, __rand__, __iand__ = _math_op("bitwise_and")
__or__, __ror__, __ior__ = _math_op("bitwise_or")
__xor__, __rxor__, __ixor__ = _math_op("bitwise_xor")
__neg__ = _unary_op("negative")
__pos__ = _unary_op("positive")
__abs__ = _unary_op("absolute")
__invert__ = _unary_op("invert")
__lt__ = _math_op("less", inplace=False, reverse=False)
__gt__ = _math_op("greater", inplace=False, reverse=False)
__le__ = _math_op("less_equal", inplace=False, reverse=False)
__ge__ = _math_op("greater_equal", inplace=False, reverse=False)
__eq__ = _math_op("equal", inplace=False, reverse=False)
__ne__ = _math_op("not_equal", inplace=False, reverse=False)
def __array_ufunc__(self, method, *inputs, **kwargs):
return NotImplemented
class dtype(metaclass=ClassOverrideMetaWithConstructorAndGetAttr):
pass
class ufunc(metaclass=ClassOverrideMeta):
def __init__(self, name, nin, nout):
self.name = name
self.nin, self.nout = nin, nout
def __str__(self):
return "<ufunc '{}'>".format(self.name)
__repr__ = __str__
@property # type: ignore
@create_numpy(_self_argreplacer)
def types(self):
return (mark_ufunc(self),)
@property # type: ignore
@create_numpy(_self_argreplacer)
def identity(self):
return (mark_ufunc(self),)
@property
def nargs(self):
return self.nin + self.nout
@property
def ntypes(self):
return len(self.types)
@create_numpy(_ufunc_argreplacer)
@all_of_type(ndarray)
def __call__(self, *args, out=None, dtype=None):
in_args = tuple(args)
dtype = mark_dtype(dtype)
if not isinstance(out, tuple):
out = (out,)
return (
(mark_ufunc(self),)
+ in_args
+ tuple(mark_non_coercible(o) for o in out)
+ (dtype,)
)
@create_numpy(_ureduce_argreplacer)
@all_of_type(ndarray)
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False):
return (mark_ufunc(self), a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_ureduce_argreplacer)
@all_of_type(ndarray)
def accumulate(self, a, axis=0, dtype=None, out=None):
return (mark_ufunc(self), a, mark_dtype(dtype), mark_non_coercible(out))
mark_ufunc = mark_as(ufunc)
mark_dtype = mark_as(dtype)
mark_non_coercible = lambda x: Dispatchable(x, ndarray, coercible=False)
# Math operations
add = ufunc("add", 2, 1)
subtract = ufunc("subtract", 2, 1)
multiply = ufunc("multiply", 2, 1)
matmul = ufunc("matmul", 2, 1)
divide = ufunc("divide", 2, 1)
logaddexp = ufunc("logaddexp", 2, 1)
logaddexp2 = ufunc("logaddexp2", 2, 1)
true_divide = ufunc("true_divide", 2, 1)
floor_divide = ufunc("floor_divide", 2, 1)
float_power = ufunc("float_power", 2, 1)
negative = ufunc("negative", 1, 1)
positive = ufunc("positive", 1, 1)
power = ufunc("power", 2, 1)
remainder = ufunc("remainder", 2, 1)
mod = ufunc("mod", 2, 1)
divmod = ufunc("divmod", 2, 2)
absolute = ufunc("absolute", 1, 1)
fabs = ufunc("fabs", 1, 1)
rint = ufunc("rint", 1, 1)
sign = ufunc("sign", 1, 1)
heaviside = ufunc("heaviside", 1, 1)
conjugate = ufunc("conjugate", 1, 1)
conj = conjugate
exp = ufunc("exp", 1, 1)
exp2 = ufunc("exp2", 1, 1)
log = ufunc("log", 1, 1)
log2 = ufunc("log2", 1, 1)
log10 = ufunc("log10", 1, 1)
expm1 = ufunc("expm1", 1, 1)
log1p = ufunc("log1p", 1, 1)
sqrt = ufunc("sqrt", 1, 1)
square = ufunc("square", 1, 1)
cbrt = ufunc("cbrt", 1, 1)
reciprocal = ufunc("reciprocal", 1, 1)
gcd = ufunc("gcd", 1, 1)
lcm = ufunc("lcm", 1, 1)
# Trigonometric functions
sin = ufunc("sin", 1, 1)
cos = ufunc("cos", 1, 1)
tan = ufunc("tan", 1, 1)
arcsin = ufunc("arcsin", 1, 1)
arccos = ufunc("arccos", 1, 1)
arctan = ufunc("arctan", 1, 1)
arctan2 = ufunc("arctan2", 2, 1)
hypot = ufunc("hypot", 2, 1)
degrees = ufunc("degrees", 1, 1)
radians = ufunc("radians", 1, 1)
sinh = ufunc("sinh", 1, 1)
cosh = ufunc("cosh", 1, 1)
tanh = ufunc("tanh", 1, 1)
arcsinh = ufunc("arcsinh", 1, 1)
arccosh = ufunc("arccosh", 1, 1)
arctanh = ufunc("arctanh", 1, 1)
deg2rad = ufunc("deg2rad", 1, 1)
rad2deg = ufunc("rad2deg", 1, 1)
# Bit-twiddling functions
bitwise_and = ufunc("bitwise_and", 2, 1)
bitwise_or = ufunc("bitwise_or", 2, 1)
bitwise_xor = ufunc("bitwise_xor", 2, 1)
invert = ufunc("invert", 1, 1)
left_shift = ufunc("left_shift", 2, 1)
right_shift = ufunc("right_shift", 2, 1)
# Comparison functions
greater = ufunc("greater", 2, 1)
greater_equal = ufunc("greater_equal", 2, 1)
less = ufunc("less", 2, 1)
less_equal = ufunc("less_equal", 2, 1)
not_equal = ufunc("not_equal", 2, 1)
equal = ufunc("equal", 2, 1)
logical_and = ufunc("logical_and", 2, 1)
logical_or = ufunc("logical_or", 2, 1)
logical_xor = ufunc("logical_xor", 2, 1)
logical_not = ufunc("logical_not", 1, 1)
maximum = ufunc("maximum", 2, 1)
minimum = ufunc("minimum", 2, 1)
fmax = ufunc("fmax", 2, 1)
fmin = ufunc("fmin", 2, 1)
# Floating functions
isfinite = ufunc("isfinite", 1, 1)
isinf = ufunc("isinf", 1, 1)
isnan = ufunc("isnan", 1, 1)
isnat = ufunc("isnat", 1, 1)
signbit = ufunc("signbit", 1, 1)
copysign = ufunc("copysign", 2, 1)
nextafter = ufunc("nextafter", 2, 1)
spacing = ufunc("spacing", 1, 1)
modf = ufunc("modf", 1, 2)
ldexp = ufunc("ldexp", 2, 1)
frexp = ufunc("frexp", 1, 2)
fmod = ufunc("fmod", 2, 1)
floor = ufunc("floor", 1, 1)
ceil = ufunc("ceil", 1, 1)
trunc = ufunc("trunc", 1, 1)
@create_numpy(_dtype_argreplacer)
def empty(shape, dtype="float64", order="C"):
return (mark_dtype(dtype),)
def _self_dtype_argreplacer(args, kwargs, dispatchables):
def replacer(a, *args, dtype=None, **kwargs):
out_kw = kwargs.copy()
out_kw["dtype"] = dispatchables[1]
if "out" in out_kw:
out_kw["out"] = dispatchables[2]
return (dispatchables[0],) + args, out_kw
return replacer(*args, **kwargs)
def _empty_like_default(prototype, dtype=None, order="K", subok=True, shape=None):
if order != "K" or subok != True:
return NotImplemented
out_shape = _shape(prototype) if shape is None else shape
out_dtype = prototype.dtype if dtype is None else dtype
return empty(out_shape, dtype=out_dtype)
@create_numpy(_self_dtype_argreplacer, default=_empty_like_default)
@all_of_type(ndarray)
def empty_like(prototype, dtype=None, order="K", subok=True, shape=None):
return (prototype, mark_dtype(dtype))
@create_numpy(_dtype_argreplacer)
def full(shape, fill_value, dtype=None, order="C"):
return (mark_dtype(dtype),)
def _full_like_default(a, fill_value, dtype=None, order="K", subok=True, shape=None):
if order != "K" or subok != True:
return NotImplemented
out_shape = _shape(a) if shape is None else shape
out_dtype = a.dtype if dtype is None else dtype
return full(out_shape, fill_value, dtype=out_dtype)
@create_numpy(_self_dtype_argreplacer, default=_full_like_default)
@all_of_type(ndarray)
def full_like(a, fill_value, dtype=None, order="K", subok=True, shape=None):
return (a, mark_dtype(dtype))
@create_numpy(_dtype_argreplacer)
def arange(start, stop=None, step=None, dtype=None):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer)
def array(object, dtype=None, copy=True, order="K", subok=False, ndmin=0):
return (mark_dtype(dtype),)
@create_numpy(
_dtype_argreplacer,
default=lambda shape, dtype, order="C": full(shape, 0, dtype, order),
)
def zeros(shape, dtype=float, order="C"):
return (mark_dtype(dtype),)
@create_numpy(
_self_dtype_argreplacer,
default=lambda a, dtype=None, order="K", subok=True, shape=None: full_like(
a, 0, dtype, order, subok, shape
),
)
@all_of_type(ndarray)
def zeros_like(a, dtype=None, order="K", subok=True, shape=None):
return (a, mark_dtype(dtype))
@create_numpy(
_dtype_argreplacer,
default=lambda shape, dtype, order="C": full(shape, 1, dtype, order),
)
def ones(shape, dtype=float, order="C"):
return (mark_dtype(dtype),)
@create_numpy(
_self_dtype_argreplacer,
default=lambda a, dtype=None, order="K", subok=True, shape=None: full_like(
a, 1, dtype, order, subok, shape
),
)
@all_of_type(ndarray)
def ones_like(a, dtype=None, order="K", subok=True, shape=None):
return (a, mark_dtype(dtype))
@create_numpy(_dtype_argreplacer)
def eye(N, M=None, k=0, dtype=float, order="C"):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer, default=lambda n, dtype=None: eye(n, dtype=dtype))
def identity(n, dtype=None):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer)
def asarray(a, dtype=None, order=None):
return (mark_dtype(dtype),)
@create_numpy(_self_dtype_argreplacer)
@all_of_type(ndarray)
def asanyarray(a, dtype=None, order=None):
return (a, mark_dtype(dtype))
def _asfarray_default(a, dtype=float):
a = asarray(a, dtype=dtype)
if not a.dtype.name.startswith("float"):
dtype = float
return asarray(a, dtype=dtype)
@create_numpy(_dtype_argreplacer, default=_asfarray_default)
def asfarray(a, dtype=float):
return (mark_dtype(dtype),)
@create_numpy(
_dtype_argreplacer,
default=lambda a, dtype=None: asarray(a, dtype=dtype, order="F"),
)
def asfortranarray(a, dtype=None):
return (mark_dtype(dtype),)
def _asarray_chkfinite_default(a, dtype=None, order=None):
arr = asarray(a, dtype=dtype, order=order)
if not all(isfinite(arr)):
raise ValueError("Array must not contain infs or NaNs.")
return arr
@create_numpy(_dtype_argreplacer, default=_asarray_chkfinite_default)
def asarray_chkfinite(a, dtype=None, order=None):
return (mark_dtype(dtype),)
@create_numpy(_self_dtype_argreplacer)
@all_of_type(ndarray)
def require(a, dtype=None, requirements=None):
return (a, mark_dtype(dtype))
@create_numpy(
_dtype_argreplacer,
default=lambda a, dtype=None: asarray(a, dtype=dtype, order="C"),
)
@all_of_type(ndarray)
def ascontiguousarray(a, dtype=None):
return (mark_dtype(dtype),)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def copy(a, order="K"):
return (a,)
@create_numpy(_dtype_argreplacer)
def frombuffer(buffer, dtype=float, count=-1, offset=0):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer)
def fromfile(file, dtype=float, count=-1, sep="", offset=0):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer)
def fromfunction(function, shape, **kwargs):
if "dtype" in kwargs:
dtype = kwargs["dtype"]
else:
dtype = float
return (mark_dtype(dtype),)
def _fromiter_default(iterable, dtype, count=-1):
if not isinstance(iterable, collections.abc.Iterable):
raise TypeError("'%s' object is not iterable" % type(iterable).__name__)
if count >= 0:
iterable = itertools.islice(iterable, 0, count)
return array(list(iterable), dtype=dtype)
@create_numpy(_dtype_argreplacer, default=_fromiter_default)
def fromiter(iterable, dtype, count=-1):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer)
def fromstring(string, dtype=float, count=-1, sep=""):
return (mark_dtype(dtype),)
@create_numpy(_dtype_argreplacer)
def loadtxt(
fname,
dtype=float,
comments="#",
delimiter=None,
converters=None,
skiprows=0,
usecols=None,
unpack=False,
ndmin=0,
encoding="bytes",
max_rows=None,
):
return (mark_dtype(dtype),)
def reduce_impl(red_ufunc: ufunc):
def inner(a, **kwargs):
return red_ufunc.reduce(a, **kwargs)
return inner
@create_numpy(_self_dtype_argreplacer, default=reduce_impl(add))
@all_of_type(ndarray)
def sum(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_self_dtype_argreplacer, default=reduce_impl(multiply))
@all_of_type(ndarray)
def prod(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_self_argreplacer, default=reduce_impl(minimum))
@all_of_type(ndarray)
def min(a, axis=None, out=None, keepdims=False):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer, default=reduce_impl(maximum))
@all_of_type(ndarray)
def max(a, axis=None, out=None, keepdims=False):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer, default=reduce_impl(logical_or))
@all_of_type(ndarray)
def any(a, axis=None, out=None, keepdims=False):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer, default=reduce_impl(logical_and))
@all_of_type(ndarray)
def all(a, axis=None, out=None, keepdims=False):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer, default=lambda x, out=None: equal(x, inf, out=out))
@all_of_type(ndarray)
def isposinf(x, out=None):
return (x, mark_non_coercible(out))
@create_numpy(_self_argreplacer, default=lambda x, out=None: equal(x, NINF, out=out))
@all_of_type(ndarray)
def isneginf(x, out=None):
return (x, mark_non_coercible(out))
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def iscomplex(x):
return (x,)
@create_numpy(_identity_argreplacer)
def iscomplexobj(x):
return ()
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def isreal(x):
return (x,)
@create_numpy(_identity_argreplacer)
def isrealobj(x):
return ()
@create_numpy(_identity_argreplacer)
def isscalar(element):
return ()
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def argmin(a, axis=None, out=None):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def nanargmin(a, axis=None):
return (a,)
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def argmax(a, axis=None, out=None):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def nanargmax(a, axis=None):
return (a,)
amin = min
amax = max
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def nanmin(a, axis=None, out=None):
return (a, mark_non_coercible(out))
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def nanmax(a, axis=None, out=None, keepdims=False):
return (a, mark_non_coercible(out))
@create_numpy(_self_dtype_argreplacer)
@all_of_type(ndarray)
def nansum(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_self_dtype_argreplacer)
@all_of_type(ndarray)
def nanprod(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_self_dtype_argreplacer)
@all_of_type(ndarray)
def cumprod(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_self_dtype_argreplacer)
@all_of_type(ndarray)
def cumsum(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(
_self_dtype_argreplacer,
default=lambda a, axis=None, dtype=None, out=None: cumprod(
where(isnan(a), 1, a), axis=axis, dtype=dtype, out=out # type: ignore
),
)
@all_of_type(ndarray)
def nancumprod(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(
_self_dtype_argreplacer,
default=lambda a, axis=None, dtype=None, out=None: cumsum(
where(isnan(a), 0, a), axis=axis, dtype=dtype, out=out # type: ignore
),
)
@all_of_type(ndarray)
def nancumsum(a, axis=None, dtype=None, out=None):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def percentile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
interpolation="linear",
keepdims=False,
):
return (a, q, mark_non_coercible(out))
@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def nanpercentile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
interpolation="linear",
keepdims=False,
):
return (a, q, mark_non_coercible(out))
@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def quantile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
interpolation="linear",
keepdims=False,
):
return (a, q, mark_non_coercible(out))
@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def nanquantile(
a,
q,
axis=None,
out=None,
overwrite_input=False,
interpolation="linear",
keepdims=False,
):
return (a, q, mark_non_coercible(out))
def _ureduce(a, axis):
nd = ndim(a)
if axis is None:
a = ravel(a)
dims = (1,) * nd
else:
if not isinstance(axis, collections.abc.Sequence):
axis = (axis,)
axis = _normalize_axis(nd, axis)
unselected_axis = tuple(set(range(nd)) - set(axis))
dims = list(a.shape)
for ax in axis:
dims[ax] = 1
dims = tuple(dims)
a = transpose(a, unselected_axis + axis)
a = a.reshape(a.shape[: len(unselected_axis)] + (-1,))
return a, dims
def _median_default(a, axis=None, out=None, overwrite_input=False, keepdims=False):
a, dims = _ureduce(a, axis)
mask = any(isnan(a), axis=-1, keepdims=True)
a = where(mask, nan, a)
N = a.shape[-1]
indexer = [slice(None)] * ndim(a)
index = N // 2
if N % 2 == 0:
a = partition(a, [index - 1, index])
indexer[-1] = slice(index - 1, index + 1)
else:
a = partition(a, index)
indexer[-1] = slice(index, index + 1)
indexer = tuple(indexer)
a = mean(a[indexer], axis=-1)
if keepdims:
a = a.reshape(dims)
if out is None:
return a
if a.shape != out.shape:
raise ValueError("out parameter must have the same shape as the output")
copyto(out, a, casting="unsafe")
return out
@create_numpy(_self_argreplacer, default=_median_default)
@all_of_type(ndarray)
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return (a, mark_non_coercible(out))
def _self_weights_argreplacer(args, kwargs, dispatchables):
def replacer(*args, weights=None, **kwargs):
return dispatchables[0:-1], dict(weights=dispatchables[-1], **kwargs)
return replacer(*args, **kwargs)
def _average_default(a, axis=None, weights=None, returned=False):
if weights is None:
avg = mean(a, axis=axis)
return avg if not returned else (avg, full(avg.shape, avg.size // a.size))
if a.shape != weights.shape:
if axis is None:
raise TypeError(
"Axis must be specified when shapes of a and weights differ."
)
if ndim(weights) != 1:
raise TypeError("1D weights expected when shapes of a and weights differ.")
if weights.shape[0] != a.shape[axis]:
raise ValueError("Length of weights not compatible with specified axis.")
weights = broadcast_to(weights, (1,) * (ndim(a) - 1) + weights.shape)
weights = swapaxes(weights, -1, axis)
weights = weights * ones(a.shape)
sum_of_weights = sum(weights, axis=axis)
if any(sum_of_weights == 0):
raise ZeroDivisionError("Weights sum to zero, can't be normalized.")
avg = sum(a * weights, axis=axis) / sum_of_weights
if returned:
return avg, sum_of_weights
else:
return avg
@create_numpy(_self_weights_argreplacer, default=_average_default)
@all_of_type(ndarray)
def average(a, axis=None, weights=None, returned=False):
return (a, weights)
def _axis_size(axis, shape):
if axis is None:
axis = tuple(range(len(shape)))
if not isinstance(axis, collections.abc.Sequence):
axis = (axis,)
size = 1
for ax in axis:
size *= shape[ax]
return size
def _mean_default(a, axis=None, dtype=None, out=None, keepdims=False):
if dtype is None:
if a.dtype.type == "i":
dtype = float
else:
dtype = a.dtype
N = _axis_size(axis, a.shape)
a = sum(a, axis=axis, dtype=dtype, keepdims=keepdims) / N
if out is None:
return a
if a.shape != out.shape:
raise ValueError("out parameter must have the same shape as the output")
copyto(out, a, casting="unsafe")
return out
@create_numpy(_self_dtype_argreplacer, default=_mean_default)
@all_of_type(ndarray)
def mean(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(
_self_dtype_argreplacer,
default=lambda a, axis=None, dtype=None, out=None, ddof=0, keepdims=False: sqrt(
var(a, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims) # type: ignore
),
)
@all_of_type(ndarray)
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
def _var_default(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if dtype is None:
if a.dtype.type == "i":
dtype = float
else:
dtype = a.dtype
N = _axis_size(axis, a.shape)
x = mean(a ** 2, axis=axis, dtype=dtype, keepdims=keepdims)
y = mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
a = (x - y ** 2) * (N / (N - ddof))
if out is None:
return a
if a.shape != out.shape:
raise ValueError("out parameter must have the same shape as the output")
copyto(out, a, casting="unsafe")
return out
@create_numpy(_self_dtype_argreplacer, default=_var_default)
@all_of_type(ndarray)
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(_self_argreplacer)
@all_of_type(ndarray)
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return (a, mark_non_coercible(out))
def _nanmean_default(a, axis=None, dtype=None, out=None, keepdims=False):
if dtype is None:
if a.dtype.kind == "i":
dtype = float
else:
dtype = a.dtype
N = sum(~isnan(a), axis=axis, keepdims=keepdims)
a = nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / N
if out is None:
return a
if a.shape != out.shape:
raise ValueError("out parameter must have the same shape as the output")
copyto(out, a, casting="unsafe")
return out
@create_numpy(_self_dtype_argreplacer, default=_nanmean_default)
@all_of_type(ndarray)
def nanmean(a, axis=None, dtype=None, out=None, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
@create_numpy(
_self_dtype_argreplacer,
default=lambda a, axis=None, dtype=None, out=None, ddof=0, keepdims=False: sqrt(
nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims) # type: ignore
),
)
@all_of_type(ndarray)
def nanstd(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
def _nanvar_default(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if dtype is None:
if a.dtype.kind == "i":
dtype = float
else:
dtype = a.dtype
N = sum(~isnan(a), axis=axis, keepdims=keepdims)
x = nansum(a ** 2, axis=axis, dtype=dtype, keepdims=keepdims) / N
y = nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / N
a = (x - y ** 2) * (N / (N - ddof))
if out is None:
return a
if a.shape != out.shape:
raise ValueError("out parameter must have the same shape as the output")
copyto(out, a, casting="unsafe")
return out
@create_numpy(_self_dtype_argreplacer, default=_nanvar_default)
@all_of_type(ndarray)
def nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
return (a, mark_dtype(dtype), mark_non_coercible(out))
def _corrcoef_argreplacer(args, kwargs, dispatchables):
def replacer(x, y=None, rowvar=True, bias=None, ddof=None):
return (
(dispatchables[0],),
dict(y=dispatchables[1], rowvar=rowvar, bias=bias, ddof=ddof),
)
return replacer(*args, **kwargs)
@create_numpy(_corrcoef_argreplacer)
@all_of_type(ndarray)
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None):
return (x, y)
@create_numpy(_first2argreplacer)
@all_of_type(ndarray)
def correlate(a, v, mode="valid"):
return (a, v)
def _cov_argreplacer(args, kwargs, dispatchables):
def replacer(
m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None
):
return (
(dispatchables[0],),
dict(
y=dispatchables[1],
rowvar=rowvar,
bias=bias,
ddof=ddof,
fweights=dispatchables[2],
aweights=dispatchables[3],
),
)
return replacer(*args, **kwargs)