fix(pi07): mask padded history; preserve current state token#253
Conversation
Two padded-history bugs in pi07 low-level, mirrors of fixes from PR #205 (which targets the parallel pi07_paligemma path). Bug A — SpaceTimeSigLIP temporal attention did not mask padded frames. Pixel-zeroing alone is insufficient: the SigLIP patch embedding bias and the temporal positional embedding e(t) for t < T-1 are non-zero, so padded "zero" frames produced non-zero hidden states the current frame attended to. Thread `obs_history_is_pad` through `embed_video` and `SpaceTimeSiglipVideoEncoder.forward`; a new `_build_temporal_attn_mask` helper produces a (B*N, 1, T, T) additive float mask combining causal lower-triangular with key-side `~obs_history_is_pad`. Inference fallback: when the mask is absent and T > 1, treat all history as padded so the current frame is uncontaminated by zero placeholders from `_build_history_batch`. Bug B — `state_mask` masked the current state token whenever `obs_history_is_pad[:, -1]` was True. The dataset's `history_state_drop_prob` augmentation flips the entire `obs_history_is_pad` tensor to all-True while zeroing the state, so the policy was being conditioned on no state at all. Add an unconditional `state_mask[:, -1] = True` override after both branches; switch the absent-mask branch from `ones` to `zeros` so the encoder cannot attend to garbage history slots without a real signal. Tests: - TestTemporalAttentionMaskBlocksPaddedHistory — current-frame output is invariant to padded-history pixel values; inference fallback works with obs_history_is_pad=None; mask shape/values pinned. - TestSubgoalSharesVideoEncoderWeights — locks in that the encoder holds the SigLIP vision tower by reference (one set of weights), and that embed_image(img) and encoder(img.unsqueeze(1)) produce byte-equivalent outputs at T=1. - TestStateMaskCurrentStepAlwaysReal — current state token stays real when obs_history_is_pad is all-True or None; helper does not mutate the caller's tensor (also threaded into embed_video). - TestPrefixLayoutInferenceMatchesTraining — minimal inference batch collapses prefix; training/inference prefixes agree on every position except the trailing "Action: " + discrete-action span. GPU regression tests + nightly determinism check still TODO.
|
[claude-review] summary for commit 57716e4 Re-review of 57716e4 (follow-up to 9e702dd review). All four prior findings are resolved:
No blocking issues found. |
Follow-up to the review on PR #253. Main change — emit obs_history_is_pad from _build_history_batch. Previously the inference path never threaded the mask through select_action -> sample_actions -> embed_prefix, so the encoder fell back to "all history padded except current" via embed_prefix's zeros + state_mask[:, -1] = True branch and the encoder's matching None-fallback. That fixed the start-of-episode contamination but silently masked genuine mid-episode history once the buffer filled. _build_history_batch already owns the idx < 0 decision for each slot; emitting obs_history_is_pad = (i*interval - missing < 0) for i in range(n_hist) costs nothing and lets the encoder use real mid-episode history while still masking only the truly-padded start-of-episode slots. Threaded through PI07LowLevelPolicy. sample_actions and PI07LowLevelFlowMatching.sample_actions to reach embed_prefix with the live mask. The encoder's None-fallback and embed_prefix's zeros-then-current branch are both kept as defensive: anything that bypasses _build_history_batch (direct sample_actions calls in tests, etc.) still gets a safe default that matches training-time augmentations. Nits from review: - Removed redundant .clone() in embed_prefix's state_mask path and in _build_temporal_attn_mask's key_valid path. Both ~obs_history_is_pad expressions already allocate fresh tensors, so subsequent [:, -1] = True writes can't reach the caller's obs_history_is_pad. Kept the regression tests that pinned the no-mutation property. - Rewrote the SpaceTimeEncoderLayerWrapper.forward docstring to drop the misleading "ignored by vanilla layers via SiglipEncoder's positional dispatch" claim — the encoder now branches on isinstance and never calls SiglipEncoder.forward, so vanilla layers simply aren't passed temporal_attn_mask. Skipped the repeat_interleave perf nit: the suggested expand+reshape materializes anyway because SDPA flattens (B, N) into one dim. Real fix is a deeper refactor of the temporal Q/K/V layout; deferred. New tests in TestBuildHistoryBatchEmitsObsHistoryIsPad (test_pi07_cpu.py) — pin the (B, T) mask shape, the first-step all-but-current pattern, the buffer-full all-False pattern, the partial-fill mid-episode pattern, batch broadcasting, and slot-by-slot agreement with the actual zero-fill of state and camera tensors.
There was a problem hiding this comment.
Re-review of 57716e4. Prior suggestion (thread obs_history_is_pad from _build_history_batch → sample_actions → embed_prefix) and three nits all addressed. Verified the new mask agrees with the per-slot zero-fill decision in the loops above and broadcasts correctly. No new blocking issues.
|
Pushed Suggestion — inference-time mid-episode masking
pad_pattern = torch.tensor(
[i * interval - missing < 0 for i in range(n_hist)],
dtype=torch.bool,
device=device,
)
temporal_batch["obs_history_is_pad"] = pad_pattern.unsqueeze(0).expand(bsize, n_hist)Threaded through New Nit — redundant
|
What this does
Ports two padded-history bug fixes from #205 (which targets
pi07_paligemma) into the parallelpi07(Gemma 3 + SpaceTime SigLIP) low-level path. Audited the rest of #205's "Issues addressed" list againstpi07/; the other 7 items are already addressed or not applicable inpi07/.Bug A — SpaceTimeSigLIP temporal attention did not mask padded history frames. Pixel-zeroing alone is insufficient because the SigLIP patch embedding has a learned bias and the temporal positional embedding
e(t)is non-zero fort < T-1, so padded "zero" frames produced non-zero hidden states the current frame attended to. Threadedobs_history_is_padthroughembed_video→SpaceTimeSiglipVideoEncoder.forward; a new_build_temporal_attn_maskhelper produces a(B*N, 1, T, T)additive float mask combining causal lower-triangular with key-side~obs_history_is_pad. Inference fallback: when the mask is absent andT > 1, treats all history as padded so the current frame is uncontaminated by zero placeholders from_build_history_batch.Bug B —
state_maskmasked the current state token wheneverobs_history_is_pad[:, -1] = True. The dataset'shistory_state_drop_probaugmentation (src/opentau/datasets/lerobot_dataset.py:891-901) flips the entireobs_history_is_padtensor to all-True while zeroing the state, so the policy was being conditioned on no state at all. Added an unconditionalstate_mask[:, -1] = Trueoverride after both branches; switched the absent-mask branch fromonestozerosso the encoder cannot attend to garbage history slots without a real signal.Audit and per-criterion status (the four explicit acceptance criteria from the planning thread):
SpaceTimeSiglipVideoEncoderholdsvision_towerandmulti_modal_projectorby reference (in lists, sonn.Module.__setattr__does not re-register them);Gemma3WithExpertModel.embed_imageusessuppress_spacetime_temporalon the same wrapped tower. Locked in by newTestSubgoalSharesVideoEncoderWeights.prepare_*methods +embed_prefixgating (via.any()checks at lines 1232-1241) on both paths. Locked in by newTestPrefixLayoutInferenceMatchesTraining.pi07_paligemmais not touched; #205 will fix it on merge.How it was tested
New regression tests:
tests/policies/test_pi07_video_encoder_cpu.py:TestTemporalAttentionMaskBlocksPaddedHistory— current-frame output is invariant to padded-history pixel values; inference fallback works withobs_history_is_pad=None; mask shape and per-cell values pinned; helper does not mutate the caller's tensor.TestSubgoalSharesVideoEncoderWeights— encoder owns no parameters of its own (vision_tower/multi_modal_projectorheld by reference);embed_image(img)andencoder(img.unsqueeze(1))produce byte-equivalent output at T=1; only one copy of each SigLIP attention weight in the wrapped vision_tower's state_dict.tests/policies/test_pi07_cpu.py:TestStateMaskCurrentStepAlwaysReal— current state token stays real whenobs_history_is_padis all-True orNone; partial pad pattern still preserves current step; helper does not mutate the caller'sobs_history_is_padtensor (also threaded intoembed_video).TestPrefixLayoutInferenceMatchesTraining— minimal inference batch collapses prefix to[videos | lang | "State: " | state | ":\n"]; training and inference prefixes agree on every position except the trailing"Action: " + discrete-actionspan.Verification run:
pre-commit run --all-files— all hooks pass.pytest -m "not gpu" -n auto tests/policies/— 185 passed, 2 skipped (CPU subset).pytest -m "gpu" -n 0 tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py— 7/7 pass on a CUDA host (test_complete_pi07_low_level_pipeline,test_no_optionals_path_on_real_gemma3, the fourTestPI07LowLevelRegressioncases, plus the high-level pipeline integration).PI07LowLevelPolicy.forwardtwice with the sametorch.manual_seed/CUDA seed and pre-sampled noise+time; MSE and CE losses are bit-identical (torch.equal, not justallclose). Per CLAUDE.md hard rule formodeling_*.pytouches.How to checkout & try? (for the reviewer)
To see the bug-A regression directly (current-frame invariance under padded history):
Checklist
Note: Before submitting this PR, please read the contributor guideline.