Skip to content

fix(pi07): mask padded history; preserve current state token#253

Merged
shuheng-liu merged 2 commits into
mainfrom
claude/vibrant-chandrasekhar-e9b8ef
May 4, 2026
Merged

fix(pi07): mask padded history; preserve current state token#253
shuheng-liu merged 2 commits into
mainfrom
claude/vibrant-chandrasekhar-e9b8ef

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

What this does

Ports two padded-history bug fixes from #205 (which targets pi07_paligemma) into the parallel pi07 (Gemma 3 + SpaceTime SigLIP) low-level path. Audited the rest of #205's "Issues addressed" list against pi07/; the other 7 items are already addressed or not applicable in pi07/.

Bug A — SpaceTimeSigLIP temporal attention did not mask padded history frames. Pixel-zeroing alone is insufficient because the SigLIP patch embedding has a learned bias and the temporal positional embedding e(t) is non-zero for t < T-1, so padded "zero" frames produced non-zero hidden states the current frame attended to. Threaded obs_history_is_pad through embed_videoSpaceTimeSiglipVideoEncoder.forward; a new _build_temporal_attn_mask helper produces a (B*N, 1, T, T) additive float mask combining causal lower-triangular with key-side ~obs_history_is_pad. Inference fallback: when the mask is absent and T > 1, treats all history as padded so the current frame is uncontaminated by zero placeholders from _build_history_batch.

Bug B — state_mask masked the current state token whenever obs_history_is_pad[:, -1] = True. The dataset's history_state_drop_prob augmentation (src/opentau/datasets/lerobot_dataset.py:891-901) flips the entire obs_history_is_pad tensor to all-True while zeroing the state, so the policy was being conditioned on no state at all. Added an unconditional state_mask[:, -1] = True override after both branches; switched the absent-mask branch from ones to zeros so the encoder cannot attend to garbage history slots without a real signal.

Audit and per-criterion status (the four explicit acceptance criteria from the planning thread):

  1. Non-existent historical images masked out, not just zeroed out — fixed by Bug A.
  2. Current proprio state never padded — fixed by Bug B.
  3. Subgoal images encoded with the same video encoder weights as observations — already satisfied. SpaceTimeSiglipVideoEncoder holds vision_tower and multi_modal_projector by reference (in lists, so nn.Module.__setattr__ does not re-register them); Gemma3WithExpertModel.embed_image uses suppress_spacetime_temporal on the same wrapped tower. Locked in by new TestSubgoalSharesVideoEncoderWeights.
  4. Inference prompt construction follows the same conditional layout as training — already satisfied. Same prepare_* methods + embed_prefix gating (via .any() checks at lines 1232-1241) on both paths. Locked in by new TestPrefixLayoutInferenceMatchesTraining.

pi07_paligemma is not touched; #205 will fix it on merge.

How it was tested

New regression tests:

  • tests/policies/test_pi07_video_encoder_cpu.py:
    • TestTemporalAttentionMaskBlocksPaddedHistory — current-frame output is invariant to padded-history pixel values; inference fallback works with obs_history_is_pad=None; mask shape and per-cell values pinned; helper does not mutate the caller's tensor.
    • TestSubgoalSharesVideoEncoderWeights — encoder owns no parameters of its own (vision_tower/multi_modal_projector held by reference); embed_image(img) and encoder(img.unsqueeze(1)) produce byte-equivalent output at T=1; only one copy of each SigLIP attention weight in the wrapped vision_tower's state_dict.
  • tests/policies/test_pi07_cpu.py:
    • TestStateMaskCurrentStepAlwaysReal — current state token stays real when obs_history_is_pad is all-True or None; partial pad pattern still preserves current step; helper does not mutate the caller's obs_history_is_pad tensor (also threaded into embed_video).
    • TestPrefixLayoutInferenceMatchesTraining — minimal inference batch collapses prefix to [videos | lang | "State: " | state | ":\n"]; training and inference prefixes agree on every position except the trailing "Action: " + discrete-action span.

Verification run:

  • pre-commit run --all-files — all hooks pass.
  • pytest -m "not gpu" -n auto tests/policies/ — 185 passed, 2 skipped (CPU subset).
  • pytest -m "gpu" -n 0 tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py — 7/7 pass on a CUDA host (test_complete_pi07_low_level_pipeline, test_no_optionals_path_on_real_gemma3, the four TestPI07LowLevelRegression cases, plus the high-level pipeline integration).
  • Determinism: ran a tiny PI07LowLevelPolicy.forward twice with the same torch.manual_seed/CUDA seed and pre-sampled noise+time; MSE and CE losses are bit-identical (torch.equal, not just allclose). Per CLAUDE.md hard rule for modeling_*.py touches.

How to checkout & try? (for the reviewer)

git fetch origin claude/vibrant-chandrasekhar-e9b8ef && git checkout claude/vibrant-chandrasekhar-e9b8ef
pre-commit run --all-files
pytest -m "not gpu" -n auto tests/policies/test_pi07_video_encoder_cpu.py tests/policies/test_pi07_cpu.py
# GPU host:
pytest -m "gpu" -n 0 tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py

To see the bug-A regression directly (current-frame invariance under padded history):

pytest -sx tests/policies/test_pi07_video_encoder_cpu.py::TestTemporalAttentionMaskBlocksPaddedHistory::test_current_frame_invariant_to_padded_history

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

Two padded-history bugs in pi07 low-level, mirrors of fixes from PR #205
(which targets the parallel pi07_paligemma path).

Bug A — SpaceTimeSigLIP temporal attention did not mask padded frames.
Pixel-zeroing alone is insufficient: the SigLIP patch embedding bias
and the temporal positional embedding e(t) for t < T-1 are non-zero,
so padded "zero" frames produced non-zero hidden states the current
frame attended to. Thread `obs_history_is_pad` through `embed_video`
and `SpaceTimeSiglipVideoEncoder.forward`; a new `_build_temporal_attn_mask`
helper produces a (B*N, 1, T, T) additive float mask combining causal
lower-triangular with key-side `~obs_history_is_pad`. Inference fallback:
when the mask is absent and T > 1, treat all history as padded so the
current frame is uncontaminated by zero placeholders from `_build_history_batch`.

Bug B — `state_mask` masked the current state token whenever
`obs_history_is_pad[:, -1]` was True. The dataset's `history_state_drop_prob`
augmentation flips the entire `obs_history_is_pad` tensor to all-True
while zeroing the state, so the policy was being conditioned on no state
at all. Add an unconditional `state_mask[:, -1] = True` override after
both branches; switch the absent-mask branch from `ones` to `zeros` so
the encoder cannot attend to garbage history slots without a real signal.

Tests:
- TestTemporalAttentionMaskBlocksPaddedHistory — current-frame output
  is invariant to padded-history pixel values; inference fallback works
  with obs_history_is_pad=None; mask shape/values pinned.
- TestSubgoalSharesVideoEncoderWeights — locks in that the encoder
  holds the SigLIP vision tower by reference (one set of weights), and
  that embed_image(img) and encoder(img.unsqueeze(1)) produce
  byte-equivalent outputs at T=1.
- TestStateMaskCurrentStepAlwaysReal — current state token stays real
  when obs_history_is_pad is all-True or None; helper does not mutate
  the caller's tensor (also threaded into embed_video).
- TestPrefixLayoutInferenceMatchesTraining — minimal inference batch
  collapses prefix; training/inference prefixes agree on every position
  except the trailing "Action: " + discrete-action span.

GPU regression tests + nightly determinism check still TODO.
@shuheng-liu shuheng-liu added the bug Something isn't working label May 4, 2026
@shuheng-liu shuheng-liu self-assigned this May 4, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review of the padded-history fix port to pi07.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 4, 2026

[claude-review] summary for commit 57716e4

Re-review of 57716e4 (follow-up to 9e702dd review). All four prior findings are resolved:

  • Prior suggestion (inference-time mid-episode masking): addressed. _build_history_batch now emits obs_history_is_pad and sample_actions threads it through prepare_videos + embed_prefix + model.sample_actions. The emitted (B, T) pattern (pad_pattern[i] = (i*interval - missing < 0)) matches the idx < 0 decision the existing state/camera loops already use, so the mask agrees with the actual zero-fill — verified by the new test_state_and_camera_padding_match_emitted_mask.
  • Prior nit (redundant .clone() in embed_prefix and _build_temporal_attn_mask): both removed; replaced with comments explaining that ~ and torch.zeros already allocate fresh tensors.
  • Prior nit (temporal_attn_mask docstring on SpaceTimeEncoderLayerWrapper): updated to reflect the actual isinstance dispatch in SpaceTimeSiglipVideoEncoder.forward.
  • Prior nit (perf) on repeat_interleave(num_patches, dim=0) (video_encoder.py:526): not addressed. Still minor at typical batch sizes; non-blocking.

No blocking issues found.

Follow-up to the review on PR #253.

Main change — emit obs_history_is_pad from _build_history_batch.
Previously the inference path never threaded the mask through
select_action -> sample_actions -> embed_prefix, so the encoder fell
back to "all history padded except current" via embed_prefix's
zeros + state_mask[:, -1] = True branch and the encoder's matching
None-fallback. That fixed the start-of-episode contamination but
silently masked genuine mid-episode history once the buffer filled.

_build_history_batch already owns the idx < 0 decision for each slot;
emitting obs_history_is_pad = (i*interval - missing < 0) for i in
range(n_hist) costs nothing and lets the encoder use real
mid-episode history while still masking only the truly-padded
start-of-episode slots. Threaded through PI07LowLevelPolicy.
sample_actions and PI07LowLevelFlowMatching.sample_actions to reach
embed_prefix with the live mask.

The encoder's None-fallback and embed_prefix's zeros-then-current
branch are both kept as defensive: anything that bypasses
_build_history_batch (direct sample_actions calls in tests, etc.)
still gets a safe default that matches training-time augmentations.

Nits from review:
- Removed redundant .clone() in embed_prefix's state_mask path and in
  _build_temporal_attn_mask's key_valid path. Both ~obs_history_is_pad
  expressions already allocate fresh tensors, so subsequent
  [:, -1] = True writes can't reach the caller's obs_history_is_pad.
  Kept the regression tests that pinned the no-mutation property.
- Rewrote the SpaceTimeEncoderLayerWrapper.forward docstring to drop
  the misleading "ignored by vanilla layers via SiglipEncoder's
  positional dispatch" claim — the encoder now branches on isinstance
  and never calls SiglipEncoder.forward, so vanilla layers simply
  aren't passed temporal_attn_mask.

Skipped the repeat_interleave perf nit: the suggested expand+reshape
materializes anyway because SDPA flattens (B, N) into one dim. Real
fix is a deeper refactor of the temporal Q/K/V layout; deferred.

New tests in TestBuildHistoryBatchEmitsObsHistoryIsPad
(test_pi07_cpu.py) — pin the (B, T) mask shape, the first-step
all-but-current pattern, the buffer-full all-False pattern, the
partial-fill mid-episode pattern, batch broadcasting, and slot-by-slot
agreement with the actual zero-fill of state and camera tensors.
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review of 57716e4. Prior suggestion (thread obs_history_is_pad from _build_history_batchsample_actionsembed_prefix) and three nits all addressed. Verified the new mask agrees with the per-slot zero-fill decision in the loops above and broadcasts correctly. No new blocking issues.

@shuheng-liu
Copy link
Copy Markdown
Member Author

Pushed 57716e4 addressing the review. Per-item below.

Suggestion — inference-time mid-episode masking

_build_history_batch now emits obs_history_is_pad from the same idx < 0 decision the state/camera loops already make:

pad_pattern = torch.tensor(
    [i * interval - missing < 0 for i in range(n_hist)],
    dtype=torch.bool,
    device=device,
)
temporal_batch["obs_history_is_pad"] = pad_pattern.unsqueeze(0).expand(bsize, n_hist)

Threaded through PI07LowLevelPolicy.sample_actionsPI07LowLevelFlowMatching.sample_actionsembed_prefix. Mid-episode now uses the real buffered history. The encoder's None-fallback and embed_prefix's zeros-then-current branch are kept as defensive defaults for callers that bypass _build_history_batch (direct sample_actions invocation in tests, etc.) — the typical inference path no longer hits them.

New TestBuildHistoryBatchEmitsObsHistoryIsPad (six cases) pins shape, the first-step [T, T, T, F] pattern, the buffer-full all-False pattern, the partial-fill mid-episode pattern, batch broadcasting, and slot-by-slot agreement with the actual zero-fill of state and camera tensors.

Nit — redundant .clone()

Removed in both spots — ~obs_history_is_pad and torch.zeros(...) both allocate fresh tensors, so the subsequent [:, -1] = True writes don't reach the caller's obs_history_is_pad. Kept the test_state_mask_does_not_mutate_obs_history_is_pad and test_pad_tensor_not_mutated regression tests so the no-mutation property stays pinned.

Nit — misleading docstring

Reworded:

Vanilla SiglipEncoderLayer instances never receive it: SpaceTimeSiglipVideoEncoder.forward dispatches it only to SpaceTimeEncoderLayerWrapper instances via an isinstance check, bypassing SiglipEncoder.forward entirely.

Nit (perf) — repeat_interleave(num_patches, dim=0)

Not addressed in this PR. The expand+reshape route still materializes: after unsqueeze(1).expand(B, N, T, T) the stride is (T*T, 0, T, 1), and reshape(B*N, T, T) flattens across the stride-0 dim, which forces a .contiguous() copy. The only true zero-copy fix is to keep (B, N) split through SDPA — restructuring q/k/v from (B*N, num_heads, T, head_dim) to (B, N, num_heads, T, head_dim) and passing a (B, 1, 1, T, T) mask that broadcasts naturally. That's a deeper refactor of the temporal-attention layout and not localized to the mask helper, so deferred to a follow-up. Today's worst case is well under the ~10MB ballpark you mentioned, so not on the critical path.

Verification

  • pre-commit run --all-files — all hooks pass.
  • pytest -m "not gpu" -n auto tests/policies/ — 191 passed, 2 skipped (the 6 new TestBuildHistoryBatchEmitsObsHistoryIsPad cases land in this CPU subset).
  • pytest -m "gpu" -n 0 tests/policies/test_pi07_low_level.py tests/policies/test_pi07_high_level_planner.py — 7/7 pass.
  • Determinism: same MSE / CE losses bit-identical across two seeded forward() calls (MSE = 21.001718521118164, CE = 8.557117462158203 — same as the previous commit's run, confirming the new emit only affects inference, not training).

@shuheng-liu shuheng-liu marked this pull request as ready for review May 4, 2026 19:14
@shuheng-liu shuheng-liu merged commit aad7522 into main May 4, 2026
9 checks passed
@shuheng-liu shuheng-liu deleted the claude/vibrant-chandrasekhar-e9b8ef branch May 4, 2026 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant