Skip to content

Commit

Permalink
Support first, last with datetime, timedelta (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Nov 5, 2024
1 parent 672be8c commit 4f6164f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 27 deletions.
4 changes: 3 additions & 1 deletion flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"nanmean": {np.int_: np.float64},
"nanvar": {np.int_: np.float64},
"nanstd": {np.int_: np.float64},
"nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64},
"nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64},
}


Expand All @@ -51,7 +53,7 @@ def _numbagg_wrapper(
if cast_to:
for from_, to_ in cast_to.items():
if np.issubdtype(array.dtype, from_):
array = array.astype(to_)
array = array.astype(to_, copy=False)

func_ = getattr(numbagg.grouped, f"group_{func}")

Expand Down
33 changes: 32 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
)
from .cache import memoize
from .xrutils import (
_contains_cftime_datetimes,
_datetime_nanmin,
_to_pytimedelta,
datetime_to_numeric,
is_chunked_array,
is_duck_array,
is_duck_cubed_array,
Expand Down Expand Up @@ -2473,7 +2477,8 @@ def groupby_reduce(
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)

if _is_first_last_reduction(func):
is_first_last = _is_first_last_reduction(func)
if is_first_last:
if has_dask and nax != 1:
raise ValueError(
"For dask arrays: first, last, nanfirst, nanlast reductions are "
Expand All @@ -2486,6 +2491,24 @@ def groupby_reduce(
"along a single axis or when reducing across all dimensions of `by`."
)

# Flox's count works with non-numeric and its faster than converting.
is_npdatetime = array.dtype.kind in "Mm"
is_cftime = _contains_cftime_datetimes(array)
requires_numeric = (
(func not in ["count", "any", "all"] and not is_first_last)
or (func == "count" and engine != "flox")
or (is_first_last and is_cftime)
)
if requires_numeric:
if is_npdatetime:
offset = _datetime_nanmin(array)
# xarray always uses np.datetime64[ns] for np.datetime64 data
dtype = "timedelta64[ns]"
array = datetime_to_numeric(array, offset)
elif is_cftime:
offset = array.min()
array = datetime_to_numeric(array, offset, datetime_unit="us")

if nax == 1 and by_.ndim > 1 and expected_ is None:
# When we reduce along all axes, we are guaranteed to see all
# groups in the final combine stage, so everything works.
Expand Down Expand Up @@ -2671,6 +2694,14 @@ def groupby_reduce(

if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
result = result.astype(bool)

# Output of count has an int dtype.
if requires_numeric and func != "count":
if is_npdatetime:
return result.astype(dtype) + offset
elif is_cftime:
return _to_pytimedelta(result, unit="us") + offset

return (result, *groups)


Expand Down
25 changes: 0 additions & 25 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
import xarray as xr
from packaging.version import Version
from xarray.core.duck_array_ops import _datetime_nanmin

from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
from .core import (
Expand All @@ -18,7 +17,6 @@
)
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric

if TYPE_CHECKING:
from xarray.core.types import T_DataArray, T_Dataset
Expand Down Expand Up @@ -366,22 +364,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
if "nan" not in func and func not in ["all", "any", "count"]:
func = f"nan{func}"

# Flox's count works with non-numeric and its faster than converting.
requires_numeric = func not in ["count", "any", "all"] or (
func == "count" and kwargs["engine"] != "flox"
)
if requires_numeric:
is_npdatetime = array.dtype.kind in "Mm"
is_cftime = _contains_cftime_datetimes(array)
if is_npdatetime:
offset = _datetime_nanmin(array)
# xarray always uses np.datetime64[ns] for np.datetime64 data
dtype = "timedelta64[ns]"
array = datetime_to_numeric(array, offset)
elif is_cftime:
offset = array.min()
array = datetime_to_numeric(array, offset, datetime_unit="us")

result, *groups = groupby_reduce(array, *by, func=func, **kwargs)

# Transpose the new quantile dimension to the end. This is ugly.
Expand All @@ -395,13 +377,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
result = np.moveaxis(result, 0, -1)

# Output of count has an int dtype.
if requires_numeric and func != "count":
if is_npdatetime:
return result.astype(dtype) + offset
elif is_cftime:
return _to_pytimedelta(result, unit="us") + offset

return result

# These data variables do not have any of the core dimension,
Expand Down
22 changes: 22 additions & 0 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,28 @@ def _contains_cftime_datetimes(array) -> bool:
return False


def _datetime_nanmin(array):
"""nanmin() function for datetime64.
Caveats that this function deals with:
- In numpy < 1.18, min() on datetime64 incorrectly ignores NaT
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
- dask min() does not work on datetime64 (all versions at the moment of writing)
"""
from .xrdtypes import is_datetime_like

dtype = array.dtype
assert is_datetime_like(dtype)
# (NaT).astype(float) does not produce NaN...
array = np.where(pd.isnull(array), np.nan, array.astype(float))
array = min(array, skipna=True)
if isinstance(array, float):
array = np.array(array)
# ...but (NaN).astype("M8") does produce NaT
return array.astype(dtype)


def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
sl = other_ind[:axis] + (idx,) + other_ind[axis:]
Expand Down
16 changes: 16 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,3 +2006,19 @@ def test_blockwise_avoid_rechunk():
actual, groups = groupby_reduce(array, by, func="first")
assert_equal(groups, ["", "0", "1"])
assert_equal(actual, np.array([0, 0, 0], dtype=np.int64))


@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
def test_datetime_timedelta_first_last(engine, func):
import flox

idx = 0 if "first" in func else -1

dt = pd.date_range("2001-01-01", freq="d", periods=5).values
by = np.ones(dt.shape, dtype=int)
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
assert_equal(actual, dt[[idx]])

dt = dt - dt[0]
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
assert_equal(actual, dt[[idx]])

0 comments on commit 4f6164f

Please sign in to comment.