Skip to content

fix(pi05_mem): mask padded history; preserve current state token; emit at inference#254

Merged
shuheng-liu merged 1 commit into
mainfrom
claude/pi05-mem-padded-history
May 5, 2026
Merged

fix(pi05_mem): mask padded history; preserve current state token; emit at inference#254
shuheng-liu merged 1 commit into
mainfrom
claude/pi05-mem-padded-history

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu commented May 4, 2026

What this does

Ports the same three padded-history fixes recently applied to pi07 (#253) and to pi07_paligemma (#205) into pi05_mem — the original SpaceTime SigLIP home, where all three bugs originated. Audited against the same five-criterion list; only the bug fixes applied.

Bug A — SpaceTimeSigLIP temporal attention did not mask padded history frames. Pixel-zeroing alone is insufficient because the SigLIP patch embedding has a learned bias and the temporal positional embedding e(t) is non-zero for t < T-1, so padded "zero" frames produced non-zero hidden states the current frame attended to. Threaded obs_history_is_pad through embed_videoSpaceTimeSiglipVideoEncoder.forward; a new _build_temporal_attn_mask helper produces a (B*N, 1, T, T) additive float mask combining causal lower-triangular with key-side ~obs_history_is_pad. Inference fallback when the mask is absent and T > 1: treats all history as padded so the current frame is uncontaminated by zero placeholders from _build_history_batch.

Bug B — state_mask masked the current state token whenever obs_history_is_pad[:, -1] = True. The dataset's history_state_drop_prob augmentation flips the entire obs_history_is_pad tensor to all-True while zeroing the state, so the policy was being conditioned on no state at all. Added an unconditional state_mask[:, -1] = True override after both branches; switched the absent-mask branch from ones to zeros so the encoder cannot attend to garbage history slots without a real signal.

Mid-episode fix — _build_history_batch now emits obs_history_is_pad based on the same idx < 0 decision the state/camera loops already make. Threaded through PI05MemPolicy.sample_actionsPI05MemFlowMatching.sample_actionsembed_prefix. Without this emit, the encoder's None-fallback masks ALL history at inference (even genuine mid-episode frames once the buffer fills); with it, only the truly-padded start-of-episode slots get masked. The None-fallback and embed_prefix's zeros-then-current branch are kept as defensive defaults for callers that bypass _build_history_batch.

Audit of the other two criteria from the original review:

  1. Subgoal images encoded with the same video encoder weights — N/A. pi05_mem has no subgoal concept (no prepare_subgoal*, no subgoal in embed_prefix). The "no second set of weights" property is already satisfied by modeling_pi05.py:801-803 constructing the encoder with vision_tower=self.paligemma_with_expert.paligemma.vision_tower held by reference.
  2. Inference prompt construction matches training — already satisfied. embed_prefix is just videos + lang + state + (training-only: discrete_actions). No optional response/subgoal/metadata blocks → no conditional layout to keep in sync. Training and inference are byte-equivalent on the shared prefix already.

How it was tested

New regression tests in tests/policies/test_pi05_mem.py:

  • TestTemporalAttentionMaskBlocksPaddedHistory — current-frame output is invariant to padded-history pixel values; inference fallback works with obs_history_is_pad=None; mask shape and per-cell values pinned; current-frame override defends the all-padded edge case; helper does not mutate the caller's tensor.
  • TestStateMaskCurrentStepAlwaysReal — current state token stays real when obs_history_is_pad is all-True or None or partial-pad; helper does not mutate the caller's obs_history_is_pad.
  • TestBuildHistoryBatchEmitsObsHistoryIsPad — pins shape, the first-step [T, T, T, F] pattern, the buffer-full all-False pattern, the partial-fill mid-episode pattern, batch broadcasting, and slot-by-slot agreement with the actual zero-fill of state and camera tensors.

Verification run:

  • pre-commit run --all-files — all hooks pass.
  • pytest -m "not gpu" -n auto tests/policies/ — 184 passed, 2 skipped (CPU subset; the new pi05_mem tests land in this subset).
  • pytest -m "gpu" -n 0 tests/policies/test_pi05_mem.py tests/policies/test_pi05_mem_gpu.py — 6/6 pass on a CUDA host (forward shape/dtype, single-frame invariance, gradient freeze/unfreeze, causality smoke, end-to-end policy).
  • Determinism: ran a tiny PI05MemPolicy.forward twice with the same torch.manual_seed/CUDA seed and pre-sampled noise+time; MSE and CE losses are bit-identical (torch.equal, not just allclose). Per CLAUDE.md hard rule for modeling_*.py touches.

How to checkout & try? (for the reviewer)

git fetch origin claude/pi05-mem-padded-history && git checkout claude/pi05-mem-padded-history
pre-commit run --all-files
pytest -m "not gpu" -n auto tests/policies/test_pi05_mem.py
# GPU host:
pytest -m "gpu" -n 0 tests/policies/test_pi05_mem.py tests/policies/test_pi05_mem_gpu.py

To see the bug-A regression directly (current-frame invariance under padded history):

pytest -sx tests/policies/test_pi05_mem.py::TestTemporalAttentionMaskBlocksPaddedHistory::test_current_frame_invariant_to_padded_history

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

…t at inference

Three padded-history bugs in pi05_mem (the original SpaceTime SigLIP
home), mirrors of fixes recently applied to pi07 (PR #253) and to
pi07_paligemma (PR #205).

Bug A — SpaceTimeSigLIP temporal attention did not mask padded frames.
Pixel-zeroing alone is insufficient: the SigLIP patch embedding bias
and the temporal positional embedding e(t) for t < T-1 are non-zero,
so padded "zero" frames produced non-zero hidden states the current
frame attended to. Threaded obs_history_is_pad through embed_video
and SpaceTimeSiglipVideoEncoder.forward; a new
_build_temporal_attn_mask helper produces a (B*N, 1, T, T) additive
float mask combining causal lower-triangular with key-side
~obs_history_is_pad. Inference fallback: when the mask is absent and
T > 1, treat all history as padded so the current frame is
uncontaminated by zero placeholders from _build_history_batch.

Bug B — state_mask masked the current state token whenever
obs_history_is_pad[:, -1] was True. The dataset's
history_state_drop_prob augmentation flips the entire
obs_history_is_pad tensor to all-True while zeroing the state, so
the policy was being conditioned on no state at all. Add an
unconditional state_mask[:, -1] = True override after both branches;
switch the absent-mask branch from ones to zeros so the encoder
cannot attend to garbage history slots without a real signal.

Mid-episode fix — _build_history_batch now emits obs_history_is_pad
based on the same idx < 0 decision the state/camera loops already
make. Threaded through PI05MemPolicy.sample_actions and
PI05MemFlowMatching.sample_actions to reach embed_prefix with the
live mask. Without this emit the encoder's None-fallback masks ALL
history at inference (even genuine mid-episode frames once the
buffer fills); with it, only the truly-padded start-of-episode slots
get masked.

Subgoal weight-sharing (criterion #3) and inference/training prompt
parity (criterion #4) audited as N/A: pi05_mem has no subgoal
concept, and embed_prefix has only video + lang + state +
(training-only: discrete_actions) — no optional middle blocks to
keep in sync, and the encoder already shares vision_tower weights
with paligemma_with_expert by reference (the only set of vision
weights that exists).

Tests:
- TestTemporalAttentionMaskBlocksPaddedHistory — current-frame
  invariance to padded-history pixel values; inference fallback;
  mask shape/values pinned; current-frame override defends the
  all-padded edge case; helper does not mutate caller tensor.
- TestStateMaskCurrentStepAlwaysReal — current state stays real when
  obs_history_is_pad is all-True or None or partial-pad; helper
  doesn't mutate obs_history_is_pad.
- TestBuildHistoryBatchEmitsObsHistoryIsPad — first-step
  [T,T,T,F] pattern, buffer-full all-False, partial-fill
  mid-episode, batch broadcast, slot-by-slot agreement with the
  zero-fill of state and camera tensors.
@shuheng-liu shuheng-liu added the bug Something isn't working label May 4, 2026
@shuheng-liu shuheng-liu self-assigned this May 4, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the diff for correctness against the pi07 pattern this is porting from. Verified the mask-construction math, mutation safety of the broadcasted expand view, dataset emit-side conventions in lerobot_dataset.py, gradient-checkpointing positional-arg requirement, and slot-by-slot agreement between the emitted obs_history_is_pad and the actual zero-fill in _build_history_batch. No blocking findings.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 4, 2026

[claude-review] summary for commit cd0bd47

No blocking issues found.

Verified:

  • Mask construction (_build_temporal_attn_mask) mirrors the pi07 pattern: causal × ~obs_history_is_pad with key_valid[:, -1] = True override; shape (B*N, 1, T, T) via unsqueeze(1) + repeat_interleave(num_patches, dim=0).
  • state_mask[:, -1] = True correctly counteracts the dataset's history_state_drop_prob augmentation (lerobot_dataset.py:894/:901 flips obs_history_is_pad all-True), which would otherwise condition on no state at all.
  • Mutation safety: ~obs_history_is_pad materializes a fresh contiguous tensor even when the input is the expand(bsize, n_hist) broadcast view from _build_history_batch; subsequent [:, -1] = True writes do not reach the caller. Pinned by test_pad_tensor_not_mutated and test_state_mask_does_not_mutate_obs_history_is_pad.
  • Pad-pattern math: i*interval - missing < 0 reproduces the same idx < 0 decision the existing state/camera loops use, so the emitted mask agrees slot-by-slot with the actual zero-fill (pinned by test_state_and_camera_padding_match_emitted_mask).
  • T=1 short-circuit in SpaceTimeEncoderLayerWrapper.forward keeps the temporal sublayer bypassed; the encoder correctly skips mask construction in that case.
  • Gradient-checkpointing call site (video_encoder.py:559) passes temporal_attn_mask positionally, which is required because torch.utils.checkpoint.checkpoint(..., use_reentrant=False) does not forward kwargs.
  • NaN containment: padded query rows produce NaN at SDPA, but hidden[:, -1] is the only output — the current-frame query is masked to attend only to non-padded keys (with the [:, -1] = True override guaranteeing self-attention), so the returned tokens are uncontaminated. Backward only flows through hidden[:, -1]'s path.
  • Cross-policy consistency: embed_video, embed_prefix, sample_actions, and _build_history_batch signature changes match the pi07 port exactly.

@shuheng-liu shuheng-liu marked this pull request as ready for review May 4, 2026 19:46
@shuheng-liu shuheng-liu merged commit 8fffb25 into main May 5, 2026
8 checks passed
@shuheng-liu shuheng-liu deleted the claude/pi05-mem-padded-history branch May 5, 2026 06:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant