feat(scripts): per-(robot_type, control_mode) normalization in fit_fast_tokenizer#348
Conversation
…st_tokenizer Adds --per-head-norm (default on) to fit_fast_tokenizer.py so each chunk is min-max-normalized using its own (robot_type, control_mode) head's pooled stats — matching what Normalize+ACTION:MIN_MAX does at training time after PR #347. Legacy global path stays available via --no-per-head-norm.
shuheng-liu
left a comment
There was a problem hiding this comment.
Review of #348 — 13 inline findings, ranked by severity.
The PR's stated invariant — "fit-time normalization matches what the policy applies at training time via Normalize({\"ACTION\": NormalizationMode.MIN_MAX}) with per-head stacked buffers (PR #347)" — is not actually established. The per-head pooling here re-implements _build_norm_heads from scratch and diverges on at least five dimensions that each silently break the equivalence:
High severity (silent normalization mismatch):
- NaN-padding raw stats here vs
pad_vectorzero-padding in_to_standard_data_format— padded action-dim slots normalize to~0at fit but-1at training. Affects every mixture withmax_action_dim > native_dim(essentially every production config). - Stats-fail / chunks-succeed mismatch — a dataset that errors during stats loading but produces chunks via direct parquet read gets its real chunks normalized with the
[-1, 1]sentinel, even when peer datasets in the same(rt, cm)head loaded fine. - Fallback name uses raw
cfg.repo_id; training uses_make_dataset_names's deduplicatedname#N. Duplicate-repo_identries pool at fit but stay separate at training. - No
RuntimeErrorguard when ALL stats fail — script silently produces a tokenizer fit on a fully[-1, 1]-sentinel-normalized corpus andexit 0.
Medium severity:
5. np.nanmin/nanmax propagate ±Inf; aggregate_stats masks Inf→NaN first.
6. require_non_empty_robot_type / require_non_empty_control_mode not enforced — fit completes, training refuses to start.
7. compute_norm_key's fallback_fired discarded — operator never sees the "N datasets lack rt/cm" warning that training emits.
8. Default-flip + hard-raise breaks every existing --use-mixture-dataloader invocation.
9. Dropped over-dim warning — --action-dim misconfig is no longer surfaced with offending repo IDs at stats-aggregation time.
10. _normalize_chunks_per_head relies on an implicit cross-function ordering invariant — no assert sum(per_dataset_chunks) == stacked.shape[0]; a future refactor to _sample_via_manual silently corrupts the BPE corpus.
Low severity / cleanup:
11. Test coverage gap — no test pins fit-time output against the production Normalize layer. A single equivalence test would catch findings 1, 3, 5 at once.
12. global_action_min/max silently null in the report on default per-head fits; no per-head replacement field.
13. _load_dataset_info_and_stats near-clone of _load_dataset_stats (already drifted on <no-repo-id> sentinel + per_dataset diagnostic).
Altitude: the right depth for this PR is to construct a stub DatasetMixtureMetadata from the already-loaded metadata list and read per_norm_key_stats off it — every divergence above collapses to zero new code surface. The current re-implementation guarantees future drift between fit and training each time _build_norm_heads evolves.
The end-to-end roundtrip described in the PR body (95s fit, MSE = 8.3e-5, mean=53.3 / p99=138 token length) confirms the per-head path produces a sane tokenizer overall, but the specific failure modes above are not exercised by either the new unit tests or the production smoke run (393 datasets with max_action_dim=32 and varying native dims — finding 1 would silently apply in production today and the BPE token-length distribution still looks reasonable because padded-suffix bytes compress to a single repetitive token regardless).
Generated by Claude Code
…ask)
Apply the high/medium-severity findings from the PR review:
- Zero-pad raw stats (was NaN-pad). Matches `pad_vector` in
`_to_standard_data_format` so padded action-dim slots normalize to -1
at both fit and training time. Without this, the BPE corpus saw
0-padded suffixes while the policy feeds -1-padded suffixes at
training -- the slight DCT/BPE divergence is one likely contributor
to observed token-length tails exceeding the fit-time analysis.
- Raise on any stats-load failure in per-head path. `_to_standard_data_format`
also raises on missing stats at training time, so silently producing a
tokenizer that the policy would refuse to consume is sneaky. Subsumes
the "stats-fail + chunks-succeed" mismatch.
- Mask +/-Inf to NaN before nanmin/nanmax pool, mirroring `aggregate_stats`
(compute_stats.py:350). A single Inf entry no longer poisons the head.
- Enforce `require_non_empty_robot_type` / `require_non_empty_control_mode`
at fit time, surfacing the same ValueError that
`datasets.factory._validate_metadata_requirements` raises -- so the
operator doesn't burn 90s on a fit the very next training launch
refuses to start.
- Restore the over-action-dim warning from the global path (lost in the
initial per-head implementation): datasets whose native action_dim
exceeds --action-dim are now logged with offending repo IDs and the
max overflow, so an --action-dim misconfig is visible at stats time.
- Surface fallback datasets (datasets with empty robot_type or
control_mode after overrides). Without this, operators got
singleton-per-dataset heads with no warning -- exactly what `_build_norm_heads`
warns about at training time.
- Demote `--per-head-norm --use-mixture-dataloader` from hard-raise to
warn-and-fall-back so existing `--use-mixture-dataloader` invocations
(which never opt into per-head) keep working without a flag change.
- Defensive assert in `_normalize_chunks_per_head`:
`sum(per_dataset_chunks) == stacked.shape[0]`. A future
`_sample_via_manual` refactor that changes the concat order or drops
chunks would now fail loudly rather than silently misalign the
per-dataset normalization windows.
- Replace null `global_action_min` / `global_action_max` in the report
on per-head fits with a `per_head_action_stats` map keyed by norm
key; `n_norm_heads` derives from that.
- Extract the shared metadata-and-stats worker
(`_load_metadata_stats_and_info`) so `_load_dataset_stats` and the
per-head path no longer drift.
- Add `TestNormalizeEquivalenceVsProduction`: pins the manual per-head
output against the production `Normalize({"ACTION": MIN_MAX}).forward`
with per-head buffers (PR #347) on a synthetic 2-dataset / 2-head
batch, including a zero-padded action-dim suffix that exercises the
-1 fallback in both paths.
Push back / deferred:
- The "use `DatasetMixtureMetadata._build_norm_heads` directly" altitude
fix is out of scope here; it requires either constructing a partial
`DatasetMixtureMetadata` (whose `__init__` loads parquet + episode
indices we don't need) or extracting `_build_norm_heads` to a free
function. Tracked separately. The surgical fixes above close every
silent-divergence finding.
- Duplicate-`repo_id` mixtures could still pool at fit while training
separates via `_make_dataset_names`'s `#N` suffix. Real but rare;
visible via the fallback-datasets warning. Deferred.
Response to review (commit 85c535b)Going through each finding with what I did or why I pushed back. Net: 11 of 13 fixed in this push; 2 deferred with a reason. High severity — all fixed
Medium severity — all fixed
Low severity — cleanup applied
Altitude critique — deferred with reasonThe "construct a stub Impact on the currently-running fitThe published |
shuheng-liu
left a comment
There was a problem hiding this comment.
Re-review of 85c535b — 11/13 of the previous round's findings landed cleanly; 2 reasonably deferred. Two small follow-up notes inline.
Verified fixes (code-read):
| # | Finding | Status |
|---|---|---|
| 1 | NaN-pad vs zero-pad (HIGH) | ✅ Fix at line 542-543 (np.zeros not np.full(NaN)); pinned by new TestNormalizeEquivalenceVsProduction::test_per_head_normalize_matches_production_normalize which asserts elementwise equality vs the production Normalize layer at atol=2e-7. |
| 2 | Stats-fail / chunks-succeed (HIGH) | ✅ Subsumed by #4 — no more [-1, 1] sentinel path for real chunks. |
| 3 | Fallback-name dedup (HIGH) | ⏭️ Deferred with reasoning. See inline follow-up for a small mitigation at the warning site. |
| 4 | No all-fail guard (HIGH) | ✅ Fail-fast at line 552 on ANY stats failure (stricter than the global path's "all fail" guard). |
| 5 | Inf masking (MED) | ✅ np.where(np.isfinite(...), ..., np.nan) pre-mask at lines 695-696; pinned by two new tests. |
| 6 | require_non_empty_* (MED) |
✅ Mirrored at lines 563-580. |
| 7 | fallback_fired warning (MED) |
✅ Aggregated warning at lines 598-610 matches _build_norm_heads's shape. |
| 8 | --use-mixture-dataloader hard-raise (MED) |
✅ Downgraded to warn-and-fall-back at lines 1394-1402. |
| 9 | Over-dim warning (MED) | ✅ Restored at lines 618-635. |
| 10 | Implicit ordering invariant (MED) | ✅ Two asserts at lines 744-752; pinned by test_chunk_count_mismatch_raises_assert. |
| 11 | Test gap vs production Normalize (LOW) |
✅ Closed by the new equivalence test — would have caught findings 1, 3, and 5 by construction. |
| 12 | global_action_* null in report (LOW) |
✅ New per_head_action_stats map carries {key: {"min": [...], "max": [...]}}. |
| 13 | _load_dataset_info_and_stats clone (LOW) |
✅ Collapsed into shared _load_metadata_stats_and_info; _load_dataset_stats is now a 3-line wrapper. |
| Alt | Construct stub DatasetMixtureMetadata |
⏭️ Reasonably deferred — the right shape requires extracting _build_norm_heads to a free function (cross-file refactor). The surgical fixes here close every concrete divergence. |
The new equivalence test is particularly load-bearing: it pins the invariant the PR was created to establish, and asserts the padded suffix lands at -1 in both paths (exactly the silent corruption finding #1 described). Solid.
New observations (low-priority):
- One inline on
[{}] * naliasing footgun at line 515 — currently safe, but a future mutation would silently fan out. - One inline suggesting a tighter fit-time signal for the deferred #3 case (duplicate-
repo_idpooled under fallback).
No new correctness issues found. Ready to mark non-draft once the deferred follow-ups are tracked elsewhere.
Generated by Claude Code
…ck warn) Two low-priority items from the re-review of 85c535b: - Replace `[{}] * n` with `[{} for _ in range(n)]` for `per_ds_info`. The multiplied-literal form creates N references to the same dict, so a future contributor doing `per_ds_info[idx]["override_X"] = ...` (a legitimate- looking accumulator pattern) would silently fan out across every slot. The `[None] * n` and `["<no-repo-id>"] * n` callsites stay as-is because their element types are immutable; comment notes the distinction so the pattern doesn't regress. - Add a tighter fit-time warning for the specific divergence the deferred fallback-name finding (#3 in the original review) flagged. When a fallback-keyed `repo_id` appears more than once in the mixture, fit-time POOLS those rows under one shared key while training keeps them as separate singleton heads (via `_make_dataset_names`'s `#N` dedup). The new Counter-based warning surfaces the exact duplicate `repo_id -> count` map so an operator can fix it before the misalignment ships. Mechanical; no new infra. Closes the divergence-detection gap while the proper dedup follow-up (tracked separately) is pending.
Response to re-review (commit c755cb8)Re-review verified all 11 fixes from the prior round and flagged 2 new low-priority items. Both addressed; deferred items moved to dedicated issues. New inline findings (re-review) — both fixed in c755cb8
Deferred items — now tracked
Both inline threads updated with issue links and resolved. StateAll 16 inline review threads on this PR are now resolved (14 from the original review + 2 from the re-review). 11 closed by direct fixes in commit 85c535b, 2 closed by direct fixes in commit c755cb8, 2 closed with cross-references to the new tracking issues. |
shuheng-liu
left a comment
There was a problem hiding this comment.
Re-review of c755cb8 — both low-priority follow-ups from the last round are addressed.
Verified:
[{}] * n→[{} for _ in range(n)](line 515) with a tight comment explaining why the adjacent[None] * n/["<no-repo-id>"] * nlines stay safe. ✅- Duplicate-fallback
Counterwarning (lines 612-638). The implementation matches the suggested shape: counts only fallback-fired entries, filters tocount > 1, surfaces the divergence at fit time with a(repo_id → count)preview capped at 10. Correctly fires on the exact case the deferred dedup finding flagged (duplicaterepo_idboth lackingrt/cm→ pooled at fit, singleton-per-entry at training) and doesn't fire on legitimate same-repo_identries with distinct(rt, cm)overrides (which pool consistently at both fit and training). ✅
No new issues introduced. With the duplicate-detection warning in place, the fit-time signal for the deferred #3 case is now strong enough that the proper dedup refactor can land as a follow-up without operators silently shipping mismatched tokenizers in the meantime.
Nothing else to flag. The PR is in good shape to mark non-draft.
Generated by Claude Code
What this does
Adds
--per-head-norm(default on) tosrc/opentau/scripts/fit_fast_tokenizer.pyso each chunk is min-max-normalized using its own(robot_type, control_mode)head's pooled stats — the same statsNormalize({"ACTION": NormalizationMode.MIN_MAX})applies via stacked buffers at training time after PR #347.Before this change, the manual sampler normalized every chunk with
mixture.meta.aggregated_action_stats()(a single global min/max). Once per-head normalization landed in the policy, that mismatch caused the BPE fit to see chunks centered tighter than the policy's chunks, so a fit with global normalization produced 2-3× longer token sequences at inference than the fit-time roundtrip suggested — hittingdiscrete_action_max_lengthtruncation on a meaningful fraction of training samples.The new code path computes pooled (per-head) min/max with the same
nanmin/nanmaxsemanticsDatasetMixtureMetadata._build_norm_headsuses for themin/maxfields, then normalizes each dataset's chunk slice independently. Datasets sharing a(robot_type, control_mode)get identical pooled stats by construction; fallback keys stay singletons; stats-load failures fall back to[-1, 1](same sentinel the global path already used).Notes:
--no-per-head-normkeeps the legacy global path for callers who need to reproduce a pre-PR-feat(policies): per-(robot_type, control_mode) normalization heads #347 fit.--use-mixture-dataloaderwith default--per-head-normnow raises with a clear message instead of silently using global stats; threading per-sampledataset_indexthrough the dataloader-drained chunks is a separate change."normalization"(per_robot_type_control_modeorglobal),per_dataset_norm_keys, andn_norm_headsfor traceability.How it was tested
Unit tests (added in this PR; all CPU, runs in 3s):
TestPoolPerHeadStats— 4 cases covering pooling across a shared head, distinct heads kept separate, stats-load-failure fallback to[-1, 1], and all-NaN dim filling.TestNormalizeChunksPerHead— 2 cases pinning the per-slice normalization contract and confirming bit-equality with the global path when all datasets share a single head.End-to-end verified on the production pretrain-pi07-10percent mixture (393 datasets, 32 distinct norm heads, 278k chunks): the per-head fit ran in ~95s on a single CPU node, roundtrip MSE = 0.000083, and a 25k held-out token-length distribution gave mean=53.3 / p99=138 / max=238 — vs. the global-normalized fit's mean=50.1 / p99=65 / max=101 over the same chunks. The original mismatch (token lengths blowing past
discrete_action_max_length=120at training time) confirmed the global-fit analysis was systematically under-counting.pre-commit run --all-filesclean (ruff, ruff-format, pyupgrade, bandit, etc.).How to checkout & try? (for the reviewer)
Checklist
Note: Before submitting this PR, please read the contributor guideline.