Skip to content

feat(policies): per-dataset normalization + skip-stats safetensors flag#336

Merged
shuheng-liu merged 6 commits into
mainfrom
claude/serene-euclid-ce1787
May 27, 2026
Merged

feat(policies): per-dataset normalization + skip-stats safetensors flag#336
shuheng-liu merged 6 commits into
mainfrom
claude/serene-euclid-ce1787

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu commented May 27, 2026

What this does

Tears down the mixture-level stats aggregation in DatasetMixtureMetadata
and routes per-dataset stats into stacked Normalize/Unnormalize buffers
indexed per sample. Adds a configurable flag to omit those buffers from
checkpoint safetensors.

Per-dataset normalization. DatasetMixtureMetadata no longer calls
aggregate_stats(...) across datasets. Instead it exposes
per_dataset_stats: list[dict] and dataset_names: list[str]. A new
_TaggedDataset wrapper around every underlying BaseDataset injects
dataset_repo_id: str and dataset_index: torch.long into each
sample (default collate batches them into list[str] and (B,) long
respectively). Normalize/Unnormalize buffers are now shaped
(num_datasets, *feat_shape); forward(batch, dataset_index) gathers
the right row per sample via index_select and broadcasts over any
extra temporal/spatial axes. Each of the eight in-scope pi policies
(pi0, pi05, pi05_mem, pi06, pi07/low_level,
pi07/high_level_planner, pi07_paligemma/low_level,
pi07_paligemma/high_level_planner) threads
dataset_index = self._resolve_dataset_index(batch) through every
forward/select_action/sample_actions. The value policy is out
of scope; it keeps its single-dict dataset_stats external API but
wraps internally into a singleton list before calling the new
Normalize.

Inference-time callers pass batch["dataset_repo_id"] (str or
list[str]); the base PreTrainedPolicy._resolve_dataset_index helper
maps strings through config.dataset_names to an integer index.
Single-dataset configs (num_datasets <= 1) default to all-zero
indices when both batch keys are absent, so legacy single-dataset
callers don't need to know about the per-sample plumbing. Multi-dataset
policies still raise a clear KeyError when the caller forgets, so
silent misuse stays loud.

Skip-stats safetensors flag. Two new fields on PreTrainedConfig:
save_normalization_stats: bool = True and
dataset_names: list[str] | None = None. _save_pretrained accepts an
include_norm_stats override that wins over the config field;
accelerator.save_state is gated via a register_save_state_pre_hook
in train.py so the buffers are stripped from the on-disk safetensors
in both code paths. Reloading a stats-less checkpoint requires
make_policy(..., ds_meta=...) so _inject_stats(...) can repopulate;
otherwise _check_norm_stats_loaded raises a clear error pointing the
user at the right knob.

Other touches. fit_fast_tokenizer.py switches from the deleted
mixture.meta.stats["actions"] to a new
DatasetMixtureMetadata.aggregated_action_stats() helper that wraps
aggregate_stats(...) on demand (BPE codec still needs one global
range). export_to_onnx.py passes an explicit dataset_index=0 to
unnormalize_outputs since the ONNX tracer bypasses
_resolve_dataset_index.

How it was tested

  • Added tests/policies/test_normalize_per_dataset.py (6 tests): D=3-4
    stacked buffers, per-row indexing on MEAN_STD and MIN_MAX,
    Unnormalize-inverts-Normalize round-trip, temporal-axis broadcasting,
    image (C, 1, 1) broadcasting, length-mismatch validation.
  • Added tests/policies/test_save_pretrained_skip_stats.py (6 tests):
    save with stats keeps normalize_*.buffer_* keys, save without
    strips them, method-level include_norm_stats override wins over the
    config field, safetensors.safe_open round-trip, is_norm_buffer_key
    predicate.
  • Added tests/datasets/test_tagged_dataset.py (5 tests):
    _TaggedDataset.__getitem__ injects the two keys, preserves
    .meta, default collate produces list[str] and (B,) long tensor.
  • Updated tests/datasets/test_dataset_mixture.py to assert on
    per_dataset_stats / dataset_names instead of the deleted stats,
    to allow mixture.datasets to be the wrapped list, and to strip the
    wrapper-injected keys before the
    test_integration_basic_functionality_with_same_fps_as_dataset
    per-sample equality check.
  • Updated tests/policies/test_policies.py::test_normalize to thread
    dataset_index and wrap singleton stats.
  • Updated every GPU policy test (test_pi05.py, test_pi05_mem_gpu.py,
    test_pi06.py, test_pi07_high_level_planner.py,
    test_pi07_low_level.py, test_pi07_paligemma_high_level_planner.py,
    test_pi07_paligemma_low_level.py) to use the new
    per_dataset_stats=[...] constructor signature.
  • Full non-slow CPU suite: 1188 passed, 14 skipped, 1 pre-existing
    unrelated robosuite import error in tests/utils/test_libero_utils.py.
  • GPU pytest subset (pytest -m "gpu" -n 0): 19 passed, 10 skipped
    (unrelated), 1359 deselected. Run on an internal GPU dev box (RTX
    5090, CUDA 13.0, ~5 min). Two iterations were needed to land — the
    first uncovered stale dataset_stats=... kwargs in seven test
    fixtures and one Normalize submodule call that bypassed the
    policy-level _resolve_dataset_index helper.
  • pre-commit clean (ruff lint + format, license header, gitleaks, bandit).

Not yet covered (tracked in #335)

The following still need attention beyond a single-GPU pytest pass:

  • Smoke training run on configs/dev/dev_config.json or the
    pi05 smoke config to confirm forward doesn't trip the new
    inf-assertion at step 1 and the per-dataset val dataloader iterates
    clean.
  • End-to-end save_normalization_stats=false round-trip on a real
    checkpoint dir, including the safe_open(...).keys() filter check
    and make_policy(..., ds_meta=...) reload.
  • Seeded determinism check (per CLAUDE.md rule Fixing reward normalizer #3): two runs with
    seed=0, diff the per-step loss series, assert bit-identical.
  • Distributed sanity under DDP / DeepSpeed ZeRO-2 — the new
    index_select is a local op with no new collectives, so the risk
    is low, but worth confirming on real backends.
  • Nightly regression suite (regression_test.yml on g6.12xlarge).

These checks are tracked in #335.

How to checkout & try? (for the reviewer)

gh pr checkout 336
uv sync --extra dev --extra libero
pre-commit run --all-files
pytest -m "not gpu" -n auto tests/policies/test_normalize_per_dataset.py tests/policies/test_save_pretrained_skip_stats.py tests/datasets/test_tagged_dataset.py tests/datasets/test_dataset_mixture.py tests/policies/test_policies.py

On a CUDA box (matches the run already done on the dev box):

pytest -m "gpu" -n 0

To exercise the safetensors round-trip on a real training run (still
tracked in #335):

opentau-train \
    --accelerate-config configs/examples/accelerate_ddp_config.yaml \
    --config_path=configs/dev/dev_config.json \
    --steps=2 --save_freq=2 \
    --policy.save_normalization_stats=false

Then inspect the safetensors with from safetensors import safe_open; safe_open(...).keys() to confirm normalize_*.buffer_* are absent.

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

Tears down the mixture-level stats aggregation in `DatasetMixtureMetadata`
and pushes per-dataset stats into stacked Normalize/Unnormalize buffers
indexed per sample. Each dataloader sample now carries `dataset_repo_id`
(str) and `dataset_index` (long) — the policy gathers the matching row
inside `Normalize.forward`. Adds a configurable
`PreTrainedConfig.save_normalization_stats` flag (with method-level
override) to strip `normalize_*.buffer_*` keys from the on-disk
safetensors, and an Accelerate save-state hook so `accelerator.save_state`
respects the same flag.

Scope: all pi variants (pi0/pi05/pi05_mem/pi06/pi07/pi07_paligemma,
low_level + high_level_planner). `value` policy is unchanged.
@shuheng-liu shuheng-liu added feature New feature or request refactor labels May 27, 2026
@shuheng-liu shuheng-liu self-assigned this May 27, 2026
`_TaggedDataset` inside `WeightedDatasetMixture` adds `dataset_repo_id`
and `dataset_index` to every sample, but the assertion in
`test_integration_basic_functionality_with_same_fps_as_dataset`
compares dataloader samples against direct
`datasets[0][idx]` lookups (which bypass the wrapper). Strip the
wrapper-added keys before the content comparison so the test reflects
intent.
Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewing as a draft. The shape of the refactor — per-dataset stacked buffers indexed per sample, a skip-stats flag stripped at the safetensors write boundary — is the right one. But the change is incomplete in a few places that the CPU test suite cannot see, and at least four of those gaps will hard-block a real run after merge. Findings inline; this is the summary view.

Hard blockers

# Where What breaks
1 pretrained.py:167 First save_state with save_normalization_stats=False on any PaliGemma-backed PI policy raises SafetensorError: Some tensors share memorysave_file doesn't dedup tied lm_head/embed_tokens, despite the docstring promising a clone step the code never performs.
2 normalize.py:172 New (num_datasets, *feat_shape) buffers are incompatible with every pre-PR checkpoint on disk (including everything under TensorAuto/*). load_state_dict(strict=False) still raises on size mismatch — make_policy(cfg, pretrained_path='TensorAuto/pi05') fails after merge with no migration path.
3 pretrained.py:293 (= _resolve_dataset_index) Every eval rollout and every gRPC inference call crashes with KeyError. The eval observation pipeline (scripts/eval.py:150-168) and scripts/grpc/server.py:107,246 never inject dataset_index / dataset_repo_id — only _TaggedDataset does, and only on training batches.
4 factory.py:250 cfg.policy.type='value' raises TypeError immediately — ValueFunction.__init__ still has the old dataset_stats= kwarg. The PR description says "value is unchanged"; in practice it's broken.
5 train.py:568 Resume with save_normalization_stats=False either (a) raises on missing buffer keys, or (b) silently retains +inf init and asserts on step 1. The save hook has no matching load hook, and make_policy's _inject_stats call is gated on cfg.pretrained_path which isn't set during resume.

GPU test suite is not updated (PR's own checklist box is unchecked)

tests/policies/test_policies.py was updated cleanly, but the per-policy GPU-marked tests still use the old API and will fail the nightly gpu_test.yml run:

  • tests/policies/test_pi05.py:265, 366-367, 593PI05Policy(config, dataset_stats=...) and bare policy.normalize_targets(action).
  • tests/policies/test_pi07_low_level.py:362, 460, 638, 747 — same pattern.
  • tests/policies/test_pi07_paligemma_low_level.py:344, 444 — same pattern.

Each will raise TypeError: ... unexpected keyword argument 'dataset_stats' at construction, or TypeError: forward() missing 1 required positional argument: 'dataset_index' at the roundtrip assertion. Worth updating in this PR rather than discovering them in nightly.

Higher-severity edge cases (inline above)

  • pretrained.py:286 — CPU dataset_index + GPU policy = device mismatch in index_select; the string-list branch handles device, the tensor branch doesn't.
  • pretrained.py:370_inject_stats(stats, dataset_names=None) silently corrupts the buffer↔name mapping when the caller's order differs from the existing config.
  • pretrained.py:404 (_check_norm_stats_loaded) — only called from make_policy; direct from_pretrained users (gRPC, notebooks) silently get +inf and crash mid-forward.
  • normalize.py:233_gather_and_broadcast silently broadcasts to a wrong shape when batch_val.ndim < gathered.ndim instead of raising.
  • dataset_mixture.py:236aggregated_action_stats slices _dataset_weights[:N] instead of zipping; misaligned when any non-trailing dataset lacks 'actions'.

A few cleanups I'm leaving as suggestions, not blockers

  • The 4-tuple ("normalize_inputs", "normalize_targets", "normalize_discrete_actions", "unnormalize_outputs") is hard-coded at pretrained.py:54-59, 337-342, 385-390 — extract once.
  • Normalize.__init__ and Unnormalize.__init__ duplicate ~18 lines verbatim — pull into a shared base.
  • The 8 policy __init__ files repeat the same num_datasets = _num_datasets(...) + 3-4 Normalize(...) calls; a PreTrainedPolicy._build_normalize_modules(...) helper would absorb ~200 lines of boilerplate and prevent future per-policy drift.
  • The training-time _TaggedDataset.__getitem__ stores both dataset_repo_id: str and dataset_index: long on every sample — the string is never read in the training path (only at inference, where the caller supplies it explicitly). Worth dropping to skip the per-batch list[str] collate cost.
  • Single-dataset configs (D=1) still pay an index_select allocating (B, *feat_shape) per stat per forward; broadcasting stat[0] would be free.
  • _gather_and_broadcast uses .reshape(...) with a computed axis count — CLAUDE.md rule #4 specifically flags this kind of unsqueeze-chain reshape as a target for einops.rearrange / indexed None.

The blockers above (1-5) plus the GPU test gap are the things I'd hold the merge on. The rest are inline.


Generated by Claude Code

Comment thread src/opentau/policies/pretrained.py Outdated
# avoid safetensors' "duplicated tensors" rejection.
state_dict = model_to_save.state_dict()
filtered = {k: v for k, v in state_dict.items() if not is_norm_buffer_key(k)}
save_safetensor_file(filtered, out_path)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker — skip-stats save crashes on every PI policy.

The docstring at lines 162-164 promises a "clone tensors that share storage with another retained tensor" step, but the code never performs it: filtered is a {k: v for k, v in state_dict.items() if not is_norm_buffer_key(k)} — the retained tensors still alias each other.

PaliGemma ties lm_head.weight and embed_tokens.weight (handled explicitly in pi0/modeling_pi0.py:274 "lm_head.weight and embed_tokens.weight share memory"). At save time both keys appear in state_dict() pointing to the same storage. save_model deduplicates automatically; save_file does not — it raises:

safetensors.SafetensorError: Some tensors share memory, this will lead to duplicate memory on disk

So the very first accelerator.save_state / policy.save_pretrained call on any pi0/pi05/pi05_mem/pi07_paligemma run with save_normalization_stats=False will crash before any disk write happens. The unit test in test_save_pretrained_skip_stats.py uses a toy model with no tied weights, so it doesn't catch this.

Suggested change
save_safetensor_file(filtered, out_path)
# `save_model` doesn't expose a filter argument, so build the
# filtered state_dict ourselves. Use `_remove_duplicate_names`
# (the helper `save_model` calls internally) to drop alias
# entries — PaliGemma ties `lm_head.weight` and
# `embed_tokens.weight`, so a naive filter trips
# `safetensors`' shared-memory rejection.
from safetensors.torch import _remove_duplicate_names
state_dict = model_to_save.state_dict()
to_removes = _remove_duplicate_names(state_dict, preferred_names=set())
for kept, aliases in to_removes.items():
for alias in aliases:
state_dict.pop(alias, None)
filtered = {k: v for k, v in state_dict.items() if not is_norm_buffer_key(k)}
save_safetensor_file(filtered, out_path)

Generated by Claude Code

# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
stacked_shape = (num_ds, *shape)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker — every existing checkpoint becomes unloadable.

Pre-PR checkpoints (everything currently published under TensorAuto/*) saved each normalize_*.buffer_*.{mean,std,min,max} with shape (*feat_shape) — e.g. (32,) for state, (3, 1, 1) for a camera. After this PR every buffer is allocated with stacked_shape = (num_ds, *shape), and for a legacy config (no dataset_names set) resolve_num_datasets returns 1, so the new shape is (1, 32) / (1, 3, 1, 1).

PreTrainedPolicy._load_as_safetensor calls load_model_as_safetensor(..., strict=False), but PyTorch's load_state_dict(strict=False) only tolerates missing and unexpected keys — size mismatches on present keys are always raised. The result:

RuntimeError: Error(s) in loading state_dict for PI05Policy:
  size mismatch for normalize_inputs.buffer_state.mean:
    copying a param with shape torch.Size([32]) from checkpoint,
    the shape in current model is torch.Size([1, 32]).

_tile_linear_input_weight only rewrites linear weights — no equivalent hook reshapes buffers. So make_policy(cfg) for any cfg.pretrained_path = 'TensorAuto/pi05' will break after merge. Needs either (a) a load-side hook in _load_as_safetensor that unsqueezes legacy 1-D buffers when the saved rank is feat_rank instead of 1 + feat_rank, or (b) an explicit migration script + release note.


Generated by Claude Code

"Per-dataset normalization requires either `dataset_index` "
"(LongTensor of shape (B,)) or `dataset_repo_id` (str or "
"list[str] of length B) in the batch."
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker — every eval rollout and the gRPC inference server crash on the first step.

The eval observation pipeline (scripts/eval.py:150-168: preprocess_observationadd_envs_taskadd_eval_metadataadd_subgoal_images) never injects dataset_index or dataset_repo_id. Likewise scripts/grpc/server.py:107, 246 builds an observation dict by hand from gRPC inputs.

So every policy.select_action(observation) / policy.sample_actions(observation, ...) on a per-dataset-normalized policy will raise this KeyError on the very first call — both opentau-eval and any cfg.eval_freq > 0 in-training eval block. The factory path is fine because _TaggedDataset injects both keys, but the eval/inference paths weren't updated to mirror it.

Three places need a fix to keep this PR's invariant ("every batch has a dataset_index") true end-to-end:

  • scripts/eval.py::add_eval_metadata (or a new add_dataset_repo_id helper) needs to attach batch['dataset_repo_id'] per env, looked up from cfg.env / cfg.dataset_mixture so it resolves through _dataset_name_to_index.
  • scripts/grpc/server.py's warmup + serving path must do the same.
  • The single-dataset legacy fallback (config.dataset_names is None) should default to dataset_index = 0 rather than raising, otherwise loading any non-mixture checkpoint also fails — the message at lines 296-300 is descriptive, but a no-stats single-dataset reload is the common case and shouldn't need the caller to construct an index tensor.

Generated by Claude Code

dataset_names = dataset_names or [f"external_{i}" for i in range(len(per_dataset_stats))]

if per_dataset_stats is not None:
kwargs["per_dataset_stats"] = per_dataset_stats
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker — cfg.policy.type = "value" immediately raises TypeError.

The PR description says the value policy is "out of scope and unchanged", but the factory now unconditionally passes per_dataset_stats= and dataset_names= to every policy class. ValueFunction.__init__ (policies/value/modeling_value.py:124-138) still has the old signature:

def __init__(
    self,
    config: ValueConfig,
    dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):

So make_policy(cfg, ds_meta=...) with cfg.type='value' triggers:

TypeError: ValueFunction.__init__() got an unexpected keyword argument 'per_dataset_stats'

Even if that were fixed, ValueFunction.__init__ still calls Normalize(config.input_features, config.normalization_mapping, dataset_stats)dataset_stats is the 3rd positional arg, which is now per_dataset_stats: list[...]. Passing a dict there would either crash in create_stats_buffers (which iterates len(per_dataset_stats) as if it's a list) or worse, silently produce nonsense buffer shapes. And self.normalize_inputs(batch) inside forward is now missing the dataset_index positional arg.

Either: (a) update ValueFunction to the new API so "unchanged" becomes true, (b) carve value out of the per_dataset_stats kwarg passing in this factory and keep the legacy path, or (c) explicitly raise NotImplementedError when cfg.type=='value' lands in this branch so it fails loudly with a useful message instead of a generic TypeError.


Generated by Claude Code

if is_norm_buffer_key(k):
del sd[k]

accelerator.register_save_state_pre_hook(_strip_norm_buffers_pre_save)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker — resume with save_normalization_stats=False deadlocks / asserts at step 1.

You register a save_state_pre_hook that strips norm buffers on disk, but there is no matching register_load_state_pre_hook (or post-load injection). The resume flow at line 573 (accelerator.load_state(cfg.checkpoint_path)) loads the model state_dict from the same on-disk file the hook stripped:

  1. make_policy(cfg, ds_meta=ds_meta) constructs the policy. Since cfg.pretrained_path is NOT set during a resume (it's only set when bootstrapping from a hub checkpoint), the _inject_stats(...) branch at factory.py:277 does not fire. Buffers stay at +inf.
  2. accelerator.load_state(...) then loads the on-disk safetensors. With strict mode it raises on missing buffer keys; with strict=False the buffers stay at the +inf inf-init.
  3. First forward fires assert not torch.isinf(mean).any() — training crashes on step 1, after potentially hours of resume-setup wall time.

This needs either:

  • A symmetric register_load_state_pre_hook that calls policy._inject_stats(ds_meta.per_dataset_stats, dataset_names=ds_meta.dataset_names) after load, or
  • make_policy unconditionally calls _inject_stats whenever per_dataset_stats is not None, not just under cfg.pretrained_path (move the guard at factory.py:277-280 so the resume path also benefits).

Either way, this PR's checklist explicitly defers smoke-running the round-trip ("End-to-end save_normalization_stats=false round-trip on a real checkpoint dir") to #335 — which means the bug above will land first and break the next person to flip the flag.


Generated by Claude Code

f"buffers are still +inf: {bad}. Either re-save the checkpoint "
"with `save_normalization_stats=True`, or pass `ds_meta=` to "
"`make_policy(...)` so stats are injected after load."
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Footgun — from_pretrained bypasses both _inject_stats and _check_norm_stats_loaded.

Only factory.make_policy (lines 277-280) calls these. Anyone using PI05Policy.from_pretrained('path/to/skip-stats-ckpt') directly — and there are real callers:

  • scripts/grpc/server.py::_load_policy (the production inference server),
  • notebooks/pi05_evaluation_only.ipynb,
  • Anything in external / downstream code that doesn't go through factory.

…will silently get +inf buffers and crash mid-forward instead of at construction time, with no clean message pointing the user at the save_normalization_stats=False round-trip mistake.

The cheapest fix is to call self._check_norm_stats_loaded() at the bottom of from_pretrained right before the policy.eval() (or at the bottom of _load_as_safetensor). It's a no-op when the buffers loaded fine, and a clear RuntimeError when they didn't — the same invariant make_policy already enforces, just hoisted to the load boundary where it actually belongs.


Generated by Claude Code

Comment thread src/opentau/policies/pretrained.py Outdated
idx = batch["dataset_index"]
if not isinstance(idx, Tensor):
idx = torch.as_tensor(idx, dtype=torch.long)
return idx.to(dtype=torch.long)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug — CPU-tensor dataset_index + GPU policy raises device mismatch in index_select.

idx.to(dtype=torch.long) preserves the existing device. The training path is safe because _TaggedDataset returns a scalar tensor that accelerator.prepare's send_to_device will move. But any direct-inference caller — gRPC server, ROS client, the notebooks/pi05_evaluation_only.ipynb cell that builds an observation dict by hand, or an accelerator-less unit test — typically constructs a CPU LongTensor and hands it to a CUDA-resident policy. Then inside normalize.py::_gather_and_broadcast:

RuntimeError: Expected all tensors to be on the same device,
  but found at least two devices, cuda:0 and cpu!

The dataset_repo_id (string) branch handles this correctly at line 312 — torch.tensor(indices, dtype=torch.long, device=device) — but the tensor branch above doesn't. The fix is symmetric:

Suggested change
return idx.to(dtype=torch.long)
if "dataset_index" in batch:
idx = batch["dataset_index"]
if not isinstance(idx, Tensor):
idx = torch.as_tensor(idx, dtype=torch.long)
_, device = self._infer_batch_size_and_device(batch)
return idx.to(dtype=torch.long, device=device)

Generated by Claude Code

)
gathered = stat.index_select(0, dataset_index) # (B, *feat_shape)
extra = batch_val.ndim - gathered.ndim
if extra > 0:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug — extra < 0 silently broadcasts wrong instead of raising.

The current branch only pads when batch_val.ndim > gathered.ndim. If a feature has shape (1,) (a scalar state dim) and a caller squeezed the trailing axis, batch_val.shape = (B,), gathered.shape = (B, 1), and extra = -1. No reshape happens, then (batch_val:(B,) - gathered:(B,1)) broadcasts to (B, B) — a silently wrong result with no diagnostic until something downstream crashes on the wrong-shape tensor.

The contract of this helper is "make gathered broadcastable to batch_val" — if it can't, that's a bug worth surfacing:

Suggested change
if extra > 0:
gathered = stat.index_select(0, dataset_index) # (B, *feat_shape)
extra = batch_val.ndim - gathered.ndim
if extra < 0:
raise ValueError(
f"batch_val has fewer dims ({batch_val.ndim}) than the gathered stat "
f"({gathered.ndim}); can't broadcast {tuple(batch_val.shape)} against "
f"{tuple(gathered.shape)}."
)
if extra > 0:
gathered = gathered.reshape(gathered.shape[0], *((1,) * extra), *gathered.shape[1:])
return gathered

Generated by Claude Code

Comment thread src/opentau/datasets/dataset_mixture.py Outdated
# pull just the "actions" sub-dicts so the call is cheap.
agg = aggregate_stats(
[{"actions": s["actions"]} for s in stats_with_actions],
weights=self._dataset_weights[: len(stats_with_actions)],
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug — weights misaligned when any non-trailing dataset lacks 'actions'.

self._dataset_weights[: len(stats_with_actions)] takes the first N weights, not the weights for the datasets that survived the "actions" in s filter.

Concrete failure: a mixture of [A (VQA-only, no actions), B (has), C (has)] with weights [wA, wB, wC]:

  • stats_with_actions = [B, C]
  • weights = self._dataset_weights[:2] = [wA, wB] ← wrong
  • aggregate_stats then weights B by wA and C by wB instead of the intended (wB, wC).

fit_fast_tokenizer.py consumes this to fit one BPE codec over a global action range — the result is a silently misaligned discrete action vocab. Fix by zipping & filtering both together:

Suggested change
weights=self._dataset_weights[: len(stats_with_actions)],
pairs = [
(s, w)
for s, w in zip(self.per_dataset_stats, self._dataset_weights, strict=True)
if "actions" in s
]
if not pairs:
raise ValueError(
"No dataset in the mixture exposes 'actions' stats; aggregated_action_stats() is undefined."
)
agg = aggregate_stats(
[{"actions": s["actions"]} for s, _ in pairs],
weights=[w for _, w in pairs],
)
return agg["actions"]

Generated by Claude Code

buffer[stat].data.copy_(new_tensor)
# Refresh the module's own dataset_names cache.
if dataset_names is not None:
module.dataset_names = list(dataset_names)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Footgun — silent corruption when dataset_names=None and caller's order differs from the existing config.

When dataset_names is None, the loop above overwrites buffer[stat] rows in the order of per_dataset_stats, but the if dataset_names is not None guards leave self.config.dataset_names and self._dataset_name_to_index pointing at the old order. After injection, inference resolves batch['dataset_repo_id'] = 'robotA' through the old name→index map, but row 0 of the buffer is now whatever was at per_dataset_stats[0] in the caller's order. No error — just silently wrong stats applied per sample.

make_policy always passes dataset_names=dataset_names, so this is safe from the factory path. But the method is a public-shaped helper (documented in the docstring and used by the suggested-fix path for resume), and anyone calling policy._inject_stats(stats) directly hits a quiet correctness bug.

Either require dataset_names always (drop the Optional), or when None, verify the new rows match the existing config.dataset_names length AND require the caller to acknowledge the existing order is being preserved (e.g. an explicit assert_existing_order=True flag).


Generated by Claude Code

…esolve_dataset_index

Single-dataset policies (`num_datasets <= 1`) no longer require the
caller to inject `dataset_index` / `dataset_repo_id` into the batch —
`_resolve_dataset_index` returns zeros when both are absent. Multi-
dataset policies still raise a clear `KeyError` so silent misuse stays
loud.

Also:
- `ValueFunction` keeps its external `dataset_stats: dict | None` API
  but wraps into a singleton list internally before calling the new
  `Normalize`. Forward calls go through the shared
  `_resolve_dataset_index` helper.
- Update GPU pytest fixtures across pi0/pi05/pi05_mem/pi06/pi07/
  pi07_paligemma to use `per_dataset_stats=[...]` matching the new
  constructor signature.
…irectly

`test_complete_pi05_pipeline_integration` reaches into
`policy.normalize_targets` / `policy.unnormalize_outputs` directly to
verify the round-trip, which bypasses the policy-level
`_resolve_dataset_index` helper. The new per-dataset Normalize forward
needs an explicit `(B,)` long index — pass a zeros tensor matching
the single-dataset construction of this fixture.
…set normalize PR

Hard blockers (1-5 from review):
1. `_save_pretrained` with stats stripped: detach norm-buffer submodules and
   reuse `save_model_as_safetensor` so tied `lm_head`/`embed_tokens` on
   PaliGemma-backed policies dedup correctly. The raw `save_file` path
   crashed with `SafetensorError: Some tensors share memory`.
2. Legacy checkpoint migration: new
   `_promote_legacy_norm_buffers_in_state_dict` shim unsqueezes pre-PR
   `(*feat_shape,)` buffers to the new `(1, *feat_shape)` layout at load
   time. Called from base `_load_as_safetensor` and from every pi
   `from_pretrained` override so anything under `TensorAuto/*` still loads.
3. Inference dataset_index plumbing: added `EvalConfig.dataset_repo_id` and
   `ServerConfig.dataset_repo_id`. `scripts/eval.py::rollout` and
   `scripts/grpc/server.py` (both warmup and per-request batch) inject the
   field when set; otherwise the single-dataset `_resolve_dataset_index`
   fallback handles things.
4. `ValueFunction.__init__` now accepts both the legacy
   `dataset_stats: dict` and the new `per_dataset_stats: list[dict] +
   dataset_names: list[str]` API, with mutual-exclusion validation. Lets
   `make_policy(cfg.policy.type='value', ds_meta=...)` work after the
   factory plumbing change.
5. Resume with `save_normalization_stats=False`: pass `strict=False` to
   `accelerator.load_state` so the stripped buffer keys don't crash the
   load. Buffers stay at the `_inject_stats`-populated values.

Higher-severity edges:
- `_resolve_dataset_index` now moves the dataset_index tensor to the
  batch's device (was missing in the tensor branch).
- `_inject_stats` validates that any caller-supplied `dataset_names`
  matches `config.dataset_names` element-for-element (was silently
  corrupting the name->index map on reorder).
- `_check_norm_stats_loaded` now also runs at the end of base
  `from_pretrained` so direct callers (notebooks, gRPC server,
  downstream scripts) surface the missing-stats mistake at construction
  time instead of mid-forward. Uses the shared `NORM_MODULE_NAMES`
  tuple.
- `Normalize._gather_and_broadcast` raises `ValueError` when the batch
  tensor has fewer dims than the stats buffer (was silently broadcasting
  to a wrong shape).
- `aggregated_action_stats` zips weights with the filtered stats list
  instead of slicing `[:N]` — a non-trailing dataset lacking `actions`
  would otherwise misalign the BPE codec weights.

Test fixes:
- `test_pi07_low_level.py` and `test_pi07_paligemma_low_level.py`: pass
  explicit `dataset_index` when calling `normalize_targets`/
  `unnormalize_outputs` submodules directly (bypassing the policy-level
  helper).

Cleanups:
- Extract `NORM_MODULE_NAMES` constant in `pretrained.py` and use it in
  `_save_pretrained`, `_inject_stats`, `_check_norm_stats_loaded`.
- Drop the unused `packaging` and `safetensors` version imports from
  `pretrained.py` (the new `_load_as_safetensor` path doesn't need
  them).
@shuheng-liu
Copy link
Copy Markdown
Member Author

Thanks for the thorough review — all five blockers and the higher-severity edge cases are addressed in f8c7492. GPU pytest is green again on a CUDA dev box: 19 passed, 10 skipped (unrelated), 1359 deselected, ~3 min.

Blockers

# Fix
1 (tied-weight save) _save_pretrained now detaches normalize_*.buffer_* submodules from _modules and reuses save_model_as_safetensor, which runs _remove_duplicate_names for the tied lm_head/embed_tokens pair. The original save_file direct path is gone.
2 (legacy checkpoint shape) New _promote_legacy_norm_buffers_in_state_dict walks the loaded state_dict and unsqueeze(0)s any norm-buffer entry that's one rank shy of the target buffer. Called from the base _load_as_safetensor and injected into every pi from_pretrained override before its model.load_state_dict(...). Logs a one-time warning.
3 (eval/grpc inference) Added EvalConfig.dataset_repo_id and ServerConfig.dataset_repo_id (both None-defaulted). scripts/eval.py::rollout and scripts/grpc/server.py (both _load_policy warmup and per-request _prepare_observation) inject the field when set. Single-dataset checkpoints keep working through _resolve_dataset_index's num_datasets <= 1 fallback. Multi-dataset users get a clear KeyError with the field name in the message.
4 (value policy) ValueFunction.__init__ now accepts both the legacy `dataset_stats: dict
5 (resume strict) accelerator.load_state(cfg.checkpoint_path, strict=False) when save_normalization_stats=False. Buffers are repopulated by make_policy's _inject_stats call upstream, so the missing keys pass through cleanly.

Edges

  • _resolve_dataset_index device mismatch (tensor branch): now does idx.to(dtype=torch.long, device=device) where device is inferred from the first tensor in the batch.
  • _inject_stats reorder safety: if a caller passes dataset_names, we now require it to equal config.dataset_names element-for-element. Silent reorder corruption is no longer possible; pass dataset_names=None if you trust the existing config order.
  • _check_norm_stats_loaded reach: now runs at the end of base from_pretrained so direct callers (notebooks, gRPC server) surface the missing-stats mistake at construction time. Uses the shared NORM_MODULE_NAMES tuple.
  • _gather_and_broadcast ndim guard: raises ValueError when batch_val.ndim < gathered.ndim instead of silently broadcasting to a wrong shape.
  • aggregated_action_stats weights alignment: now zips weights with the filtered stats list. A non-trailing dataset lacking actions no longer misaligns the BPE codec's weighted mean.

Test gap

test_pi07_low_level.py and test_pi07_paligemma_low_level.py had bare policy.unnormalize_outputs(policy.normalize_targets(action)) calls bypassing the policy-level helper; both now pass an explicit zero-row dataset_index. Full audit of tests/policies/ is clean — no remaining dataset_stats= or single-arg normalize_targets( / unnormalize_outputs( calls.

Deferred (cleanups, marked as suggestions)

Skipping for this PR — happy to do as follow-ups:

  • Pulling Normalize.__init__ / Unnormalize.__init__'s shared body into a base.
  • A PreTrainedPolicy._build_normalize_modules(...) helper for the 8 policy ctors.
  • Dropping dataset_repo_id from _TaggedDataset training-time samples (it's only read at inference).
  • Skipping index_select when D=1.
  • Rewriting _gather_and_broadcast's .reshape(...) as einops.rearrange — would need a dynamic pattern string, which einops doesn't accept; sticking with .reshape for now.

I extracted the NORM_MODULE_NAMES tuple (suggestion #1) since I touched all three call sites anyway.

Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second pass after f8c7492. The five blockers from the first review are all addressed cleanly — _save_pretrained's detach/reattach approach is the right call for the tied-weights problem, the legacy-buffer promotion shim handles the migration path well, the eval/gRPC dataset_repo_id plumbing is clean, and ValueFunction now bridges both APIs. Most of the cleanup follow-ups from the prior review (the shared NORM_MODULE_NAMES constant, the _check_norm_stats_loaded invariant in from_pretrained, the aggregated_action_stats zip-fix, the _gather_and_broadcast ValueError, the explicit _inject_stats cross-check) all landed too.

A few new things surface this pass — three real correctness bugs introduced or missed by the fix commit, two train.py resume edge cases, and an indentation drift across six modeling files.

Correctness

  • pretrained.py:331 — The device-mismatch fix is incomplete. _infer_batch_size_and_device iterates batch.values() in dict-insertion order; if a caller hand-builds batch = {'dataset_index': cpu_idx, 'state': gpu_state, ...}, the function returns dataset_index's own device, idx.to(device=cpu) becomes a no-op, and index_select later crashes with the same device-mismatch error the fix was supposed to prevent. Inline above with a 4-line patch (skip dataset_index when inferring) or use next(self.parameters()).device.

  • value/modeling_value.py:160 — Multi-dataset value config silently corrupts. super().__init__(config) populates _dataset_name_to_index with all N names from config.dataset_names (length 3 in a typical mixture), but ValueFunction then truncates per_dataset_stats to length 1. Buffer has 1 row; the name→index map has 3 entries. dataset_repo_id='b' → index 1 → index_select(0, 1) on a 1-row buffer → CUDA-side IndexError. The warning at lines 153-157 is not enough; either refuse the multi-dataset case loudly or also truncate cfg.dataset_names.

  • pretrained.py:339 — The single-dataset fallback derives num_datasets from len(self._dataset_name_to_index), not from the actual buffer leading dim. Anyone constructing the policy directly (PI05Policy(config, per_dataset_stats=[s1, s2]) without setting config.dataset_names — common in notebooks/tests) hits _dataset_name_to_index = None, the <= 1 branch returns torch.zeros(B), and every sample silently normalizes against s1 even when the caller intended s2. Make the buffer's leading dim the source of truth, or always populate _dataset_name_to_index when stats are passed.

Resume / strict=False

  • train.py:582strict=False is too broad. It correctly lets stripped norm-buffer keys through, but it also silently swallows any other genuine missing key (renamed module, forgotten buffer, accidental refactor). A symmetric register_load_state_pre_hook that drops only is_norm_buffer_key-matching keys would let real missing-key bugs still raise.

  • train.py:579 — Asymmetric: trained-with-False then resumed-with-True hard-crashes. The condition reads the current config, not what was on disk. A user who flips save_normalization_stats between the initial run and the resume gets RuntimeError: Missing key(s) after a long load. Persist the as-saved flag into the checkpoint metadata, or always use strict=False on resume paired with the targeted skip above.

Style / clarity

  • pi05_mem/modeling_pi05.py:364 (and pi06, both pi07/*, both pi07_paligemma/* — six files) — Indentation drift in the legacy-promotion comment block. The first comment line is at 16 spaces (inside if remap_count > 0 and is_main_process:); the continuation comment and the actual _promote_legacy_norm_buffers_in_state_dict call are at 12 spaces (outside the if). The code is functionally correct, but the visual mismatch invites a future maintainer to "fix" the indentation by moving the call into the if block — at which point promotion silently no-ops on every load that didn't need a key remap, and skips entirely on non-main ranks.

Minor nits not posted inline

  • pretrained.py:562: device_arg = map_location if map_location != "cpu" else "cpu" — both branches evaluate to map_location; the ternary is dead code.
  • tests/policies/test_value.py:155 still passes dataset_stats=... (legacy API). Today this is fine because ValueFunction.__init__ accepts both, but it leaves the test corpus inconsistent — when the legacy alias is eventually retired the test silently regresses.
  • The make_policy factory now calls _inject_stats(...) (line 278) after from_pretrained already populated buffers via the per_dataset_stats=... kwargs (passed through **kwargs to the constructor inside from_pretrained). The injection is now redundant but harmless — the new equality check in _inject_stats catches reorder bugs, so this is essentially a defensive double-check. Worth a one-line comment noting the redundancy is intentional.
  • _promote_legacy_norm_buffers_in_state_dict only handles N=1 promotion. Loading a legacy single-dataset checkpoint into a new multi-dataset policy (target shape (N, *feat), loaded (*feat)) still raises a shape mismatch — the shim unsqueezes to (1, *feat) and the subsequent load_state_dict rejects. Edge case (user has to actively migrate single→multi), but a clearer error message pointing at _inject_stats would help.

The three correctness items above are the only things I'd hold the merge on. Everything else is style or future-proofing.


Generated by Claude Code

if per_dataset_stats is None and dataset_stats is not None:
per_dataset_stats = [dataset_stats]
if per_dataset_stats is not None and len(per_dataset_stats) > 1:
logging.warning(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi-dataset value config silently corrupts at inference.

super().__init__(config) runs before this block and populates self._dataset_name_to_index from config.dataset_names. When make_policy(value_cfg, ds_meta=<DatasetMixtureMetadata with N>1 datasets>) calls this constructor, the factory has already set cfg.dataset_names = ['a', 'b', 'c'] (length 3) at factory.py:255. So:

  1. super().__init__(config) builds _dataset_name_to_index = {'a':0, 'b':1, 'c':2} (length 3).
  2. The truncation below sets per_dataset_stats = per_dataset_stats[:1] → 1-row buffer.
  3. At inference, batch['dataset_repo_id'] = 'b' resolves to index 1, then Normalize.forward calls stat.index_select(0, tensor([1])) on a (1, *feat) buffer → IndexError: index 1 is out of bounds for dimension 0 with size 1 (CUDA-side, hard to debug).
  4. At training, _resolve_dataset_index's single-dataset fallback won't fire either (num_datasets = len(_dataset_name_to_index) = 3 > 1), so missing-key batches raise KeyError instead of defaulting.

Two options:

  • Refuse the multi-dataset case loudly (raise ValueError("ValueFunction is single-dataset only; got len(per_dataset_stats)=N")) instead of silently truncating.
  • If the policy genuinely is single-dataset by design, also override cfg.dataset_names = dataset_names[:1] before super().__init__ to keep the name map and buffer in sync.

The warning at lines 153-157 is not enough — the user gets a runtime crash they can't easily trace back to "value can only handle one dataset".


Generated by Claude Code

Comment thread src/opentau/policies/pretrained.py Outdated
# tensors and Accelerate later does the host->device copy on a
# per-key basis — `dataset_index` rides along for tensor keys,
# but a caller hand-constructing a batch on GPU may leave the
# index on CPU, breaking `index_select` inside Normalize.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device-placement fix doesn't actually fix the bug if dataset_index iterates first.

_infer_batch_size_and_device walks batch.values() in insertion order and returns the device of the first tensor it finds. Python dicts preserve insertion order, so if a caller hand-builds the batch as:

batch = {'dataset_index': cpu_idx_tensor, 'state': gpu_state_tensor, ...}

…then _infer_batch_size_and_device returns (B, cpu) from the very first tensor it sees — dataset_index itself — even though state and the rest of the batch are on GPU. idx.to(device='cpu') becomes a no-op, and the subsequent index_select inside _gather_and_broadcast raises:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The original device-mismatch bug is unfixed for any caller who happens to put dataset_index first in their dict (gRPC server, notebooks, custom inference loops).

Two viable fixes:

Suggested change
# index on CPU, breaking `index_select` inside Normalize.
# Skip `dataset_index` when inferring device — otherwise a caller
# who put it first in the dict gets `idx.to(idx.device)` (no-op).
device = next(
(v.device for k, v in batch.items() if k != "dataset_index" and isinstance(v, Tensor)),
idx.device,
)
return idx.to(dtype=torch.long, device=device)

Or use next(self.parameters()).device — the policy's own device is always the right target for a Normalize index_select.


Generated by Claude Code

num_datasets = len(self._dataset_name_to_index) if self._dataset_name_to_index else 1
if num_datasets <= 1:
batch_size, device = self._infer_batch_size_and_device(batch)
return torch.zeros(batch_size or 1, dtype=torch.long, device=device)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_datasets derived from _dataset_name_to_index, not from the actual buffer shape — silent corruption when they disagree.

The fallback uses len(self._dataset_name_to_index) (or 1 when None) to decide whether to default to all-zero indices. But this dict is built once in __init__ from config.dataset_names, which only make_policy sets as a side effect (line 255 of factory.py). A caller who constructs the policy directly with stats but no name list:

policy = PI05Policy(config, per_dataset_stats=[stats_robotA, stats_robotB])
# config.dataset_names is None, so PreTrainedPolicy.__init__ leaves
# _dataset_name_to_index = None.
policy.select_action({'state': ..., 'image': ...})

…falls through to the num_datasets <= 1 branch, returns torch.zeros(B), and every sample is normalized with robotA's mean/std regardless of the caller's intent. No exception, just silently degraded action quality.

The actual leading dim of the stats buffer is the source of truth. Either:

  • Look at self.normalize_inputs / normalize_targets / etc. and pull the leading dim from a buffer_* parameter, or
  • Always update self._dataset_name_to_index (or at least its length) when per_dataset_stats is passed at construction, even when the names list is None (use generated default_0, default_1 placeholder names).

Right now the relationship between "the buffer has N rows" and "the policy knows it has N rows" only holds when make_policy is the constructor.


Generated by Claude Code

load_kwargs: dict = {}
if not cfg.policy.save_normalization_stats:
load_kwargs["strict"] = False
accelerator.load_state(cfg.checkpoint_path, **load_kwargs)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict=False on resume is too broad — masks every missing-key bug, not just norm buffers.

accelerator.load_state passes strict down to the underlying model.load_state_dict. When save_normalization_stats=False, the fix correctly turns off strict to let the stripped normalize_*.buffer_* keys through. But that same flag also silently swallows:

  • A renamed module between save and resume (e.g. backboneencoder).
  • A forgotten buffer in a new layer.
  • Any genuine architecture-regression introduced in the gap between save and resume.

The user would have to notice the warning log lines (which are easy to miss in 100k-step training output) rather than getting a hard error at resume time.

A targeted approach is safer: keep strict=True, register a corresponding register_load_state_pre_hook that drops norm-buffer keys from the loaded state_dict before load_state_dict runs. Symmetric to the save-state pre-hook on lines 561-568, and lets real missing-key bugs still raise. Sketch:

if not cfg.policy.save_normalization_stats:
    def _drop_norm_buffer_keys_post_load(models, input_dir):
        # No-op: keys never landed on disk, so they were already absent
        # from the model's `load_state_dict` call. The buffers were
        # repopulated by `_inject_stats` inside `make_policy` and the
        # missing-key warning is the only signal — turn it into a noop.
        del models, input_dir
    accelerator.register_load_state_pre_hook(_drop_norm_buffer_keys_post_load)

Or — simpler — assert that every missing key passes is_norm_buffer_key and re-raise otherwise.


Generated by Claude Code

# buffers were already repopulated by `_inject_stats` inside
# `make_policy(cfg, ds_meta=...)` above, so this leaves them at the
# right values rather than silently overwriting.
load_kwargs: dict = {}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asymmetric resume — trained-with-False then resumed-with-True hard-crashes.

The condition if not cfg.policy.save_normalization_stats reads the current config, not what was on disk. This is fine for the symmetric case, but a user who:

  1. Initially trains with save_normalization_stats=False (smaller checkpoint disk footprint), then
  2. Flips the flag to True on the resume config (because they want stats included in the next round of checkpoints for downstream eval),

…gets:

RuntimeError: Error(s) in loading state_dict for ...:
  Missing key(s) in state_dict: "normalize_inputs.buffer_observation_state.mean", ...

The resume crashes because load_kwargs stays empty (current cfg has True), strict=True is applied, and the on-disk safetensors lacks the buffer keys that the previous run stripped.

Two options:

  • Persist the as-saved flag into the checkpoint's config.json (or sidecar metadata) and read that for the strict decision instead of cfg.policy.save_normalization_stats.
  • Or simpler: always pass strict=False on resume, paired with the targeted missing-key check above.

This is the same flavor of footgun as inverted-flag config drift — the only signal is a runtime crash after a long load.


Generated by Claude Code

if remap_count > 0 and is_main_process:
logging.info("Remapped %d state dict keys", remap_count)

# Promote legacy single-dataset Normalize/Unnormalize buffers from
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation drift — first comment line is inside the if block but the call is outside (repeated in 6 modeling files).

Lines 361-368 read as:

            if remap_count > 0 and is_main_process:
                logging.info("Remapped %d state dict keys", remap_count)

                # Promote legacy single-dataset Normalize/Unnormalize buffers from   ← 16 spaces, inside if
            # `(*feat_shape,)` to the new `(1, *feat_shape)` stacked layout so pre-PR  ← 12 spaces
            # checkpoints load via `model.load_state_dict(...)`.                       ← 12 spaces
            model._promote_legacy_norm_buffers_in_state_dict(remapped_state_dict)      ← 12 spaces
            missing_keys, unexpected_keys = model.load_state_dict(...)

The first # Promote legacy ... line is inside if remap_count > 0 and is_main_process:, while the continuation comment and the actual _promote_legacy_norm_buffers_in_state_dict call sit at the outer scope. Functionally correct (the call DOES run unconditionally as intended), but the visual mismatch is exactly the kind of thing a future maintainer "fixes" by indenting the call into the if — at which point legacy promotion silently no-ops on every load that didn't need a key remap (and skips entirely on non-main ranks because of is_main_process), and pre-PR checkpoints stop loading.

Same pattern in pi06/modeling_pi06.py, pi07/low_level/modeling_pi07_low_level.py, pi07/high_level_planner/modeling_pi07_high_level.py, and both pi07_paligemma/*/modeling_*.py. Move the orphaned # Promote legacy ... line up by 4 spaces (or paragraph-break it out of the if block) in all six files.


Generated by Claude Code

…value name-map sync, on-disk strict detection, indent drift

Three correctness bugs from review #2:

- `_resolve_dataset_index` device fix is now belt-and-braces: tensor branch
  uses `next(self.parameters()).device` (the model's own device), not
  whatever device `_infer_batch_size_and_device` picked from the batch.
  The helper also skips the wrapper-injected `dataset_index` /
  `dataset_repo_id` keys so a CPU-resident index can't override a GPU
  batch when dict iteration order puts it first.

- `ValueFunction` no longer leaves a stale `_dataset_name_to_index` after
  truncating `per_dataset_stats` to the single-row buffer. Inference with
  `dataset_repo_id='<non-first>'` would otherwise index out of range on
  the 1-row buffer. Multi-dataset configs are explicitly truncated (with a
  warning) so the name map and buffer agree.

- New `_stacked_num_datasets()` helper reads the leading dim of any
  attached Normalize/Unnormalize buffer. `_resolve_dataset_index`'s
  single-dataset fallback now uses that as the source of truth instead of
  `len(self._dataset_name_to_index)`, so a direct
  `PI05Policy(config, per_dataset_stats=[s1, s2])` call (bypassing the
  factory's `cfg.dataset_names` population) raises the multi-dataset
  KeyError rather than silently normalizing every sample against `s1`.

Resume / strict-flag fixes:

- `train.py` now inspects the on-disk safetensors header to decide
  `strict=False`, not the *current* `cfg.policy.save_normalization_stats`.
  Toggling the flag between an initial run and a resume no longer
  hard-crashes. Strict is tightened to exactly the case where the on-disk
  file actually lacks norm-buffer keys, so an unrelated module rename in
  any other code path still raises (it would have been silently swallowed
  by the previous broad `strict=False`).

Indentation drift:

- `pi05_mem`, `pi06`, `pi07/low_level` had the legacy-buffer-promotion
  call's first comment line at 16 spaces (inside `if remap_count > 0`)
  while the call itself was at 12 (outside). All three are now uniformly
  outside the `if`, with an explicit "Always run" note so a future
  maintainer can't "fix" the misalignment by hoisting the call into the
  `if` block and silently skip promotion on no-rename loads.
@shuheng-liu
Copy link
Copy Markdown
Member Author

Thanks again — all three correctness bugs, both resume edge cases, and the indentation drift are addressed in e8fa45d. GPU pytest stays green: 19 passed, 10 skipped, 1359 deselected, ~2.5 min.

Correctness

  • N1 (device-mismatch fix completion) — Tensor branch now uses next(self.parameters()).device (the model's compute device), unambiguously. As a defense-in-depth, _infer_batch_size_and_device also skips the wrapper-injected dataset_index / dataset_repo_id keys via a new _NON_DATA_BATCH_KEYS set so the fallback path also can't be hijacked by a CPU-resident index appearing first in dict-iteration order.
  • N2 (value multi-dataset corruption) — Truncating per_dataset_stats to length 1 now also truncates config.dataset_names and rebuilds _dataset_name_to_index to match the 1-row buffer. Multi-dataset configs carried into the value path emit an explicit warning at the truncation point. Inference with dataset_repo_id matching the kept name works; anything else raises a clear ValueError instead of crashing at index_select.
  • N3 (single-dataset fallback source-of-truth) — New _stacked_num_datasets() reads the leading dim of any attached Normalize/Unnormalize buffer. _resolve_dataset_index's fallback now consults that, not len(self._dataset_name_to_index). Direct PI05Policy(config, per_dataset_stats=[s1, s2]) (without populating config.dataset_names) now raises the explicit multi-dataset KeyError instead of silently normalizing every sample against s1.

Resume

  • R1+R2 collapsed into a single disk-state-driven decisiontrain.py now inspects <checkpoint_path>/model.safetensors via safe_open to see whether any normalize_*.buffer_* / unnormalize_*.buffer_* keys are actually present on disk, and uses that to decide strict=False. This:
    • Closes R2: a user who flips save_normalization_stats between the initial run and the resume no longer hard-crashes — the on-disk header is the only source of truth.
    • Tightens R1: strict=False now fires only when the file actually lacks the keys; an unrelated module rename anywhere else in the policy still raises (it would have been silently swallowed by the previous broad strict=False).
    • Falls back to the cfg flag (with a logged warning) only if the safetensors header read fails entirely.

Style

  • S1 (indent drift)pi05_mem/modeling_pi05.py, pi06/modeling_pi06.py, and pi07/low_level/modeling_pi07_low_level.py had the first comment line of the legacy-promotion block at 16 spaces (inside if remap_count > 0 and is_main_process:) while the actual call was at 12. All three are now uniformly outside the if, with an explicit "Always run (outside the if remap_count > 0 block) — promotion is needed whether or not any other keys were renamed" note in each so a future maintainer can't "fix" the misalignment by hoisting the call into the if block and silently disable promotion on no-rename loads. The four files that were already correct (pi05, pi07/high_level, pi07_paligemma/{low_level, high_level}) were not touched.

Ready for another pass when you have time.

Copy link
Copy Markdown
Member Author

@shuheng-liu shuheng-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Third pass after e8fa45d. All six findings from round 2 are addressed cleanly:

Round-2 finding Fix in e8fa45d
_infer_batch_size_and_device device-from-dataset_index bug (pretrained.py:331) New _NON_DATA_BATCH_KEYS skips the key during iteration; tensor branch now uses _model_device() (reads from self.parameters()) rather than dict-iteration order. Source of truth is now unambiguous.
Multi-dataset value silent corruption (value/modeling_value.py:160) ValueFunction.__init__ now truncates cfg.dataset_names and rebuilds _dataset_name_to_index to match the 1-row buffer. The mismatch can no longer occur.
num_datasets derived from _dataset_name_to_index (pretrained.py:339) New _stacked_num_datasets() reads the buffer's actual leading dim. Direct-construction PI05Policy(config, per_dataset_stats=[s1, s2]) without cfg.dataset_names now correctly raises KeyError instead of silently routing every sample to row 0.
Resume strict=False too broad (train.py:582) safe_open(.../model.safetensors).keys() is inspected at resume time; strict=False only applies when norm-buffer keys are actually absent from disk. Unrelated missing keys still raise.
Asymmetric trained=False/resume=True (train.py:579) Same on-disk-header fix — the source of truth is the file, not the current cfg flag.
Indent drift in legacy-promotion comment block Fixed in pi05_mem, pi06, pi07/low_level. (My round-2 finding overstated the scope — pi05, pi07/high_level_planner, and both pi07_paligemma/* already had correct indentation in f8c7492; no false-fix risk in those files.)

No new blockers and no new correctness bugs from this commit. A handful of small items that I'd leave as cleanups for either this PR or a follow-up:

  • pretrained.py:619device_arg = map_location if map_location != "cpu" else "cpu" is still a dead ternary (both branches evaluate to map_location). Mentioned in round 2 as a nit; just device_arg = map_location would do.
  • tests/policies/test_value.py:155 — still constructs with ValueFunction(config, dataset_stats=...) (legacy single-dict). Works today (the new __init__ keeps the alias), but when the legacy alias is eventually retired the test regresses silently.
  • pretrained.py:559 (_promote_legacy_norm_buffers_in_state_dict) — still only handles N=1 promotion. Migrating a legacy single-dataset checkpoint into a fresh N>1-dataset policy raises a shape mismatch with no message pointing the user at _inject_stats. Edge case; defer to a follow-up.
  • pretrained.py:289 (_stacked_num_datasets) — returns the first buffer's leading dim and assumes every other Normalize module agrees. True today by construction (factory builds them together with the same num_datasets), but worth a one-line assert or a comment noting the invariant for future maintainers.
  • scripts/train.py:589 (except Exception) — the bare except around safe_open is broad; specific catches (FileNotFoundError, safetensors.SafetensorError) would surface unexpected I/O issues instead of swallowing them into the cfg-based fallback. Minor.
  • No tests for _stacked_num_datasets, _model_device, the ValueFunction config sync, or the resume header inspection. Each is a small behavior change that could regress silently. A unit test per method (e.g. PI05Policy(config, per_dataset_stats=[s1, s2]) then _resolve_dataset_index({'state': x}) should raise; policy._stacked_num_datasets() == 2) would catch future regressions cheaply.

Aside from the test-coverage point, none of those gate the merge. The five hard blockers from round 1 and the three correctness bugs from round 2 are all genuinely fixed. Once the open items from the PR's own "Not yet covered" list (smoke training, end-to-end skip-stats round-trip, seeded determinism, DDP/ZeRO-2) clear, this is ready.


Generated by Claude Code

@shuheng-liu shuheng-liu marked this pull request as ready for review May 27, 2026 04:58
@shuheng-liu shuheng-liu merged commit e27a1f1 into main May 27, 2026
15 checks passed
@shuheng-liu shuheng-liu deleted the claude/serene-euclid-ce1787 branch May 27, 2026 04:59
shuheng-liu added a commit that referenced this pull request May 27, 2026
PR #336 (per-dataset normalization) renamed `Normalize(stats=...)` to
`Normalize(per_dataset_stats=[...])` (a list, with `num_datasets=` as the
inf-init path). PR #337 was authored against the pre-#336 API and merged
first; #336 then merged without updating these tests, so CPU CI now fails
on every PR with:

    ValueError: create_stats_buffers requires either `per_dataset_stats` or `num_datasets`
    TypeError: Normalize.__init__() got an unexpected keyword argument 'stats'

Three failures in tests/policies/test_pi07_paligemma_low_level.py::TestSkipNormalizationWeights:

  - test_predicate_matches_real_normalize_buffer_keys
  - test_find_inf_normalize_buffers_detects_init_inf_sentinels
  - test_find_inf_normalize_buffers_empty_when_stats_passed

Updates:
  - inf-init paths now pass `num_datasets=1` (replaces the implicit single-dataset default).
  - the populated-stats path passes `per_dataset_stats=[stats]` (wrap the
    single stat dict in a singleton list to match the new D-row layout).

Verified locally: all 7 tests in TestSkipNormalizationWeights pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request refactor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant