From 63c30f8129ac21038b453fd62e367e129a5058bd Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:16:40 +0200 Subject: [PATCH] perf: scatter groupby-sum terms directly instead of unstacking The fast path of LinearExpression.groupby(...).sum() used ds.unstack(group_dim, fill_value=...) followed by a stack, which materializes 2-3 intermediate copies of the padded result (n_groups x max_group_size x nterm) and goes through pandas MultiIndex machinery sized by the number of elements. Instead, factorize the groups and scatter coeffs/vars directly into the preallocated padded result arrays; constants are group-summed with np.add.at. Peak memory drops to input + result (the minimum for the padded layout) and the grouping itself gets considerably faster. The result is unchanged: same dims, coords, term ordering and padding. The unstack-based implementation is kept as _sum_by_unstack and still used for chunked (dask-backed) data, which cannot be scattered into numpy arrays. NaN group labels now raise an informative ValueError instead of failing inside unstack. Co-Authored-By: Claude Opus 4.8 (1M context) --- linopy/expressions.py | 138 ++++++++++++++++++++++++++++++--- test/test_linear_expression.py | 124 +++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 12 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index b0515ea2..f7b71dd8 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -249,18 +249,13 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression: # At this point, group is always a pandas Series assert isinstance(group, pd.Series) - group_dim = group.index.name - - arrays = [group, group.groupby(group).cumcount()] - idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) - new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) - coords = self.data.indexes[group_dim] - names_to_drop = [coords.name] - if isinstance(coords, pd.MultiIndex): - names_to_drop += list(coords.names) - ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) - ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) - ds = LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + + if self._can_sum_by_scatter(group): + ds = self._sum_by_scatter(group) + else: + # chunked (e.g. dask-backed) data or exotic coordinates on the + # grouped dimension: use xarray's unstack machinery + ds = self._sum_by_unstack(group) if int_map is not None: index = ds.indexes[GROUP_DIM].map({v: k for k, v in int_map.items()}) @@ -279,6 +274,125 @@ def func(ds: Dataset) -> Dataset: return self.map(func, **kwargs, shortcut=True) + def _can_sum_by_scatter(self, group: pd.Series) -> bool: + """ + Whether :meth:`_sum_by_scatter` covers the structure of the data. + + The scatter kernel requires numpy-backed arrays (chunked data cannot be + scattered into preallocated numpy arrays) and no coordinates tied to + the grouped dimension besides its own index. Everything else falls + back to :meth:`_sum_by_unstack`. + """ + data = self.data + group_dim = group.index.name + + numpy_backed = all( + isinstance(data[k].data, np.ndarray) for k in ("coeffs", "vars", "const") + ) + if not numpy_backed: + return False + + index = data.indexes.get(group_dim) + index_names = {group_dim, *(index.names if index is not None else ())} + return all( + coord.dims == (group_dim,) and name in index_names + for name, coord in data.coords.items() + if group_dim in coord.dims + ) + + def _sum_by_scatter(self, group: pd.Series) -> Dataset: + """ + Sum groups by scattering all terms directly into the final padded arrays. + + Every group member keeps its block of ``nterm`` terms, so the resulting + term dimension has size ``max_group_size * nterm`` and smaller groups are + padded with fill values. In contrast to :meth:`_sum_by_unstack` only the + result arrays are allocated, without intermediate copies of that size. + + Only the term and constant values are computed with numpy; the result + structure (dimensions, coordinates and their order) is assembled by + xarray. :meth:`_can_sum_by_scatter` decides whether the data is simple + enough for this kernel. + """ + data = self.data + group_dim = group.index.name + fill_value = LinearExpression._fill_value + + codes, unique_groups = pd.factorize(group, sort=True) + if (codes == -1).any(): + raise ValueError( + "Cannot group by a pandas object containing NaN values. " + "Drop or fill the corresponding entries before grouping." + ) + + n_groups = len(unique_groups) + sizes = np.bincount(codes, minlength=n_groups) + max_size = int(sizes.max()) if n_groups else 0 + + # position of each element within its group (order of appearance) + positions = pd.Series(codes).groupby(codes).cumcount().to_numpy() + + def scatter( + da: DataArray, fill: Any + ) -> tuple[tuple[Hashable, ...], np.ndarray]: + """Scatter one term-array into its padded (group x term) layout.""" + rest_dims = [d for d in da.dims if d not in (group_dim, TERM_DIM)] + values = da.transpose(group_dim, *rest_dims, TERM_DIM).values + rest_shape = values.shape[1:-1] + nterm = values.shape[-1] + + out = np.full( + (n_groups, *rest_shape, nterm, max_size), fill, dtype=values.dtype + ) + locs = (codes, *(slice(None),) * (len(rest_shape) + 1), positions) + out[locs] = values + # collapsing (nterm, max_size) into one axis keeps all terms of one + # group member together, with padding at the end of each block + out = out.reshape((n_groups, *rest_shape, nterm * max_size)) + return (GROUP_DIM, *rest_dims, TERM_DIM), out + + coeffs_dims, coeffs = scatter(data.coeffs, fill_value["coeffs"]) + vars_dims, vars = scatter(data.vars, fill_value["vars"]) + + # constants are summed up within each group, skipping NaN values + const_dims = [d for d in data.const.dims if d != group_dim] + const_values = data.const.transpose(group_dim, *const_dims).values + const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype) + np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values)) + + # only the values above are computed with numpy, the result structure + # (dimensions, coordinates and their order) is assembled by xarray + # itself and thereby matches a result of unstacking the group dimension + structure = data.drop_vars(["coeffs", "vars", "const"]) + structure = structure.drop_dims(group_dim) + structure = structure.expand_dims({GROUP_DIM: unique_groups}) + + return structure.assign( + coeffs=(coeffs_dims, coeffs), + vars=(vars_dims, vars), + const=((GROUP_DIM, *const_dims), const), + ) + + def _sum_by_unstack(self, group: pd.Series) -> Dataset: + """ + Sum groups by unstacking the group dimension into a padded helper + dimension and summing over it. + + Equivalent to :meth:`_sum_by_scatter` but goes through xarray's + unstack/stack machinery, which also supports chunked (dask) data. + """ + group_dim = group.index.name + arrays = [group, group.groupby(group).cumcount()] + idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) + new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) + coords = self.data.indexes[group_dim] + names_to_drop = [coords.name] + if isinstance(coords, pd.MultiIndex): + names_to_drop += list(coords.names) + ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) + ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) + return LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + def roll(self, **kwargs: Any) -> LinearExpression: """ Roll the groupby object. diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 2580f033..85ba7270 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1625,6 +1625,130 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None: assert grouped.nterm == 10 +def test_linear_expression_groupby_skewed_unsorted_groups(v: Variable) -> None: + """ + The scatter-based fast path must match the xarray fallback for groups that + are unsorted, non-contiguous and of very different sizes. + """ + expr = 2 * v + 5 + # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension + labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 + groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") + + grouped = expr.groupby(groups).sum() + fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) + + assert list(grouped.data.letter) == ["a", "b", "c"] + # padded to the largest group times the number of terms of the input + assert grouped.nterm == 14 * expr.nterm + assert_linequal(grouped, fallback) + + # every group must carry exactly the variables of its members, the rest is fill + for letter in ["a", "b", "c"]: + members = np.where(np.array(labels) == letter)[0] + vars_of_group = grouped.data.vars.sel(letter=letter).values + assert set(vars_of_group[vars_of_group >= 0]) == set(v.labels.values[members]) + assert (vars_of_group >= 0).sum() == len(members) * expr.nterm + assert grouped.const.sel(letter=letter).item() == 5 * len(members) + + +def test_linear_expression_groupby_chunked(v: Variable) -> None: + """Chunked (dask-backed) expressions group via xarray's unstack machinery.""" + pytest.importorskip("dask") + expr = 2 * v + 5 + groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") + + chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) + grouped_chunked = chunked.groupby(groups).sum() + grouped = expr.groupby(groups).sum() + + assert grouped_chunked.nterm == grouped.nterm + assert_linequal( + LinearExpression(grouped_chunked.data.compute(), expr.model), grouped + ) + + +def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: + expr = 1 * v + groups = pd.Series([1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans") + with pytest.raises(ValueError, match="NaN"): + expr.groupby(groups).sum() + + +@pytest.mark.parametrize( + "case", + [ + "skewed_int_groups", + "multidim_with_const", + "nan_const", + "masked_vars", + "quadratic", + "single_group", + "identity_groups", + ], +) +def test_linear_expression_groupby_scatter_equals_unstack(case: str) -> None: + """ + Lock the two groupby-sum kernels together. + + The fast path of groupby(...).sum() scatters terms into numpy arrays + (_sum_by_scatter); the xarray unstack implementation (_sum_by_unstack) is + kept for chunked data and exotic coordinates. Both must stay + interchangeable — if an xarray/pandas update changes the unstack output or + an edge case diverges, this fails. + """ + m = Model() + rng = np.random.default_rng(0) + idx = pd.RangeIndex(60, name="elem") + skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") + groups = skewed + + if case == "skewed_int_groups": + x = m.add_variables(coords=[idx], name="x") + expr: LinearExpression | QuadraticExpression = 3 * x - 2 * x + 7 + elif case == "multidim_with_const": + other = pd.Index(list("abc"), name="other") + y = m.add_variables(coords=[other, idx], name="y") + const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) + expr = 2 * y + 1 * y + const + elif case == "nan_const": + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + np.where(np.arange(60) % 3, np.nan, 5.0) + elif case == "masked_vars": + mask = xr.DataArray(np.arange(60) % 4 != 0, coords=[idx]) + x = m.add_variables(coords=[idx], name="x", mask=mask) + expr = 1 * x + elif case == "quadratic": + x = m.add_variables(coords=[idx], name="x") + expr = x * x + 2 * x + elif case == "single_group": + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + groups = pd.Series(1, index=idx, name="g") + else: # identity_groups + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + groups = pd.Series(np.arange(60), index=idx, name="g") + + gb = expr.groupby(groups) + assert gb._can_sum_by_scatter(groups) + scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) + unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) + + # identical structure: dims, dim order, coordinates + assert scatter.data.coeffs.dims == unstack.data.coeffs.dims + assert scatter.data.const.dims == unstack.data.const.dims + assert list(scatter.data.coords) == list(unstack.data.coords) + for name in scatter.data.coords: + assert_equal(scatter.data[name], unstack.data[name]) + + # identical values: vars and coeffs bit-exact, including padding positions + np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) + np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) + # constants may differ by floating-point summation order + np.testing.assert_allclose(scatter.const.values, unstack.const.values, rtol=1e-12) + + def test_linear_expression_rolling(v: Variable) -> None: expr = 1 * v rolled = expr.rolling(dim_2=2).sum()