Skip to content

feat(scripts): per-(robot_type, control_mode) normalization in fit_fast_tokenizer#348

Merged
shuheng-liu merged 3 commits into
mainfrom
feat/fit-fast-tokenizer-per-head-norm
May 28, 2026
Merged

feat(scripts): per-(robot_type, control_mode) normalization in fit_fast_tokenizer#348
shuheng-liu merged 3 commits into
mainfrom
feat/fit-fast-tokenizer-per-head-norm

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

What this does

Adds --per-head-norm (default on) to src/opentau/scripts/fit_fast_tokenizer.py so each chunk is min-max-normalized using its own (robot_type, control_mode) head's pooled stats — the same stats Normalize({"ACTION": NormalizationMode.MIN_MAX}) applies via stacked buffers at training time after PR #347.

Before this change, the manual sampler normalized every chunk with mixture.meta.aggregated_action_stats() (a single global min/max). Once per-head normalization landed in the policy, that mismatch caused the BPE fit to see chunks centered tighter than the policy's chunks, so a fit with global normalization produced 2-3× longer token sequences at inference than the fit-time roundtrip suggested — hitting discrete_action_max_length truncation on a meaningful fraction of training samples.

The new code path computes pooled (per-head) min/max with the same nanmin/nanmax semantics DatasetMixtureMetadata._build_norm_heads uses for the min/max fields, then normalizes each dataset's chunk slice independently. Datasets sharing a (robot_type, control_mode) get identical pooled stats by construction; fallback keys stay singletons; stats-load failures fall back to [-1, 1] (same sentinel the global path already used).

Notes:

  • --no-per-head-norm keeps the legacy global path for callers who need to reproduce a pre-PR-feat(policies): per-(robot_type, control_mode) normalization heads #347 fit.
  • --use-mixture-dataloader with default --per-head-norm now raises with a clear message instead of silently using global stats; threading per-sample dataset_index through the dataloader-drained chunks is a separate change.
  • The fit report now records "normalization" (per_robot_type_control_mode or global), per_dataset_norm_keys, and n_norm_heads for traceability.

How it was tested

Unit tests (added in this PR; all CPU, runs in 3s):

pytest -sx tests/scripts/test_fit_fast_tokenizer.py
  • TestPoolPerHeadStats — 4 cases covering pooling across a shared head, distinct heads kept separate, stats-load-failure fallback to [-1, 1], and all-NaN dim filling.
  • TestNormalizeChunksPerHead — 2 cases pinning the per-slice normalization contract and confirming bit-equality with the global path when all datasets share a single head.

End-to-end verified on the production pretrain-pi07-10percent mixture (393 datasets, 32 distinct norm heads, 278k chunks): the per-head fit ran in ~95s on a single CPU node, roundtrip MSE = 0.000083, and a 25k held-out token-length distribution gave mean=53.3 / p99=138 / max=238 — vs. the global-normalized fit's mean=50.1 / p99=65 / max=101 over the same chunks. The original mismatch (token lengths blowing past discrete_action_max_length=120 at training time) confirmed the global-fit analysis was systematically under-counting.

pre-commit run --all-files clean (ruff, ruff-format, pyupgrade, bandit, etc.).

How to checkout & try? (for the reviewer)

gh pr checkout feat/fit-fast-tokenizer-per-head-norm
pytest -sx tests/scripts/test_fit_fast_tokenizer.py

# Re-fit any mixture with per-head normalization (now the default):
python -m opentau.scripts.fit_fast_tokenizer \
    --mixture-json /path/to/mixture.json \
    --out-dir /path/to/out_dir/ \
    --chunk-size 30 \
    --action-dim 32 \
    --max-state-dim 118 \
    --total-chunks 278000 \
    --num-workers 64 \
    --seed 0

# Reproduce the legacy single-head behavior:
python -m opentau.scripts.fit_fast_tokenizer ... --no-per-head-norm

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

…st_tokenizer

Adds --per-head-norm (default on) to fit_fast_tokenizer.py so each chunk is min-max-normalized using its own (robot_type, control_mode) head's pooled stats — matching what Normalize+ACTION:MIN_MAX does at training time after PR #347. Legacy global path stays available via --no-per-head-norm.
@shuheng-liu shuheng-liu added the feature New feature or request label May 28, 2026
@shuheng-liu shuheng-liu self-assigned this May 28, 2026
Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review of #348 — 13 inline findings, ranked by severity.

The PR's stated invariant — "fit-time normalization matches what the policy applies at training time via Normalize({\"ACTION\": NormalizationMode.MIN_MAX}) with per-head stacked buffers (PR #347)" — is not actually established. The per-head pooling here re-implements _build_norm_heads from scratch and diverges on at least five dimensions that each silently break the equivalence:

High severity (silent normalization mismatch):

  1. NaN-padding raw stats here vs pad_vector zero-padding in _to_standard_data_format — padded action-dim slots normalize to ~0 at fit but -1 at training. Affects every mixture with max_action_dim > native_dim (essentially every production config).
  2. Stats-fail / chunks-succeed mismatch — a dataset that errors during stats loading but produces chunks via direct parquet read gets its real chunks normalized with the [-1, 1] sentinel, even when peer datasets in the same (rt, cm) head loaded fine.
  3. Fallback name uses raw cfg.repo_id; training uses _make_dataset_names's deduplicated name#N. Duplicate-repo_id entries pool at fit but stay separate at training.
  4. No RuntimeError guard when ALL stats fail — script silently produces a tokenizer fit on a fully [-1, 1]-sentinel-normalized corpus and exit 0.

Medium severity:
5. np.nanmin/nanmax propagate ±Inf; aggregate_stats masks Inf→NaN first.
6. require_non_empty_robot_type / require_non_empty_control_mode not enforced — fit completes, training refuses to start.
7. compute_norm_key's fallback_fired discarded — operator never sees the "N datasets lack rt/cm" warning that training emits.
8. Default-flip + hard-raise breaks every existing --use-mixture-dataloader invocation.
9. Dropped over-dim warning — --action-dim misconfig is no longer surfaced with offending repo IDs at stats-aggregation time.
10. _normalize_chunks_per_head relies on an implicit cross-function ordering invariant — no assert sum(per_dataset_chunks) == stacked.shape[0]; a future refactor to _sample_via_manual silently corrupts the BPE corpus.

Low severity / cleanup:
11. Test coverage gap — no test pins fit-time output against the production Normalize layer. A single equivalence test would catch findings 1, 3, 5 at once.
12. global_action_min/max silently null in the report on default per-head fits; no per-head replacement field.
13. _load_dataset_info_and_stats near-clone of _load_dataset_stats (already drifted on <no-repo-id> sentinel + per_dataset diagnostic).

Altitude: the right depth for this PR is to construct a stub DatasetMixtureMetadata from the already-loaded metadata list and read per_norm_key_stats off it — every divergence above collapses to zero new code surface. The current re-implementation guarantees future drift between fit and training each time _build_norm_heads evolves.

The end-to-end roundtrip described in the PR body (95s fit, MSE = 8.3e-5, mean=53.3 / p99=138 token length) confirms the per-head path produces a sane tokenizer overall, but the specific failure modes above are not exercised by either the new unit tests or the production smoke run (393 datasets with max_action_dim=32 and varying native dims — finding 1 would silently apply in production today and the BPE token-length distribution still looks reasonable because padded-suffix bytes compress to a single repetitive token regardless).


Generated by Claude Code

Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py
Comment thread tests/scripts/test_fit_fast_tokenizer.py
Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py Outdated
Comment thread src/opentau/scripts/fit_fast_tokenizer.py
…ask)

Apply the high/medium-severity findings from the PR review:

- Zero-pad raw stats (was NaN-pad). Matches `pad_vector` in
  `_to_standard_data_format` so padded action-dim slots normalize to -1
  at both fit and training time. Without this, the BPE corpus saw
  0-padded suffixes while the policy feeds -1-padded suffixes at
  training -- the slight DCT/BPE divergence is one likely contributor
  to observed token-length tails exceeding the fit-time analysis.

- Raise on any stats-load failure in per-head path. `_to_standard_data_format`
  also raises on missing stats at training time, so silently producing a
  tokenizer that the policy would refuse to consume is sneaky. Subsumes
  the "stats-fail + chunks-succeed" mismatch.

- Mask +/-Inf to NaN before nanmin/nanmax pool, mirroring `aggregate_stats`
  (compute_stats.py:350). A single Inf entry no longer poisons the head.

- Enforce `require_non_empty_robot_type` / `require_non_empty_control_mode`
  at fit time, surfacing the same ValueError that
  `datasets.factory._validate_metadata_requirements` raises -- so the
  operator doesn't burn 90s on a fit the very next training launch
  refuses to start.

- Restore the over-action-dim warning from the global path (lost in the
  initial per-head implementation): datasets whose native action_dim
  exceeds --action-dim are now logged with offending repo IDs and the
  max overflow, so an --action-dim misconfig is visible at stats time.

- Surface fallback datasets (datasets with empty robot_type or
  control_mode after overrides). Without this, operators got
  singleton-per-dataset heads with no warning -- exactly what `_build_norm_heads`
  warns about at training time.

- Demote `--per-head-norm --use-mixture-dataloader` from hard-raise to
  warn-and-fall-back so existing `--use-mixture-dataloader` invocations
  (which never opt into per-head) keep working without a flag change.

- Defensive assert in `_normalize_chunks_per_head`:
  `sum(per_dataset_chunks) == stacked.shape[0]`. A future
  `_sample_via_manual` refactor that changes the concat order or drops
  chunks would now fail loudly rather than silently misalign the
  per-dataset normalization windows.

- Replace null `global_action_min` / `global_action_max` in the report
  on per-head fits with a `per_head_action_stats` map keyed by norm
  key; `n_norm_heads` derives from that.

- Extract the shared metadata-and-stats worker
  (`_load_metadata_stats_and_info`) so `_load_dataset_stats` and the
  per-head path no longer drift.

- Add `TestNormalizeEquivalenceVsProduction`: pins the manual per-head
  output against the production `Normalize({"ACTION": MIN_MAX}).forward`
  with per-head buffers (PR #347) on a synthetic 2-dataset / 2-head
  batch, including a zero-padded action-dim suffix that exercises the
  -1 fallback in both paths.

Push back / deferred:

- The "use `DatasetMixtureMetadata._build_norm_heads` directly" altitude
  fix is out of scope here; it requires either constructing a partial
  `DatasetMixtureMetadata` (whose `__init__` loads parquet + episode
  indices we don't need) or extracting `_build_norm_heads` to a free
  function. Tracked separately. The surgical fixes above close every
  silent-divergence finding.

- Duplicate-`repo_id` mixtures could still pool at fit while training
  separates via `_make_dataset_names`'s `#N` suffix. Real but rare;
  visible via the fallback-datasets warning. Deferred.
@shuheng-liu
Copy link
Copy Markdown
Member Author

Response to review (commit 85c535b)

Going through each finding with what I did or why I pushed back. Net: 11 of 13 fixed in this push; 2 deferred with a reason.

High severity — all fixed

  1. NaN-pad → zero-pad raw stats. Fixed. _aggregate_stats_per_head now np.zeros((action_dim,)) per dataset before pool, matching pad_vector in _to_standard_data_format. Padded-suffix slots now pool to min=max=0 and normalize to -1 at both fit and training, byte-for-byte (modulo the float64→float32 cast rounding, ~1 ULP). The new TestNormalizeEquivalenceVsProduction pins this against the production Normalize layer directly.

  2. Stats-fail / chunks-succeed mismatch. Fixed by subsumption into clean up #4_aggregate_stats_per_head now raises RuntimeError on any stats-load failure, with the offending repo_id: err pairs in the message. Training crashes on the same dataset via _to_standard_data_format's KeyError, so silently producing a fit the policy will refuse is the wrong behavior. The [-1, 1] sentinel path for failed datasets is gone; _pool_per_head_stats now asserts every row is populated.

  3. Fallback-name dedup (repo_id vs name#N). Deferred. Real but narrow: it only diverges when a mixture has both duplicate repo_id and missing (robot_type, control_mode) on those rows. The new fallback-datasets warning surfaces the second condition loudly; an operator hitting both at once will see the warning and notice. Replicating WeightedDatasetMixture._make_dataset_names's dedup at fit time is the right fix but pulls in TrainPipelineConfig/BaseDataset plumbing that doesn't fit this PR's scope. Tracked as a follow-up.

  4. No RuntimeError on all-stats-fail. Fixed via the same change as fix: record rollouts with multiple ranks #2.

Medium severity — all fixed

  1. Mask ±InfNaN before pool. Fixed. _pool_per_head_stats now does np.where(np.isfinite(...), ..., np.nan) on the stacked min/max before nanmin/nanmax, mirroring compute_stats.py:350. New test_inf_masked_before_pool + test_all_inf_dim_falls_back_to_minus_one_one pin both cases.

  2. require_non_empty_robot_type / _control_mode not enforced. Fixed. _aggregate_stats_per_head reads mixture_cfg.require_non_empty_* and raises the same ValueError text as datasets.factory._validate_metadata_requirements if any row is empty after overrides. Operator no longer wastes a fit on a mixture training will refuse.

  3. fallback_fired warning discarded. Fixed. Tracked per-dataset, then logged in the same shape _build_norm_heads uses (capped at 10 names, with , ... and N more suffix).

  4. Default-flip + hard-raise. Fixed. --use-mixture-dataloader --per-head-norm now warns and falls back to global, so existing --use-mixture-dataloader invocations keep working without a flag change. Help text + module docstring updated.

  5. Over---action-dim warning dropped. Fixed. Mirrors the old warning from _aggregate_stats_manual — logs N/total datasets have native action_dim > --action-dim=K (max M) with the first 5 offending repo IDs.

  6. No assert sum(per_dataset_chunks) == stacked.shape[0]. Fixed in _normalize_chunks_per_head with an explicit error message; new test_chunk_count_mismatch_raises_assert pins it.

Low severity — cleanup applied

  1. No equivalence test against production Normalize. Fixed. TestNormalizeEquivalenceVsProduction::test_per_head_normalize_matches_production_normalize builds a 2-dataset / 2-head batch (with zero-padded trailing slots) and asserts the manual output matches Normalize({\"ACTION\": MIN_MAX}).forward(...) to atol=2e-7. Caught the original float64-vs-float32 ULP drift; the comment explains why bit-equality isn't claimed.

  2. global_action_min/max silently null on per-head fits. Fixed. Report now has per_head_action_stats: {key: {min: [...], max: [...]}} and n_norm_heads derived from that. Global fields remain populated only on global fits.

  3. _load_dataset_info_and_stats near-clone of _load_dataset_stats. Fixed. Extracted _load_metadata_stats_and_info as the shared worker; _load_dataset_stats now thin-wraps it.

Altitude critique — deferred with reason

The "construct a stub DatasetMixtureMetadata and read per_norm_key_stats off it" path is the right long-term answer and would eliminate every divergence finding above by construction. The blocker is that DatasetMixtureMetadata.__init__ does much more than build heads (loads parquet, builds episode indices, computes sample weights), and _build_norm_heads is currently a method that reads self.dataset_names / self.per_dataset_stats / self.cfg.{max_action_dim, num_cams}. Doing it right means extracting _build_norm_heads (and the _to_standard_data_format step it depends on) to a free function — a cross-file refactor that touches dataset_mixture.py directly. Out of scope for this PR; tracked as a follow-up. The surgical fixes above close every concrete silent-divergence the review identified.

Impact on the currently-running fit

The published TensorAuto/fast-mixture-per-head was built with the pre-fix code, so the NaN-pad vs zero-pad mismatch (finding #1) applied. The empirical impact is small — the BPE token-length tail at training (max 288 observed) exceeded the fit-time prediction (max 238) by ~20%, with truncation rate ~3 / 152K samples = 0.002% on the current run. Not blocking the in-flight job, but a re-fit on the corrected code would tighten the prediction.

Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review of 85c535b — 11/13 of the previous round's findings landed cleanly; 2 reasonably deferred. Two small follow-up notes inline.

Verified fixes (code-read):

# Finding Status
1 NaN-pad vs zero-pad (HIGH) ✅ Fix at line 542-543 (np.zeros not np.full(NaN)); pinned by new TestNormalizeEquivalenceVsProduction::test_per_head_normalize_matches_production_normalize which asserts elementwise equality vs the production Normalize layer at atol=2e-7.
2 Stats-fail / chunks-succeed (HIGH) ✅ Subsumed by #4 — no more [-1, 1] sentinel path for real chunks.
3 Fallback-name dedup (HIGH) ⏭️ Deferred with reasoning. See inline follow-up for a small mitigation at the warning site.
4 No all-fail guard (HIGH) ✅ Fail-fast at line 552 on ANY stats failure (stricter than the global path's "all fail" guard).
5 Inf masking (MED) np.where(np.isfinite(...), ..., np.nan) pre-mask at lines 695-696; pinned by two new tests.
6 require_non_empty_* (MED) ✅ Mirrored at lines 563-580.
7 fallback_fired warning (MED) ✅ Aggregated warning at lines 598-610 matches _build_norm_heads's shape.
8 --use-mixture-dataloader hard-raise (MED) ✅ Downgraded to warn-and-fall-back at lines 1394-1402.
9 Over-dim warning (MED) ✅ Restored at lines 618-635.
10 Implicit ordering invariant (MED) ✅ Two asserts at lines 744-752; pinned by test_chunk_count_mismatch_raises_assert.
11 Test gap vs production Normalize (LOW) ✅ Closed by the new equivalence test — would have caught findings 1, 3, and 5 by construction.
12 global_action_* null in report (LOW) ✅ New per_head_action_stats map carries {key: {"min": [...], "max": [...]}}.
13 _load_dataset_info_and_stats clone (LOW) ✅ Collapsed into shared _load_metadata_stats_and_info; _load_dataset_stats is now a 3-line wrapper.
Alt Construct stub DatasetMixtureMetadata ⏭️ Reasonably deferred — the right shape requires extracting _build_norm_heads to a free function (cross-file refactor). The surgical fixes here close every concrete divergence.

The new equivalence test is particularly load-bearing: it pins the invariant the PR was created to establish, and asserts the padded suffix lands at -1 in both paths (exactly the silent corruption finding #1 described). Solid.

New observations (low-priority):

  • One inline on [{}] * n aliasing footgun at line 515 — currently safe, but a future mutation would silently fan out.
  • One inline suggesting a tighter fit-time signal for the deferred #3 case (duplicate-repo_id pooled under fallback).

No new correctness issues found. Ready to mark non-draft once the deferred follow-ups are tracked elsewhere.


Generated by Claude Code

Comment thread src/opentau/scripts/fit_fast_tokenizer.py
Comment thread src/opentau/scripts/fit_fast_tokenizer.py
…ck warn)

Two low-priority items from the re-review of 85c535b:

- Replace `[{}] * n` with `[{} for _ in range(n)]` for `per_ds_info`. The
  multiplied-literal form creates N references to the same dict, so a future
  contributor doing `per_ds_info[idx]["override_X"] = ...` (a legitimate-
  looking accumulator pattern) would silently fan out across every slot.
  The `[None] * n` and `["<no-repo-id>"] * n` callsites stay as-is because
  their element types are immutable; comment notes the distinction so the
  pattern doesn't regress.

- Add a tighter fit-time warning for the specific divergence the deferred
  fallback-name finding (#3 in the original review) flagged. When a
  fallback-keyed `repo_id` appears more than once in the mixture, fit-time
  POOLS those rows under one shared key while training keeps them as
  separate singleton heads (via `_make_dataset_names`'s `#N` dedup). The
  new Counter-based warning surfaces the exact duplicate `repo_id -> count`
  map so an operator can fix it before the misalignment ships. Mechanical;
  no new infra. Closes the divergence-detection gap while the proper dedup
  follow-up (tracked separately) is pending.
@shuheng-liu
Copy link
Copy Markdown
Member Author

Response to re-review (commit c755cb8)

Re-review verified all 11 fixes from the prior round and flagged 2 new low-priority items. Both addressed; deferred items moved to dedicated issues.

New inline findings (re-review) — both fixed in c755cb8

  • [{}] * n aliasing footgun (line 515). Switched per_ds_info to a list-comprehension form so a future per_ds_info[idx]["override_X"] = ... doesn't fan out across slots. Added a comment noting why [None] * n / ["<no-repo-id>"] * n stay as multiplied-literals (immutable element types). Thread resolved.

  • Tighter signal for duplicate-repo_id in fallback (line 610). Added a Counter-based block inside the existing fallback-handling branch that detects when a fallback-keyed repo_id appears more than once and warns with the exact {repo_id -> count} map. Mechanical, no new infra. Closes the divergence-detection gap until the full dedup follow-up lands. Thread resolved.

Deferred items — now tracked

  • Deferred finding Fixing reward normalizer #3 (fallback-name dedup at fit time): #349.
  • Deferred altitude refactor (extract _build_norm_heads to a free function so fit_fast_tokenizer can call it directly): #350.

Both inline threads updated with issue links and resolved.

State

All 16 inline review threads on this PR are now resolved (14 from the original review + 2 from the re-review). 11 closed by direct fixes in commit 85c535b, 2 closed by direct fixes in commit c755cb8, 2 closed with cross-references to the new tracking issues.

Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review of c755cb8 — both low-priority follow-ups from the last round are addressed.

Verified:

  1. [{}] * n[{} for _ in range(n)] (line 515) with a tight comment explaining why the adjacent [None] * n / ["<no-repo-id>"] * n lines stay safe. ✅
  2. Duplicate-fallback Counter warning (lines 612-638). The implementation matches the suggested shape: counts only fallback-fired entries, filters to count > 1, surfaces the divergence at fit time with a (repo_id → count) preview capped at 10. Correctly fires on the exact case the deferred dedup finding flagged (duplicate repo_id both lacking rt/cm → pooled at fit, singleton-per-entry at training) and doesn't fire on legitimate same-repo_id entries with distinct (rt, cm) overrides (which pool consistently at both fit and training). ✅

No new issues introduced. With the duplicate-detection warning in place, the fit-time signal for the deferred #3 case is now strong enough that the proper dedup refactor can land as a follow-up without operators silently shipping mismatched tokenizers in the meantime.

Nothing else to flag. The PR is in good shape to mark non-draft.


Generated by Claude Code

@shuheng-liu shuheng-liu marked this pull request as ready for review May 28, 2026 18:59
@shuheng-liu shuheng-liu merged commit 765a932 into main May 28, 2026
19 checks passed
@shuheng-liu shuheng-liu deleted the feat/fit-fast-tokenizer-per-head-norm branch May 28, 2026 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant