diff --git a/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py index 642d3cad..91c523d9 100644 --- a/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py @@ -430,8 +430,12 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: Appends the single-frame observation from ``batch`` to internal deque buffers, then assembles a batch with ``n_obs_history`` evenly-spaced - frames (interval = ``history_interval``). Early in an episode, missing - history slots are zero-padded. + frames (interval = ``history_interval``). Early in an episode the + buffer is partially filled, so some slots are zero-padded; the + returned ``"obs_history_is_pad"`` (B, T) bool tensor flags those + slots ``True`` so the model can mask them out of attention. Once the + buffer is full (typically a handful of steps in), the mask is all + ``False`` and the encoder uses the real history. Expected batch keys: - ``"state"``: (B, D) current proprioceptive state. @@ -439,8 +443,9 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - ``"prompt"``: list[str] language instructions (passed through unchanged). - Any other metadata keys are forwarded unchanged. - Returns a new dict with ``"state"`` expanded to (B, T, D) and image keys - expanded to (B, T, C, H, W), where T = ``n_obs_history``. + Returns a new dict with ``"state"`` expanded to (B, T, D), image keys + expanded to (B, T, C, H, W), and a new ``"obs_history_is_pad"`` (B, T) + bool tensor (``True`` = padded). T = ``n_obs_history``. """ assert self.config.n_obs_history is not None n_hist: int = self.config.n_obs_history @@ -465,6 +470,8 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # sample n_hist frames at the configured interval buf_len = len(self._state_buffer) missing = buf_maxlen - buf_len # how many slots are still empty + bsize = batch["state"].shape[0] + device = batch["state"].device # Pass through all non-image, non-state keys (e.g. "prompt" and other metadata). temporal_batch = {key: v for key, v in batch.items() if key not in img_keys and key != "state"} @@ -490,6 +497,21 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: cam_frames.append(self._obs_buffers[key][idx]) temporal_batch[key] = torch.stack(cam_frames, dim=1) # (B, T, C, H, W) + # Same `idx < 0` decision as the loops above: a slot is padded iff the + # buffer didn't have an entry to fill it. The pattern is identical + # for state and every camera (they share the same buffer length), so + # we emit one (B, T) mask. Broadcast across batch — every sample sees + # the same padding pattern at any given step. Without this, the + # encoder's None-fallback masks ALL history at inference (including + # genuine mid-episode frames once the buffer is full); with it, only + # the actually-padded start-of-episode slots get masked. + pad_pattern = torch.tensor( + [i * interval - missing < 0 for i in range(n_hist)], + dtype=torch.bool, + device=device, + ) + temporal_batch["obs_history_is_pad"] = pad_pattern.unsqueeze(0).expand(bsize, n_hist) + return temporal_batch @torch.no_grad() @@ -564,7 +586,12 @@ def sample_actions( batch = self.normalize_inputs(batch) - videos, vid_masks = self.prepare_videos(batch) + # `_build_history_batch` (called from `select_action` upstream) emits + # this; it's None when the caller skipped that step (e.g. n_obs_history + # is None/1, or sample_actions is invoked directly without the buffer). + obs_history_is_pad = batch.get("obs_history_is_pad") + + videos, vid_masks = self.prepare_videos(batch, obs_history_is_pad=obs_history_is_pad) lang_tokens, lang_masks = self.prepare_language(batch) response_tokens, response_masks = self.prepare_response(batch) state = self.prepare_state(batch) @@ -616,6 +643,7 @@ def sample_actions( metadata_masks=metadata_masks, response_tokens=response_tokens, response_masks=response_masks, + obs_history_is_pad=obs_history_is_pad, ) action_feature = self.config.action_feature @@ -1121,16 +1149,23 @@ def sample_time(self, bsize: int, device: torch.device | str) -> Tensor: time = time_beta * 0.999 + 0.001 return time - def embed_video(self, video: Tensor) -> Tensor: + def embed_video(self, video: Tensor, obs_history_is_pad: Tensor | None = None) -> Tensor: """Encode a video through SpaceTimeSiglip + Perceiver reducer + projection. Args: video: (B, T, C, H, W) + obs_history_is_pad: Optional ``(B, T)`` bool mask — ``True`` for + padded history frames. Threaded into the SpaceTime SigLIP + encoder so temporal attention blocks padded frames (pixel- + zeroing alone is insufficient — the patch embedding bias and + temporal PE for ``t < T-1`` are non-zero, so zero pixels + still produce non-zero hidden states the current frame would + otherwise attend to). Returns: (B, num_video_tokens, vlm_hidden_size) """ - return self.video_encoder(video) + return self.video_encoder(video, obs_history_is_pad=obs_history_is_pad) def embed_prefix( self, @@ -1241,7 +1276,7 @@ def embed_prefix( has_any_optional = bool(has_response or has_subgoal or has_metadata) for vid, vid_mask in zip(videos, vid_masks, strict=True): - vid_emb = self.embed_video(vid) # (B, num_video_tokens, vlm_hidden) + vid_emb = self.embed_video(vid, obs_history_is_pad=obs_history_is_pad) vid_emb = vid_emb.to(dtype=_preferred_dtype()) num_vid_embs = vid_emb.shape[1] @@ -1292,9 +1327,19 @@ def embed_prefix( state_emb = self.state_proj(state.to(dtype=_preferred_dtype())) num_state_tokens = state_emb.shape[1] # T if obs_history_is_pad is not None: - state_mask = ~obs_history_is_pad # True = real, False = padded + state_mask = ~obs_history_is_pad # (B, T) — `~` allocates a fresh tensor else: - state_mask = torch.ones(bsize, num_state_tokens, dtype=torch.bool, device=state.device) + # Absent → assume all history is padded; only current step is real. + state_mask = torch.zeros(bsize, num_state_tokens, dtype=torch.bool, device=state.device) + # Current step (t = T-1) is ALWAYS real even when the dataset's + # history_state_drop_prob augmentation flips obs_history_is_pad to + # all-True. Without this override the policy would condition on no + # state at all, since attention to the current state token would be + # masked out — defeating the purpose of preserving the current frame. + # Both branches above produce fresh tensors (`~` allocates; + # `torch.zeros` allocates), so the `[:, -1] = True` write below does + # not reach the caller's `obs_history_is_pad`. + state_mask[:, -1] = True embs.append(state_emb) pad_masks.append(state_mask) @@ -1739,6 +1784,7 @@ def sample_actions( metadata_masks: Tensor | None = None, response_tokens: Tensor | None = None, response_masks: Tensor | None = None, + obs_history_is_pad: Tensor | None = None, ) -> Tensor: """Inference: iteratively denoise to produce a continuous action chunk. @@ -1763,6 +1809,13 @@ def sample_actions( metadata_masks: Optional mask for metadata tokens. response_tokens: Optional subtask response token IDs. response_masks: Optional mask for response tokens. + obs_history_is_pad: Optional ``(B, T)`` bool mask flagging padded + history slots (``True`` = padded). Emitted by + ``PI07LowLevelPolicy._build_history_batch`` so the encoder can + use real mid-episode history while still masking out the + start-of-episode zero-fill. ``None`` falls back to "all + history padded except current" via ``embed_prefix`` and the + encoder's None-fallback. Returns: Denoised action chunk ``(B, chunk_size, max_action_dim)``. @@ -1786,6 +1839,7 @@ def sample_actions( metadata_masks, subgoal_images=subgoal_images, subgoal_img_masks=subgoal_img_masks, + obs_history_is_pad=obs_history_is_pad, ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 diff --git a/src/opentau/policies/pi07/low_level/video_encoder.py b/src/opentau/policies/pi07/low_level/video_encoder.py index 9ee4e7cf..e29943b6 100644 --- a/src/opentau/policies/pi07/low_level/video_encoder.py +++ b/src/opentau/policies/pi07/low_level/video_encoder.py @@ -128,8 +128,16 @@ def __init__(self, attn: SiglipAttention): def attn(self) -> SiglipAttention: return self._attn_ref[0] - def forward(self, hidden_states: Tensor) -> Tensor: - """hidden_states: (B*N, T, D) -> (B*N, T, D).""" + def forward(self, hidden_states: Tensor, temporal_attn_mask: Tensor | None = None) -> Tensor: + """hidden_states: (B*N, T, D) -> (B*N, T, D). + + Args: + hidden_states: ``(B*N, T, D)`` hidden states reshaped for temporal + attention (one sequence per spatial patch position). + temporal_attn_mask: Optional ``(B*N, 1, T, T)`` additive float mask. + ``0.0`` = attend, ``-inf`` = block. When ``None``, falls back to + the standard causal lower-triangular mask via ``is_causal=True``. + """ attn = self.attn bn, t, d = hidden_states.shape num_heads = attn.num_heads @@ -139,11 +147,19 @@ def forward(self, hidden_states: Tensor) -> Tensor: k = attn.k_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2) v = attn.v_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2) - # is_causal=True -> lower-triangular mask, each position attends to - # itself and earlier positions (our convention: t=T-1 is current). - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=True, dropout_p=0.0, scale=attn.scale - ) + if temporal_attn_mask is not None: + # Caller-supplied mask already encodes both causal AND padded-history + # blocking; do NOT also set is_causal=True (SDPA disallows combining + # the two). + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=temporal_attn_mask, is_causal=False, dropout_p=0.0, scale=attn.scale + ) + else: + # is_causal=True -> lower-triangular mask, each position attends to + # itself and earlier positions (our convention: t=T-1 is current). + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, dropout_p=0.0, scale=attn.scale + ) out = out.transpose(1, 2).reshape(bn, t, d) return attn.out_proj(out) @@ -270,11 +286,20 @@ def forward( hidden_states: Tensor, attention_mask: Optional[Tensor] = None, output_attentions: bool = False, + temporal_attn_mask: Tensor | None = None, ) -> tuple[Tensor, ...]: """hidden_states: (B*T, N, D) -> tuple starting with (B*T, N, D). - Signature matches ``SiglipEncoderLayer.forward`` so ``SiglipEncoder`` - can dispatch unchanged. + Signature extends ``SiglipEncoderLayer.forward`` with an extra + ``temporal_attn_mask`` kwarg. Vanilla ``SiglipEncoderLayer`` instances + never receive it: ``SpaceTimeSiglipVideoEncoder.forward`` dispatches + it only to ``SpaceTimeEncoderLayerWrapper`` instances via an + ``isinstance`` check, bypassing ``SiglipEncoder.forward`` entirely. + + Args: + temporal_attn_mask: Optional ``(B*N, 1, T, T)`` additive float mask + for temporal attention. ``0.0`` = attend, ``-inf`` = block. + When ``None``, the standard causal mask is used. """ t = self.num_frames bt, n, d = hidden_states.shape @@ -319,7 +344,7 @@ def forward( t_in = rearrange(x_pe, "b t n d -> (b n) t d") t_norm = self.layer_norm1(t_in) - t_out = self._temporal_attn(t_norm) + t_out = self._temporal_attn(t_norm, temporal_attn_mask=temporal_attn_mask) # Residual on the pre-PE hidden (not on x_pe): PE is a transient # positional signal, not a feature perturbation to carry forward. t_res = rearrange(x, "b t n d -> (b n) t d") + t_out @@ -442,13 +467,80 @@ def vision_tower(self) -> SiglipVisionModel: def multi_modal_projector(self) -> nn.Module: return self._multi_modal_projector_ref[0] - def forward(self, video: Tensor) -> Tensor: + @staticmethod + def _build_temporal_attn_mask( + obs_history_is_pad: Tensor, + num_patches: int, + dtype: torch.dtype, + ) -> Tensor: + """Build a causal temporal attention mask that blocks padded frames. + + Pure pixel-level zeroing of padded frames is not enough — the SigLIP + patch embedding has a learned bias and the temporal positional + embedding ``e(t)`` is non-zero for ``t < T-1``, so padded "zero" frames + still produce non-zero hidden states that the current frame would + attend to. This mask blocks attention to padded keys at the SDPA call. + + Args: + obs_history_is_pad: ``(B, T)`` bool — ``True`` for padded steps. + num_patches: ``N``, number of spatial patches per frame (the + video encoder runs one temporal sequence per patch position, + so each patch row of the (B*N) batch reuses the same mask). + dtype: Float dtype matching the hidden states (additive mask gets + added to attention scores; mismatched dtypes force upcasts). + + Returns: + ``(B*N, 1, T, T)`` additive float mask where ``0.0`` = attend and + ``-inf`` = block. Row ``i`` can attend to column ``j`` iff + ``j <= i`` (causal) **and** ``obs_history_is_pad[:, j]`` is + ``False``. The current frame (``j = T-1``) is always attendable + even if the caller set ``obs_history_is_pad[:, -1] = True`` — + losing the current frame would defeat the encoder. + """ + b, t = obs_history_is_pad.shape + device = obs_history_is_pad.device + + # Causal: position i attends to j <= i. + causal = torch.tril(torch.ones(t, t, dtype=torch.bool, device=device)) # (T, T) + + # Key-side visibility: True where frame j is real (not padded). + # Force the last frame always attendable as a defensive fallback — + # callers (e.g. the dataset's history_state_drop_prob augmentation) + # set obs_history_is_pad to all-True; without this override, the + # current frame would have no key to attend to and produce NaNs. + # `~obs_history_is_pad` allocates a fresh tensor, so the in-place + # write below does not reach the caller's `obs_history_is_pad`. + key_valid = ~obs_history_is_pad # (B, T) + key_valid[:, -1] = True + + # Combined: (B, T_query, T_key) + mask_bool = causal.unsqueeze(0) & key_valid.unsqueeze(1) # (B, T, T) + + # Bool → additive float: True → 0.0, False → -inf. + float_mask = torch.zeros(b, t, t, dtype=dtype, device=device) + float_mask.masked_fill_(~mask_bool, float("-inf")) + + # Expand for the (B*N) flattened-patch batch dimension: + # (B, T, T) → (B, 1, T, T) → repeat_interleave(N) → (B*N, 1, T, T) + float_mask = float_mask.unsqueeze(1) # (B, 1, T, T) + float_mask = float_mask.repeat_interleave(num_patches, dim=0) # (B*N, 1, T, T) + + return float_mask + + def forward(self, video: Tensor, obs_history_is_pad: Tensor | None = None) -> Tensor: """Encode a video clip and return the current-frame tokens. Args: video: ``(B, T, C, H, W)`` pixel values in ``[0, 1]``, with ``T == num_frames``, ``C == 3``, and spatial size matching the SigLIP config (224x224 by default). + obs_history_is_pad: Optional ``(B, T)`` bool mask where ``True`` + marks padded history frames. Padded frames are blocked in + the temporal attention so the current frame cannot read + contaminated hidden states from them. When ``None`` and + ``T > 1``, falls back to "only the current frame is real" + (matches inference-time semantics where + ``_build_history_batch`` does not populate this mask). Returns: ``(B, num_video_tokens, vlm_hidden_size)`` current-frame tokens, @@ -473,20 +565,54 @@ def forward(self, video: Tensor) -> Tensor: # Patch embedding + learned spatial position embedding. hidden = self.vision_tower.vision_model.embeddings(flat) + # Build temporal attention mask. Skipped at T=1 because the wrapper's + # T=1 short-circuit bypasses temporal attention entirely. + temporal_attn_mask: Tensor | None = None + if t > 1: + if obs_history_is_pad is not None: + temporal_attn_mask = self._build_temporal_attn_mask( + obs_history_is_pad, self.num_video_tokens, hidden.dtype + ) + else: + # Inference fallback: select_action -> _build_history_batch + # zero-pads missing slots but does NOT emit obs_history_is_pad. + # Treat all history as padded so the current frame's + # representation is uncontaminated by the zero-pixel + # placeholders. + fallback_pad = torch.ones(b, t, dtype=torch.bool, device=hidden.device) + fallback_pad[:, -1] = False + temporal_attn_mask = self._build_temporal_attn_mask( + fallback_pad, self.num_video_tokens, hidden.dtype + ) + # Encoder stack: standard spatial layers + wrapped every-Nth layer # with temporal attention. SpaceTimeEncoderLayerWrapper matches the # SiglipEncoderLayer signature, so we drive the loop manually here # (instead of calling SiglipEncoder.forward) so we can wrap each # layer in torch.utils.checkpoint.checkpoint when the flag is set — # the same explicit pattern PaliGemmaWithExpertModel uses. + # ``temporal_attn_mask`` is passed only to spacetime-wrapped layers + # (vanilla layers don't accept it). Under gradient checkpointing it + # MUST go in as a positional arg — torch.utils.checkpoint.checkpoint + # with use_reentrant=False does not forward kwargs to the wrapped + # function. use_ckpt = self.gradient_checkpointing and self.training for layer in self.vision_tower.vision_model.encoder.layers: + is_spacetime = isinstance(layer, SpaceTimeEncoderLayerWrapper) if use_ckpt: - layer_outputs = torch.utils.checkpoint.checkpoint( - layer, hidden, None, False, use_reentrant=False - ) + if is_spacetime: + layer_outputs = torch.utils.checkpoint.checkpoint( + layer, hidden, None, False, temporal_attn_mask, use_reentrant=False + ) + else: + layer_outputs = torch.utils.checkpoint.checkpoint( + layer, hidden, None, False, use_reentrant=False + ) else: - layer_outputs = layer(hidden, None, False) + if is_spacetime: + layer_outputs = layer(hidden, None, False, temporal_attn_mask=temporal_attn_mask) + else: + layer_outputs = layer(hidden, None, False) hidden = layer_outputs[0] hidden = self.vision_tower.vision_model.post_layernorm(hidden) diff --git a/tests/policies/test_pi07_cpu.py b/tests/policies/test_pi07_cpu.py index 7b9a9757..c64d39ee 100644 --- a/tests/policies/test_pi07_cpu.py +++ b/tests/policies/test_pi07_cpu.py @@ -585,8 +585,11 @@ def _state_proj(state): # state: (B, T, D) → (B, T, hidden) return torch.zeros(state.shape[0], state.shape[1], hidden, dtype=torch.float32) - def _embed_video(video): + def _embed_video(video, obs_history_is_pad=None): # video: (B, T, C, H, W) → (B, n_video_tokens, hidden) + # obs_history_is_pad accepted for signature compat; this fake ignores + # it (the real encoder uses it to build the temporal attention mask). + del obs_history_is_pad return torch.zeros(video.shape[0], n_video_tokens, hidden, dtype=torch.float32) fake = types.SimpleNamespace( @@ -1278,3 +1281,382 @@ def test_post_init_preserves_explicit_vlm_config_when_policy_default(self): assert cfg.gradient_checkpointing is False assert cfg.vlm_config.attention_implementation == "sdpa" assert cfg.vlm_config.gradient_checkpointing is True + + +# State mask: current step (t = T-1) must always be marked real, even when +# obs_history_is_pad sets it to True (e.g. dataset's history_state_drop_prob +# augmentation flips the entire tensor to all-True). Without the override the +# policy is conditioned on no state at all when that augmentation fires. + + +def _state_slice_indices(prompt_len: int, n_video_tokens: int, t_state: int) -> slice: + """Compute the slice of ``pad_masks`` corresponding to the state tokens. + + Layout (no optional blocks; fake tokenizer encodes every indicator phrase + to 2 tokens): + videos(n_video_tokens) + lang(prompt_len) + "State: "(2) + state(t_state) + """ + state_lo = n_video_tokens + prompt_len + 2 + return slice(state_lo, state_lo + t_state) + + +class TestStateMaskCurrentStepAlwaysReal: + """Pin the post-fix invariant: state_mask[:, -1] is True regardless of + obs_history_is_pad. Bug B from the PR #205 audit (port to pi07). + """ + + def test_state_mask_current_step_real_when_all_history_padded(self): + """``obs_history_is_pad = ones(B, T)`` (the + ``history_state_drop_prob=1.0`` case) MUST still leave the current + state token (index T-1) marked real, otherwise attention to it is + masked out and the policy conditions on no state at all. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + t_state = 4 # > 1 so the state span is multi-token + kwargs = _build_default_inputs(batch_size=bsize, t_state=t_state) + kwargs["obs_history_is_pad"] = torch.ones(bsize, t_state, dtype=torch.bool) + # No optional blocks — keeps the layout deterministic. + kwargs["response_tokens"] = None + kwargs["response_masks"] = None + kwargs["metadata_tokens"] = None + kwargs["metadata_masks"] = None + + _, pad_masks, _ = method(fake, **kwargs) + + state_slice = _state_slice_indices(prompt_len=3, n_video_tokens=3, t_state=t_state) + state_mask = pad_masks[:, state_slice] + assert state_mask.shape == (bsize, t_state) + + # Pre-fix: ~obs_history_is_pad = all-False → current also masked out. + # Post-fix: state_mask[:, -1] = True override. + for i in range(bsize): + assert state_mask[i, -1].item() is True, ( + f"sample {i}: current state token (T-1) is masked out — the " + f"history_state_drop_prob augmentation would condition on no " + f"state at all. state_mask = {state_mask[i].tolist()}" + ) + # Earlier history tokens are still padded. + assert (~state_mask[:, :-1]).all().item() is True + + def test_state_mask_none_branch_assumes_history_padded_keeps_current_real(self): + """``obs_history_is_pad = None`` means the caller didn't tell us + which slots are real. Post-fix: assume all history is padded so the + encoder cannot attend to garbage history slots — but the current step + is still real. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + t_state = 4 + kwargs = _build_default_inputs(batch_size=bsize, t_state=t_state) + kwargs["obs_history_is_pad"] = None + kwargs["response_tokens"] = None + kwargs["response_masks"] = None + kwargs["metadata_tokens"] = None + kwargs["metadata_masks"] = None + + _, pad_masks, _ = method(fake, **kwargs) + + state_slice = _state_slice_indices(prompt_len=3, n_video_tokens=3, t_state=t_state) + state_mask = pad_masks[:, state_slice] + assert state_mask.shape == (bsize, t_state) + + # Post-fix None-branch: zeros for history, True for current step. + for i in range(bsize): + assert state_mask[i, -1].item() is True + assert (~state_mask[i, :-1]).all().item() is True, ( + f"sample {i}: history slots should be padded by default in the " + f"None-branch, got state_mask = {state_mask[i].tolist()}" + ) + + def test_state_mask_partial_history_pad_preserves_current(self): + """Mixed pad pattern (typical of natural episode-boundary padding): + some history slots padded, current step real → state_mask matches + ``~obs_history_is_pad`` exactly, with the override a no-op since + the current bit was already True. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + t_state = 4 + kwargs = _build_default_inputs(batch_size=bsize, t_state=t_state) + # Sample 0: first 2 padded; sample 1: only first padded. + kwargs["obs_history_is_pad"] = torch.tensor([[True, True, False, False], [True, False, False, False]]) + kwargs["response_tokens"] = None + kwargs["response_masks"] = None + kwargs["metadata_tokens"] = None + kwargs["metadata_masks"] = None + + _, pad_masks, _ = method(fake, **kwargs) + + state_slice = _state_slice_indices(prompt_len=3, n_video_tokens=3, t_state=t_state) + state_mask = pad_masks[:, state_slice] + assert state_mask.shape == (bsize, t_state) + + torch.testing.assert_close( + state_mask, + torch.tensor([[False, False, True, True], [False, True, True, True]]), + ) + + def test_state_mask_does_not_mutate_obs_history_is_pad(self): + """The override path uses .clone() to avoid mutating the caller's + ``obs_history_is_pad`` tensor (which is also threaded into + ``embed_video`` for the temporal attention mask). + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 1 + t_state = 4 + kwargs = _build_default_inputs(batch_size=bsize, t_state=t_state) + original_pad = torch.ones(bsize, t_state, dtype=torch.bool) + kwargs["obs_history_is_pad"] = original_pad + snapshot = original_pad.clone() + kwargs["response_tokens"] = None + kwargs["response_masks"] = None + kwargs["metadata_tokens"] = None + kwargs["metadata_masks"] = None + + method(fake, **kwargs) + + torch.testing.assert_close(original_pad, snapshot) + + +# Inference vs training prompt-construction parity. The embed_prefix code path +# is the same for both — the only difference is whether ``discrete_actions`` +# is present (training) or None (inference). With identical optional-block +# inputs, the prefix tensors must agree on every dimension EXCEPT the +# trailing "Action: " + discrete-action span. + + +class TestPrefixLayoutInferenceMatchesTraining: + def test_minimal_inference_batch_collapses_prefix(self): + """An inference-style call (``discrete_actions=None``, no optional + blocks) produces the collapsed layout + videos | lang | "State: " | state | ":\\n" + with the state-end separator collapsed from ", " to ":\\n" and no + ";\\n " prefix-end. Mirrors what the user sees when calling + ``select_action`` with only state + prompt + camera. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 1 + kwargs = _build_default_inputs(batch_size=bsize, t_state=1) + # No optional blocks AND no discrete_actions (inference signature). + kwargs["response_tokens"] = None + kwargs["response_masks"] = None + kwargs["metadata_tokens"] = None + kwargs["metadata_masks"] = None + + embs, pad_masks, att_masks = method(fake, **kwargs) + + # videos(3) + lang(3) + "State: "(2) + state(1) + ":\n"(2) = 11. + # NOT videos+lang+State:+state+", "+";\n " = 13 (which would mean + # has_any_optional incorrectly evaluated True). + expected_len = 11 + assert embs.shape == (bsize, expected_len, 4) + assert pad_masks.shape == (bsize, expected_len) + assert att_masks.shape == (bsize, expected_len) + + def test_train_inference_prefix_diff_only_in_action_tail(self): + """With identical optional-block inputs, training prefix + (``discrete_actions != None``) and inference prefix + (``discrete_actions == None``) must agree on every position EXCEPT + the trailing "Action: " + discrete-action span. This pins the + property that ``prepare_*`` and ``embed_prefix`` produce the same + layout for training and inference batches. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 1 + kwargs_infer = _build_default_inputs(batch_size=bsize, t_state=1) + kwargs_infer["response_tokens"] = None + kwargs_infer["response_masks"] = None + kwargs_infer["metadata_tokens"] = None + kwargs_infer["metadata_masks"] = None + + kwargs_train = dict(kwargs_infer) + num_action_tokens = 3 + kwargs_train["discrete_actions"] = torch.zeros(bsize, num_action_tokens, dtype=torch.long) + kwargs_train["discrete_action_masks"] = torch.ones(bsize, num_action_tokens, dtype=torch.bool) + + embs_infer, pad_masks_infer, att_masks_infer = method(fake, **kwargs_infer) + embs_train, pad_masks_train, att_masks_train = method(fake, **kwargs_train) + + # Inference prefix is a strict prefix of the training one. + infer_len = embs_infer.shape[1] + # Training adds "Action: "(2 fake tokens) + discrete_actions(3) = 5. + train_len = infer_len + 2 + num_action_tokens + assert embs_train.shape[1] == train_len + + # Tensors agree on the shared inference-length prefix. + torch.testing.assert_close(embs_train[:, :infer_len], embs_infer) + torch.testing.assert_close(pad_masks_train[:, :infer_len], pad_masks_infer) + torch.testing.assert_close(att_masks_train[:, :infer_len], att_masks_infer) + + def test_full_optional_blocks_produce_same_layout_train_vs_infer(self): + """Same property as above but with ALL optional blocks present + (response, metadata, subgoals): training and inference prefixes + agree on the shared length, training extends with "Action: " + + discrete actions only. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + + def _build(with_actions: bool): + kwargs = _build_default_inputs(batch_size=bsize, t_state=1) + kwargs["response_tokens"] = torch.zeros(bsize, 5, dtype=torch.long) + kwargs["response_masks"] = torch.ones(bsize, 5, dtype=torch.bool) + kwargs["metadata_tokens"] = torch.zeros(bsize, 4, dtype=torch.long) + kwargs["metadata_masks"] = torch.ones(bsize, 4, dtype=torch.bool) + kwargs["subgoal_images"] = [torch.zeros(bsize, 3, 4, 4)] + kwargs["subgoal_img_masks"] = [torch.ones(bsize, dtype=torch.bool)] + if with_actions: + num_action_tokens = 3 + kwargs["discrete_actions"] = torch.zeros(bsize, num_action_tokens, dtype=torch.long) + kwargs["discrete_action_masks"] = torch.ones(bsize, num_action_tokens, dtype=torch.bool) + return kwargs + + embs_infer, pad_infer, att_infer = method(fake, **_build(with_actions=False)) + embs_train, pad_train, att_train = method(fake, **_build(with_actions=True)) + + infer_len = embs_infer.shape[1] + # Training extends by 2 ("Action: ") + 3 (discrete actions). + assert embs_train.shape[1] == infer_len + 5 + + torch.testing.assert_close(embs_train[:, :infer_len], embs_infer) + torch.testing.assert_close(pad_train[:, :infer_len], pad_infer) + torch.testing.assert_close(att_train[:, :infer_len], att_infer) + + +# `_build_history_batch` emits ``obs_history_is_pad`` so the encoder can use +# real mid-episode history while still masking start-of-episode zero-fill. +# Without this emit, the encoder's None-fallback masks ALL history at +# inference (mid-episode regression flagged in the PR #253 review). + + +class TestBuildHistoryBatchEmitsObsHistoryIsPad: + @staticmethod + def _make_policy_stub(*, n_obs_history: int, history_interval: int, image_keys: list[str]): + """Construct a partial PI07LowLevelPolicy that exposes only the + attrs ``_build_history_batch`` reads: ``config.{n_obs_history, + history_interval, obs_buffer_size, image_features}`` plus the deque + slots. Skips Gemma 3 init so the test stays CPU-cheap. + """ + import types + + from opentau.policies.pi07.low_level.modeling_pi07_low_level import ( + PI07LowLevelPolicy, + ) + + policy = PI07LowLevelPolicy.__new__(PI07LowLevelPolicy) + buf_size = (n_obs_history - 1) * history_interval + 1 + policy.config = types.SimpleNamespace( + n_obs_history=n_obs_history, + history_interval=history_interval, + obs_buffer_size=buf_size, + image_features=dict.fromkeys(image_keys), + ) + policy._state_buffer = None + policy._obs_buffers = None + return policy + + def _make_batch(self, image_keys: list[str], state_dim: int = 4) -> dict: + return { + "state": torch.zeros(1, state_dim), + **{k: torch.zeros(1, 3, 8, 8) for k in image_keys}, + } + + def test_first_step_marks_all_but_current_padded(self): + """At episode start, only the very first observation is in the + buffer; every other slot in the requested history was zero-filled. + Mask should be ``[True, ..., True, False]`` — the canonical case + the PR's Bug A fix protects against contamination from. + """ + policy = self._make_policy_stub(n_obs_history=4, history_interval=1, image_keys=["camera0"]) + out = policy._build_history_batch(self._make_batch(["camera0"])) + + assert "obs_history_is_pad" in out + assert out["obs_history_is_pad"].shape == (1, 4) + assert out["obs_history_is_pad"].dtype == torch.bool + assert out["obs_history_is_pad"].tolist() == [[True, True, True, False]] + + def test_buffer_full_emits_all_false(self): + """Once the buffer is full (after ``obs_buffer_size`` calls), every + slot maps to a real observation — mask is all-False. This is the + mid-episode case the previous PR regressed: with the None-fallback, + the encoder masked these real frames out as if they were padded. + """ + policy = self._make_policy_stub(n_obs_history=4, history_interval=2, image_keys=["camera0"]) + # obs_buffer_size = (4-1)*2 + 1 = 7. Need 7 calls to fill. + batch = self._make_batch(["camera0"]) + for _ in range(7): + out = policy._build_history_batch(batch) + assert out["obs_history_is_pad"].tolist() == [[False, False, False, False]] + + def test_partial_fill_marks_only_unfilled_slots(self): + """After ``k < obs_buffer_size`` calls, the leading ``T - + ceil(k / interval)`` slots are still virtual past-steps. With + ``n_obs_history=4, history_interval=2`` (buffer_size=7), after 4 + calls the deque has 4 entries → ``missing = 3`` → slots with + ``i*interval - 3 < 0`` are padded: i=0 → -3 (T), i=1 → -1 (T), + i=2 → 1 (F), i=3 → 3 (F). So mask = [T, T, F, F]. + """ + policy = self._make_policy_stub(n_obs_history=4, history_interval=2, image_keys=["camera0"]) + batch = self._make_batch(["camera0"]) + for _ in range(4): + out = policy._build_history_batch(batch) + assert out["obs_history_is_pad"].tolist() == [[True, True, False, False]] + + def test_mask_is_broadcast_over_batch(self): + """The buffer is shared across batch elements (every sample sees + the same buffer length at any given step), so the (B, T) mask is + the same across the batch dim. Verify by emitting from a B=3 batch. + """ + policy = self._make_policy_stub(n_obs_history=4, history_interval=1, image_keys=["camera0"]) + batch = { + "state": torch.zeros(3, 4), + "camera0": torch.zeros(3, 3, 8, 8), + } + out = policy._build_history_batch(batch) + + assert out["obs_history_is_pad"].shape == (3, 4) + # Every batch element sees the same mask. + assert torch.all(out["obs_history_is_pad"] == out["obs_history_is_pad"][0:1]) + + def test_n_obs_history_one_emits_all_false(self): + """With ``n_obs_history=1`` the buffer always contains the current + frame — no historical slots exist, so the (B, 1) mask is False + from step 1. (In practice ``select_action`` skips + ``_build_history_batch`` entirely when ``n_obs_history <= 1``, so + this is just defending the function's own contract.) + """ + policy = self._make_policy_stub(n_obs_history=1, history_interval=1, image_keys=["camera0"]) + out = policy._build_history_batch(self._make_batch(["camera0"])) + assert out["obs_history_is_pad"].tolist() == [[False]] + + def test_state_and_camera_padding_match_emitted_mask(self): + """The emitted mask must agree slot-for-slot with the actual + zero-padding pattern of state and camera tensors. State / camera + are zeroed where ``idx < 0``; the mask flags the same slots ``True``. + """ + policy = self._make_policy_stub(n_obs_history=3, history_interval=1, image_keys=["camera0"]) + # Inject a non-zero observation so we can detect zero-fill. + batch = { + "state": torch.full((1, 4), 7.0), + "camera0": torch.full((1, 3, 8, 8), 5.0), + } + out = policy._build_history_batch(batch) + # After one call: missing = 2; mask = [True, True, False]. + is_pad = out["obs_history_is_pad"][0] # (T,) + state = out["state"][0] # (T, D) + cam = out["camera0"][0] # (T, C, H, W) + for t, padded in enumerate(is_pad.tolist()): + if padded: + assert torch.all(state[t] == 0.0), f"state[{t}] not zero-filled" + assert torch.all(cam[t] == 0.0), f"camera[{t}] not zero-filled" + else: + assert torch.all(state[t] == 7.0), f"state[{t}] zero-filled but mask says real" + assert torch.all(cam[t] == 5.0), f"camera[{t}] zero-filled but mask says real" diff --git a/tests/policies/test_pi07_video_encoder_cpu.py b/tests/policies/test_pi07_video_encoder_cpu.py index 4d4fe0f2..8916e4a6 100644 --- a/tests/policies/test_pi07_video_encoder_cpu.py +++ b/tests/policies/test_pi07_video_encoder_cpu.py @@ -353,3 +353,197 @@ def test_non_divisible_batch_succeeds_when_suppressed(self): with suppress_spacetime_temporal(encoder.vision_tower): out = wrapper.forward(bad_input)[0] assert out.shape == bad_input.shape + + +# Padded-history attention masking (Bug A from the audit against PR #205). +# +# Pixel-zeroing of padded frames is not enough — the SigLIP patch embedding +# carries a learned bias and the temporal positional embedding e(t) is +# non-zero for t < T-1. So zero pixels still produce non-zero hidden states +# that the current frame would otherwise attend to. The encoder must build +# an attention mask blocking padded keys at the SDPA call. + + +class TestTemporalAttentionMaskBlocksPaddedHistory: + def test_current_frame_invariant_to_padded_history(self): + """Gold-standard: with the first three frames marked padded and + identical current frame at ``[:, -1]``, the encoder output must NOT + depend on the pixel values of those padded frames. + + Pre-fix code reads contaminated hidden states from padded frames + through the temporal attention path and the two outputs diverge. + """ + torch.manual_seed(0) + encoder = _make_encoder(num_frames=4, spacetime_stride=4).eval() + + current = torch.rand(1, 1, 3, 224, 224) + # vid_a: random padded history. + vid_a = torch.cat([torch.rand(1, 3, 3, 224, 224), current], dim=1) + # vid_b: zeroed padded history; same current frame. + vid_b = torch.cat([torch.zeros(1, 3, 3, 224, 224), current], dim=1) + + obs_history_is_pad = torch.tensor([[True, True, True, False]]) + with torch.no_grad(): + out_a = encoder(vid_a, obs_history_is_pad=obs_history_is_pad) + out_b = encoder(vid_b, obs_history_is_pad=obs_history_is_pad) + + torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5) + + def test_inference_fallback_blocks_all_history(self): + """``obs_history_is_pad=None`` triggers the inference fallback that + treats every history slot as padded; pins the property that + ``select_action`` (which never emits ``obs_history_is_pad``) does + not let zero-pixel placeholders contaminate the current frame. + """ + torch.manual_seed(1) + encoder = _make_encoder(num_frames=4, spacetime_stride=4).eval() + + current = torch.rand(1, 1, 3, 224, 224) + vid_a = torch.cat([torch.rand(1, 3, 3, 224, 224), current], dim=1) + vid_b = torch.cat([torch.zeros(1, 3, 3, 224, 224), current], dim=1) + + with torch.no_grad(): + out_a = encoder(vid_a) # obs_history_is_pad defaults to None + out_b = encoder(vid_b) + + torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5) + + def test_temporal_mask_shape_and_values(self): + """Exercise ``_build_temporal_attn_mask`` directly. Documents the + contract: causal lower-triangular AND key-side ``~obs_history_is_pad`` + with current-frame override; (B*N, 1, T, T) shape; additive 0/-inf. + """ + obs_history_is_pad = torch.tensor([[True, False, False]]) + mask = SpaceTimeSiglipVideoEncoder._build_temporal_attn_mask( + obs_history_is_pad, num_patches=4, dtype=torch.float32 + ) + + # Shape: B=1, N=4, T=3 → (B*N, 1, T, T) = (4, 1, 3, 3). + assert mask.shape == (4, 1, 3, 3) + assert mask.dtype == torch.float32 + + # All N rows for batch 0 share the same (T, T) submatrix. + m = mask[0, 0] # (3, 3) + + # Diagonal: each query attends to itself when key is real. + # Row 0: query=frame0 (padded query, but causal allows j=0 only; + # since frame 0 is padded its key is blocked → -inf). The current-frame + # override only forces key j=T-1 real, not j=0. + assert m[0, 0] == float("-inf"), "padded frame 0 key should be blocked" + # Row 1: query=frame1 attends to {0 (padded → -inf), 1 (real → 0)}. + assert m[1, 0] == float("-inf") + assert m[1, 1] == 0.0 + # Row 2: query=current attends to {0 (padded → -inf), 1 (real → 0), + # 2 (current, always real → 0)}. + assert m[2, 0] == float("-inf") + assert m[2, 1] == 0.0 + assert m[2, 2] == 0.0 + + # Upper triangular off-diagonals are blocked (causal). + assert m[0, 1] == float("-inf") + assert m[0, 2] == float("-inf") + assert m[1, 2] == float("-inf") + + # Patch rows for the same batch are identical (mask is broadcast over + # spatial patch positions). + for n in range(1, 4): + torch.testing.assert_close(mask[n, 0], mask[0, 0]) + + def test_current_frame_override_when_pad_includes_current(self): + """Defensive: even if a caller passes ``obs_history_is_pad`` with + the current step (T-1) marked padded — as the dataset's + ``history_state_drop_prob`` augmentation does — the mask must keep + the current frame attendable, otherwise the current query has no + valid key and softmax produces NaNs. + """ + all_padded = torch.tensor([[True, True, True, True]]) + mask = SpaceTimeSiglipVideoEncoder._build_temporal_attn_mask( + all_padded, num_patches=1, dtype=torch.float32 + ) + # Current row (T-1=3) attends to the current key (T-1=3) at minimum. + m = mask[0, 0] + assert m[3, 3] == 0.0, "current frame must remain self-attendable" + + def test_pad_tensor_not_mutated(self): + """The mask helper must not mutate the caller's ``obs_history_is_pad`` + tensor (it is also consumed downstream as the state mask). + """ + all_padded = torch.tensor([[True, True, True, True]]) + snapshot = all_padded.clone() + SpaceTimeSiglipVideoEncoder._build_temporal_attn_mask(all_padded, num_patches=1, dtype=torch.float32) + torch.testing.assert_close(all_padded, snapshot) + + +# Subgoal images share the video encoder weights — the only set of SigLIP +# weights that exists. ``embed_image`` (subgoal path) and the encoder's +# T=1 short-circuit go through the same modules; outputs must match +# byte-for-byte. + + +class TestSubgoalSharesVideoEncoderWeights: + def test_embed_image_via_suppress_matches_encoder_t1(self): + """A single-frame forward through ``vision_tower`` under + ``suppress_spacetime_temporal`` (the path ``Gemma3WithExpertModel. + embed_image`` takes for subgoal images) must produce byte-identical + output to ``SpaceTimeSiglipVideoEncoder.forward`` at T=1 on the same + image, **before** the [0, 1] → [-1, 1] rescale (since ``embed_image`` + expects callers to have already rescaled, while the encoder rescales + internally). + """ + torch.manual_seed(2) + vision_tower, projector = _build_siglip_and_projector() + encoder = SpaceTimeSiglipVideoEncoder( + vision_tower=vision_tower, + multi_modal_projector=projector, + num_frames=1, + spacetime_layer_stride=4, + ).eval() + + # Encoder takes [0, 1] and rescales internally. + image_unit = torch.rand(2, 3, 224, 224) + + # embed_image equivalent: the caller pre-rescales to [-1, 1] and + # invokes the same vision_tower under suppress_spacetime_temporal. + image_siglip = image_unit * 2.0 - 1.0 + with torch.no_grad(): + with suppress_spacetime_temporal(vision_tower): + last_hidden = vision_tower(pixel_values=image_siglip).last_hidden_state + out_image = projector(last_hidden) + out_video = encoder(image_unit.unsqueeze(1)) # (B, 1, C, H, W) + + torch.testing.assert_close(out_video, out_image, rtol=1e-5, atol=1e-5) + + def test_encoder_owns_no_siglip_parameters(self): + """``SpaceTimeSiglipVideoEncoder`` holds ``vision_tower`` and + ``multi_modal_projector`` by reference in lists. Its OWN parameters + and state_dict must therefore be empty — proving there is no second + copy of the SigLIP weights anywhere in the encoder's tree. + """ + encoder = _make_encoder(num_frames=4, spacetime_stride=4) + + own_param_count = sum(1 for _ in encoder.parameters()) + assert own_param_count == 0, ( + f"encoder owns {own_param_count} params; expected 0 — vision_tower " + "and multi_modal_projector should be held by reference, not registered" + ) + + # state_dict contains only registered buffers/params. _temporal_pe is + # non-persistent (excluded from state_dict). Adopted submodule keys + # are accessed through the wrapped vision_tower, NOT under the + # encoder's path. So the encoder's own state_dict must be empty. + own_state = encoder.state_dict() + assert len(own_state) == 0, f"encoder.state_dict() = {list(own_state.keys())}; expected empty" + + def test_only_one_copy_of_siglip_q_proj(self): + """Across the wrapped vision_tower's full state_dict, there is + exactly ONE copy of each SigLIP attention weight. Wrapping with + space-time attention must not duplicate any q/k/v/o projection. + """ + encoder = _make_encoder(num_frames=4, spacetime_stride=4) + keys = list(encoder.vision_tower.state_dict().keys()) + q_proj_keys = [k for k in keys if k.endswith("self_attn.q_proj.weight")] + + # 27 SigLIP layers → exactly 27 q_proj.weight entries. + assert len(q_proj_keys) == 27, ( + f"expected 27 q_proj entries (one per layer); got {len(q_proj_keys)}: {q_proj_keys}" + )