-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
topk and argtopk #10086
base: main
Are you sure you want to change the base?
topk and argtopk #10086
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -27,7 +27,12 @@ | |||||
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils | ||||||
from xarray.core.array_api_compat import get_array_namespace | ||||||
from xarray.core.options import OPTIONS | ||||||
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available | ||||||
from xarray.core.utils import ( | ||||||
is_duck_array, | ||||||
is_duck_dask_array, | ||||||
module_available, | ||||||
to_0d_object_array, | ||||||
) | ||||||
from xarray.namedarray.parallelcompat import get_chunked_array_type | ||||||
from xarray.namedarray.pycompat import array_type, is_chunked_array | ||||||
|
||||||
|
@@ -229,7 +234,7 @@ | |||||
xp = get_array_namespace(data) | ||||||
if xp == np: | ||||||
# numpy currently doesn't have a astype: | ||||||
return data.astype(dtype, **kwargs) | ||||||
Check warning on line 237 in xarray/core/duck_array_ops.py
|
||||||
return xp.astype(data, dtype, **kwargs) | ||||||
return data.astype(dtype, **kwargs) | ||||||
|
||||||
|
@@ -875,3 +880,74 @@ | |||||
|
||||||
def chunked_nanlast(darray, axis): | ||||||
return _chunked_first_or_last(darray, axis, op=nputils.nanlast) | ||||||
|
||||||
|
||||||
def argtopk(values, k, axis=None, skipna=None): | ||||||
if is_chunked_array(values): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
func = dask_array_ops.argtopk | ||||||
else: | ||||||
func = nputils.argtopk | ||||||
|
||||||
# Borrowed from nanops | ||||||
xp = get_array_namespace(values) | ||||||
if skipna or ( | ||||||
skipna is None | ||||||
and ( | ||||||
dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp) | ||||||
or dtypes.is_object(values.dtype) | ||||||
) | ||||||
): | ||||||
valid_count = count(values, axis=axis) | ||||||
|
||||||
if k < 0: | ||||||
fill_value = dtypes.get_pos_infinity(values.dtype) | ||||||
else: | ||||||
fill_value = dtypes.get_neg_infinity(values.dtype) | ||||||
|
||||||
filled_values = fillna(values, fill_value) | ||||||
else: | ||||||
return func(values, k=k, axis=axis) | ||||||
|
||||||
data = func(filled_values, k=k, axis=axis) | ||||||
|
||||||
# TODO This will evaluate dask arrays and might be costly. | ||||||
if array_any(valid_count == 0): | ||||||
raise ValueError("All-NaN slice encountered") | ||||||
return data | ||||||
|
||||||
|
||||||
def topk(values, k, axis=None, skipna=None): | ||||||
if is_chunked_array(values): | ||||||
func = dask_array_ops.topk | ||||||
else: | ||||||
func = nputils.topk | ||||||
|
||||||
# Borrowed from nanops | ||||||
xp = get_array_namespace(values) | ||||||
if skipna or ( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the way to do this is to use the fact that |
||||||
skipna is None | ||||||
and ( | ||||||
dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp) | ||||||
or dtypes.is_object(values.dtype) | ||||||
) | ||||||
): | ||||||
valid_count = count(values, axis=axis) | ||||||
|
||||||
if k < 0: | ||||||
fill_value = dtypes.get_pos_infinity(values.dtype) | ||||||
else: | ||||||
fill_value = dtypes.get_neg_infinity(values.dtype) | ||||||
|
||||||
filled_values = fillna(values, fill_value) | ||||||
else: | ||||||
return func(values, k=k, axis=axis) | ||||||
|
||||||
data = func(filled_values, k=k, axis=axis) | ||||||
|
||||||
if not hasattr(data, "dtype"): # scalar case | ||||||
data = fill_value if valid_count == 0 else data | ||||||
# we've computed a single min, max value of type object. | ||||||
# don't let np.array turn a tuple back into an array | ||||||
return to_0d_object_array(data) | ||||||
|
||||||
return where_method(data, valid_count != 0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,6 +302,59 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): | |
return coeffs, residuals | ||
|
||
|
||
def topk(values, k: int, axis: int): | ||
"""Extract the k largest elements from a on the given axis. | ||
If k is negative, extract the -k smallest elements instead. | ||
The returned elements are sorted. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would not sort. The user can do that if they really need to. |
||
""" | ||
if axis < 0: | ||
axis = values.ndim + axis | ||
|
||
if abs(k) >= values.shape[axis]: | ||
b = np.sort(values, axis=axis) | ||
else: | ||
a = np.partition(values, -k, axis=axis) | ||
k_slice = slice(-k, None) if k > 0 else slice(-k) | ||
b = a[tuple(k_slice if i == axis else slice(None) for i in range(values.ndim))] | ||
b.sort(axis=axis) | ||
if k < 0: | ||
return b | ||
return b[ | ||
tuple( | ||
slice(None, None, -1) if i == axis else slice(None) | ||
for i in range(values.ndim) | ||
) | ||
] | ||
|
||
|
||
def argtopk(values, k: int, axis: int): | ||
"""Extract the indices of the k largest elements from a on the given axis. | ||
If k is negative, extract the indices of the -k smallest elements instead. | ||
The returned elements are argsorted. | ||
""" | ||
if axis < 0: | ||
axis = values.ndim + axis | ||
|
||
if abs(k) >= values.shape[axis]: | ||
idx3 = np.argsort(values, axis=axis) | ||
else: | ||
idx = np.argpartition(values, -k, axis=axis) | ||
k_slice = slice(-k, None) if k > 0 else slice(-k) | ||
idx = idx[ | ||
tuple(k_slice if i == axis else slice(None) for i in range(values.ndim)) | ||
] | ||
a = np.take_along_axis(values, idx, axis) | ||
idx2 = np.argsort(a, axis=axis) | ||
idx3 = np.take_along_axis(idx, idx2, axis) | ||
if k < 0: | ||
return idx3 | ||
return idx3[ | ||
tuple( | ||
slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim) | ||
) | ||
] | ||
|
||
|
||
nanmin = _create_method("nanmin") | ||
nanmax = _create_method("nanmax") | ||
nanmean = _create_method("nanmean") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2511,6 +2511,146 @@ def argmax( | |
""" | ||
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) | ||
|
||
def _topk_stack( | ||
self, | ||
topk_funcname: str, | ||
dim: Dims, | ||
) -> Variable: | ||
# Get a name for the new dimension that does not conflict with any existing | ||
# dimension | ||
newdimname = f"_unravel_{topk_funcname}_dim_0" | ||
count = 1 | ||
while newdimname in self.dims: | ||
newdimname = f"_unravel_{topk_funcname}_dim_{count}" | ||
count += 1 | ||
return self.stack({newdimname: dim}) | ||
|
||
def _topk_helper( | ||
self, | ||
topk_funcname: str, | ||
k: int, | ||
dim: str, | ||
dtype: Any, | ||
keep_attrs: bool | None = None, | ||
skipna: bool | None = None, | ||
) -> Variable: | ||
from xarray.core.computation import apply_ufunc | ||
|
||
topk_func = getattr(duck_array_ops, topk_funcname) | ||
# apply_ufunc moves the dimension to the back. | ||
kwargs = {"k": k, "axis": -1, "skipna": skipna} | ||
|
||
result = apply_ufunc( | ||
topk_func, | ||
self, | ||
input_core_dims=[[dim]], | ||
exclude_dims={dim}, | ||
output_core_dims=[[topk_funcname]], | ||
output_dtypes=[dtype], | ||
dask_gufunc_kwargs=dict(output_sizes={topk_funcname: k}), | ||
dask="allowed", | ||
kwargs=kwargs, | ||
) | ||
|
||
keep_attrs_ = ( | ||
_get_keep_attrs(default=False) if keep_attrs is None else keep_attrs | ||
) | ||
|
||
if keep_attrs_: | ||
result.attrs = self._attrs | ||
return result | ||
|
||
def topk( | ||
self, | ||
k: int, | ||
dim: Dims = None, | ||
keep_attrs: bool | None = None, | ||
skipna: bool | None = None, | ||
) -> Variable | dict[Hashable, Variable]: | ||
""" | ||
TODO docstring | ||
""" | ||
# topk accepts only an integer axis like argmin or argmax, | ||
# not tuples, so we need to stack multiple dimensions. | ||
if dim is ... or dim is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Return dimension for 1D data. | ||
if self.ndim == 1: | ||
dim = self.dims[0] | ||
else: | ||
dim = self.dims | ||
|
||
if isinstance(dim, str): | ||
stacked = self | ||
else: | ||
stacked = self._topk_stack("topk", dim) | ||
dim = stacked.dims[-1] | ||
|
||
result = stacked._topk_helper( | ||
"topk", k=k, dim=dim, dtype=self.dtype, keep_attrs=keep_attrs, skipna=skipna | ||
) | ||
return result | ||
|
||
def argtopk( | ||
self, | ||
k: int, | ||
dim: Dims = None, | ||
keep_attrs: bool | None = None, | ||
skipna: bool | None = None, | ||
) -> Variable | dict[Hashable, Variable]: | ||
""" | ||
TODO docstring | ||
""" | ||
# argtopk accepts only an integer axis like argmin or argmax, | ||
# not tuples, so we need to stack multiple dimensions. | ||
if dim is ... or dim is None: | ||
# Return dimension for 1D data. | ||
if self.ndim == 1: | ||
dim = self.dims[0] | ||
else: | ||
dim = self.dims | ||
|
||
if isinstance(dim, str): | ||
return self._topk_helper( | ||
"argtopk", | ||
k=k, | ||
dim=dim, | ||
dtype=np.intp, | ||
keep_attrs=keep_attrs, | ||
skipna=skipna, | ||
) | ||
|
||
stacked = self._topk_stack("topk", dim) | ||
newdimname = stacked.dims[-1] | ||
|
||
result_flat_indices = stacked._topk_helper( | ||
"argtopk", | ||
k=k, | ||
dim=newdimname, | ||
dtype=np.intp, | ||
keep_attrs=keep_attrs, | ||
skipna=skipna, | ||
) | ||
|
||
reduce_shape = tuple(self.sizes[d] for d in dim) | ||
|
||
result_unravelled_indices = duck_array_ops.unravel_index( | ||
result_flat_indices.data, reduce_shape | ||
) | ||
|
||
result_dims = [d for d in stacked.dims if d != newdimname] + ["argtopk"] | ||
result = { | ||
d: Variable(dims=result_dims, data=i) | ||
for d, i in zip(dim, result_unravelled_indices, strict=True) | ||
} | ||
|
||
if keep_attrs is None: | ||
keep_attrs = _get_keep_attrs(default=False) | ||
if keep_attrs: | ||
for v in result.values(): | ||
v.attrs = self.attrs | ||
|
||
return result | ||
|
||
def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: | ||
""" | ||
Use sparse-array as backend. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
generate_aggregations.py
would make sense here so that we add it everywhere with the same docstring. That would get us groupby support for example, and I can eventually plug in flox when https://github.com/xarray-contrib/flox/pull/374/files is ready