Skip to content

Fix pooled global/groupby lag transforms to use RANGE semantics#641

Open
simonez-tuidi wants to merge 1 commit intoNixtla:mainfrom
simonez-tuidi:feature/groupby_with_range_semantics
Open

Fix pooled global/groupby lag transforms to use RANGE semantics#641
simonez-tuidi wants to merge 1 commit intoNixtla:mainfrom
simonez-tuidi:feature/groupby_with_range_semantics

Conversation

@simonez-tuidi
Copy link
Copy Markdown
Contributor

@simonez-tuidi simonez-tuidi commented May 8, 2026

PR Description

Summary

This PR reworks global_ and groupby lag transforms so pooled features are computed over the underlying observations in each time range, matching SQL-style RANGE BETWEEN ... PRECEDING semantics.

This implements Option A from the issue #640 : change the default global_ / groupby behavior to RANGE semantics instead of preserving the current sum-then-roll behavior.

Previously, global and grouped transforms were backed by separate ad hoc state paths that aggregated each timestamp using sum before applying the lag transform. That made transforms such as:

RollingMean(window_size=2, global_=True)
RollingMean(window_size=2, groupby=["brand"])

behave like transforms over per-timestamp sums rather than transforms over all rows in the relevant time window. This branch introduces a shared pooled state representation and computes pooled transforms directly from bucketed observation arrays.

Problem Addressed

When multiple series share a pooled bucket, the old implementation first collapsed the data by timestamp:

ds a b summed y
1 1 10 11
2 2 20 22
3 3 30 33
4 4 40 44

Then RollingMean(window_size=2, lag=1, global_=True) operated on [11, 22, 33, 44], producing values such as mean(11, 22) = 16.5.

That has two practical problems:

  • RollingMean scales with the number of series in the group, because it is effectively averaging sums.
  • Results can jump when the membership of a group changes, even if the target distribution is stable.

With this PR, the same transform operates over the individual observations in the RANGE window. For example, at ds=3 the window contains [1, 10, 2, 20], so the mean is 8.25.

One detail worth making explicit: min_samples is still evaluated over observations. With multiple series in a bucket, a window containing one timestamp can satisfy min_samples=2 if that timestamp has two observed rows.

What Changed

Added PooledState

Added mlforecast/pooled.py with a PooledState object that owns the state needed by pooled transforms:

  • flat observation arrays: bucket id, timestamp, time index, and target value
  • a GroupedArray for existing transform state initialization
  • bucket metadata used to join computed features back to the original dataframe
  • group key mappings for groupby transforms
  • update logic for predictions, new observations, new groups, and new series

This replaces the previous separate _global_ga / _global_times and _group_states code paths with a single state model.

Compute Pooled Features Directly

Lag transforms already expose _compute_bucket_feature(...) (added in earlier commits on this branch) for rolling, seasonal, expanding, EWM, Offset, and Combine transforms. This PR routes all pooled computation through those methods via a new compute_pooled_features() function, and removes the previous silent fallback to positional GroupedArray behavior.

The old fallback was problematic: when a transform did not implement _compute_bucket_feature, the code silently fell back to GA positional semantics, which produced incorrect results under RANGE window bounds. Unsupported pooled transforms now raise a clear NotImplementedError instead.

SQL-Like Range Semantics

Pooled transforms now use a per-bucket time_index derived from the validated regular time grid. For global and groupby transforms, this gives interval-style window bounds while preserving the existing codebase assumption that non-partitioned series do not contain gaps.

This means a grouped feature behaves like:

AVG(y) OVER (
  PARTITION BY brand
  RANGE BETWEEN 2 PERIODS PRECEDING AND 1 PERIOD PRECEDING
)

rather than first aggregating y by (brand, timestamp).

The intended equivalence model is:

mlforecast configuration SQL mental model
RollingMean(w, lag=l) AVG(y) OVER (PARTITION BY unique_id RANGE BETWEEN ...)
RollingMean(w, lag=l, global_=True) AVG(y) OVER (RANGE BETWEEN ...)
RollingMean(w, lag=l, groupby=["brand"]) AVG(y) OVER (PARTITION BY brand RANGE BETWEEN ...)

Update Path Fixes

The pooled state update path now keeps all related arrays and metadata in sync:

  • appends observations to bucket_df
  • recomputes/extends time indexes after update()
  • handles new series and new groups
  • updates series-to-bucket mappings after static features change
  • preserves prediction-time feature computation after updates

This also fixes a pre-existing bug in TimeSeries.update() where new series received wrong static features: ufp.take_rows(df, ...) was indexing into the full DataFrame instead of the new-series subset, causing incorrect bucket assignments for series introduced via update().

Categorical Group Key Support

Grouped buckets are represented internally by numeric _bucket_ids, but public group keys such as brand or subcategory may be categorical. The new helpers reconcile pandas and Polars categoricals before joins and concatenations, including when updates introduce a new group value.

Tests

Added tests/test_pooled.py covering:

  • global and grouped update state preservation
  • sequential updates
  • staggered series starts
  • new series in new groups
  • categorical group keys with new group values
  • prediction-time feature computation after updates
  • unsupported pooled transform errors

Updated existing core tests to assert the new RANGE-style semantics for global and grouped rolling/expanding transforms.

The previous test_group_lag_transform used one series per group, which made sum-by-timestamp and RANGE semantics indistinguishable. The updated tests include multiple series in the same group so this behavior is covered directly.

Verification

Ran:

python -m pytest tests/test_pooled.py -x -q
python -m pytest tests/test_core.py -x -q
ruff check mlforecast/core.py mlforecast/pooled.py mlforecast/lag_transforms.py tests/test_pooled.py tests/test_core.py

The full branch suite was also verified with:

python -m pytest tests/test_core.py tests/test_lag_transforms.py tests/test_forecast.py tests/test_pooled.py -x -q

Result:

  • 279 passed
  • 2 skipped
  • lint clean
  • mlforecast/pooled.py at 98% coverage

Compatibility / Breaking Change

The public transform API is unchanged. Existing global_ and groupby arguments continue to be used.

This is nevertheless a necessary breaking change for users who already rely on global_ or groupby transforms with more than one series in a pooled bucket. The numeric output changes from "sum by timestamp, then apply the transform" to "apply the transform over all observations in the RANGE window".

For example, RollingMean(window_size=2, lag=1, global_=True) over two aligned series changes from:

ds=3: mean(sum(ds=1), sum(ds=2)) = mean(11, 22) = 16.5

to:

ds=3: mean(all observations at ds=1 and ds=2) = mean(1, 10, 2, 20) = 8.25

This change is intentional because the previous behavior made means scale with the number of series in the group and diverged from the SQL RANGE mental model used by the rest of the pooled/partitioned transform design. Preserving the old behavior would require adding a separate aggregation mode (for example, "sum by timestamp before transforming"), which would keep the incorrect default and add API complexity. This PR chooses correctness and consistency instead.

Users who depended on the old sum-then-transform behavior will need to reproduce that aggregation explicitly before fitting or use a future explicit aggregation option if one is added.

The internal fitted TimeSeries state shape changed, so previously pickled fitted TimeSeries objects that depend on the old private pooled state are not expected to be compatible.

Introduce PooledState for global/groupby transform state, compute pooled features from raw bucketed observations instead of summed timestamps, and add update/new-group/categorical coverage.
@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented May 8, 2026

Merging this PR will not alter performance

✅ 12 untouched benchmarks


Comparing simonez-tuidi:feature/groupby_with_range_semantics (3267a2f) with main (bed599b)

Open in CodSpeed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant