Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 64 additions & 10 deletions src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,17 +430,22 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:

Appends the single-frame observation from ``batch`` to internal deque
buffers, then assembles a batch with ``n_obs_history`` evenly-spaced
frames (interval = ``history_interval``). Early in an episode, missing
history slots are zero-padded.
frames (interval = ``history_interval``). Early in an episode the
buffer is partially filled, so some slots are zero-padded; the
returned ``"obs_history_is_pad"`` (B, T) bool tensor flags those
slots ``True`` so the model can mask them out of attention. Once the
buffer is full (typically a handful of steps in), the mask is all
``False`` and the encoder uses the real history.

Expected batch keys:
- ``"state"``: (B, D) current proprioceptive state.
- image keys matching ``config.image_features``: (B, C, H, W) camera frames.
- ``"prompt"``: list[str] language instructions (passed through unchanged).
- Any other metadata keys are forwarded unchanged.

Returns a new dict with ``"state"`` expanded to (B, T, D) and image keys
expanded to (B, T, C, H, W), where T = ``n_obs_history``.
Returns a new dict with ``"state"`` expanded to (B, T, D), image keys
expanded to (B, T, C, H, W), and a new ``"obs_history_is_pad"`` (B, T)
bool tensor (``True`` = padded). T = ``n_obs_history``.
"""
assert self.config.n_obs_history is not None
n_hist: int = self.config.n_obs_history
Expand All @@ -465,6 +470,8 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# sample n_hist frames at the configured interval
buf_len = len(self._state_buffer)
missing = buf_maxlen - buf_len # how many slots are still empty
bsize = batch["state"].shape[0]
device = batch["state"].device

# Pass through all non-image, non-state keys (e.g. "prompt" and other metadata).
temporal_batch = {key: v for key, v in batch.items() if key not in img_keys and key != "state"}
Expand All @@ -490,6 +497,21 @@ def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
cam_frames.append(self._obs_buffers[key][idx])
temporal_batch[key] = torch.stack(cam_frames, dim=1) # (B, T, C, H, W)

# Same `idx < 0` decision as the loops above: a slot is padded iff the
# buffer didn't have an entry to fill it. The pattern is identical
# for state and every camera (they share the same buffer length), so
# we emit one (B, T) mask. Broadcast across batch — every sample sees
# the same padding pattern at any given step. Without this, the
# encoder's None-fallback masks ALL history at inference (including
# genuine mid-episode frames once the buffer is full); with it, only
# the actually-padded start-of-episode slots get masked.
pad_pattern = torch.tensor(
[i * interval - missing < 0 for i in range(n_hist)],
dtype=torch.bool,
device=device,
)
temporal_batch["obs_history_is_pad"] = pad_pattern.unsqueeze(0).expand(bsize, n_hist)

return temporal_batch

@torch.no_grad()
Expand Down Expand Up @@ -564,7 +586,12 @@ def sample_actions(

batch = self.normalize_inputs(batch)

videos, vid_masks = self.prepare_videos(batch)
# `_build_history_batch` (called from `select_action` upstream) emits
# this; it's None when the caller skipped that step (e.g. n_obs_history
# is None/1, or sample_actions is invoked directly without the buffer).
obs_history_is_pad = batch.get("obs_history_is_pad")

videos, vid_masks = self.prepare_videos(batch, obs_history_is_pad=obs_history_is_pad)
lang_tokens, lang_masks = self.prepare_language(batch)
response_tokens, response_masks = self.prepare_response(batch)
state = self.prepare_state(batch)
Expand Down Expand Up @@ -616,6 +643,7 @@ def sample_actions(
metadata_masks=metadata_masks,
response_tokens=response_tokens,
response_masks=response_masks,
obs_history_is_pad=obs_history_is_pad,
)

action_feature = self.config.action_feature
Expand Down Expand Up @@ -1121,16 +1149,23 @@ def sample_time(self, bsize: int, device: torch.device | str) -> Tensor:
time = time_beta * 0.999 + 0.001
return time

def embed_video(self, video: Tensor) -> Tensor:
def embed_video(self, video: Tensor, obs_history_is_pad: Tensor | None = None) -> Tensor:
"""Encode a video through SpaceTimeSiglip + Perceiver reducer + projection.

Args:
video: (B, T, C, H, W)
obs_history_is_pad: Optional ``(B, T)`` bool mask — ``True`` for
padded history frames. Threaded into the SpaceTime SigLIP
encoder so temporal attention blocks padded frames (pixel-
zeroing alone is insufficient — the patch embedding bias and
temporal PE for ``t < T-1`` are non-zero, so zero pixels
still produce non-zero hidden states the current frame would
otherwise attend to).

Returns:
(B, num_video_tokens, vlm_hidden_size)
"""
return self.video_encoder(video)
return self.video_encoder(video, obs_history_is_pad=obs_history_is_pad)

def embed_prefix(
self,
Expand Down Expand Up @@ -1241,7 +1276,7 @@ def embed_prefix(
has_any_optional = bool(has_response or has_subgoal or has_metadata)

for vid, vid_mask in zip(videos, vid_masks, strict=True):
vid_emb = self.embed_video(vid) # (B, num_video_tokens, vlm_hidden)
vid_emb = self.embed_video(vid, obs_history_is_pad=obs_history_is_pad)
vid_emb = vid_emb.to(dtype=_preferred_dtype())

num_vid_embs = vid_emb.shape[1]
Expand Down Expand Up @@ -1292,9 +1327,19 @@ def embed_prefix(
state_emb = self.state_proj(state.to(dtype=_preferred_dtype()))
num_state_tokens = state_emb.shape[1] # T
if obs_history_is_pad is not None:
state_mask = ~obs_history_is_pad # True = real, False = padded
state_mask = ~obs_history_is_pad # (B, T) — `~` allocates a fresh tensor
else:
state_mask = torch.ones(bsize, num_state_tokens, dtype=torch.bool, device=state.device)
# Absent → assume all history is padded; only current step is real.
state_mask = torch.zeros(bsize, num_state_tokens, dtype=torch.bool, device=state.device)
# Current step (t = T-1) is ALWAYS real even when the dataset's
# history_state_drop_prob augmentation flips obs_history_is_pad to
# all-True. Without this override the policy would condition on no
# state at all, since attention to the current state token would be
# masked out — defeating the purpose of preserving the current frame.
# Both branches above produce fresh tensors (`~` allocates;
# `torch.zeros` allocates), so the `[:, -1] = True` write below does
# not reach the caller's `obs_history_is_pad`.
state_mask[:, -1] = True

embs.append(state_emb)
pad_masks.append(state_mask)
Expand Down Expand Up @@ -1739,6 +1784,7 @@ def sample_actions(
metadata_masks: Tensor | None = None,
response_tokens: Tensor | None = None,
response_masks: Tensor | None = None,
obs_history_is_pad: Tensor | None = None,
) -> Tensor:
"""Inference: iteratively denoise to produce a continuous action chunk.

Expand All @@ -1763,6 +1809,13 @@ def sample_actions(
metadata_masks: Optional mask for metadata tokens.
response_tokens: Optional subtask response token IDs.
response_masks: Optional mask for response tokens.
obs_history_is_pad: Optional ``(B, T)`` bool mask flagging padded
history slots (``True`` = padded). Emitted by
``PI07LowLevelPolicy._build_history_batch`` so the encoder can
use real mid-episode history while still masking out the
start-of-episode zero-fill. ``None`` falls back to "all
history padded except current" via ``embed_prefix`` and the
encoder's None-fallback.

Returns:
Denoised action chunk ``(B, chunk_size, max_action_dim)``.
Expand All @@ -1786,6 +1839,7 @@ def sample_actions(
metadata_masks,
subgoal_images=subgoal_images,
subgoal_img_masks=subgoal_img_masks,
obs_history_is_pad=obs_history_is_pad,
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
Expand Down
156 changes: 141 additions & 15 deletions src/opentau/policies/pi07/low_level/video_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,16 @@ def __init__(self, attn: SiglipAttention):
def attn(self) -> SiglipAttention:
return self._attn_ref[0]

def forward(self, hidden_states: Tensor) -> Tensor:
"""hidden_states: (B*N, T, D) -> (B*N, T, D)."""
def forward(self, hidden_states: Tensor, temporal_attn_mask: Tensor | None = None) -> Tensor:
"""hidden_states: (B*N, T, D) -> (B*N, T, D).

Args:
hidden_states: ``(B*N, T, D)`` hidden states reshaped for temporal
attention (one sequence per spatial patch position).
temporal_attn_mask: Optional ``(B*N, 1, T, T)`` additive float mask.
``0.0`` = attend, ``-inf`` = block. When ``None``, falls back to
the standard causal lower-triangular mask via ``is_causal=True``.
"""
attn = self.attn
bn, t, d = hidden_states.shape
num_heads = attn.num_heads
Expand All @@ -139,11 +147,19 @@ def forward(self, hidden_states: Tensor) -> Tensor:
k = attn.k_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2)
v = attn.v_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2)

# is_causal=True -> lower-triangular mask, each position attends to
# itself and earlier positions (our convention: t=T-1 is current).
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=True, dropout_p=0.0, scale=attn.scale
)
if temporal_attn_mask is not None:
# Caller-supplied mask already encodes both causal AND padded-history
# blocking; do NOT also set is_causal=True (SDPA disallows combining
# the two).
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=temporal_attn_mask, is_causal=False, dropout_p=0.0, scale=attn.scale
)
else:
# is_causal=True -> lower-triangular mask, each position attends to
# itself and earlier positions (our convention: t=T-1 is current).
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=True, dropout_p=0.0, scale=attn.scale
)
out = out.transpose(1, 2).reshape(bn, t, d)
return attn.out_proj(out)

Expand Down Expand Up @@ -270,11 +286,20 @@ def forward(
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
output_attentions: bool = False,
temporal_attn_mask: Tensor | None = None,
) -> tuple[Tensor, ...]:
"""hidden_states: (B*T, N, D) -> tuple starting with (B*T, N, D).

Signature matches ``SiglipEncoderLayer.forward`` so ``SiglipEncoder``
can dispatch unchanged.
Signature extends ``SiglipEncoderLayer.forward`` with an extra
``temporal_attn_mask`` kwarg. Vanilla ``SiglipEncoderLayer`` instances
never receive it: ``SpaceTimeSiglipVideoEncoder.forward`` dispatches
it only to ``SpaceTimeEncoderLayerWrapper`` instances via an
``isinstance`` check, bypassing ``SiglipEncoder.forward`` entirely.

Args:
temporal_attn_mask: Optional ``(B*N, 1, T, T)`` additive float mask
for temporal attention. ``0.0`` = attend, ``-inf`` = block.
When ``None``, the standard causal mask is used.
"""
t = self.num_frames
bt, n, d = hidden_states.shape
Expand Down Expand Up @@ -319,7 +344,7 @@ def forward(

t_in = rearrange(x_pe, "b t n d -> (b n) t d")
t_norm = self.layer_norm1(t_in)
t_out = self._temporal_attn(t_norm)
t_out = self._temporal_attn(t_norm, temporal_attn_mask=temporal_attn_mask)
# Residual on the pre-PE hidden (not on x_pe): PE is a transient
# positional signal, not a feature perturbation to carry forward.
t_res = rearrange(x, "b t n d -> (b n) t d") + t_out
Expand Down Expand Up @@ -442,13 +467,80 @@ def vision_tower(self) -> SiglipVisionModel:
def multi_modal_projector(self) -> nn.Module:
return self._multi_modal_projector_ref[0]

def forward(self, video: Tensor) -> Tensor:
@staticmethod
def _build_temporal_attn_mask(
obs_history_is_pad: Tensor,
num_patches: int,
dtype: torch.dtype,
) -> Tensor:
"""Build a causal temporal attention mask that blocks padded frames.

Pure pixel-level zeroing of padded frames is not enough — the SigLIP
patch embedding has a learned bias and the temporal positional
embedding ``e(t)`` is non-zero for ``t < T-1``, so padded "zero" frames
still produce non-zero hidden states that the current frame would
attend to. This mask blocks attention to padded keys at the SDPA call.

Args:
obs_history_is_pad: ``(B, T)`` bool — ``True`` for padded steps.
num_patches: ``N``, number of spatial patches per frame (the
video encoder runs one temporal sequence per patch position,
so each patch row of the (B*N) batch reuses the same mask).
dtype: Float dtype matching the hidden states (additive mask gets
added to attention scores; mismatched dtypes force upcasts).

Returns:
``(B*N, 1, T, T)`` additive float mask where ``0.0`` = attend and
``-inf`` = block. Row ``i`` can attend to column ``j`` iff
``j <= i`` (causal) **and** ``obs_history_is_pad[:, j]`` is
``False``. The current frame (``j = T-1``) is always attendable
even if the caller set ``obs_history_is_pad[:, -1] = True`` —
losing the current frame would defeat the encoder.
"""
b, t = obs_history_is_pad.shape
device = obs_history_is_pad.device

# Causal: position i attends to j <= i.
causal = torch.tril(torch.ones(t, t, dtype=torch.bool, device=device)) # (T, T)

# Key-side visibility: True where frame j is real (not padded).
# Force the last frame always attendable as a defensive fallback —
# callers (e.g. the dataset's history_state_drop_prob augmentation)
# set obs_history_is_pad to all-True; without this override, the
# current frame would have no key to attend to and produce NaNs.
# `~obs_history_is_pad` allocates a fresh tensor, so the in-place
# write below does not reach the caller's `obs_history_is_pad`.
key_valid = ~obs_history_is_pad # (B, T)
key_valid[:, -1] = True

# Combined: (B, T_query, T_key)
mask_bool = causal.unsqueeze(0) & key_valid.unsqueeze(1) # (B, T, T)

# Bool → additive float: True → 0.0, False → -inf.
float_mask = torch.zeros(b, t, t, dtype=dtype, device=device)
float_mask.masked_fill_(~mask_bool, float("-inf"))

# Expand for the (B*N) flattened-patch batch dimension:
# (B, T, T) → (B, 1, T, T) → repeat_interleave(N) → (B*N, 1, T, T)
float_mask = float_mask.unsqueeze(1) # (B, 1, T, T)
float_mask = float_mask.repeat_interleave(num_patches, dim=0) # (B*N, 1, T, T)

return float_mask

def forward(self, video: Tensor, obs_history_is_pad: Tensor | None = None) -> Tensor:
"""Encode a video clip and return the current-frame tokens.

Args:
video: ``(B, T, C, H, W)`` pixel values in ``[0, 1]``, with
``T == num_frames``, ``C == 3``, and spatial size matching
the SigLIP config (224x224 by default).
obs_history_is_pad: Optional ``(B, T)`` bool mask where ``True``
marks padded history frames. Padded frames are blocked in
the temporal attention so the current frame cannot read
contaminated hidden states from them. When ``None`` and
``T > 1``, falls back to "only the current frame is real"
(matches inference-time semantics where
``_build_history_batch`` does not populate this mask).

Returns:
``(B, num_video_tokens, vlm_hidden_size)`` current-frame tokens,
Expand All @@ -473,20 +565,54 @@ def forward(self, video: Tensor) -> Tensor:
# Patch embedding + learned spatial position embedding.
hidden = self.vision_tower.vision_model.embeddings(flat)

# Build temporal attention mask. Skipped at T=1 because the wrapper's
# T=1 short-circuit bypasses temporal attention entirely.
temporal_attn_mask: Tensor | None = None
if t > 1:
if obs_history_is_pad is not None:
temporal_attn_mask = self._build_temporal_attn_mask(
obs_history_is_pad, self.num_video_tokens, hidden.dtype
)
else:
# Inference fallback: select_action -> _build_history_batch
# zero-pads missing slots but does NOT emit obs_history_is_pad.
# Treat all history as padded so the current frame's
# representation is uncontaminated by the zero-pixel
# placeholders.
fallback_pad = torch.ones(b, t, dtype=torch.bool, device=hidden.device)
fallback_pad[:, -1] = False
temporal_attn_mask = self._build_temporal_attn_mask(
fallback_pad, self.num_video_tokens, hidden.dtype
)

# Encoder stack: standard spatial layers + wrapped every-Nth layer
# with temporal attention. SpaceTimeEncoderLayerWrapper matches the
# SiglipEncoderLayer signature, so we drive the loop manually here
# (instead of calling SiglipEncoder.forward) so we can wrap each
# layer in torch.utils.checkpoint.checkpoint when the flag is set —
# the same explicit pattern PaliGemmaWithExpertModel uses.
# ``temporal_attn_mask`` is passed only to spacetime-wrapped layers
# (vanilla layers don't accept it). Under gradient checkpointing it
# MUST go in as a positional arg — torch.utils.checkpoint.checkpoint
# with use_reentrant=False does not forward kwargs to the wrapped
# function.
use_ckpt = self.gradient_checkpointing and self.training
for layer in self.vision_tower.vision_model.encoder.layers:
is_spacetime = isinstance(layer, SpaceTimeEncoderLayerWrapper)
if use_ckpt:
layer_outputs = torch.utils.checkpoint.checkpoint(
layer, hidden, None, False, use_reentrant=False
)
if is_spacetime:
layer_outputs = torch.utils.checkpoint.checkpoint(
layer, hidden, None, False, temporal_attn_mask, use_reentrant=False
)
else:
layer_outputs = torch.utils.checkpoint.checkpoint(
layer, hidden, None, False, use_reentrant=False
)
else:
layer_outputs = layer(hidden, None, False)
if is_spacetime:
layer_outputs = layer(hidden, None, False, temporal_attn_mask=temporal_attn_mask)
else:
layer_outputs = layer(hidden, None, False)
hidden = layer_outputs[0]

hidden = self.vision_tower.vision_model.post_layernorm(hidden)
Expand Down
Loading
Loading