Skip to content

Commit 57716e4

Browse files
committed
fix(pi07): emit obs_history_is_pad at inference; address review nits
Follow-up to the review on PR #253. Main change — emit obs_history_is_pad from _build_history_batch. Previously the inference path never threaded the mask through select_action -> sample_actions -> embed_prefix, so the encoder fell back to "all history padded except current" via embed_prefix's zeros + state_mask[:, -1] = True branch and the encoder's matching None-fallback. That fixed the start-of-episode contamination but silently masked genuine mid-episode history once the buffer filled. _build_history_batch already owns the idx < 0 decision for each slot; emitting obs_history_is_pad = (i*interval - missing < 0) for i in range(n_hist) costs nothing and lets the encoder use real mid-episode history while still masking only the truly-padded start-of-episode slots. Threaded through PI07LowLevelPolicy. sample_actions and PI07LowLevelFlowMatching.sample_actions to reach embed_prefix with the live mask. The encoder's None-fallback and embed_prefix's zeros-then-current branch are both kept as defensive: anything that bypasses _build_history_batch (direct sample_actions calls in tests, etc.) still gets a safe default that matches training-time augmentations. Nits from review: - Removed redundant .clone() in embed_prefix's state_mask path and in _build_temporal_attn_mask's key_valid path. Both ~obs_history_is_pad expressions already allocate fresh tensors, so subsequent [:, -1] = True writes can't reach the caller's obs_history_is_pad. Kept the regression tests that pinned the no-mutation property. - Rewrote the SpaceTimeEncoderLayerWrapper.forward docstring to drop the misleading "ignored by vanilla layers via SiglipEncoder's positional dispatch" claim — the encoder now branches on isinstance and never calls SiglipEncoder.forward, so vanilla layers simply aren't passed temporal_attn_mask. Skipped the repeat_interleave perf nit: the suggested expand+reshape materializes anyway because SDPA flattens (B, N) into one dim. Real fix is a deeper refactor of the temporal Q/K/V layout; deferred. New tests in TestBuildHistoryBatchEmitsObsHistoryIsPad (test_pi07_cpu.py) — pin the (B, T) mask shape, the first-step all-but-current pattern, the buffer-full all-False pattern, the partial-fill mid-episode pattern, batch broadcasting, and slot-by-slot agreement with the actual zero-fill of state and camera tensors.
1 parent 9e702dd commit 57716e4

3 files changed

Lines changed: 183 additions & 10 deletions

File tree

src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,17 +430,22 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
430430
431431
Appends the single-frame observation from ``batch`` to internal deque
432432
buffers, then assembles a batch with ``n_obs_history`` evenly-spaced
433-
frames (interval = ``history_interval``). Early in an episode, missing
434-
history slots are zero-padded.
433+
frames (interval = ``history_interval``). Early in an episode the
434+
buffer is partially filled, so some slots are zero-padded; the
435+
returned ``"obs_history_is_pad"`` (B, T) bool tensor flags those
436+
slots ``True`` so the model can mask them out of attention. Once the
437+
buffer is full (typically a handful of steps in), the mask is all
438+
``False`` and the encoder uses the real history.
435439
436440
Expected batch keys:
437441
- ``"state"``: (B, D) current proprioceptive state.
438442
- image keys matching ``config.image_features``: (B, C, H, W) camera frames.
439443
- ``"prompt"``: list[str] language instructions (passed through unchanged).
440444
- Any other metadata keys are forwarded unchanged.
441445
442-
Returns a new dict with ``"state"`` expanded to (B, T, D) and image keys
443-
expanded to (B, T, C, H, W), where T = ``n_obs_history``.
446+
Returns a new dict with ``"state"`` expanded to (B, T, D), image keys
447+
expanded to (B, T, C, H, W), and a new ``"obs_history_is_pad"`` (B, T)
448+
bool tensor (``True`` = padded). T = ``n_obs_history``.
444449
"""
445450
assert self.config.n_obs_history is not None
446451
n_hist: int = self.config.n_obs_history
@@ -465,6 +470,8 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
465470
# sample n_hist frames at the configured interval
466471
buf_len = len(self._state_buffer)
467472
missing = buf_maxlen - buf_len # how many slots are still empty
473+
bsize = batch["state"].shape[0]
474+
device = batch["state"].device
468475

469476
# Pass through all non-image, non-state keys (e.g. "prompt" and other metadata).
470477
temporal_batch = {key: v for key, v in batch.items() if key not in img_keys and key != "state"}
@@ -490,6 +497,21 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
490497
cam_frames.append(self._obs_buffers[key][idx])
491498
temporal_batch[key] = torch.stack(cam_frames, dim=1) # (B, T, C, H, W)
492499

500+
# Same `idx < 0` decision as the loops above: a slot is padded iff the
501+
# buffer didn't have an entry to fill it. The pattern is identical
502+
# for state and every camera (they share the same buffer length), so
503+
# we emit one (B, T) mask. Broadcast across batch — every sample sees
504+
# the same padding pattern at any given step. Without this, the
505+
# encoder's None-fallback masks ALL history at inference (including
506+
# genuine mid-episode frames once the buffer is full); with it, only
507+
# the actually-padded start-of-episode slots get masked.
508+
pad_pattern = torch.tensor(
509+
[i * interval - missing < 0 for i in range(n_hist)],
510+
dtype=torch.bool,
511+
device=device,
512+
)
513+
temporal_batch["obs_history_is_pad"] = pad_pattern.unsqueeze(0).expand(bsize, n_hist)
514+
493515
return temporal_batch
494516

495517
@torch.no_grad()
@@ -564,7 +586,12 @@ def sample_actions(
564586

565587
batch = self.normalize_inputs(batch)
566588

567-
videos, vid_masks = self.prepare_videos(batch)
589+
# `_build_history_batch` (called from `select_action` upstream) emits
590+
# this; it's None when the caller skipped that step (e.g. n_obs_history
591+
# is None/1, or sample_actions is invoked directly without the buffer).
592+
obs_history_is_pad = batch.get("obs_history_is_pad")
593+
594+
videos, vid_masks = self.prepare_videos(batch, obs_history_is_pad=obs_history_is_pad)
568595
lang_tokens, lang_masks = self.prepare_language(batch)
569596
response_tokens, response_masks = self.prepare_response(batch)
570597
state = self.prepare_state(batch)
@@ -616,6 +643,7 @@ def sample_actions(
616643
metadata_masks=metadata_masks,
617644
response_tokens=response_tokens,
618645
response_masks=response_masks,
646+
obs_history_is_pad=obs_history_is_pad,
619647
)
620648

621649
action_feature = self.config.action_feature
@@ -1299,7 +1327,7 @@ def embed_prefix(
12991327
state_emb = self.state_proj(state.to(dtype=_preferred_dtype()))
13001328
num_state_tokens = state_emb.shape[1] # T
13011329
if obs_history_is_pad is not None:
1302-
state_mask = ~obs_history_is_pad # (B, T)
1330+
state_mask = ~obs_history_is_pad # (B, T) — `~` allocates a fresh tensor
13031331
else:
13041332
# Absent → assume all history is padded; only current step is real.
13051333
state_mask = torch.zeros(bsize, num_state_tokens, dtype=torch.bool, device=state.device)
@@ -1308,7 +1336,9 @@ def embed_prefix(
13081336
# all-True. Without this override the policy would condition on no
13091337
# state at all, since attention to the current state token would be
13101338
# masked out — defeating the purpose of preserving the current frame.
1311-
state_mask = state_mask.clone() # avoid in-place mutation of obs_history_is_pad
1339+
# Both branches above produce fresh tensors (`~` allocates;
1340+
# `torch.zeros` allocates), so the `[:, -1] = True` write below does
1341+
# not reach the caller's `obs_history_is_pad`.
13121342
state_mask[:, -1] = True
13131343

13141344
embs.append(state_emb)
@@ -1754,6 +1784,7 @@ def sample_actions(
17541784
metadata_masks: Tensor | None = None,
17551785
response_tokens: Tensor | None = None,
17561786
response_masks: Tensor | None = None,
1787+
obs_history_is_pad: Tensor | None = None,
17571788
) -> Tensor:
17581789
"""Inference: iteratively denoise to produce a continuous action chunk.
17591790
@@ -1778,6 +1809,13 @@ def sample_actions(
17781809
metadata_masks: Optional mask for metadata tokens.
17791810
response_tokens: Optional subtask response token IDs.
17801811
response_masks: Optional mask for response tokens.
1812+
obs_history_is_pad: Optional ``(B, T)`` bool mask flagging padded
1813+
history slots (``True`` = padded). Emitted by
1814+
``PI07LowLevelPolicy._build_history_batch`` so the encoder can
1815+
use real mid-episode history while still masking out the
1816+
start-of-episode zero-fill. ``None`` falls back to "all
1817+
history padded except current" via ``embed_prefix`` and the
1818+
encoder's None-fallback.
17811819
17821820
Returns:
17831821
Denoised action chunk ``(B, chunk_size, max_action_dim)``.
@@ -1801,6 +1839,7 @@ def sample_actions(
18011839
metadata_masks,
18021840
subgoal_images=subgoal_images,
18031841
subgoal_img_masks=subgoal_img_masks,
1842+
obs_history_is_pad=obs_history_is_pad,
18041843
)
18051844
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
18061845
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1

src/opentau/policies/pi07/low_level/video_encoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,10 @@ def forward(
291291
"""hidden_states: (B*T, N, D) -> tuple starting with (B*T, N, D).
292292
293293
Signature extends ``SiglipEncoderLayer.forward`` with an extra
294-
``temporal_attn_mask`` kwarg (ignored by vanilla layers via
295-
``SiglipEncoder``'s positional dispatch).
294+
``temporal_attn_mask`` kwarg. Vanilla ``SiglipEncoderLayer`` instances
295+
never receive it: ``SpaceTimeSiglipVideoEncoder.forward`` dispatches
296+
it only to ``SpaceTimeEncoderLayerWrapper`` instances via an
297+
``isinstance`` check, bypassing ``SiglipEncoder.forward`` entirely.
296298
297299
Args:
298300
temporal_attn_mask: Optional ``(B*N, 1, T, T)`` additive float mask
@@ -506,8 +508,9 @@ def _build_temporal_attn_mask(
506508
# callers (e.g. the dataset's history_state_drop_prob augmentation)
507509
# set obs_history_is_pad to all-True; without this override, the
508510
# current frame would have no key to attend to and produce NaNs.
511+
# `~obs_history_is_pad` allocates a fresh tensor, so the in-place
512+
# write below does not reach the caller's `obs_history_is_pad`.
509513
key_valid = ~obs_history_is_pad # (B, T)
510-
key_valid = key_valid.clone() # avoid in-place mutation of caller's tensor
511514
key_valid[:, -1] = True
512515

513516
# Combined: (B, T_query, T_key)

tests/policies/test_pi07_cpu.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,3 +1529,134 @@ def _build(with_actions: bool):
15291529
torch.testing.assert_close(embs_train[:, :infer_len], embs_infer)
15301530
torch.testing.assert_close(pad_train[:, :infer_len], pad_infer)
15311531
torch.testing.assert_close(att_train[:, :infer_len], att_infer)
1532+
1533+
1534+
# `_build_history_batch` emits ``obs_history_is_pad`` so the encoder can use
1535+
# real mid-episode history while still masking start-of-episode zero-fill.
1536+
# Without this emit, the encoder's None-fallback masks ALL history at
1537+
# inference (mid-episode regression flagged in the PR #253 review).
1538+
1539+
1540+
class TestBuildHistoryBatchEmitsObsHistoryIsPad:
1541+
@staticmethod
1542+
def _make_policy_stub(*, n_obs_history: int, history_interval: int, image_keys: list[str]):
1543+
"""Construct a partial PI07LowLevelPolicy that exposes only the
1544+
attrs ``_build_history_batch`` reads: ``config.{n_obs_history,
1545+
history_interval, obs_buffer_size, image_features}`` plus the deque
1546+
slots. Skips Gemma 3 init so the test stays CPU-cheap.
1547+
"""
1548+
import types
1549+
1550+
from opentau.policies.pi07.low_level.modeling_pi07_low_level import (
1551+
PI07LowLevelPolicy,
1552+
)
1553+
1554+
policy = PI07LowLevelPolicy.__new__(PI07LowLevelPolicy)
1555+
buf_size = (n_obs_history - 1) * history_interval + 1
1556+
policy.config = types.SimpleNamespace(
1557+
n_obs_history=n_obs_history,
1558+
history_interval=history_interval,
1559+
obs_buffer_size=buf_size,
1560+
image_features=dict.fromkeys(image_keys),
1561+
)
1562+
policy._state_buffer = None
1563+
policy._obs_buffers = None
1564+
return policy
1565+
1566+
def _make_batch(self, image_keys: list[str], state_dim: int = 4) -> dict:
1567+
return {
1568+
"state": torch.zeros(1, state_dim),
1569+
**{k: torch.zeros(1, 3, 8, 8) for k in image_keys},
1570+
}
1571+
1572+
def test_first_step_marks_all_but_current_padded(self):
1573+
"""At episode start, only the very first observation is in the
1574+
buffer; every other slot in the requested history was zero-filled.
1575+
Mask should be ``[True, ..., True, False]`` — the canonical case
1576+
the PR's Bug A fix protects against contamination from.
1577+
"""
1578+
policy = self._make_policy_stub(n_obs_history=4, history_interval=1, image_keys=["camera0"])
1579+
out = policy._build_history_batch(self._make_batch(["camera0"]))
1580+
1581+
assert "obs_history_is_pad" in out
1582+
assert out["obs_history_is_pad"].shape == (1, 4)
1583+
assert out["obs_history_is_pad"].dtype == torch.bool
1584+
assert out["obs_history_is_pad"].tolist() == [[True, True, True, False]]
1585+
1586+
def test_buffer_full_emits_all_false(self):
1587+
"""Once the buffer is full (after ``obs_buffer_size`` calls), every
1588+
slot maps to a real observation — mask is all-False. This is the
1589+
mid-episode case the previous PR regressed: with the None-fallback,
1590+
the encoder masked these real frames out as if they were padded.
1591+
"""
1592+
policy = self._make_policy_stub(n_obs_history=4, history_interval=2, image_keys=["camera0"])
1593+
# obs_buffer_size = (4-1)*2 + 1 = 7. Need 7 calls to fill.
1594+
batch = self._make_batch(["camera0"])
1595+
for _ in range(7):
1596+
out = policy._build_history_batch(batch)
1597+
assert out["obs_history_is_pad"].tolist() == [[False, False, False, False]]
1598+
1599+
def test_partial_fill_marks_only_unfilled_slots(self):
1600+
"""After ``k < obs_buffer_size`` calls, the leading ``T -
1601+
ceil(k / interval)`` slots are still virtual past-steps. With
1602+
``n_obs_history=4, history_interval=2`` (buffer_size=7), after 4
1603+
calls the deque has 4 entries → ``missing = 3`` → slots with
1604+
``i*interval - 3 < 0`` are padded: i=0 → -3 (T), i=1 → -1 (T),
1605+
i=2 → 1 (F), i=3 → 3 (F). So mask = [T, T, F, F].
1606+
"""
1607+
policy = self._make_policy_stub(n_obs_history=4, history_interval=2, image_keys=["camera0"])
1608+
batch = self._make_batch(["camera0"])
1609+
for _ in range(4):
1610+
out = policy._build_history_batch(batch)
1611+
assert out["obs_history_is_pad"].tolist() == [[True, True, False, False]]
1612+
1613+
def test_mask_is_broadcast_over_batch(self):
1614+
"""The buffer is shared across batch elements (every sample sees
1615+
the same buffer length at any given step), so the (B, T) mask is
1616+
the same across the batch dim. Verify by emitting from a B=3 batch.
1617+
"""
1618+
policy = self._make_policy_stub(n_obs_history=4, history_interval=1, image_keys=["camera0"])
1619+
batch = {
1620+
"state": torch.zeros(3, 4),
1621+
"camera0": torch.zeros(3, 3, 8, 8),
1622+
}
1623+
out = policy._build_history_batch(batch)
1624+
1625+
assert out["obs_history_is_pad"].shape == (3, 4)
1626+
# Every batch element sees the same mask.
1627+
assert torch.all(out["obs_history_is_pad"] == out["obs_history_is_pad"][0:1])
1628+
1629+
def test_n_obs_history_one_emits_all_false(self):
1630+
"""With ``n_obs_history=1`` the buffer always contains the current
1631+
frame — no historical slots exist, so the (B, 1) mask is False
1632+
from step 1. (In practice ``select_action`` skips
1633+
``_build_history_batch`` entirely when ``n_obs_history <= 1``, so
1634+
this is just defending the function's own contract.)
1635+
"""
1636+
policy = self._make_policy_stub(n_obs_history=1, history_interval=1, image_keys=["camera0"])
1637+
out = policy._build_history_batch(self._make_batch(["camera0"]))
1638+
assert out["obs_history_is_pad"].tolist() == [[False]]
1639+
1640+
def test_state_and_camera_padding_match_emitted_mask(self):
1641+
"""The emitted mask must agree slot-for-slot with the actual
1642+
zero-padding pattern of state and camera tensors. State / camera
1643+
are zeroed where ``idx < 0``; the mask flags the same slots ``True``.
1644+
"""
1645+
policy = self._make_policy_stub(n_obs_history=3, history_interval=1, image_keys=["camera0"])
1646+
# Inject a non-zero observation so we can detect zero-fill.
1647+
batch = {
1648+
"state": torch.full((1, 4), 7.0),
1649+
"camera0": torch.full((1, 3, 8, 8), 5.0),
1650+
}
1651+
out = policy._build_history_batch(batch)
1652+
# After one call: missing = 2; mask = [True, True, False].
1653+
is_pad = out["obs_history_is_pad"][0] # (T,)
1654+
state = out["state"][0] # (T, D)
1655+
cam = out["camera0"][0] # (T, C, H, W)
1656+
for t, padded in enumerate(is_pad.tolist()):
1657+
if padded:
1658+
assert torch.all(state[t] == 0.0), f"state[{t}] not zero-filled"
1659+
assert torch.all(cam[t] == 0.0), f"camera[{t}] not zero-filled"
1660+
else:
1661+
assert torch.all(state[t] == 7.0), f"state[{t}] zero-filled but mask says real"
1662+
assert torch.all(cam[t] == 5.0), f"camera[{t}] zero-filled but mask says real"

0 commit comments

Comments
 (0)