Skip to content

fix(datasets): NaN-tolerant aggregate_feature_stats + load-time WARN#288

Merged
shuheng-liu merged 2 commits into
mainfrom
claude/sad-sanderson-957544
May 12, 2026
Merged

fix(datasets): NaN-tolerant aggregate_feature_stats + load-time WARN#288
shuheng-liu merged 2 commits into
mainfrom
claude/sad-sanderson-957544

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

What this does

Fixes a NaN propagation bug in aggregate_feature_stats where a single dataset whose action min/max stats contained NaN at any dim poisoned the global mixture buffer. np.min / np.max / weighted-mean all propagate NaN, so the aggregated buffer's NaN dims turned every sample's normalized discrete_actions into NaN regardless of source repo, blowing up the FAST tokenizer at the first batch (OverflowError: Python int too large to convert to C int from chr() on garbage cast from NaN).

  • aggregate_feature_stats now masks per-dim NaN contributors via np.nanmin/np.nanmax for min/max, plus a NaN-aware weighted mean/variance that zeroes a contributor's weight where its mean (or variance) is NaN. Clean dims aggregate exactly as before; a dim that is NaN across every contributor stays NaN so the downstream loader (or Normalize buffer) can still surface it.
  • DatasetMixtureMetadata._to_standard_data_format now logs a WARNING with the offending repo_id, key, stat name, and bad flat-dim indices, so a contaminator can be identified from training logs without a separate diagnostic patch.

Label: bug.

How it was tested

  • Added two regressions in tests/datasets/test_compute_stats.py:
    • test_aggregate_feature_stats_nan_tolerant_per_dim — a contributor with NaN at one dim must not affect aggregation at clean dims for other contributors.
    • test_aggregate_feature_stats_all_nan_dim_stays_nan — if every contributor is NaN at a dim, the aggregated result stays NaN there (downstream is responsible for surfacing it).
  • pytest -m "not gpu" -n auto tests/datasets/test_compute_stats.py → 17/17 pass.
  • pytest -m "not gpu" -n auto tests/datasets/ → 378 passed, 7 skipped (gpu-only), 0 failures.
  • The original test_aggregate_feature_stats covers the no-NaN happy path and continues to pass — clean inputs aggregate bit-identically to the prior implementation.
  • Root cause was independently reproduced and confirmed via a temporary [NaN-trace] diagnostic patch (already reverted) which showed 14 of 32 padded action dims had NaN in normalize.actions.{min,max} after _to_standard_data_format + aggregate_stats, while pre-normalize raw batch['actions'] was finite — proving the contamination came from the aggregator buffer, not runtime parquet data.

End-to-end verification on real GPU infrastructure is left to a follow-up dispatch by the reviewer; the unit tests cover the aggregator path that was broken.

How to checkout & try? (for the reviewer)

git checkout claude/sad-sanderson-957544
pytest -m "not gpu" -n auto tests/datasets/test_compute_stats.py

To exercise the load-time WARN, point any pi07/pi05/pi06 training config at a dataset whose action stats contain NaN — the warning lands at DatasetMixtureMetadata construction time, naming the offending repo_id, key, stat name, and bad flat-dim indices.

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

A single dataset whose `min`/`max`/`mean`/`std` contained NaN at any dim
poisoned the global mixture buffer because np.min / np.max / weighted-mean
propagate NaN. That NaN buffer then turned every sample's normalized
`discrete_actions` into NaN regardless of its source repo, blowing up the
FAST tokenizer at the first batch (`OverflowError: Python int too large
to convert to C int` from `chr()` on garbage cast from NaN).

`aggregate_feature_stats` now masks out per-dim NaN contributors via
`np.nanmin`/`np.nanmax` for min/max and a NaN-aware weighted mean/variance
that zeroes contributor weight where its mean (or variance) is NaN. Clean
dims aggregate exactly as before; a dim that is NaN across every
contributor stays NaN so the downstream loader can still surface it.

`DatasetMixtureMetadata._to_standard_data_format` now logs a WARNING with
the offending `repo_id`, key, stat name, and bad flat-dim indices so a
contaminator can be identified from the slurm log without a separate
diagnostic patch.

Tests: two regressions added in `tests/datasets/test_compute_stats.py`
covering per-dim NaN exclusion and the all-NaN-dim path.
@shuheng-liu shuheng-liu added the bug Something isn't working label May 8, 2026
@shuheng-liu shuheng-liu self-assigned this May 8, 2026
@shuheng-liu shuheng-liu marked this pull request as ready for review May 8, 2026 21:13
Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

Solid, well-scoped fix. The aggregator change is mathematically correct, the per-dim masking preserves clean dims from other contributors so a single bad dataset can no longer poison the global mixture, and the regression tests cover both the partial-NaN and all-NaN-dim cases. The load-time WARN gives the on-call a concrete breadcrumb (offending repo_id + stat name + flat dim indices) instead of needing a one-off diagnostic patch — that's a real operational win.

CI on this PR: CPU Tests green, Pre-commit green, check-checklist green. (The review check failure is the Claude auto-review bot job, not a test failure.)

Findings

  1. Mismatch between the WARN predicate and what the aggregator actually masks — inline on dataset_mixture.py. The WARN uses ~np.isfinite (NaN + Inf) but the aggregator only handles NaN. A dataset with +Inf in stats would get logged as "excluded" while still poisoning max/mean/std at that dim. Either tighten the WARN to np.isnan or extend the aggregator — option 1 keeps this PR scoped, option 2 is more defensive.

  2. Asymmetric NaN masking between mean and variance — inline on compute_stats.py. Intentional (you can't compute a variance contribution without std), but the two np.where(...) blocks have different predicates and that asymmetry isn't called out anywhere. Suggest one extra line in the docstring. Same paragraph could clarify that total_count is summed unconditionally across contributors (i.e. it's a scalar sample count, not a per-dim effective contributor count).

  3. Defensive fix, not root-cause fix — the PR description acknowledges this explicitly, and that's fine: the band-aid stops the bleeding (no more poisoned global buffer), and the WARN names the offender for the follow-up. Worth filing the upstream "why does dataset X have NaN in actions.{min,max}?" issue once the WARN identifies it on real infra, so this doesn't drift.

Nits

  • dataset_mixture.py:215-224data is interpolated twice in the format string; the trailing aggregated %r stats repeats info already in the preamble. Minor.
  • The original variance comment ("parallel algorithm for variance") was deleted in the rewrite of aggregate_feature_stats. Not load-bearing, but the algorithmic context — that the per-step parallel/Chan-style variance update is what's being preserved here — is no longer visible in the function body. A one-line # Chan-style parallel weighted variance, NaN-masked per dim would put it back cheaply.

Tests

test_aggregate_feature_stats_nan_tolerant_per_dim and test_aggregate_feature_stats_all_nan_dim_stays_nan are the right two cases, the existing test_aggregate_feature_stats confirms the clean-input path is bit-identical, and pytest -m "not gpu" -n auto tests/datasets/ passed at 378/378. No suggested additions.

Overall: ship this once the WARN/aggregator predicate mismatch (finding 1) is resolved one way or the other. Findings 2 and the nits are docstring / message polish and can land either in this PR or a follow-up.


Generated by Claude Code

Comment thread src/opentau/datasets/dataset_mixture.py
Comment thread src/opentau/datasets/compute_stats.py Outdated
Comment thread src/opentau/datasets/dataset_mixture.py Outdated
Addresses review feedback on #288.

- aggregator: broaden the per-dim mask from `np.isnan` to `~np.isfinite`,
  matching the `_to_standard_data_format` WARN. Without this, a
  contributor with +/-Inf in its stats would trigger the WARN ("excluded
  from aggregation") while still poisoning the aggregated `max`/`mean`/
  `std` -- `np.nanmin`/`np.nanmax` skip NaN but not Inf, and the weighted
  mean leaves Inf in the numerator. Also masks `mins`/`maxs` against
  +/-Inf before `nanmin`/`nanmax` so the per-dim exclusion applies
  symmetrically across all four stats.
- docstring: spell out the asymmetric mean/variance masking (variance
  drops a contributor at a dim iff either its mean or its std is
  non-finite; mean only requires the mean to be finite) and that
  `count` is summed unconditionally (scalar sample-count, not a per-dim
  effective contributor count).
- restore the Chan-style parallel weighted variance comment that the
  original rewrite dropped.
- WARN message: tighten by removing the duplicated `data` interpolation
  (was reading "non-finite values in 'actions'.'min' ... excluded from
  aggregated 'actions' stats at those dims").
- tests: add `test_aggregate_feature_stats_inf_masked_per_dim`. Locks
  in the `~np.isfinite` predicate at the aggregator boundary.
Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review of 689546f

All three findings from the prior review are addressed in 689546f. Verification against the new diff:

1. Predicate mismatch (was: WARN catches non-finite, aggregator only masks NaN) — ✅ resolved

Author went with option 2 (broaden the aggregator). compute_stats.py:314-315 now defines means_bad = ~np.isfinite(means) / variances_bad = ~np.isfinite(variances) and uses those as the per-dim masks in np.where. Min/max are also protected: compute_stats.py:340-341 masks ±Inf to NaN before nanmin/nanmax (mins_safe = np.where(np.isfinite(mins), mins, np.nan)), so +Inf no longer leaks into max and -Inf no longer leaks into min. WARN text and aggregator behavior are now consistent. New regression test_aggregate_feature_stats_inf_masked_per_dim locks in the +/-Inf path.

2. Asymmetric NaN handling between mean and variance — ✅ resolved

Two new docstring paragraphs in aggregate_feature_stats explicitly call out (a) which contributors get dropped from the mean vs. the variance and (b) that count is unconditional ("sum of contributors' counts, not a per-dim effective contributor count"). Plus the algorithmic note # Chan-style parallel weighted variance, non-finite-masked per dim. is back inline above the variance block — that addresses the dropped-comment nit from the last review too.

3. Duplicate data interpolation in WARN — ✅ resolved

dataset_mixture.py:215-222 — trailing data interpolation removed. Message now reads Dataset 'TensorAuto/foo-overlay': non-finite values in 'actions'.'min' at flat indices [0, 1, 2] (shape=(32,)); these dims are excluded from aggregation. Clean.

LGTM to ship

The fix is materially better than the v1: the aggregator now genuinely excludes every non-finite contributor (NaN and ±Inf), the docstring explains the per-dim semantics including the mean/variance asymmetry, and the test suite covers NaN per-dim, ±Inf per-dim, and all-non-finite-dim-stays-NaN. CPU tests, pre-commit, and check-checklist were all green on the prior commit and the new diff is additive (extra masks + new test + docstring), so I'd expect CI to stay green on 689546f.

Note on thread resolution: I tried to programmatically resolve the three review threads I opened, but the available MCP tooling doesn't expose PullRequestReviewThread node IDs in the get_review_comments response (only is_resolved/is_outdated/is_collapsed/comments), and resolve_review_thread requires the GraphQL node ID. Please resolve the three threads via the PR UI — all three are addressed.


Generated by Claude Code

@shuheng-liu shuheng-liu merged commit 7b528cf into main May 12, 2026
9 of 10 checks passed
@shuheng-liu shuheng-liu deleted the claude/sad-sanderson-957544 branch May 12, 2026 04:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants