fix(pi05_mem): mask padded history; preserve current state token; emit at inference#254
Merged
Merged
Conversation
…t at inference Three padded-history bugs in pi05_mem (the original SpaceTime SigLIP home), mirrors of fixes recently applied to pi07 (PR #253) and to pi07_paligemma (PR #205). Bug A — SpaceTimeSigLIP temporal attention did not mask padded frames. Pixel-zeroing alone is insufficient: the SigLIP patch embedding bias and the temporal positional embedding e(t) for t < T-1 are non-zero, so padded "zero" frames produced non-zero hidden states the current frame attended to. Threaded obs_history_is_pad through embed_video and SpaceTimeSiglipVideoEncoder.forward; a new _build_temporal_attn_mask helper produces a (B*N, 1, T, T) additive float mask combining causal lower-triangular with key-side ~obs_history_is_pad. Inference fallback: when the mask is absent and T > 1, treat all history as padded so the current frame is uncontaminated by zero placeholders from _build_history_batch. Bug B — state_mask masked the current state token whenever obs_history_is_pad[:, -1] was True. The dataset's history_state_drop_prob augmentation flips the entire obs_history_is_pad tensor to all-True while zeroing the state, so the policy was being conditioned on no state at all. Add an unconditional state_mask[:, -1] = True override after both branches; switch the absent-mask branch from ones to zeros so the encoder cannot attend to garbage history slots without a real signal. Mid-episode fix — _build_history_batch now emits obs_history_is_pad based on the same idx < 0 decision the state/camera loops already make. Threaded through PI05MemPolicy.sample_actions and PI05MemFlowMatching.sample_actions to reach embed_prefix with the live mask. Without this emit the encoder's None-fallback masks ALL history at inference (even genuine mid-episode frames once the buffer fills); with it, only the truly-padded start-of-episode slots get masked. Subgoal weight-sharing (criterion #3) and inference/training prompt parity (criterion #4) audited as N/A: pi05_mem has no subgoal concept, and embed_prefix has only video + lang + state + (training-only: discrete_actions) — no optional middle blocks to keep in sync, and the encoder already shares vision_tower weights with paligemma_with_expert by reference (the only set of vision weights that exists). Tests: - TestTemporalAttentionMaskBlocksPaddedHistory — current-frame invariance to padded-history pixel values; inference fallback; mask shape/values pinned; current-frame override defends the all-padded edge case; helper does not mutate caller tensor. - TestStateMaskCurrentStepAlwaysReal — current state stays real when obs_history_is_pad is all-True or None or partial-pad; helper doesn't mutate obs_history_is_pad. - TestBuildHistoryBatchEmitsObsHistoryIsPad — first-step [T,T,T,F] pattern, buffer-full all-False, partial-fill mid-episode, batch broadcast, slot-by-slot agreement with the zero-fill of state and camera tensors.
Contributor
There was a problem hiding this comment.
Reviewed the diff for correctness against the pi07 pattern this is porting from. Verified the mask-construction math, mutation safety of the broadcasted expand view, dataset emit-side conventions in lerobot_dataset.py, gradient-checkpointing positional-arg requirement, and slot-by-slot agreement between the emitted obs_history_is_pad and the actual zero-fill in _build_history_batch. No blocking findings.
Contributor
|
[claude-review] summary for commit cd0bd47 No blocking issues found. Verified:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Ports the same three padded-history fixes recently applied to
pi07(#253) and topi07_paligemma(#205) intopi05_mem— the original SpaceTime SigLIP home, where all three bugs originated. Audited against the same five-criterion list; only the bug fixes applied.Bug A — SpaceTimeSigLIP temporal attention did not mask padded history frames. Pixel-zeroing alone is insufficient because the SigLIP patch embedding has a learned bias and the temporal positional embedding
e(t)is non-zero fort < T-1, so padded "zero" frames produced non-zero hidden states the current frame attended to. Threadedobs_history_is_padthroughembed_video→SpaceTimeSiglipVideoEncoder.forward; a new_build_temporal_attn_maskhelper produces a(B*N, 1, T, T)additive float mask combining causal lower-triangular with key-side~obs_history_is_pad. Inference fallback when the mask is absent andT > 1: treats all history as padded so the current frame is uncontaminated by zero placeholders from_build_history_batch.Bug B —
state_maskmasked the current state token wheneverobs_history_is_pad[:, -1] = True. The dataset'shistory_state_drop_probaugmentation flips the entireobs_history_is_padtensor to all-True while zeroing the state, so the policy was being conditioned on no state at all. Added an unconditionalstate_mask[:, -1] = Trueoverride after both branches; switched the absent-mask branch fromonestozerosso the encoder cannot attend to garbage history slots without a real signal.Mid-episode fix —
_build_history_batchnow emitsobs_history_is_padbased on the sameidx < 0decision the state/camera loops already make. Threaded throughPI05MemPolicy.sample_actions→PI05MemFlowMatching.sample_actions→embed_prefix. Without this emit, the encoder'sNone-fallback masks ALL history at inference (even genuine mid-episode frames once the buffer fills); with it, only the truly-padded start-of-episode slots get masked. TheNone-fallback andembed_prefix's zeros-then-current branch are kept as defensive defaults for callers that bypass_build_history_batch.Audit of the other two criteria from the original review:
pi05_memhas no subgoal concept (noprepare_subgoal*, no subgoal inembed_prefix). The "no second set of weights" property is already satisfied by modeling_pi05.py:801-803 constructing the encoder withvision_tower=self.paligemma_with_expert.paligemma.vision_towerheld by reference.embed_prefixis justvideos + lang + state + (training-only: discrete_actions). No optional response/subgoal/metadata blocks → no conditional layout to keep in sync. Training and inference are byte-equivalent on the shared prefix already.How it was tested
New regression tests in
tests/policies/test_pi05_mem.py:TestTemporalAttentionMaskBlocksPaddedHistory— current-frame output is invariant to padded-history pixel values; inference fallback works withobs_history_is_pad=None; mask shape and per-cell values pinned; current-frame override defends the all-padded edge case; helper does not mutate the caller's tensor.TestStateMaskCurrentStepAlwaysReal— current state token stays real whenobs_history_is_padis all-True orNoneor partial-pad; helper does not mutate the caller'sobs_history_is_pad.TestBuildHistoryBatchEmitsObsHistoryIsPad— pins shape, the first-step[T, T, T, F]pattern, the buffer-full all-False pattern, the partial-fill mid-episode pattern, batch broadcasting, and slot-by-slot agreement with the actual zero-fill of state and camera tensors.Verification run:
pre-commit run --all-files— all hooks pass.pytest -m "not gpu" -n auto tests/policies/— 184 passed, 2 skipped (CPU subset; the new pi05_mem tests land in this subset).pytest -m "gpu" -n 0 tests/policies/test_pi05_mem.py tests/policies/test_pi05_mem_gpu.py— 6/6 pass on a CUDA host (forward shape/dtype, single-frame invariance, gradient freeze/unfreeze, causality smoke, end-to-end policy).PI05MemPolicy.forwardtwice with the sametorch.manual_seed/CUDA seed and pre-sampled noise+time; MSE and CE losses are bit-identical (torch.equal, not justallclose). Per CLAUDE.md hard rule formodeling_*.pytouches.How to checkout & try? (for the reviewer)
To see the bug-A regression directly (current-frame invariance under padded history):
Checklist
Note: Before submitting this PR, please read the contributor guideline.