Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 126 additions & 12 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,18 +249,13 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression:

# At this point, group is always a pandas Series
assert isinstance(group, pd.Series)
group_dim = group.index.name

arrays = [group, group.groupby(group).cumcount()]
idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM])
new_coords = Coordinates.from_pandas_multiindex(idx, group_dim)
coords = self.data.indexes[group_dim]
names_to_drop = [coords.name]
if isinstance(coords, pd.MultiIndex):
names_to_drop += list(coords.names)
ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords)
ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value)
ds = LinearExpression._sum(ds, dim=GROUPED_TERM_DIM)

if self._can_sum_by_scatter(group):
ds = self._sum_by_scatter(group)
else:
# chunked (e.g. dask-backed) data or exotic coordinates on the
# grouped dimension: use xarray's unstack machinery
ds = self._sum_by_unstack(group)

if int_map is not None:
index = ds.indexes[GROUP_DIM].map({v: k for k, v in int_map.items()})
Expand All @@ -279,6 +274,125 @@ def func(ds: Dataset) -> Dataset:

return self.map(func, **kwargs, shortcut=True)

def _can_sum_by_scatter(self, group: pd.Series) -> bool:
"""
Whether :meth:`_sum_by_scatter` covers the structure of the data.

The scatter kernel requires numpy-backed arrays (chunked data cannot be
scattered into preallocated numpy arrays) and no coordinates tied to
the grouped dimension besides its own index. Everything else falls
back to :meth:`_sum_by_unstack`.
"""
data = self.data
group_dim = group.index.name

numpy_backed = all(
isinstance(data[k].data, np.ndarray) for k in ("coeffs", "vars", "const")
)
if not numpy_backed:
return False

index = data.indexes.get(group_dim)
index_names = {group_dim, *(index.names if index is not None else ())}
return all(
coord.dims == (group_dim,) and name in index_names
for name, coord in data.coords.items()
if group_dim in coord.dims
)

def _sum_by_scatter(self, group: pd.Series) -> Dataset:
"""
Sum groups by scattering all terms directly into the final padded arrays.

Every group member keeps its block of ``nterm`` terms, so the resulting
term dimension has size ``max_group_size * nterm`` and smaller groups are
padded with fill values. In contrast to :meth:`_sum_by_unstack` only the
result arrays are allocated, without intermediate copies of that size.

Only the term and constant values are computed with numpy; the result
structure (dimensions, coordinates and their order) is assembled by
xarray. :meth:`_can_sum_by_scatter` decides whether the data is simple
enough for this kernel.
"""
data = self.data
group_dim = group.index.name
fill_value = LinearExpression._fill_value

codes, unique_groups = pd.factorize(group, sort=True)
if (codes == -1).any():
raise ValueError(
"Cannot group by a pandas object containing NaN values. "
"Drop or fill the corresponding entries before grouping."
)

n_groups = len(unique_groups)
sizes = np.bincount(codes, minlength=n_groups)
max_size = int(sizes.max()) if n_groups else 0

# position of each element within its group (order of appearance)
positions = pd.Series(codes).groupby(codes).cumcount().to_numpy()

def scatter(
da: DataArray, fill: Any
) -> tuple[tuple[Hashable, ...], np.ndarray]:
"""Scatter one term-array into its padded (group x term) layout."""
rest_dims = [d for d in da.dims if d not in (group_dim, TERM_DIM)]
values = da.transpose(group_dim, *rest_dims, TERM_DIM).values
rest_shape = values.shape[1:-1]
nterm = values.shape[-1]

out = np.full(
(n_groups, *rest_shape, nterm, max_size), fill, dtype=values.dtype
)
locs = (codes, *(slice(None),) * (len(rest_shape) + 1), positions)
out[locs] = values
# collapsing (nterm, max_size) into one axis keeps all terms of one
# group member together, with padding at the end of each block
out = out.reshape((n_groups, *rest_shape, nterm * max_size))
return (GROUP_DIM, *rest_dims, TERM_DIM), out

coeffs_dims, coeffs = scatter(data.coeffs, fill_value["coeffs"])
vars_dims, vars = scatter(data.vars, fill_value["vars"])

# constants are summed up within each group, skipping NaN values
const_dims = [d for d in data.const.dims if d != group_dim]
const_values = data.const.transpose(group_dim, *const_dims).values
const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype)
np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values))

# only the values above are computed with numpy, the result structure
# (dimensions, coordinates and their order) is assembled by xarray
# itself and thereby matches a result of unstacking the group dimension
structure = data.drop_vars(["coeffs", "vars", "const"])
structure = structure.drop_dims(group_dim)
structure = structure.expand_dims({GROUP_DIM: unique_groups})

return structure.assign(
coeffs=(coeffs_dims, coeffs),
vars=(vars_dims, vars),
const=((GROUP_DIM, *const_dims), const),
)

def _sum_by_unstack(self, group: pd.Series) -> Dataset:
"""
Sum groups by unstacking the group dimension into a padded helper
dimension and summing over it.

Equivalent to :meth:`_sum_by_scatter` but goes through xarray's
unstack/stack machinery, which also supports chunked (dask) data.
"""
group_dim = group.index.name
arrays = [group, group.groupby(group).cumcount()]
idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM])
new_coords = Coordinates.from_pandas_multiindex(idx, group_dim)
coords = self.data.indexes[group_dim]
names_to_drop = [coords.name]
if isinstance(coords, pd.MultiIndex):
names_to_drop += list(coords.names)
ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords)
ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value)
return LinearExpression._sum(ds, dim=GROUPED_TERM_DIM)

def roll(self, **kwargs: Any) -> LinearExpression:
"""
Roll the groupby object.
Expand Down
124 changes: 124 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,130 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None:
assert grouped.nterm == 10


def test_linear_expression_groupby_skewed_unsorted_groups(v: Variable) -> None:
"""
The scatter-based fast path must match the xarray fallback for groups that
are unsorted, non-contiguous and of very different sizes.
"""
expr = 2 * v + 5
# 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension
labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5
groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter")

grouped = expr.groupby(groups).sum()
fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True)

assert list(grouped.data.letter) == ["a", "b", "c"]
# padded to the largest group times the number of terms of the input
assert grouped.nterm == 14 * expr.nterm
assert_linequal(grouped, fallback)

# every group must carry exactly the variables of its members, the rest is fill
for letter in ["a", "b", "c"]:
members = np.where(np.array(labels) == letter)[0]
vars_of_group = grouped.data.vars.sel(letter=letter).values
assert set(vars_of_group[vars_of_group >= 0]) == set(v.labels.values[members])
assert (vars_of_group >= 0).sum() == len(members) * expr.nterm
assert grouped.const.sel(letter=letter).item() == 5 * len(members)


def test_linear_expression_groupby_chunked(v: Variable) -> None:
"""Chunked (dask-backed) expressions group via xarray's unstack machinery."""
pytest.importorskip("dask")
expr = 2 * v + 5
groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group")

chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model)
grouped_chunked = chunked.groupby(groups).sum()
grouped = expr.groupby(groups).sum()

assert grouped_chunked.nterm == grouped.nterm
assert_linequal(
LinearExpression(grouped_chunked.data.compute(), expr.model), grouped
)


def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None:
expr = 1 * v
groups = pd.Series([1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans")
with pytest.raises(ValueError, match="NaN"):
expr.groupby(groups).sum()


@pytest.mark.parametrize(
"case",
[
"skewed_int_groups",
"multidim_with_const",
"nan_const",
"masked_vars",
"quadratic",
"single_group",
"identity_groups",
],
)
def test_linear_expression_groupby_scatter_equals_unstack(case: str) -> None:
"""
Lock the two groupby-sum kernels together.

The fast path of groupby(...).sum() scatters terms into numpy arrays
(_sum_by_scatter); the xarray unstack implementation (_sum_by_unstack) is
kept for chunked data and exotic coordinates. Both must stay
interchangeable — if an xarray/pandas update changes the unstack output or
an edge case diverges, this fails.
"""
m = Model()
rng = np.random.default_rng(0)
idx = pd.RangeIndex(60, name="elem")
skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g")
groups = skewed

if case == "skewed_int_groups":
x = m.add_variables(coords=[idx], name="x")
expr: LinearExpression | QuadraticExpression = 3 * x - 2 * x + 7
elif case == "multidim_with_const":
other = pd.Index(list("abc"), name="other")
y = m.add_variables(coords=[other, idx], name="y")
const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx])
expr = 2 * y + 1 * y + const
elif case == "nan_const":
x = m.add_variables(coords=[idx], name="x")
expr = 1 * x + np.where(np.arange(60) % 3, np.nan, 5.0)
elif case == "masked_vars":
mask = xr.DataArray(np.arange(60) % 4 != 0, coords=[idx])
x = m.add_variables(coords=[idx], name="x", mask=mask)
expr = 1 * x
elif case == "quadratic":
x = m.add_variables(coords=[idx], name="x")
expr = x * x + 2 * x
elif case == "single_group":
x = m.add_variables(coords=[idx], name="x")
expr = 1 * x
groups = pd.Series(1, index=idx, name="g")
else: # identity_groups
x = m.add_variables(coords=[idx], name="x")
expr = 1 * x
groups = pd.Series(np.arange(60), index=idx, name="g")

gb = expr.groupby(groups)
assert gb._can_sum_by_scatter(groups)
scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m)
unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m)

# identical structure: dims, dim order, coordinates
assert scatter.data.coeffs.dims == unstack.data.coeffs.dims
assert scatter.data.const.dims == unstack.data.const.dims
assert list(scatter.data.coords) == list(unstack.data.coords)
for name in scatter.data.coords:
assert_equal(scatter.data[name], unstack.data[name])

# identical values: vars and coeffs bit-exact, including padding positions
np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values)
np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values)
# constants may differ by floating-point summation order
np.testing.assert_allclose(scatter.const.values, unstack.const.values, rtol=1e-12)


def test_linear_expression_rolling(v: Variable) -> None:
expr = 1 * v
rolled = expr.rolling(dim_2=2).sum()
Expand Down
Loading