Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topk and argtopk #10086

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -134,3 +134,15 @@ def push(array, n, axis, method="blelloch"):
pushed_array = da.where(valid_positions, pushed_array, np.nan)

return pushed_array


def topk(a, k, axis):
import dask.array as da

return da.topk(a, k=k, axis=axis)


def argtopk(a, k, axis):
import dask.array as da

return da.argtopk(a, k=k, axis=axis)
31 changes: 31 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -6288,6 +6288,37 @@ def argmax(
else:
return self._replace_maybe_drop_dims(result)

def argtopk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think generate_aggregations.py would make sense here so that we add it everywhere with the same docstring. That would get us groupby support for example, and I can eventually plug in flox when https://github.com/xarray-contrib/flox/pull/374/files is ready

self,
k: int,
dim: Dims = None,
*,
keep_attrs: bool | None = None,
skipna: bool | None = None,
) -> Self | dict[Hashable, Self]:
"""
TODO docstring
"""
result = self.variable.argtopk(k, dim, keep_attrs, skipna)
if isinstance(result, dict):
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
else:
return self._replace_maybe_drop_dims(result)

def topk(
self,
k: int,
dim: Dims = None,
*,
keep_attrs: bool | None = None,
skipna: bool | None = None,
) -> Self:
"""
TODO docstring
"""
result = self.variable.topk(k, dim, keep_attrs, skipna)
return self._replace_maybe_drop_dims(result)

def query(
self,
queries: Mapping[Any, Any] | None = None,
20 changes: 20 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -9727,6 +9727,26 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
"Dataset.argmin() with a sequence or ... for dim"
)

def argtopk(self, k: int, dim: Hashable | None = None, **kwargs) -> Self:
"""
TODO docstring
"""
from xarray.core.missing import _apply_over_vars_with_dim

func = duck_array_ops.argtopk
new = _apply_over_vars_with_dim(func, self, dim=dim, k=k)
return new

def topk(self, k: int, dim: Hashable | None = None, **kwargs) -> Self:
"""
TODO docstring
"""
from xarray.core.missing import _apply_over_vars_with_dim

func = duck_array_ops.topk
new = _apply_over_vars_with_dim(func, self, dim=dim, k=k)
return new

def eval(
self,
statement: str,
78 changes: 77 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,12 @@
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils
from xarray.core.array_api_compat import get_array_namespace
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.core.utils import (
is_duck_array,
is_duck_dask_array,
module_available,
to_0d_object_array,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import array_type, is_chunked_array

@@ -229,7 +234,7 @@
xp = get_array_namespace(data)
if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13 all-but-numba

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13 all-but-numba

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / ubuntu-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.13

invalid value encountered in cast

Check warning on line 237 in xarray/core/duck_array_ops.py

GitHub Actions / windows-latest py3.13

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)
return data.astype(dtype, **kwargs)

@@ -875,3 +880,74 @@

def chunked_nanlast(darray, axis):
return _chunked_first_or_last(darray, axis, op=nputils.nanlast)


def argtopk(values, k, axis=None, skipna=None):
if is_chunked_array(values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_chunked_array(values):
if is_duck_dask_array(values):

func = dask_array_ops.argtopk
else:
func = nputils.argtopk

# Borrowed from nanops
xp = get_array_namespace(values)
if skipna or (
skipna is None
and (
dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp)
or dtypes.is_object(values.dtype)
)
):
valid_count = count(values, axis=axis)

if k < 0:
fill_value = dtypes.get_pos_infinity(values.dtype)
else:
fill_value = dtypes.get_neg_infinity(values.dtype)

filled_values = fillna(values, fill_value)
else:
return func(values, k=k, axis=axis)

data = func(filled_values, k=k, axis=axis)

# TODO This will evaluate dask arrays and might be costly.
if array_any(valid_count == 0):
raise ValueError("All-NaN slice encountered")
return data


def topk(values, k, axis=None, skipna=None):
if is_chunked_array(values):
func = dask_array_ops.topk
else:
func = nputils.topk

# Borrowed from nanops
xp = get_array_namespace(values)
if skipna or (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way to do this is to use the fact that nans sort to the end. Then given k and count you know what to provide to partition.

See https://github.com/xarray-contrib/flox/blob/a5bcc5be642c0c0c825ccb536208a0b736d569e3/flox/aggregate_flox.py#L85-L92

skipna is None
and (
dtypes.isdtype(values.dtype, ("complex floating", "real floating"), xp=xp)
or dtypes.is_object(values.dtype)
)
):
valid_count = count(values, axis=axis)

if k < 0:
fill_value = dtypes.get_pos_infinity(values.dtype)
else:
fill_value = dtypes.get_neg_infinity(values.dtype)

filled_values = fillna(values, fill_value)
else:
return func(values, k=k, axis=axis)

data = func(filled_values, k=k, axis=axis)

if not hasattr(data, "dtype"): # scalar case
data = fill_value if valid_count == 0 else data
# we've computed a single min, max value of type object.
# don't let np.array turn a tuple back into an array
return to_0d_object_array(data)

return where_method(data, valid_count != 0)
53 changes: 53 additions & 0 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
@@ -302,6 +302,59 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return coeffs, residuals


def topk(values, k: int, axis: int):
"""Extract the k largest elements from a on the given axis.
If k is negative, extract the -k smallest elements instead.
The returned elements are sorted.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not sort. The user can do that if they really need to.

"""
if axis < 0:
axis = values.ndim + axis

if abs(k) >= values.shape[axis]:
b = np.sort(values, axis=axis)
else:
a = np.partition(values, -k, axis=axis)
k_slice = slice(-k, None) if k > 0 else slice(-k)
b = a[tuple(k_slice if i == axis else slice(None) for i in range(values.ndim))]
b.sort(axis=axis)
if k < 0:
return b
return b[
tuple(
slice(None, None, -1) if i == axis else slice(None)
for i in range(values.ndim)
)
]


def argtopk(values, k: int, axis: int):
"""Extract the indices of the k largest elements from a on the given axis.
If k is negative, extract the indices of the -k smallest elements instead.
The returned elements are argsorted.
"""
if axis < 0:
axis = values.ndim + axis

if abs(k) >= values.shape[axis]:
idx3 = np.argsort(values, axis=axis)
else:
idx = np.argpartition(values, -k, axis=axis)
k_slice = slice(-k, None) if k > 0 else slice(-k)
idx = idx[
tuple(k_slice if i == axis else slice(None) for i in range(values.ndim))
]
a = np.take_along_axis(values, idx, axis)
idx2 = np.argsort(a, axis=axis)
idx3 = np.take_along_axis(idx, idx2, axis)
if k < 0:
return idx3
return idx3[
tuple(
slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim)
)
]


nanmin = _create_method("nanmin")
nanmax = _create_method("nanmax")
nanmean = _create_method("nanmean")
140 changes: 140 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -2511,6 +2511,146 @@ def argmax(
"""
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna)

def _topk_stack(
self,
topk_funcname: str,
dim: Dims,
) -> Variable:
# Get a name for the new dimension that does not conflict with any existing
# dimension
newdimname = f"_unravel_{topk_funcname}_dim_0"
count = 1
while newdimname in self.dims:
newdimname = f"_unravel_{topk_funcname}_dim_{count}"
count += 1
return self.stack({newdimname: dim})

def _topk_helper(
self,
topk_funcname: str,
k: int,
dim: str,
dtype: Any,
keep_attrs: bool | None = None,
skipna: bool | None = None,
) -> Variable:
from xarray.core.computation import apply_ufunc

topk_func = getattr(duck_array_ops, topk_funcname)
# apply_ufunc moves the dimension to the back.
kwargs = {"k": k, "axis": -1, "skipna": skipna}

result = apply_ufunc(
topk_func,
self,
input_core_dims=[[dim]],
exclude_dims={dim},
output_core_dims=[[topk_funcname]],
output_dtypes=[dtype],
dask_gufunc_kwargs=dict(output_sizes={topk_funcname: k}),
dask="allowed",
kwargs=kwargs,
)

keep_attrs_ = (
_get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
)

if keep_attrs_:
result.attrs = self._attrs
return result

def topk(
self,
k: int,
dim: Dims = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
) -> Variable | dict[Hashable, Variable]:
"""
TODO docstring
"""
# topk accepts only an integer axis like argmin or argmax,
# not tuples, so we need to stack multiple dimensions.
if dim is ... or dim is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. we have a infix_dims utility function for this.
  2. Can we punt the multiple dims case to later once someone asks for it? The stacking approach is bad with dask and will require some reshape_blockwise trickery, which isn't hard but we may as well do it in a followup (see polyfit for an example)

# Return dimension for 1D data.
if self.ndim == 1:
dim = self.dims[0]
else:
dim = self.dims

if isinstance(dim, str):
stacked = self
else:
stacked = self._topk_stack("topk", dim)
dim = stacked.dims[-1]

result = stacked._topk_helper(
"topk", k=k, dim=dim, dtype=self.dtype, keep_attrs=keep_attrs, skipna=skipna
)
return result

def argtopk(
self,
k: int,
dim: Dims = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
) -> Variable | dict[Hashable, Variable]:
"""
TODO docstring
"""
# argtopk accepts only an integer axis like argmin or argmax,
# not tuples, so we need to stack multiple dimensions.
if dim is ... or dim is None:
# Return dimension for 1D data.
if self.ndim == 1:
dim = self.dims[0]
else:
dim = self.dims

if isinstance(dim, str):
return self._topk_helper(
"argtopk",
k=k,
dim=dim,
dtype=np.intp,
keep_attrs=keep_attrs,
skipna=skipna,
)

stacked = self._topk_stack("topk", dim)
newdimname = stacked.dims[-1]

result_flat_indices = stacked._topk_helper(
"argtopk",
k=k,
dim=newdimname,
dtype=np.intp,
keep_attrs=keep_attrs,
skipna=skipna,
)

reduce_shape = tuple(self.sizes[d] for d in dim)

result_unravelled_indices = duck_array_ops.unravel_index(
result_flat_indices.data, reduce_shape
)

result_dims = [d for d in stacked.dims if d != newdimname] + ["argtopk"]
result = {
d: Variable(dims=result_dims, data=i)
for d, i in zip(dim, result_unravelled_indices, strict=True)
}

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
if keep_attrs:
for v in result.values():
v.attrs = self.attrs

return result

def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable:
"""
Use sparse-array as backend.
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.