fix(datasets): NaN-tolerant aggregate_feature_stats + load-time WARN#288
Conversation
A single dataset whose `min`/`max`/`mean`/`std` contained NaN at any dim poisoned the global mixture buffer because np.min / np.max / weighted-mean propagate NaN. That NaN buffer then turned every sample's normalized `discrete_actions` into NaN regardless of its source repo, blowing up the FAST tokenizer at the first batch (`OverflowError: Python int too large to convert to C int` from `chr()` on garbage cast from NaN). `aggregate_feature_stats` now masks out per-dim NaN contributors via `np.nanmin`/`np.nanmax` for min/max and a NaN-aware weighted mean/variance that zeroes contributor weight where its mean (or variance) is NaN. Clean dims aggregate exactly as before; a dim that is NaN across every contributor stays NaN so the downstream loader can still surface it. `DatasetMixtureMetadata._to_standard_data_format` now logs a WARNING with the offending `repo_id`, key, stat name, and bad flat-dim indices so a contaminator can be identified from the slurm log without a separate diagnostic patch. Tests: two regressions added in `tests/datasets/test_compute_stats.py` covering per-dim NaN exclusion and the all-NaN-dim path.
shuheng-liu
left a comment
There was a problem hiding this comment.
Review summary
Solid, well-scoped fix. The aggregator change is mathematically correct, the per-dim masking preserves clean dims from other contributors so a single bad dataset can no longer poison the global mixture, and the regression tests cover both the partial-NaN and all-NaN-dim cases. The load-time WARN gives the on-call a concrete breadcrumb (offending repo_id + stat name + flat dim indices) instead of needing a one-off diagnostic patch — that's a real operational win.
CI on this PR: CPU Tests green, Pre-commit green, check-checklist green. (The review check failure is the Claude auto-review bot job, not a test failure.)
Findings
-
Mismatch between the WARN predicate and what the aggregator actually masks — inline on
dataset_mixture.py. The WARN uses~np.isfinite(NaN + Inf) but the aggregator only handles NaN. A dataset with+Infin stats would get logged as "excluded" while still poisoningmax/mean/stdat that dim. Either tighten the WARN tonp.isnanor extend the aggregator — option 1 keeps this PR scoped, option 2 is more defensive. -
Asymmetric NaN masking between mean and variance — inline on
compute_stats.py. Intentional (you can't compute a variance contribution withoutstd), but the twonp.where(...)blocks have different predicates and that asymmetry isn't called out anywhere. Suggest one extra line in the docstring. Same paragraph could clarify thattotal_countis summed unconditionally across contributors (i.e. it's a scalar sample count, not a per-dim effective contributor count). -
Defensive fix, not root-cause fix — the PR description acknowledges this explicitly, and that's fine: the band-aid stops the bleeding (no more poisoned global buffer), and the WARN names the offender for the follow-up. Worth filing the upstream "why does dataset X have NaN in
actions.{min,max}?" issue once the WARN identifies it on real infra, so this doesn't drift.
Nits
dataset_mixture.py:215-224—datais interpolated twice in the format string; the trailingaggregated %r statsrepeats info already in the preamble. Minor.- The original variance comment ("parallel algorithm for variance") was deleted in the rewrite of
aggregate_feature_stats. Not load-bearing, but the algorithmic context — that the per-step parallel/Chan-style variance update is what's being preserved here — is no longer visible in the function body. A one-line# Chan-style parallel weighted variance, NaN-masked per dimwould put it back cheaply.
Tests
test_aggregate_feature_stats_nan_tolerant_per_dim and test_aggregate_feature_stats_all_nan_dim_stays_nan are the right two cases, the existing test_aggregate_feature_stats confirms the clean-input path is bit-identical, and pytest -m "not gpu" -n auto tests/datasets/ passed at 378/378. No suggested additions.
Overall: ship this once the WARN/aggregator predicate mismatch (finding 1) is resolved one way or the other. Findings 2 and the nits are docstring / message polish and can land either in this PR or a follow-up.
Generated by Claude Code
Addresses review feedback on #288. - aggregator: broaden the per-dim mask from `np.isnan` to `~np.isfinite`, matching the `_to_standard_data_format` WARN. Without this, a contributor with +/-Inf in its stats would trigger the WARN ("excluded from aggregation") while still poisoning the aggregated `max`/`mean`/ `std` -- `np.nanmin`/`np.nanmax` skip NaN but not Inf, and the weighted mean leaves Inf in the numerator. Also masks `mins`/`maxs` against +/-Inf before `nanmin`/`nanmax` so the per-dim exclusion applies symmetrically across all four stats. - docstring: spell out the asymmetric mean/variance masking (variance drops a contributor at a dim iff either its mean or its std is non-finite; mean only requires the mean to be finite) and that `count` is summed unconditionally (scalar sample-count, not a per-dim effective contributor count). - restore the Chan-style parallel weighted variance comment that the original rewrite dropped. - WARN message: tighten by removing the duplicated `data` interpolation (was reading "non-finite values in 'actions'.'min' ... excluded from aggregated 'actions' stats at those dims"). - tests: add `test_aggregate_feature_stats_inf_masked_per_dim`. Locks in the `~np.isfinite` predicate at the aggregator boundary.
shuheng-liu
left a comment
There was a problem hiding this comment.
Re-review of 689546f
All three findings from the prior review are addressed in 689546f. Verification against the new diff:
1. Predicate mismatch (was: WARN catches non-finite, aggregator only masks NaN) — ✅ resolved
Author went with option 2 (broaden the aggregator). compute_stats.py:314-315 now defines means_bad = ~np.isfinite(means) / variances_bad = ~np.isfinite(variances) and uses those as the per-dim masks in np.where. Min/max are also protected: compute_stats.py:340-341 masks ±Inf to NaN before nanmin/nanmax (mins_safe = np.where(np.isfinite(mins), mins, np.nan)), so +Inf no longer leaks into max and -Inf no longer leaks into min. WARN text and aggregator behavior are now consistent. New regression test_aggregate_feature_stats_inf_masked_per_dim locks in the +/-Inf path.
2. Asymmetric NaN handling between mean and variance — ✅ resolved
Two new docstring paragraphs in aggregate_feature_stats explicitly call out (a) which contributors get dropped from the mean vs. the variance and (b) that count is unconditional ("sum of contributors' counts, not a per-dim effective contributor count"). Plus the algorithmic note # Chan-style parallel weighted variance, non-finite-masked per dim. is back inline above the variance block — that addresses the dropped-comment nit from the last review too.
3. Duplicate data interpolation in WARN — ✅ resolved
dataset_mixture.py:215-222 — trailing data interpolation removed. Message now reads Dataset 'TensorAuto/foo-overlay': non-finite values in 'actions'.'min' at flat indices [0, 1, 2] (shape=(32,)); these dims are excluded from aggregation. Clean.
LGTM to ship
The fix is materially better than the v1: the aggregator now genuinely excludes every non-finite contributor (NaN and ±Inf), the docstring explains the per-dim semantics including the mean/variance asymmetry, and the test suite covers NaN per-dim, ±Inf per-dim, and all-non-finite-dim-stays-NaN. CPU tests, pre-commit, and check-checklist were all green on the prior commit and the new diff is additive (extra masks + new test + docstring), so I'd expect CI to stay green on 689546f.
Note on thread resolution: I tried to programmatically resolve the three review threads I opened, but the available MCP tooling doesn't expose PullRequestReviewThread node IDs in the get_review_comments response (only is_resolved/is_outdated/is_collapsed/comments), and resolve_review_thread requires the GraphQL node ID. Please resolve the three threads via the PR UI — all three are addressed.
Generated by Claude Code
What this does
Fixes a NaN propagation bug in
aggregate_feature_statswhere a single dataset whose actionmin/maxstats contained NaN at any dim poisoned the global mixture buffer.np.min/np.max/ weighted-mean all propagate NaN, so the aggregated buffer's NaN dims turned every sample's normalizeddiscrete_actionsinto NaN regardless of source repo, blowing up the FAST tokenizer at the first batch (OverflowError: Python int too large to convert to C intfromchr()on garbage cast from NaN).aggregate_feature_statsnow masks per-dim NaN contributors vianp.nanmin/np.nanmaxfor min/max, plus a NaN-aware weighted mean/variance that zeroes a contributor's weight where its mean (or variance) is NaN. Clean dims aggregate exactly as before; a dim that is NaN across every contributor stays NaN so the downstream loader (orNormalizebuffer) can still surface it.DatasetMixtureMetadata._to_standard_data_formatnow logs aWARNINGwith the offendingrepo_id, key, stat name, and bad flat-dim indices, so a contaminator can be identified from training logs without a separate diagnostic patch.Label: bug.
How it was tested
tests/datasets/test_compute_stats.py:test_aggregate_feature_stats_nan_tolerant_per_dim— a contributor with NaN at one dim must not affect aggregation at clean dims for other contributors.test_aggregate_feature_stats_all_nan_dim_stays_nan— if every contributor is NaN at a dim, the aggregated result stays NaN there (downstream is responsible for surfacing it).pytest -m "not gpu" -n auto tests/datasets/test_compute_stats.py→ 17/17 pass.pytest -m "not gpu" -n auto tests/datasets/→ 378 passed, 7 skipped (gpu-only), 0 failures.test_aggregate_feature_statscovers the no-NaN happy path and continues to pass — clean inputs aggregate bit-identically to the prior implementation.[NaN-trace]diagnostic patch (already reverted) which showed 14 of 32 padded action dims had NaN innormalize.actions.{min,max}after_to_standard_data_format+aggregate_stats, while pre-normalize rawbatch['actions']was finite — proving the contamination came from the aggregator buffer, not runtime parquet data.End-to-end verification on real GPU infrastructure is left to a follow-up dispatch by the reviewer; the unit tests cover the aggregator path that was broken.
How to checkout & try? (for the reviewer)
git checkout claude/sad-sanderson-957544 pytest -m "not gpu" -n auto tests/datasets/test_compute_stats.pyTo exercise the load-time WARN, point any pi07/pi05/pi06 training config at a dataset whose action stats contain NaN — the warning lands at
DatasetMixtureMetadataconstruction time, naming the offendingrepo_id, key, stat name, and bad flat-dim indices.Checklist
Note: Before submitting this PR, please read the contributor guideline.