feat(policies): per-dataset normalization + skip-stats safetensors flag#336
Conversation
Tears down the mixture-level stats aggregation in `DatasetMixtureMetadata` and pushes per-dataset stats into stacked Normalize/Unnormalize buffers indexed per sample. Each dataloader sample now carries `dataset_repo_id` (str) and `dataset_index` (long) — the policy gathers the matching row inside `Normalize.forward`. Adds a configurable `PreTrainedConfig.save_normalization_stats` flag (with method-level override) to strip `normalize_*.buffer_*` keys from the on-disk safetensors, and an Accelerate save-state hook so `accelerator.save_state` respects the same flag. Scope: all pi variants (pi0/pi05/pi05_mem/pi06/pi07/pi07_paligemma, low_level + high_level_planner). `value` policy is unchanged.
`_TaggedDataset` inside `WeightedDatasetMixture` adds `dataset_repo_id` and `dataset_index` to every sample, but the assertion in `test_integration_basic_functionality_with_same_fps_as_dataset` compares dataloader samples against direct `datasets[0][idx]` lookups (which bypass the wrapper). Strip the wrapper-added keys before the content comparison so the test reflects intent.
shuheng-liu
left a comment
There was a problem hiding this comment.
Reviewing as a draft. The shape of the refactor — per-dataset stacked buffers indexed per sample, a skip-stats flag stripped at the safetensors write boundary — is the right one. But the change is incomplete in a few places that the CPU test suite cannot see, and at least four of those gaps will hard-block a real run after merge. Findings inline; this is the summary view.
Hard blockers
| # | Where | What breaks |
|---|---|---|
| 1 | pretrained.py:167 |
First save_state with save_normalization_stats=False on any PaliGemma-backed PI policy raises SafetensorError: Some tensors share memory — save_file doesn't dedup tied lm_head/embed_tokens, despite the docstring promising a clone step the code never performs. |
| 2 | normalize.py:172 |
New (num_datasets, *feat_shape) buffers are incompatible with every pre-PR checkpoint on disk (including everything under TensorAuto/*). load_state_dict(strict=False) still raises on size mismatch — make_policy(cfg, pretrained_path='TensorAuto/pi05') fails after merge with no migration path. |
| 3 | pretrained.py:293 (= _resolve_dataset_index) |
Every eval rollout and every gRPC inference call crashes with KeyError. The eval observation pipeline (scripts/eval.py:150-168) and scripts/grpc/server.py:107,246 never inject dataset_index / dataset_repo_id — only _TaggedDataset does, and only on training batches. |
| 4 | factory.py:250 |
cfg.policy.type='value' raises TypeError immediately — ValueFunction.__init__ still has the old dataset_stats= kwarg. The PR description says "value is unchanged"; in practice it's broken. |
| 5 | train.py:568 |
Resume with save_normalization_stats=False either (a) raises on missing buffer keys, or (b) silently retains +inf init and asserts on step 1. The save hook has no matching load hook, and make_policy's _inject_stats call is gated on cfg.pretrained_path which isn't set during resume. |
GPU test suite is not updated (PR's own checklist box is unchecked)
tests/policies/test_policies.py was updated cleanly, but the per-policy GPU-marked tests still use the old API and will fail the nightly gpu_test.yml run:
tests/policies/test_pi05.py:265, 366-367, 593—PI05Policy(config, dataset_stats=...)and barepolicy.normalize_targets(action).tests/policies/test_pi07_low_level.py:362, 460, 638, 747— same pattern.tests/policies/test_pi07_paligemma_low_level.py:344, 444— same pattern.
Each will raise TypeError: ... unexpected keyword argument 'dataset_stats' at construction, or TypeError: forward() missing 1 required positional argument: 'dataset_index' at the roundtrip assertion. Worth updating in this PR rather than discovering them in nightly.
Higher-severity edge cases (inline above)
pretrained.py:286— CPUdataset_index+ GPU policy = device mismatch inindex_select; the string-list branch handles device, the tensor branch doesn't.pretrained.py:370—_inject_stats(stats, dataset_names=None)silently corrupts the buffer↔name mapping when the caller's order differs from the existing config.pretrained.py:404(_check_norm_stats_loaded) — only called frommake_policy; directfrom_pretrainedusers (gRPC, notebooks) silently get+infand crash mid-forward.normalize.py:233—_gather_and_broadcastsilently broadcasts to a wrong shape whenbatch_val.ndim < gathered.ndiminstead of raising.dataset_mixture.py:236—aggregated_action_statsslices_dataset_weights[:N]instead of zipping; misaligned when any non-trailing dataset lacks'actions'.
A few cleanups I'm leaving as suggestions, not blockers
- The 4-tuple
("normalize_inputs", "normalize_targets", "normalize_discrete_actions", "unnormalize_outputs")is hard-coded atpretrained.py:54-59, 337-342, 385-390— extract once. Normalize.__init__andUnnormalize.__init__duplicate ~18 lines verbatim — pull into a shared base.- The 8 policy
__init__files repeat the samenum_datasets = _num_datasets(...)+ 3-4Normalize(...)calls; aPreTrainedPolicy._build_normalize_modules(...)helper would absorb ~200 lines of boilerplate and prevent future per-policy drift. - The training-time
_TaggedDataset.__getitem__stores bothdataset_repo_id: stranddataset_index: longon every sample — the string is never read in the training path (only at inference, where the caller supplies it explicitly). Worth dropping to skip the per-batchlist[str]collate cost. - Single-dataset configs (
D=1) still pay anindex_selectallocating(B, *feat_shape)per stat per forward; broadcastingstat[0]would be free. _gather_and_broadcastuses.reshape(...)with a computed axis count — CLAUDE.md rule #4 specifically flags this kind ofunsqueeze-chain reshape as a target foreinops.rearrange/ indexedNone.
The blockers above (1-5) plus the GPU test gap are the things I'd hold the merge on. The rest are inline.
Generated by Claude Code
| # avoid safetensors' "duplicated tensors" rejection. | ||
| state_dict = model_to_save.state_dict() | ||
| filtered = {k: v for k, v in state_dict.items() if not is_norm_buffer_key(k)} | ||
| save_safetensor_file(filtered, out_path) |
There was a problem hiding this comment.
Blocker — skip-stats save crashes on every PI policy.
The docstring at lines 162-164 promises a "clone tensors that share storage with another retained tensor" step, but the code never performs it: filtered is a {k: v for k, v in state_dict.items() if not is_norm_buffer_key(k)} — the retained tensors still alias each other.
PaliGemma ties lm_head.weight and embed_tokens.weight (handled explicitly in pi0/modeling_pi0.py:274 "lm_head.weight and embed_tokens.weight share memory"). At save time both keys appear in state_dict() pointing to the same storage. save_model deduplicates automatically; save_file does not — it raises:
safetensors.SafetensorError: Some tensors share memory, this will lead to duplicate memory on disk
So the very first accelerator.save_state / policy.save_pretrained call on any pi0/pi05/pi05_mem/pi07_paligemma run with save_normalization_stats=False will crash before any disk write happens. The unit test in test_save_pretrained_skip_stats.py uses a toy model with no tied weights, so it doesn't catch this.
| save_safetensor_file(filtered, out_path) | |
| # `save_model` doesn't expose a filter argument, so build the | |
| # filtered state_dict ourselves. Use `_remove_duplicate_names` | |
| # (the helper `save_model` calls internally) to drop alias | |
| # entries — PaliGemma ties `lm_head.weight` and | |
| # `embed_tokens.weight`, so a naive filter trips | |
| # `safetensors`' shared-memory rejection. | |
| from safetensors.torch import _remove_duplicate_names | |
| state_dict = model_to_save.state_dict() | |
| to_removes = _remove_duplicate_names(state_dict, preferred_names=set()) | |
| for kept, aliases in to_removes.items(): | |
| for alias in aliases: | |
| state_dict.pop(alias, None) | |
| filtered = {k: v for k, v in state_dict.items() if not is_norm_buffer_key(k)} | |
| save_safetensor_file(filtered, out_path) |
Generated by Claude Code
| # Note: we initialize mean, std, min, max to infinity. They should be overwritten | ||
| # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, | ||
| # we assert they are not infinity anymore. | ||
| stacked_shape = (num_ds, *shape) |
There was a problem hiding this comment.
Blocker — every existing checkpoint becomes unloadable.
Pre-PR checkpoints (everything currently published under TensorAuto/*) saved each normalize_*.buffer_*.{mean,std,min,max} with shape (*feat_shape) — e.g. (32,) for state, (3, 1, 1) for a camera. After this PR every buffer is allocated with stacked_shape = (num_ds, *shape), and for a legacy config (no dataset_names set) resolve_num_datasets returns 1, so the new shape is (1, 32) / (1, 3, 1, 1).
PreTrainedPolicy._load_as_safetensor calls load_model_as_safetensor(..., strict=False), but PyTorch's load_state_dict(strict=False) only tolerates missing and unexpected keys — size mismatches on present keys are always raised. The result:
RuntimeError: Error(s) in loading state_dict for PI05Policy:
size mismatch for normalize_inputs.buffer_state.mean:
copying a param with shape torch.Size([32]) from checkpoint,
the shape in current model is torch.Size([1, 32]).
_tile_linear_input_weight only rewrites linear weights — no equivalent hook reshapes buffers. So make_policy(cfg) for any cfg.pretrained_path = 'TensorAuto/pi05' will break after merge. Needs either (a) a load-side hook in _load_as_safetensor that unsqueezes legacy 1-D buffers when the saved rank is feat_rank instead of 1 + feat_rank, or (b) an explicit migration script + release note.
Generated by Claude Code
| "Per-dataset normalization requires either `dataset_index` " | ||
| "(LongTensor of shape (B,)) or `dataset_repo_id` (str or " | ||
| "list[str] of length B) in the batch." | ||
| ) |
There was a problem hiding this comment.
Blocker — every eval rollout and the gRPC inference server crash on the first step.
The eval observation pipeline (scripts/eval.py:150-168: preprocess_observation → add_envs_task → add_eval_metadata → add_subgoal_images) never injects dataset_index or dataset_repo_id. Likewise scripts/grpc/server.py:107, 246 builds an observation dict by hand from gRPC inputs.
So every policy.select_action(observation) / policy.sample_actions(observation, ...) on a per-dataset-normalized policy will raise this KeyError on the very first call — both opentau-eval and any cfg.eval_freq > 0 in-training eval block. The factory path is fine because _TaggedDataset injects both keys, but the eval/inference paths weren't updated to mirror it.
Three places need a fix to keep this PR's invariant ("every batch has a dataset_index") true end-to-end:
scripts/eval.py::add_eval_metadata(or a newadd_dataset_repo_idhelper) needs to attachbatch['dataset_repo_id']per env, looked up fromcfg.env/cfg.dataset_mixtureso it resolves through_dataset_name_to_index.scripts/grpc/server.py's warmup + serving path must do the same.- The single-dataset legacy fallback (
config.dataset_names is None) should default todataset_index = 0rather than raising, otherwise loading any non-mixture checkpoint also fails — the message at lines 296-300 is descriptive, but a no-stats single-dataset reload is the common case and shouldn't need the caller to construct an index tensor.
Generated by Claude Code
| dataset_names = dataset_names or [f"external_{i}" for i in range(len(per_dataset_stats))] | ||
|
|
||
| if per_dataset_stats is not None: | ||
| kwargs["per_dataset_stats"] = per_dataset_stats |
There was a problem hiding this comment.
Blocker — cfg.policy.type = "value" immediately raises TypeError.
The PR description says the value policy is "out of scope and unchanged", but the factory now unconditionally passes per_dataset_stats= and dataset_names= to every policy class. ValueFunction.__init__ (policies/value/modeling_value.py:124-138) still has the old signature:
def __init__(
self,
config: ValueConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):So make_policy(cfg, ds_meta=...) with cfg.type='value' triggers:
TypeError: ValueFunction.__init__() got an unexpected keyword argument 'per_dataset_stats'
Even if that were fixed, ValueFunction.__init__ still calls Normalize(config.input_features, config.normalization_mapping, dataset_stats) — dataset_stats is the 3rd positional arg, which is now per_dataset_stats: list[...]. Passing a dict there would either crash in create_stats_buffers (which iterates len(per_dataset_stats) as if it's a list) or worse, silently produce nonsense buffer shapes. And self.normalize_inputs(batch) inside forward is now missing the dataset_index positional arg.
Either: (a) update ValueFunction to the new API so "unchanged" becomes true, (b) carve value out of the per_dataset_stats kwarg passing in this factory and keep the legacy path, or (c) explicitly raise NotImplementedError when cfg.type=='value' lands in this branch so it fails loudly with a useful message instead of a generic TypeError.
Generated by Claude Code
| if is_norm_buffer_key(k): | ||
| del sd[k] | ||
|
|
||
| accelerator.register_save_state_pre_hook(_strip_norm_buffers_pre_save) |
There was a problem hiding this comment.
Blocker — resume with save_normalization_stats=False deadlocks / asserts at step 1.
You register a save_state_pre_hook that strips norm buffers on disk, but there is no matching register_load_state_pre_hook (or post-load injection). The resume flow at line 573 (accelerator.load_state(cfg.checkpoint_path)) loads the model state_dict from the same on-disk file the hook stripped:
make_policy(cfg, ds_meta=ds_meta)constructs the policy. Sincecfg.pretrained_pathis NOT set during a resume (it's only set when bootstrapping from a hub checkpoint), the_inject_stats(...)branch atfactory.py:277does not fire. Buffers stay at+inf.accelerator.load_state(...)then loads the on-disk safetensors. With strict mode it raises on missing buffer keys; with strict=False the buffers stay at the+infinf-init.- First forward fires
assert not torch.isinf(mean).any()— training crashes on step 1, after potentially hours of resume-setup wall time.
This needs either:
- A symmetric
register_load_state_pre_hookthat callspolicy._inject_stats(ds_meta.per_dataset_stats, dataset_names=ds_meta.dataset_names)after load, or make_policyunconditionally calls_inject_statswheneverper_dataset_stats is not None, not just undercfg.pretrained_path(move the guard atfactory.py:277-280so the resume path also benefits).
Either way, this PR's checklist explicitly defers smoke-running the round-trip ("End-to-end save_normalization_stats=false round-trip on a real checkpoint dir") to #335 — which means the bug above will land first and break the next person to flip the flag.
Generated by Claude Code
| f"buffers are still +inf: {bad}. Either re-save the checkpoint " | ||
| "with `save_normalization_stats=True`, or pass `ds_meta=` to " | ||
| "`make_policy(...)` so stats are injected after load." | ||
| ) |
There was a problem hiding this comment.
Footgun — from_pretrained bypasses both _inject_stats and _check_norm_stats_loaded.
Only factory.make_policy (lines 277-280) calls these. Anyone using PI05Policy.from_pretrained('path/to/skip-stats-ckpt') directly — and there are real callers:
scripts/grpc/server.py::_load_policy(the production inference server),notebooks/pi05_evaluation_only.ipynb,- Anything in
external/ downstream code that doesn't go through factory.
…will silently get +inf buffers and crash mid-forward instead of at construction time, with no clean message pointing the user at the save_normalization_stats=False round-trip mistake.
The cheapest fix is to call self._check_norm_stats_loaded() at the bottom of from_pretrained right before the policy.eval() (or at the bottom of _load_as_safetensor). It's a no-op when the buffers loaded fine, and a clear RuntimeError when they didn't — the same invariant make_policy already enforces, just hoisted to the load boundary where it actually belongs.
Generated by Claude Code
| idx = batch["dataset_index"] | ||
| if not isinstance(idx, Tensor): | ||
| idx = torch.as_tensor(idx, dtype=torch.long) | ||
| return idx.to(dtype=torch.long) |
There was a problem hiding this comment.
Bug — CPU-tensor dataset_index + GPU policy raises device mismatch in index_select.
idx.to(dtype=torch.long) preserves the existing device. The training path is safe because _TaggedDataset returns a scalar tensor that accelerator.prepare's send_to_device will move. But any direct-inference caller — gRPC server, ROS client, the notebooks/pi05_evaluation_only.ipynb cell that builds an observation dict by hand, or an accelerator-less unit test — typically constructs a CPU LongTensor and hands it to a CUDA-resident policy. Then inside normalize.py::_gather_and_broadcast:
RuntimeError: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
The dataset_repo_id (string) branch handles this correctly at line 312 — torch.tensor(indices, dtype=torch.long, device=device) — but the tensor branch above doesn't. The fix is symmetric:
| return idx.to(dtype=torch.long) | |
| if "dataset_index" in batch: | |
| idx = batch["dataset_index"] | |
| if not isinstance(idx, Tensor): | |
| idx = torch.as_tensor(idx, dtype=torch.long) | |
| _, device = self._infer_batch_size_and_device(batch) | |
| return idx.to(dtype=torch.long, device=device) |
Generated by Claude Code
| ) | ||
| gathered = stat.index_select(0, dataset_index) # (B, *feat_shape) | ||
| extra = batch_val.ndim - gathered.ndim | ||
| if extra > 0: |
There was a problem hiding this comment.
Bug — extra < 0 silently broadcasts wrong instead of raising.
The current branch only pads when batch_val.ndim > gathered.ndim. If a feature has shape (1,) (a scalar state dim) and a caller squeezed the trailing axis, batch_val.shape = (B,), gathered.shape = (B, 1), and extra = -1. No reshape happens, then (batch_val:(B,) - gathered:(B,1)) broadcasts to (B, B) — a silently wrong result with no diagnostic until something downstream crashes on the wrong-shape tensor.
The contract of this helper is "make gathered broadcastable to batch_val" — if it can't, that's a bug worth surfacing:
| if extra > 0: | |
| gathered = stat.index_select(0, dataset_index) # (B, *feat_shape) | |
| extra = batch_val.ndim - gathered.ndim | |
| if extra < 0: | |
| raise ValueError( | |
| f"batch_val has fewer dims ({batch_val.ndim}) than the gathered stat " | |
| f"({gathered.ndim}); can't broadcast {tuple(batch_val.shape)} against " | |
| f"{tuple(gathered.shape)}." | |
| ) | |
| if extra > 0: | |
| gathered = gathered.reshape(gathered.shape[0], *((1,) * extra), *gathered.shape[1:]) | |
| return gathered |
Generated by Claude Code
| # pull just the "actions" sub-dicts so the call is cheap. | ||
| agg = aggregate_stats( | ||
| [{"actions": s["actions"]} for s in stats_with_actions], | ||
| weights=self._dataset_weights[: len(stats_with_actions)], |
There was a problem hiding this comment.
Bug — weights misaligned when any non-trailing dataset lacks 'actions'.
self._dataset_weights[: len(stats_with_actions)] takes the first N weights, not the weights for the datasets that survived the "actions" in s filter.
Concrete failure: a mixture of [A (VQA-only, no actions), B (has), C (has)] with weights [wA, wB, wC]:
stats_with_actions = [B, C]weights = self._dataset_weights[:2] = [wA, wB]← wrongaggregate_statsthen weights B bywAand C bywBinstead of the intended(wB, wC).
fit_fast_tokenizer.py consumes this to fit one BPE codec over a global action range — the result is a silently misaligned discrete action vocab. Fix by zipping & filtering both together:
| weights=self._dataset_weights[: len(stats_with_actions)], | |
| pairs = [ | |
| (s, w) | |
| for s, w in zip(self.per_dataset_stats, self._dataset_weights, strict=True) | |
| if "actions" in s | |
| ] | |
| if not pairs: | |
| raise ValueError( | |
| "No dataset in the mixture exposes 'actions' stats; aggregated_action_stats() is undefined." | |
| ) | |
| agg = aggregate_stats( | |
| [{"actions": s["actions"]} for s, _ in pairs], | |
| weights=[w for _, w in pairs], | |
| ) | |
| return agg["actions"] |
Generated by Claude Code
| buffer[stat].data.copy_(new_tensor) | ||
| # Refresh the module's own dataset_names cache. | ||
| if dataset_names is not None: | ||
| module.dataset_names = list(dataset_names) |
There was a problem hiding this comment.
Footgun — silent corruption when dataset_names=None and caller's order differs from the existing config.
When dataset_names is None, the loop above overwrites buffer[stat] rows in the order of per_dataset_stats, but the if dataset_names is not None guards leave self.config.dataset_names and self._dataset_name_to_index pointing at the old order. After injection, inference resolves batch['dataset_repo_id'] = 'robotA' through the old name→index map, but row 0 of the buffer is now whatever was at per_dataset_stats[0] in the caller's order. No error — just silently wrong stats applied per sample.
make_policy always passes dataset_names=dataset_names, so this is safe from the factory path. But the method is a public-shaped helper (documented in the docstring and used by the suggested-fix path for resume), and anyone calling policy._inject_stats(stats) directly hits a quiet correctness bug.
Either require dataset_names always (drop the Optional), or when None, verify the new rows match the existing config.dataset_names length AND require the caller to acknowledge the existing order is being preserved (e.g. an explicit assert_existing_order=True flag).
Generated by Claude Code
…esolve_dataset_index Single-dataset policies (`num_datasets <= 1`) no longer require the caller to inject `dataset_index` / `dataset_repo_id` into the batch — `_resolve_dataset_index` returns zeros when both are absent. Multi- dataset policies still raise a clear `KeyError` so silent misuse stays loud. Also: - `ValueFunction` keeps its external `dataset_stats: dict | None` API but wraps into a singleton list internally before calling the new `Normalize`. Forward calls go through the shared `_resolve_dataset_index` helper. - Update GPU pytest fixtures across pi0/pi05/pi05_mem/pi06/pi07/ pi07_paligemma to use `per_dataset_stats=[...]` matching the new constructor signature.
…irectly `test_complete_pi05_pipeline_integration` reaches into `policy.normalize_targets` / `policy.unnormalize_outputs` directly to verify the round-trip, which bypasses the policy-level `_resolve_dataset_index` helper. The new per-dataset Normalize forward needs an explicit `(B,)` long index — pass a zeros tensor matching the single-dataset construction of this fixture.
…set normalize PR Hard blockers (1-5 from review): 1. `_save_pretrained` with stats stripped: detach norm-buffer submodules and reuse `save_model_as_safetensor` so tied `lm_head`/`embed_tokens` on PaliGemma-backed policies dedup correctly. The raw `save_file` path crashed with `SafetensorError: Some tensors share memory`. 2. Legacy checkpoint migration: new `_promote_legacy_norm_buffers_in_state_dict` shim unsqueezes pre-PR `(*feat_shape,)` buffers to the new `(1, *feat_shape)` layout at load time. Called from base `_load_as_safetensor` and from every pi `from_pretrained` override so anything under `TensorAuto/*` still loads. 3. Inference dataset_index plumbing: added `EvalConfig.dataset_repo_id` and `ServerConfig.dataset_repo_id`. `scripts/eval.py::rollout` and `scripts/grpc/server.py` (both warmup and per-request batch) inject the field when set; otherwise the single-dataset `_resolve_dataset_index` fallback handles things. 4. `ValueFunction.__init__` now accepts both the legacy `dataset_stats: dict` and the new `per_dataset_stats: list[dict] + dataset_names: list[str]` API, with mutual-exclusion validation. Lets `make_policy(cfg.policy.type='value', ds_meta=...)` work after the factory plumbing change. 5. Resume with `save_normalization_stats=False`: pass `strict=False` to `accelerator.load_state` so the stripped buffer keys don't crash the load. Buffers stay at the `_inject_stats`-populated values. Higher-severity edges: - `_resolve_dataset_index` now moves the dataset_index tensor to the batch's device (was missing in the tensor branch). - `_inject_stats` validates that any caller-supplied `dataset_names` matches `config.dataset_names` element-for-element (was silently corrupting the name->index map on reorder). - `_check_norm_stats_loaded` now also runs at the end of base `from_pretrained` so direct callers (notebooks, gRPC server, downstream scripts) surface the missing-stats mistake at construction time instead of mid-forward. Uses the shared `NORM_MODULE_NAMES` tuple. - `Normalize._gather_and_broadcast` raises `ValueError` when the batch tensor has fewer dims than the stats buffer (was silently broadcasting to a wrong shape). - `aggregated_action_stats` zips weights with the filtered stats list instead of slicing `[:N]` — a non-trailing dataset lacking `actions` would otherwise misalign the BPE codec weights. Test fixes: - `test_pi07_low_level.py` and `test_pi07_paligemma_low_level.py`: pass explicit `dataset_index` when calling `normalize_targets`/ `unnormalize_outputs` submodules directly (bypassing the policy-level helper). Cleanups: - Extract `NORM_MODULE_NAMES` constant in `pretrained.py` and use it in `_save_pretrained`, `_inject_stats`, `_check_norm_stats_loaded`. - Drop the unused `packaging` and `safetensors` version imports from `pretrained.py` (the new `_load_as_safetensor` path doesn't need them).
|
Thanks for the thorough review — all five blockers and the higher-severity edge cases are addressed in Blockers
Edges
Test gap
Deferred (cleanups, marked as suggestions)Skipping for this PR — happy to do as follow-ups:
I extracted the |
shuheng-liu
left a comment
There was a problem hiding this comment.
Second pass after f8c7492. The five blockers from the first review are all addressed cleanly — _save_pretrained's detach/reattach approach is the right call for the tied-weights problem, the legacy-buffer promotion shim handles the migration path well, the eval/gRPC dataset_repo_id plumbing is clean, and ValueFunction now bridges both APIs. Most of the cleanup follow-ups from the prior review (the shared NORM_MODULE_NAMES constant, the _check_norm_stats_loaded invariant in from_pretrained, the aggregated_action_stats zip-fix, the _gather_and_broadcast ValueError, the explicit _inject_stats cross-check) all landed too.
A few new things surface this pass — three real correctness bugs introduced or missed by the fix commit, two train.py resume edge cases, and an indentation drift across six modeling files.
Correctness
-
pretrained.py:331— The device-mismatch fix is incomplete._infer_batch_size_and_deviceiteratesbatch.values()in dict-insertion order; if a caller hand-buildsbatch = {'dataset_index': cpu_idx, 'state': gpu_state, ...}, the function returnsdataset_index's own device,idx.to(device=cpu)becomes a no-op, andindex_selectlater crashes with the same device-mismatch error the fix was supposed to prevent. Inline above with a 4-line patch (skipdataset_indexwhen inferring) or usenext(self.parameters()).device. -
value/modeling_value.py:160— Multi-dataset value config silently corrupts.super().__init__(config)populates_dataset_name_to_indexwith all N names fromconfig.dataset_names(length 3 in a typical mixture), but ValueFunction then truncatesper_dataset_statsto length 1. Buffer has 1 row; the name→index map has 3 entries.dataset_repo_id='b'→ index 1 →index_select(0, 1)on a 1-row buffer → CUDA-sideIndexError. The warning at lines 153-157 is not enough; either refuse the multi-dataset case loudly or also truncatecfg.dataset_names. -
pretrained.py:339— The single-dataset fallback derivesnum_datasetsfromlen(self._dataset_name_to_index), not from the actual buffer leading dim. Anyone constructing the policy directly (PI05Policy(config, per_dataset_stats=[s1, s2])without settingconfig.dataset_names— common in notebooks/tests) hits_dataset_name_to_index = None, the<= 1branch returnstorch.zeros(B), and every sample silently normalizes againsts1even when the caller intendeds2. Make the buffer's leading dim the source of truth, or always populate_dataset_name_to_indexwhen stats are passed.
Resume / strict=False
-
train.py:582—strict=Falseis too broad. It correctly lets stripped norm-buffer keys through, but it also silently swallows any other genuine missing key (renamed module, forgotten buffer, accidental refactor). A symmetricregister_load_state_pre_hookthat drops onlyis_norm_buffer_key-matching keys would let real missing-key bugs still raise. -
train.py:579— Asymmetric: trained-with-False then resumed-with-True hard-crashes. The condition reads the current config, not what was on disk. A user who flipssave_normalization_statsbetween the initial run and the resume getsRuntimeError: Missing key(s)after a long load. Persist the as-saved flag into the checkpoint metadata, or always usestrict=Falseon resume paired with the targeted skip above.
Style / clarity
pi05_mem/modeling_pi05.py:364(andpi06, bothpi07/*, bothpi07_paligemma/*— six files) — Indentation drift in the legacy-promotion comment block. The first comment line is at 16 spaces (insideif remap_count > 0 and is_main_process:); the continuation comment and the actual_promote_legacy_norm_buffers_in_state_dictcall are at 12 spaces (outside theif). The code is functionally correct, but the visual mismatch invites a future maintainer to "fix" the indentation by moving the call into theifblock — at which point promotion silently no-ops on every load that didn't need a key remap, and skips entirely on non-main ranks.
Minor nits not posted inline
pretrained.py:562:device_arg = map_location if map_location != "cpu" else "cpu"— both branches evaluate tomap_location; the ternary is dead code.tests/policies/test_value.py:155still passesdataset_stats=...(legacy API). Today this is fine becauseValueFunction.__init__accepts both, but it leaves the test corpus inconsistent — when the legacy alias is eventually retired the test silently regresses.- The
make_policyfactory now calls_inject_stats(...)(line 278) afterfrom_pretrainedalready populated buffers via theper_dataset_stats=...kwargs (passed through**kwargsto the constructor insidefrom_pretrained). The injection is now redundant but harmless — the new equality check in_inject_statscatches reorder bugs, so this is essentially a defensive double-check. Worth a one-line comment noting the redundancy is intentional. _promote_legacy_norm_buffers_in_state_dictonly handles N=1 promotion. Loading a legacy single-dataset checkpoint into a new multi-dataset policy (target shape(N, *feat), loaded(*feat)) still raises a shape mismatch — the shim unsqueezes to(1, *feat)and the subsequentload_state_dictrejects. Edge case (user has to actively migrate single→multi), but a clearer error message pointing at_inject_statswould help.
The three correctness items above are the only things I'd hold the merge on. Everything else is style or future-proofing.
Generated by Claude Code
| if per_dataset_stats is None and dataset_stats is not None: | ||
| per_dataset_stats = [dataset_stats] | ||
| if per_dataset_stats is not None and len(per_dataset_stats) > 1: | ||
| logging.warning( |
There was a problem hiding this comment.
Multi-dataset value config silently corrupts at inference.
super().__init__(config) runs before this block and populates self._dataset_name_to_index from config.dataset_names. When make_policy(value_cfg, ds_meta=<DatasetMixtureMetadata with N>1 datasets>) calls this constructor, the factory has already set cfg.dataset_names = ['a', 'b', 'c'] (length 3) at factory.py:255. So:
super().__init__(config)builds_dataset_name_to_index = {'a':0, 'b':1, 'c':2}(length 3).- The truncation below sets
per_dataset_stats = per_dataset_stats[:1]→ 1-row buffer. - At inference,
batch['dataset_repo_id'] = 'b'resolves to index1, thenNormalize.forwardcallsstat.index_select(0, tensor([1]))on a(1, *feat)buffer →IndexError: index 1 is out of bounds for dimension 0 with size 1(CUDA-side, hard to debug). - At training,
_resolve_dataset_index's single-dataset fallback won't fire either (num_datasets = len(_dataset_name_to_index) = 3 > 1), so missing-key batches raiseKeyErrorinstead of defaulting.
Two options:
- Refuse the multi-dataset case loudly (
raise ValueError("ValueFunction is single-dataset only; got len(per_dataset_stats)=N")) instead of silently truncating. - If the policy genuinely is single-dataset by design, also override
cfg.dataset_names = dataset_names[:1]beforesuper().__init__to keep the name map and buffer in sync.
The warning at lines 153-157 is not enough — the user gets a runtime crash they can't easily trace back to "value can only handle one dataset".
Generated by Claude Code
| # tensors and Accelerate later does the host->device copy on a | ||
| # per-key basis — `dataset_index` rides along for tensor keys, | ||
| # but a caller hand-constructing a batch on GPU may leave the | ||
| # index on CPU, breaking `index_select` inside Normalize. |
There was a problem hiding this comment.
The device-placement fix doesn't actually fix the bug if dataset_index iterates first.
_infer_batch_size_and_device walks batch.values() in insertion order and returns the device of the first tensor it finds. Python dicts preserve insertion order, so if a caller hand-builds the batch as:
batch = {'dataset_index': cpu_idx_tensor, 'state': gpu_state_tensor, ...}…then _infer_batch_size_and_device returns (B, cpu) from the very first tensor it sees — dataset_index itself — even though state and the rest of the batch are on GPU. idx.to(device='cpu') becomes a no-op, and the subsequent index_select inside _gather_and_broadcast raises:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
The original device-mismatch bug is unfixed for any caller who happens to put dataset_index first in their dict (gRPC server, notebooks, custom inference loops).
Two viable fixes:
| # index on CPU, breaking `index_select` inside Normalize. | |
| # Skip `dataset_index` when inferring device — otherwise a caller | |
| # who put it first in the dict gets `idx.to(idx.device)` (no-op). | |
| device = next( | |
| (v.device for k, v in batch.items() if k != "dataset_index" and isinstance(v, Tensor)), | |
| idx.device, | |
| ) | |
| return idx.to(dtype=torch.long, device=device) |
Or use next(self.parameters()).device — the policy's own device is always the right target for a Normalize index_select.
Generated by Claude Code
| num_datasets = len(self._dataset_name_to_index) if self._dataset_name_to_index else 1 | ||
| if num_datasets <= 1: | ||
| batch_size, device = self._infer_batch_size_and_device(batch) | ||
| return torch.zeros(batch_size or 1, dtype=torch.long, device=device) |
There was a problem hiding this comment.
num_datasets derived from _dataset_name_to_index, not from the actual buffer shape — silent corruption when they disagree.
The fallback uses len(self._dataset_name_to_index) (or 1 when None) to decide whether to default to all-zero indices. But this dict is built once in __init__ from config.dataset_names, which only make_policy sets as a side effect (line 255 of factory.py). A caller who constructs the policy directly with stats but no name list:
policy = PI05Policy(config, per_dataset_stats=[stats_robotA, stats_robotB])
# config.dataset_names is None, so PreTrainedPolicy.__init__ leaves
# _dataset_name_to_index = None.
policy.select_action({'state': ..., 'image': ...})…falls through to the num_datasets <= 1 branch, returns torch.zeros(B), and every sample is normalized with robotA's mean/std regardless of the caller's intent. No exception, just silently degraded action quality.
The actual leading dim of the stats buffer is the source of truth. Either:
- Look at
self.normalize_inputs/normalize_targets/ etc. and pull the leading dim from abuffer_*parameter, or - Always update
self._dataset_name_to_index(or at least its length) whenper_dataset_statsis passed at construction, even when the names list isNone(use generateddefault_0,default_1placeholder names).
Right now the relationship between "the buffer has N rows" and "the policy knows it has N rows" only holds when make_policy is the constructor.
Generated by Claude Code
| load_kwargs: dict = {} | ||
| if not cfg.policy.save_normalization_stats: | ||
| load_kwargs["strict"] = False | ||
| accelerator.load_state(cfg.checkpoint_path, **load_kwargs) |
There was a problem hiding this comment.
strict=False on resume is too broad — masks every missing-key bug, not just norm buffers.
accelerator.load_state passes strict down to the underlying model.load_state_dict. When save_normalization_stats=False, the fix correctly turns off strict to let the stripped normalize_*.buffer_* keys through. But that same flag also silently swallows:
- A renamed module between save and resume (e.g.
backbone→encoder). - A forgotten buffer in a new layer.
- Any genuine architecture-regression introduced in the gap between save and resume.
The user would have to notice the warning log lines (which are easy to miss in 100k-step training output) rather than getting a hard error at resume time.
A targeted approach is safer: keep strict=True, register a corresponding register_load_state_pre_hook that drops norm-buffer keys from the loaded state_dict before load_state_dict runs. Symmetric to the save-state pre-hook on lines 561-568, and lets real missing-key bugs still raise. Sketch:
if not cfg.policy.save_normalization_stats:
def _drop_norm_buffer_keys_post_load(models, input_dir):
# No-op: keys never landed on disk, so they were already absent
# from the model's `load_state_dict` call. The buffers were
# repopulated by `_inject_stats` inside `make_policy` and the
# missing-key warning is the only signal — turn it into a noop.
del models, input_dir
accelerator.register_load_state_pre_hook(_drop_norm_buffer_keys_post_load)Or — simpler — assert that every missing key passes is_norm_buffer_key and re-raise otherwise.
Generated by Claude Code
| # buffers were already repopulated by `_inject_stats` inside | ||
| # `make_policy(cfg, ds_meta=...)` above, so this leaves them at the | ||
| # right values rather than silently overwriting. | ||
| load_kwargs: dict = {} |
There was a problem hiding this comment.
Asymmetric resume — trained-with-False then resumed-with-True hard-crashes.
The condition if not cfg.policy.save_normalization_stats reads the current config, not what was on disk. This is fine for the symmetric case, but a user who:
- Initially trains with
save_normalization_stats=False(smaller checkpoint disk footprint), then - Flips the flag to
Trueon the resume config (because they want stats included in the next round of checkpoints for downstream eval),
…gets:
RuntimeError: Error(s) in loading state_dict for ...:
Missing key(s) in state_dict: "normalize_inputs.buffer_observation_state.mean", ...
The resume crashes because load_kwargs stays empty (current cfg has True), strict=True is applied, and the on-disk safetensors lacks the buffer keys that the previous run stripped.
Two options:
- Persist the as-saved flag into the checkpoint's
config.json(or sidecar metadata) and read that for the strict decision instead ofcfg.policy.save_normalization_stats. - Or simpler: always pass
strict=Falseon resume, paired with the targeted missing-key check above.
This is the same flavor of footgun as inverted-flag config drift — the only signal is a runtime crash after a long load.
Generated by Claude Code
| if remap_count > 0 and is_main_process: | ||
| logging.info("Remapped %d state dict keys", remap_count) | ||
|
|
||
| # Promote legacy single-dataset Normalize/Unnormalize buffers from |
There was a problem hiding this comment.
Indentation drift — first comment line is inside the if block but the call is outside (repeated in 6 modeling files).
Lines 361-368 read as:
if remap_count > 0 and is_main_process:
logging.info("Remapped %d state dict keys", remap_count)
# Promote legacy single-dataset Normalize/Unnormalize buffers from ← 16 spaces, inside if
# `(*feat_shape,)` to the new `(1, *feat_shape)` stacked layout so pre-PR ← 12 spaces
# checkpoints load via `model.load_state_dict(...)`. ← 12 spaces
model._promote_legacy_norm_buffers_in_state_dict(remapped_state_dict) ← 12 spaces
missing_keys, unexpected_keys = model.load_state_dict(...)The first # Promote legacy ... line is inside if remap_count > 0 and is_main_process:, while the continuation comment and the actual _promote_legacy_norm_buffers_in_state_dict call sit at the outer scope. Functionally correct (the call DOES run unconditionally as intended), but the visual mismatch is exactly the kind of thing a future maintainer "fixes" by indenting the call into the if — at which point legacy promotion silently no-ops on every load that didn't need a key remap (and skips entirely on non-main ranks because of is_main_process), and pre-PR checkpoints stop loading.
Same pattern in pi06/modeling_pi06.py, pi07/low_level/modeling_pi07_low_level.py, pi07/high_level_planner/modeling_pi07_high_level.py, and both pi07_paligemma/*/modeling_*.py. Move the orphaned # Promote legacy ... line up by 4 spaces (or paragraph-break it out of the if block) in all six files.
Generated by Claude Code
…value name-map sync, on-disk strict detection, indent drift Three correctness bugs from review #2: - `_resolve_dataset_index` device fix is now belt-and-braces: tensor branch uses `next(self.parameters()).device` (the model's own device), not whatever device `_infer_batch_size_and_device` picked from the batch. The helper also skips the wrapper-injected `dataset_index` / `dataset_repo_id` keys so a CPU-resident index can't override a GPU batch when dict iteration order puts it first. - `ValueFunction` no longer leaves a stale `_dataset_name_to_index` after truncating `per_dataset_stats` to the single-row buffer. Inference with `dataset_repo_id='<non-first>'` would otherwise index out of range on the 1-row buffer. Multi-dataset configs are explicitly truncated (with a warning) so the name map and buffer agree. - New `_stacked_num_datasets()` helper reads the leading dim of any attached Normalize/Unnormalize buffer. `_resolve_dataset_index`'s single-dataset fallback now uses that as the source of truth instead of `len(self._dataset_name_to_index)`, so a direct `PI05Policy(config, per_dataset_stats=[s1, s2])` call (bypassing the factory's `cfg.dataset_names` population) raises the multi-dataset KeyError rather than silently normalizing every sample against `s1`. Resume / strict-flag fixes: - `train.py` now inspects the on-disk safetensors header to decide `strict=False`, not the *current* `cfg.policy.save_normalization_stats`. Toggling the flag between an initial run and a resume no longer hard-crashes. Strict is tightened to exactly the case where the on-disk file actually lacks norm-buffer keys, so an unrelated module rename in any other code path still raises (it would have been silently swallowed by the previous broad `strict=False`). Indentation drift: - `pi05_mem`, `pi06`, `pi07/low_level` had the legacy-buffer-promotion call's first comment line at 16 spaces (inside `if remap_count > 0`) while the call itself was at 12 (outside). All three are now uniformly outside the `if`, with an explicit "Always run" note so a future maintainer can't "fix" the misalignment by hoisting the call into the `if` block and silently skip promotion on no-rename loads.
|
Thanks again — all three correctness bugs, both resume edge cases, and the indentation drift are addressed in Correctness
Resume
Style
Ready for another pass when you have time. |
shuheng-liu
left a comment
There was a problem hiding this comment.
Third pass after e8fa45d. All six findings from round 2 are addressed cleanly:
| Round-2 finding | Fix in e8fa45d |
|---|---|
_infer_batch_size_and_device device-from-dataset_index bug (pretrained.py:331) |
New _NON_DATA_BATCH_KEYS skips the key during iteration; tensor branch now uses _model_device() (reads from self.parameters()) rather than dict-iteration order. Source of truth is now unambiguous. |
Multi-dataset value silent corruption (value/modeling_value.py:160) |
ValueFunction.__init__ now truncates cfg.dataset_names and rebuilds _dataset_name_to_index to match the 1-row buffer. The mismatch can no longer occur. |
num_datasets derived from _dataset_name_to_index (pretrained.py:339) |
New _stacked_num_datasets() reads the buffer's actual leading dim. Direct-construction PI05Policy(config, per_dataset_stats=[s1, s2]) without cfg.dataset_names now correctly raises KeyError instead of silently routing every sample to row 0. |
Resume strict=False too broad (train.py:582) |
safe_open(.../model.safetensors).keys() is inspected at resume time; strict=False only applies when norm-buffer keys are actually absent from disk. Unrelated missing keys still raise. |
Asymmetric trained=False/resume=True (train.py:579) |
Same on-disk-header fix — the source of truth is the file, not the current cfg flag. |
| Indent drift in legacy-promotion comment block | Fixed in pi05_mem, pi06, pi07/low_level. (My round-2 finding overstated the scope — pi05, pi07/high_level_planner, and both pi07_paligemma/* already had correct indentation in f8c7492; no false-fix risk in those files.) |
No new blockers and no new correctness bugs from this commit. A handful of small items that I'd leave as cleanups for either this PR or a follow-up:
pretrained.py:619—device_arg = map_location if map_location != "cpu" else "cpu"is still a dead ternary (both branches evaluate tomap_location). Mentioned in round 2 as a nit; justdevice_arg = map_locationwould do.tests/policies/test_value.py:155— still constructs withValueFunction(config, dataset_stats=...)(legacy single-dict). Works today (the new__init__keeps the alias), but when the legacy alias is eventually retired the test regresses silently.pretrained.py:559(_promote_legacy_norm_buffers_in_state_dict) — still only handles N=1 promotion. Migrating a legacy single-dataset checkpoint into a fresh N>1-dataset policy raises a shape mismatch with no message pointing the user at_inject_stats. Edge case; defer to a follow-up.pretrained.py:289(_stacked_num_datasets) — returns the first buffer's leading dim and assumes every other Normalize module agrees. True today by construction (factory builds them together with the samenum_datasets), but worth a one-line assert or a comment noting the invariant for future maintainers.scripts/train.py:589(except Exception) — the bareexceptaroundsafe_openis broad; specific catches (FileNotFoundError,safetensors.SafetensorError) would surface unexpected I/O issues instead of swallowing them into the cfg-based fallback. Minor.- No tests for
_stacked_num_datasets,_model_device, the ValueFunction config sync, or the resume header inspection. Each is a small behavior change that could regress silently. A unit test per method (e.g.PI05Policy(config, per_dataset_stats=[s1, s2])then_resolve_dataset_index({'state': x})should raise;policy._stacked_num_datasets() == 2) would catch future regressions cheaply.
Aside from the test-coverage point, none of those gate the merge. The five hard blockers from round 1 and the three correctness bugs from round 2 are all genuinely fixed. Once the open items from the PR's own "Not yet covered" list (smoke training, end-to-end skip-stats round-trip, seeded determinism, DDP/ZeRO-2) clear, this is ready.
Generated by Claude Code
PR #336 (per-dataset normalization) renamed `Normalize(stats=...)` to `Normalize(per_dataset_stats=[...])` (a list, with `num_datasets=` as the inf-init path). PR #337 was authored against the pre-#336 API and merged first; #336 then merged without updating these tests, so CPU CI now fails on every PR with: ValueError: create_stats_buffers requires either `per_dataset_stats` or `num_datasets` TypeError: Normalize.__init__() got an unexpected keyword argument 'stats' Three failures in tests/policies/test_pi07_paligemma_low_level.py::TestSkipNormalizationWeights: - test_predicate_matches_real_normalize_buffer_keys - test_find_inf_normalize_buffers_detects_init_inf_sentinels - test_find_inf_normalize_buffers_empty_when_stats_passed Updates: - inf-init paths now pass `num_datasets=1` (replaces the implicit single-dataset default). - the populated-stats path passes `per_dataset_stats=[stats]` (wrap the single stat dict in a singleton list to match the new D-row layout). Verified locally: all 7 tests in TestSkipNormalizationWeights pass.
What this does
Tears down the mixture-level stats aggregation in
DatasetMixtureMetadataand routes per-dataset stats into stacked Normalize/Unnormalize buffers
indexed per sample. Adds a configurable flag to omit those buffers from
checkpoint safetensors.
Per-dataset normalization.
DatasetMixtureMetadatano longer callsaggregate_stats(...)across datasets. Instead it exposesper_dataset_stats: list[dict]anddataset_names: list[str]. A new_TaggedDatasetwrapper around every underlyingBaseDatasetinjectsdataset_repo_id: stranddataset_index: torch.longinto eachsample (default collate batches them into
list[str]and(B,)longrespectively).
Normalize/Unnormalizebuffers are now shaped(num_datasets, *feat_shape);forward(batch, dataset_index)gathersthe right row per sample via
index_selectand broadcasts over anyextra temporal/spatial axes. Each of the eight in-scope pi policies
(
pi0,pi05,pi05_mem,pi06,pi07/low_level,pi07/high_level_planner,pi07_paligemma/low_level,pi07_paligemma/high_level_planner) threadsdataset_index = self._resolve_dataset_index(batch)through everyforward/select_action/sample_actions. Thevaluepolicy is outof scope; it keeps its single-dict
dataset_statsexternal API butwraps internally into a singleton list before calling the new
Normalize.Inference-time callers pass
batch["dataset_repo_id"](str orlist[str]); the base
PreTrainedPolicy._resolve_dataset_indexhelpermaps strings through
config.dataset_namesto an integer index.Single-dataset configs (
num_datasets <= 1) default to all-zeroindices when both batch keys are absent, so legacy single-dataset
callers don't need to know about the per-sample plumbing. Multi-dataset
policies still raise a clear
KeyErrorwhen the caller forgets, sosilent misuse stays loud.
Skip-stats safetensors flag. Two new fields on
PreTrainedConfig:save_normalization_stats: bool = Trueanddataset_names: list[str] | None = None._save_pretrainedaccepts aninclude_norm_statsoverride that wins over the config field;accelerator.save_stateis gated via aregister_save_state_pre_hookin
train.pyso the buffers are stripped from the on-disk safetensorsin both code paths. Reloading a stats-less checkpoint requires
make_policy(..., ds_meta=...)so_inject_stats(...)can repopulate;otherwise
_check_norm_stats_loadedraises a clear error pointing theuser at the right knob.
Other touches.
fit_fast_tokenizer.pyswitches from the deletedmixture.meta.stats["actions"]to a newDatasetMixtureMetadata.aggregated_action_stats()helper that wrapsaggregate_stats(...)on demand (BPE codec still needs one globalrange).
export_to_onnx.pypasses an explicitdataset_index=0tounnormalize_outputssince the ONNX tracer bypasses_resolve_dataset_index.How it was tested
tests/policies/test_normalize_per_dataset.py(6 tests): D=3-4stacked buffers, per-row indexing on MEAN_STD and MIN_MAX,
Unnormalize-inverts-Normalize round-trip, temporal-axis broadcasting,
image
(C, 1, 1)broadcasting, length-mismatch validation.tests/policies/test_save_pretrained_skip_stats.py(6 tests):save with stats keeps
normalize_*.buffer_*keys, save withoutstrips them, method-level
include_norm_statsoverride wins over theconfig field,
safetensors.safe_openround-trip,is_norm_buffer_keypredicate.
tests/datasets/test_tagged_dataset.py(5 tests):_TaggedDataset.__getitem__injects the two keys, preserves.meta, default collate produceslist[str]and(B,)long tensor.tests/datasets/test_dataset_mixture.pyto assert onper_dataset_stats/dataset_namesinstead of the deletedstats,to allow
mixture.datasetsto be the wrapped list, and to strip thewrapper-injected keys before the
test_integration_basic_functionality_with_same_fps_as_datasetper-sample equality check.
tests/policies/test_policies.py::test_normalizeto threaddataset_indexand wrap singleton stats.test_pi05.py,test_pi05_mem_gpu.py,test_pi06.py,test_pi07_high_level_planner.py,test_pi07_low_level.py,test_pi07_paligemma_high_level_planner.py,test_pi07_paligemma_low_level.py) to use the newper_dataset_stats=[...]constructor signature.unrelated
robosuiteimport error intests/utils/test_libero_utils.py.pytest -m "gpu" -n 0): 19 passed, 10 skipped(unrelated), 1359 deselected. Run on an internal GPU dev box (RTX
5090, CUDA 13.0, ~5 min). Two iterations were needed to land — the
first uncovered stale
dataset_stats=...kwargs in seven testfixtures and one
Normalizesubmodule call that bypassed thepolicy-level
_resolve_dataset_indexhelper.Not yet covered (tracked in #335)
The following still need attention beyond a single-GPU pytest pass:
configs/dev/dev_config.jsonor thepi05 smoke config to confirm
forwarddoesn't trip the newinf-assertion at step 1 and the per-dataset val dataloader iterates
clean.
save_normalization_stats=falseround-trip on a realcheckpoint dir, including the
safe_open(...).keys()filter checkand
make_policy(..., ds_meta=...)reload.seed=0, diff the per-step loss series, assert bit-identical.index_selectis a local op with no new collectives, so the riskis low, but worth confirming on real backends.
regression_test.ymlon g6.12xlarge).These checks are tracked in #335.
How to checkout & try? (for the reviewer)
gh pr checkout 336 uv sync --extra dev --extra libero pre-commit run --all-files pytest -m "not gpu" -n auto tests/policies/test_normalize_per_dataset.py tests/policies/test_save_pretrained_skip_stats.py tests/datasets/test_tagged_dataset.py tests/datasets/test_dataset_mixture.py tests/policies/test_policies.pyOn a CUDA box (matches the run already done on the dev box):
pytest -m "gpu" -n 0To exercise the safetensors round-trip on a real training run (still
tracked in #335):
opentau-train \ --accelerate-config configs/examples/accelerate_ddp_config.yaml \ --config_path=configs/dev/dev_config.json \ --steps=2 --save_freq=2 \ --policy.save_normalization_stats=falseThen inspect the safetensors with
from safetensors import safe_open; safe_open(...).keys()to confirmnormalize_*.buffer_*are absent.Checklist
Note: Before submitting this PR, please read the contributor guideline.