Fix pooled global/groupby lag transforms to use RANGE semantics#641
Open
simonez-tuidi wants to merge 1 commit intoNixtla:mainfrom
Open
Fix pooled global/groupby lag transforms to use RANGE semantics#641simonez-tuidi wants to merge 1 commit intoNixtla:mainfrom
simonez-tuidi wants to merge 1 commit intoNixtla:mainfrom
Conversation
Introduce PooledState for global/groupby transform state, compute pooled features from raw bucketed observations instead of summed timestamps, and add update/new-group/categorical coverage.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR Description
Summary
This PR reworks
global_andgroupbylag transforms so pooled features are computed over the underlying observations in each time range, matching SQL-styleRANGE BETWEEN ... PRECEDINGsemantics.This implements Option A from the issue #640 : change the default
global_/groupbybehavior to RANGE semantics instead of preserving the current sum-then-roll behavior.Previously, global and grouped transforms were backed by separate ad hoc state paths that aggregated each timestamp using
sumbefore applying the lag transform. That made transforms such as:behave like transforms over per-timestamp sums rather than transforms over all rows in the relevant time window. This branch introduces a shared pooled state representation and computes pooled transforms directly from bucketed observation arrays.
Problem Addressed
When multiple series share a pooled bucket, the old implementation first collapsed the data by timestamp:
Then
RollingMean(window_size=2, lag=1, global_=True)operated on[11, 22, 33, 44], producing values such asmean(11, 22) = 16.5.That has two practical problems:
RollingMeanscales with the number of series in the group, because it is effectively averaging sums.With this PR, the same transform operates over the individual observations in the RANGE window. For example, at
ds=3the window contains[1, 10, 2, 20], so the mean is8.25.One detail worth making explicit:
min_samplesis still evaluated over observations. With multiple series in a bucket, a window containing one timestamp can satisfymin_samples=2if that timestamp has two observed rows.What Changed
Added
PooledStateAdded
mlforecast/pooled.pywith aPooledStateobject that owns the state needed by pooled transforms:GroupedArrayfor existing transform state initializationgroupbytransformsThis replaces the previous separate
_global_ga/_global_timesand_group_statescode paths with a single state model.Compute Pooled Features Directly
Lag transforms already expose
_compute_bucket_feature(...)(added in earlier commits on this branch) for rolling, seasonal, expanding, EWM,Offset, andCombinetransforms. This PR routes all pooled computation through those methods via a newcompute_pooled_features()function, and removes the previous silent fallback to positionalGroupedArraybehavior.The old fallback was problematic: when a transform did not implement
_compute_bucket_feature, the code silently fell back to GA positional semantics, which produced incorrect results under RANGE window bounds. Unsupported pooled transforms now raise a clearNotImplementedErrorinstead.SQL-Like Range Semantics
Pooled transforms now use a per-bucket
time_indexderived from the validated regular time grid. For global and groupby transforms, this gives interval-style window bounds while preserving the existing codebase assumption that non-partitioned series do not contain gaps.This means a grouped feature behaves like:
rather than first aggregating
yby(brand, timestamp).The intended equivalence model is:
RollingMean(w, lag=l)AVG(y) OVER (PARTITION BY unique_id RANGE BETWEEN ...)RollingMean(w, lag=l, global_=True)AVG(y) OVER (RANGE BETWEEN ...)RollingMean(w, lag=l, groupby=["brand"])AVG(y) OVER (PARTITION BY brand RANGE BETWEEN ...)Update Path Fixes
The pooled state update path now keeps all related arrays and metadata in sync:
bucket_dfupdate()This also fixes a pre-existing bug in
TimeSeries.update()where new series received wrong static features:ufp.take_rows(df, ...)was indexing into the full DataFrame instead of the new-series subset, causing incorrect bucket assignments for series introduced viaupdate().Categorical Group Key Support
Grouped buckets are represented internally by numeric
_bucket_ids, but public group keys such asbrandorsubcategorymay be categorical. The new helpers reconcile pandas and Polars categoricals before joins and concatenations, including when updates introduce a new group value.Tests
Added
tests/test_pooled.pycovering:Updated existing core tests to assert the new RANGE-style semantics for global and grouped rolling/expanding transforms.
The previous
test_group_lag_transformused one series per group, which made sum-by-timestamp and RANGE semantics indistinguishable. The updated tests include multiple series in the same group so this behavior is covered directly.Verification
Ran:
The full branch suite was also verified with:
Result:
mlforecast/pooled.pyat 98% coverageCompatibility / Breaking Change
The public transform API is unchanged. Existing
global_andgroupbyarguments continue to be used.This is nevertheless a necessary breaking change for users who already rely on
global_orgroupbytransforms with more than one series in a pooled bucket. The numeric output changes from "sum by timestamp, then apply the transform" to "apply the transform over all observations in the RANGE window".For example,
RollingMean(window_size=2, lag=1, global_=True)over two aligned series changes from:to:
This change is intentional because the previous behavior made means scale with the number of series in the group and diverged from the SQL
RANGEmental model used by the rest of the pooled/partitioned transform design. Preserving the old behavior would require adding a separate aggregation mode (for example, "sum by timestamp before transforming"), which would keep the incorrect default and add API complexity. This PR chooses correctness and consistency instead.Users who depended on the old sum-then-transform behavior will need to reproduce that aggregation explicitly before fitting or use a future explicit aggregation option if one is added.
The internal fitted
TimeSeriesstate shape changed, so previously pickled fittedTimeSeriesobjects that depend on the old private pooled state are not expected to be compatible.