From c0cc4e9e8d81b7f0cd082cf98e012a91528ac4a9 Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Tue, 28 Apr 2026 18:30:34 -0700 Subject: [PATCH 1/8] Adding original pi07 with gemma3 backbone and space-time siglip video encoder --- src/opentau/policies/factory.py | 34 +- src/opentau/policies/pi07/__init__.py | 13 + .../policies/pi07/gemma3_with_expert.py | 712 +++++++ .../pi07/high_level_planner/__init__.py | 20 + .../configuration_pi07_high_level.py | 234 +++ .../modeling_pi07_high_level.py | 1434 ++++++++++++++ .../pi07/low_level_planner/__init__.py | 24 + .../configuration_pi07_low_level.py | 263 +++ .../modeling_pi07_low_level.py | 1740 +++++++++++++++++ .../pi07/low_level_planner/video_encoder.py | 460 +++++ .../policies/test_pi07_high_level_planner.py | 333 ++++ tests/policies/test_pi07_low_level_planner.py | 673 +++++++ 12 files changed, 5934 insertions(+), 6 deletions(-) create mode 100644 src/opentau/policies/pi07/__init__.py create mode 100644 src/opentau/policies/pi07/gemma3_with_expert.py create mode 100644 src/opentau/policies/pi07/high_level_planner/__init__.py create mode 100644 src/opentau/policies/pi07/high_level_planner/configuration_pi07_high_level.py create mode 100644 src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py create mode 100644 src/opentau/policies/pi07/low_level_planner/__init__.py create mode 100644 src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py create mode 100644 src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py create mode 100644 src/opentau/policies/pi07/low_level_planner/video_encoder.py create mode 100644 tests/policies/test_pi07_high_level_planner.py create mode 100644 tests/policies/test_pi07_low_level_planner.py diff --git a/src/opentau/policies/factory.py b/src/opentau/policies/factory.py index 5aab6cdb..91ed9db0 100644 --- a/src/opentau/policies/factory.py +++ b/src/opentau/policies/factory.py @@ -36,11 +36,17 @@ from opentau.policies.pi0.configuration_pi0 import PI0Config from opentau.policies.pi05.configuration_pi05 import PI05Config from opentau.policies.pi05_mem.configuration_pi05 import PI05MemConfig -from opentau.policies.pi07_paligemma.high_level_planner.configuration_pi07_high_level import ( +from opentau.policies.pi07.high_level_planner.configuration_pi07_high_level import ( PI07HighLevelPlannerConfig, ) +from opentau.policies.pi07.low_level_planner.configuration_pi07_low_level import ( + PI07LowLevelPlannerConfig, +) +from opentau.policies.pi07_paligemma.high_level_planner.configuration_pi07_high_level import ( + PI07HighLevelPlannerConfig as PI07PaligemmaHighLevelPlannerConfig, +) from opentau.policies.pi07_paligemma.low_level_planner.configuration_pi07_low_level import ( - PI07lowlevelPlannerConfig, + PI07lowlevelPlannerConfig as PI07PaligemmaLowLevelPlannerConfig, ) from opentau.policies.pretrained import PreTrainedPolicy from opentau.policies.value.configuration_value import ValueConfig @@ -82,12 +88,24 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: return PI05MemPolicy elif name == "pi07_paligemma_high_level_planner": from opentau.policies.pi07_paligemma.high_level_planner.modeling_pi07_high_level import ( - PI07HighLevelPlannerPolicy, + PI07HighLevelPlannerPolicy as PI07PaligemmaHighLevelPlannerPolicy, ) - return PI07HighLevelPlannerPolicy + return PI07PaligemmaHighLevelPlannerPolicy elif name == "pi07_paligemma_low_level_planner": from opentau.policies.pi07_paligemma.low_level_planner.modeling_pi07_low_level import ( + PI07LowLevelPlannerPolicy as PI07PaligemmaLowLevelPlannerPolicy, + ) + + return PI07PaligemmaLowLevelPlannerPolicy + elif name == "pi07_high_level": + from opentau.policies.pi07.high_level_planner.modeling_pi07_high_level import ( + PI07HighLevelPlannerPolicy, + ) + + return PI07HighLevelPlannerPolicy + elif name == "pi07_low_level": + from opentau.policies.pi07.low_level_planner.modeling_pi07_low_level import ( PI07LowLevelPlannerPolicy, ) @@ -128,9 +146,13 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: elif policy_type == "pi05_mem": return PI05MemConfig(**kwargs) elif policy_type == "pi07_paligemma_high_level_planner": - return PI07HighLevelPlannerConfig(**kwargs) + return PI07PaligemmaHighLevelPlannerConfig(**kwargs) elif policy_type == "pi07_paligemma_low_level_planner": - return PI07lowlevelPlannerConfig(**kwargs) + return PI07PaligemmaLowLevelPlannerConfig(**kwargs) + elif policy_type == "pi07_high_level": + return PI07HighLevelPlannerConfig(**kwargs) + elif policy_type == "pi07_low_level": + return PI07LowLevelPlannerConfig(**kwargs) elif policy_type == "value": return ValueConfig(**kwargs) else: diff --git a/src/opentau/policies/pi07/__init__.py b/src/opentau/policies/pi07/__init__.py new file mode 100644 index 00000000..787f750f --- /dev/null +++ b/src/opentau/policies/pi07/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/opentau/policies/pi07/gemma3_with_expert.py b/src/opentau/policies/pi07/gemma3_with_expert.py new file mode 100644 index 00000000..ea296e12 --- /dev/null +++ b/src/opentau/policies/pi07/gemma3_with_expert.py @@ -0,0 +1,712 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma 3 backbone with Gemma-v1 action expert, for the PI06 policy. + +Mirrors `paligemma_with_expert.py` but: + * the vision-language backbone is `Gemma3ForConditionalGeneration` (Gemma 3 4B, + SigLIP-400m/14 + Gemma 3 text, 34 interleaved sliding-window/global layers); + * the action expert is a Gemma-v1 `GemmaForCausalLM` with the AdaRMS and + gated-residual patches applied by `opentau.utils.transformers_patch`. + +The per-layer attention loop below concatenates backbone and expert queries/ +keys/values along the sequence dimension at every layer (the MoE-like pattern +introduced in π0), so the expert can cross-attend to the backbone's activations +at every depth. Gemma 3 specifics (q_norm/k_norm, pre/post feedforward RMSNorms, +per-layer local vs global RoPE, sliding-window attention) are all honored. + +`transformers_patch` is imported at module load so the expert path picks up +adaptive RMSNorm and `_gated_residual`. The Gemma 3 backbone remains stock — +its layer-norms return a plain tensor and are used without a `cond=` argument. +""" + +import torch +from torch import nn +from transformers import ( + AutoConfig, + Cache, + Gemma3ForConditionalGeneration, + GemmaForCausalLM, + PretrainedConfig, + PreTrainedModel, +) +from transformers.models.auto import CONFIG_MAPPING +from transformers.models.gemma import modeling_gemma + +# Ensure the Gemma-v1 AdaRMS / gated-residual patches are live before we +# construct an action expert. Import for side effects only. +from opentau.utils import transformers_patch # noqa: F401 + + +def _preferred_dtype(): + return torch.float32 if torch.onnx.is_in_onnx_export() else torch.bfloat16 + + +def apply_rope(x: torch.Tensor, positions: torch.Tensor, max_wavelength: float = 10_000.0) -> torch.Tensor: + """Applies RoPE to `x` with the given positions and base wavelength. + + Args: + x: Tensor of shape `(B, L, H, D)`. + positions: Tensor of shape `(B, L)`. + max_wavelength: RoPE base frequency. Gemma 3 uses 10_000 for sliding + (local) layers and 1_000_000 for full (global) layers; Gemma-v1 + expert uses 10_000. + + Returns: + RoPE-transformed tensor, same shape / dtype as the input. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + radians = radians[..., None, :] + + sin = torch.sin(radians) + cos = torch.cos(radians) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + +# NOTE: π0.6 deliberately does NOT enforce Gemma 3's sliding-window mask. +# The model card describes "bidirectional attention among ALL of the image +# tokens" and "block-wise causal" prefix attention — wording that's +# incompatible with a 1024-token window once you have 4 cameras × 256 image +# tokens = 1024 image tokens. The local layers' pretrained weights still +# rotate at θ=10_000 (we honour that), but the per-layer attention pattern +# is the global block-causal mask everywhere. + + +class Gemma3WithExpertConfig(PretrainedConfig): + """Configuration wrapper bundling a Gemma 3 VLM config and a Gemma-v1 expert config.""" + + model_type = "Gemma3WithExpertModel" + sub_configs = {"gemma3_config": AutoConfig, "gemma_expert_config": AutoConfig} + + def __init__( + self, + gemma3_config: dict | None = None, + gemma_expert_config: dict | None = None, + freeze_vision_encoder: bool = True, + train_expert_only: bool = True, + attention_implementation: str = "eager", + load_pretrained_gemma3: bool = False, + discrete_action_vocab_size: int | None = None, + dropout: float = 0.1, + **kwargs, + ): + """Initializes the configuration. + + Args: + gemma3_config: Optional Gemma 3 config dict. Defaults to the + `google/gemma-3-4b-pt` topology. + gemma_expert_config: Optional Gemma-v1 action-expert config dict. + Defaults to a ~860M-parameter Gemma with AdaRMS enabled. + freeze_vision_encoder: Freeze the SigLIP tower during training. + train_expert_only: Only update the expert and its heads. + attention_implementation: "eager" or "fa2" (fa2 not yet supported). + load_pretrained_gemma3: Whether to pull pretrained Gemma 3 weights + from the Hub (only recommended when He-initializing the expert). + discrete_action_vocab_size: FAST tokenizer vocab size. + dropout: Dropout probability applied in the per-layer loop. + **kwargs: Passed to `PretrainedConfig`. + """ + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + self.attention_implementation = attention_implementation + self.load_pretrained_gemma3 = load_pretrained_gemma3 + self.discrete_action_vocab_size = discrete_action_vocab_size + self.dropout = dropout + + # Gemma 3 backbone defaults (match google/gemma-3-4b-pt). + if gemma3_config is None: + self.gemma3_config = CONFIG_MAPPING["gemma3"]( + text_config={ + "model_type": "gemma3_text", + "hidden_size": 2560, + "intermediate_size": 10240, + "num_hidden_layers": 34, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "head_dim": 256, + "query_pre_attn_scalar": 256, + "sliding_window": 1024, + "rope_theta": 1_000_000.0, + "rope_local_base_freq": 10_000.0, + "rms_norm_eps": 1e-6, + "vocab_size": 262_208, + "max_position_embeddings": 131_072, + "attention_bias": False, + "attention_dropout": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + "sliding_window_pattern": 6, + "torch_dtype": "float32", + }, + vision_config={ + "model_type": "siglip_vision_model", + "hidden_size": 1152, + "intermediate_size": 4304, + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + # π0.6 feeds 448×448 images. `Gemma3MultiModalProjector` + # hardcodes `patches_per_image = image_size // patch_size`, + # so this MUST match the actual input resolution or the + # projector's reshape crashes (see test_pi06.py:: + # TestGemma3WithExpertConfig::test_vision_image_size_matches_input). + "image_size": 448, + "projection_dim": 2560, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + "torch_dtype": "float32", + "layer_norm_eps": 1e-6, + }, + image_token_index=262144, + mm_tokens_per_image=256, + boi_token_index=255999, + eoi_token_index=256000, + initializer_range=0.02, + ) + elif isinstance(gemma3_config, dict): + if "model_type" not in gemma3_config: + gemma3_config["model_type"] = "gemma3" + cfg_cls = CONFIG_MAPPING[gemma3_config["model_type"]] + self.gemma3_config = cfg_cls(**gemma3_config) + else: + self.gemma3_config = gemma3_config + + # Gemma-v1 action-expert defaults (~860M params). + if gemma_expert_config is None: + self.gemma_expert_config = CONFIG_MAPPING["gemma"]( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=2, + eos_token_id=1, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation="gelu_pytorch_tanh", + hidden_size=1280, + initializer_range=0.02, + intermediate_size=5120, + max_position_embeddings=8192, + model_type="gemma", + num_attention_heads=8, + num_hidden_layers=34, + # GQA to match the backbone so per-layer KV concatenation works. + num_key_value_heads=4, + pad_token_id=0, + rms_norm_eps=1e-6, + rope_theta=10_000.0, + torch_dtype="float32", + use_adarms=True, + adarms_cond_dim=1280, + use_cache=True, + vocab_size=262_208, + ) + elif isinstance(gemma_expert_config, dict): + if "model_type" not in gemma_expert_config: + gemma_expert_config["model_type"] = "gemma" + cfg_cls = CONFIG_MAPPING[gemma_expert_config["model_type"]] + self.gemma_expert_config = cfg_cls(**gemma_expert_config) + else: + self.gemma_expert_config = gemma_expert_config + + if self.train_expert_only and not self.freeze_vision_encoder: + raise ValueError( + "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." + ) + if self.attention_implementation not in ["eager", "fa2"]: + raise ValueError( + f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). " + "Expected 'eager' or 'fa2'." + ) + + super().__init__(**kwargs) + + +class Gemma3WithExpertModel(PreTrainedModel): + """Gemma 3 VLM interleaved layer-wise with a Gemma-v1 action expert.""" + + config_class = Gemma3WithExpertConfig + + def __init__(self, config: Gemma3WithExpertConfig): + super().__init__(config=config) + self.config = config + + if config.load_pretrained_gemma3: + self.gemma3 = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-pt") + else: + self.gemma3 = Gemma3ForConditionalGeneration(config=config.gemma3_config) + + self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) + # The expert shares embeddings nowhere — drop the unused token table. + self.gemma_expert.model.embed_tokens = None + + text_hidden = config.gemma3_config.text_config.hidden_size + + self.discrete_action_embedding = nn.Embedding( + num_embeddings=config.discrete_action_vocab_size, + embedding_dim=text_hidden, + padding_idx=0, + ) + self.da_head = nn.Linear( + in_features=text_hidden, + out_features=config.discrete_action_vocab_size, + ) + + self.dropout = nn.Dropout(config.dropout) + + if not torch.compiler.is_compiling(): + self.to_bfloat16_like_physical_intelligence() + self.set_requires_grad() + + # Cache commonly accessed config scalars. + self._text_config = config.gemma3_config.text_config + self._expert_config = config.gemma_expert_config + self._num_layers = self._text_config.num_hidden_layers + self._head_dim = self._text_config.head_dim + self._rope_global = float(self._text_config.rope_theta) + self._rope_local = float(getattr(self._text_config, "rope_local_base_freq", 10_000.0)) + self._layer_types: list[str] = list(self._text_config.layer_types) + # Notes: + # * the expert's own `rope_theta` is deliberately ignored at runtime + # — the shared attention requires the backbone's per-layer θ for + # both streams (see `forward()`). + # * `text_config.sliding_window` is also deliberately unused — see + # the comment near `apply_rope` for why π0.6 doesn't enforce it. + self._query_pre_attn_scaling = float(self._text_config.query_pre_attn_scalar) ** -0.5 + + # Trainable / dtype plumbing + + def set_requires_grad(self) -> None: + if self.config.freeze_vision_encoder: + vision_tower = self._vision_tower() + if vision_tower is not None: + vision_tower.eval() + for params in vision_tower.parameters(): + params.requires_grad = False + + if self.config.train_expert_only: + self.gemma3.eval() + for params in self.gemma3.parameters(): + params.requires_grad = False + for param in self.da_head.parameters(): + param.requires_grad = False + for param in self.discrete_action_embedding.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + if self.config.freeze_vision_encoder: + vision_tower = self._vision_tower() + if vision_tower is not None: + vision_tower.eval() + if self.config.train_expert_only: + self.gemma3.eval() + return self + + def to_bfloat16_like_physical_intelligence(self) -> None: + self.gemma3 = self.gemma3.to(dtype=torch.bfloat16) + params_to_change_dtype = [ + "language_model.model.layers", + "gemma_expert.model.layers", + "vision_tower", + "multi_modal_projector", + ] + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_change_dtype): + param.data = param.data.to(dtype=torch.bfloat16) + + # Embedding helpers + + def _vision_tower(self): + # Gemma 3's vision tower lives at `gemma3.model.vision_tower` depending on + # the transformers version; fall back gracefully. + for path in ("vision_tower", "model.vision_tower"): + obj = self.gemma3 + ok = True + for part in path.split("."): + if hasattr(obj, part): + obj = getattr(obj, part) + else: + ok = False + break + if ok: + return obj + return None + + def _multi_modal_projector(self): + for path in ("multi_modal_projector", "model.multi_modal_projector"): + obj = self.gemma3 + ok = True + for part in path.split("."): + if hasattr(obj, part): + obj = getattr(obj, part) + else: + ok = False + break + if ok: + return obj + return None + + def embed_image(self, image: torch.Tensor) -> torch.Tensor: + """Runs the SigLIP tower + multimodal projector to obtain image tokens. + + Gemma 3's `get_image_features` returns a `BaseModelOutputWithPooling` + whose `pooler_output` holds the projected image tokens (shape + `(B, mm_tokens_per_image, text_hidden_size)`). We extract that tensor + so callers can treat the return as a plain `(B, N, D)` tensor — matching + the patched PaliGemma behavior in `transformers_patch.py`. + """ + vision_tower = self._vision_tower() + projector = self._multi_modal_projector() + if vision_tower is None or projector is None: + raise RuntimeError( + "Gemma3 vision tower / multi_modal_projector could not be located on `self.gemma3`." + ) + last_hidden_state = vision_tower(pixel_values=image).last_hidden_state + return projector(last_hidden_state) + + def _lm_head(self) -> nn.Module: + """Returns the language modeling head. + + ``Gemma3ForConditionalGeneration`` owns ``lm_head`` directly (its + ``language_model`` is a ``Gemma3TextModel`` with no ``lm_head``). Older + / nested layouts are tolerated via the ``model.lm_head`` fallback. + """ + if hasattr(self.gemma3, "lm_head"): + return self.gemma3.lm_head + return self.gemma3.model.lm_head + + def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """Embed token ids through Gemma 3's shared text embedding table.""" + lm = getattr(self.gemma3, "language_model", None) + if lm is None: + lm = self.gemma3.model.language_model + return lm.embed_tokens(tokens) + + def embed_discrete_actions(self, actions: torch.Tensor) -> torch.Tensor: + if actions.dtype != torch.long: + actions = actions.long() + return self.discrete_action_embedding(actions) + + # Attention core + + def get_attention_interface(self): + if self.config.attention_implementation == "fa2": + raise NotImplementedError( + "fa2 attention is not supported for pi06 yet because of the interleaved " + "local/global mask pattern — use 'eager' instead." + ) + return self.eager_attention_forward + + def eager_attention_forward( + self, + attention_mask: torch.Tensor, + batch_size: int, + head_dim: int, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + scaling: float | None = None, + ) -> torch.Tensor: + """Standard eager scaled-dot-product attention. `attention_mask` is a + boolean 2D mask of shape `(B, Q, K)` (True = attend).""" + num_att_heads = self._text_config.num_attention_heads + num_key_value_heads = self._text_config.num_key_value_heads + num_key_value_groups = num_att_heads // num_key_value_heads + + sequence_length = key_states.shape[1] + + key_states = key_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + query_states = query_states.to(dtype=torch.float32) + key_states = key_states.to(dtype=torch.float32) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + att_weights *= scaling if scaling is not None else head_dim**-0.5 + big_neg = -2.3819763e38 + + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + + probs = nn.functional.softmax(masked_att_weights, dim=-1) + probs = probs.to(dtype=value_states.dtype) + + att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) + att_output = att_output.permute(0, 2, 1, 3) + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + return att_output + + # Per-layer interleaved forward + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | Cache | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + n_cross_att_tokens: int | None = None, + use_cache: bool | None = None, + fill_kv_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ) -> tuple[list[torch.FloatTensor | None], list[torch.FloatTensor] | Cache | None]: + """Interleaved per-layer forward for the Gemma 3 backbone and Gemma-v1 expert. + + The two streams (index 0 = backbone, index 1 = expert) share each layer's + attention — queries and KVs are concatenated along the sequence axis. When + one stream's embeddings are None the other runs alone, pulling KVs for the + missing stream from `past_key_values` when `use_cache=True`. + + Args: + attention_mask: 2D boolean mask of shape `(B, Q, K_total)`. See + `opentau.policies.pi05.modeling_pi05.make_att_2d_masks`. + position_ids: `(B, L_total)` token positions, used for RoPE. + past_key_values: Per-layer KV cache populated on a previous call. + inputs_embeds: `[backbone_embeds, expert_embeds]`. Either may be None. + n_cross_att_tokens: Number of prefix tokens to retain in the cache + (must be provided when `fill_kv_cache=True`). + use_cache: Read KVs from `past_key_values` (prefix cross-attention). + fill_kv_cache: Write this call's KVs into `past_key_values`. + adarms_cond: Per-stream AdaRMS conditioning tensors `[None, cond]`. + + Returns: + A pair `(outputs_embeds, past_key_values)` where `outputs_embeds` is a + two-element list mirroring the `inputs_embeds` layout. + """ + if adarms_cond is None: + adarms_cond = [None, None] + + backbone_layers = self._backbone_layers() + expert_layers = self.gemma_expert.model.layers + backbone_norm = self._backbone_final_norm() + expert_norm = self.gemma_expert.model.norm + + # Infer batch size from whichever stream is present. + batch_size = None + for h in inputs_embeds: + if h is not None: + batch_size = h.shape[0] + break + if batch_size is None: + raise ValueError("`inputs_embeds` must contain at least one non-None entry.") + + head_dim = self._head_dim + + for layer_idx in range(self._num_layers): + layer_type = self._layer_types[layer_idx] + is_sliding = layer_type == "sliding_attention" + layer_rope_theta = self._rope_local if is_sliding else self._rope_global + + layers_this_step = [backbone_layers[layer_idx], expert_layers[layer_idx]] + # Both streams MUST use the same RoPE base at this layer. Shared + # attention concatenates Q/K along the sequence axis; the dot-product + # invariant `R(q,p)·R(k,q) = q·R(q-p)k` only holds when the same θ + # produced both rotations. For global Gemma-3 layers (θ=1M) this + # means the expert also rotates at 1M even though the config carries + # a single fallback `rope_theta=10k`. + rope_thetas = [layer_rope_theta, layer_rope_theta] + + query_states = [] + key_states = [] + value_states = [] + gates = [] + # Track the pre-attention residual + post-attn layernorm output for the + # Gemma-3 backbone side, since it needs a second residual around the MLP + # using `pre_feedforward_layernorm` / `post_feedforward_layernorm`. + backbone_preattn_residual = None + + for stream_idx, hidden_states in enumerate(inputs_embeds): + if hidden_states is None: + gates.append(None) + query_states.append(None) + key_states.append(None) + value_states.append(None) + continue + + layer = layers_this_step[stream_idx] + + if stream_idx == 0: + # Gemma 3 backbone. + backbone_preattn_residual = hidden_states + h = layer.input_layernorm(hidden_states) + gate = None + else: + # Gemma-v1 expert (patched to return (tensor, gate)). + h, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[stream_idx]) + + gates.append(gate) + bsize, seq_len, _ = h.shape + h = h.to(dtype=_preferred_dtype()) + + q = layer.self_attn.q_proj(h).view(bsize, seq_len, -1, head_dim) + k = layer.self_attn.k_proj(h).view(bsize, seq_len, -1, head_dim) + v = layer.self_attn.v_proj(h).view(bsize, seq_len, -1, head_dim) + + if stream_idx == 0: + # Gemma-3 applies an extra per-head RMSNorm on Q and K. + q_norm = getattr(layer.self_attn, "q_norm", None) + k_norm = getattr(layer.self_attn, "k_norm", None) + if q_norm is not None: + q = q_norm(q) + if k_norm is not None: + k = k_norm(k) + + q = apply_rope(q, position_ids, max_wavelength=rope_thetas[stream_idx]) + k = apply_rope(k, position_ids, max_wavelength=rope_thetas[stream_idx]) + + query_states.append(q) + key_states.append(k) + value_states.append(v) + + # Drop Nones before concatenating. + q_list = [q for q in query_states if q is not None] + k_list = [k for k in key_states if k is not None] + v_list = [v for v in value_states if v is not None] + + q_concat = torch.cat(q_list, dim=1) + k_concat = torch.cat(k_list, dim=1) + v_concat = torch.cat(v_list, dim=1) + + if use_cache and past_key_values is not None and layer_idx in past_key_values: + k_concat = torch.cat([past_key_values[layer_idx]["key_states"], k_concat], dim=1) + v_concat = torch.cat([past_key_values[layer_idx]["value_states"], v_concat], dim=1) + + if fill_kv_cache: + if past_key_values is None: + past_key_values = {} + if n_cross_att_tokens is None: + raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True") + past_key_values[layer_idx] = { + "key_states": k_concat[:, :n_cross_att_tokens, :, :], + "value_states": v_concat[:, :n_cross_att_tokens, :, :], + } + + # π0.6 keeps the prefix block-causal mask at every layer — the + # Gemma 3 sliding-window pattern is deliberately not applied + # (see the note next to `apply_rope`). + layer_attention_mask = attention_mask + + attention_interface = self.get_attention_interface() + att_output = attention_interface( + layer_attention_mask, + batch_size, + head_dim, + q_concat, + k_concat, + v_concat, + scaling=self._query_pre_attn_scaling, + ) + att_output = att_output.to(dtype=_preferred_dtype()) + + outputs_embeds: list[torch.Tensor | None] = [] + start = 0 + for stream_idx, hidden_states in enumerate(inputs_embeds): + if hidden_states is None: + outputs_embeds.append(None) + continue + + layer = layers_this_step[stream_idx] + seq_len = hidden_states.shape[1] + end = start + seq_len + part = att_output[:, start:end] + start = end + + if part.dtype != layer.self_attn.o_proj.weight.dtype: + part = part.to(layer.self_attn.o_proj.weight.dtype) + part = layer.self_attn.o_proj(part) + part = self.dropout(part) + + if stream_idx == 0: + # Gemma 3 block: residual + post_attn_norm(attn); then a second + # residual with pre_feedforward_layernorm / mlp / post_feedforward_layernorm. + post_attn = layer.post_attention_layernorm(part) + h = backbone_preattn_residual + post_attn + + ff_residual = h + h = layer.pre_feedforward_layernorm(h) + h = layer.mlp(h) + h = self.dropout(h) + h = layer.post_feedforward_layernorm(h) + h = ff_residual + h + outputs_embeds.append(h) + else: + # Gemma-v1 expert block with AdaRMS gates. + h = modeling_gemma._gated_residual(hidden_states, part, gates[stream_idx]) # noqa: SLF001 + ff_residual = h.clone() + h, gate2 = layer.post_attention_layernorm(h, cond=adarms_cond[stream_idx]) + h = layer.mlp(h) + h = self.dropout(h) + h = modeling_gemma._gated_residual(ff_residual, h, gate2) # noqa: SLF001 + outputs_embeds.append(h) + + inputs_embeds = outputs_embeds + + # Final norms. + final_outputs: list[torch.Tensor | None] = [] + for stream_idx, hidden_states in enumerate(inputs_embeds): + if hidden_states is None: + final_outputs.append(None) + continue + if stream_idx == 0: + final_outputs.append(backbone_norm(hidden_states)) + else: + out, _ = expert_norm(hidden_states, cond=adarms_cond[stream_idx]) + final_outputs.append(out) + + return final_outputs, past_key_values + + # Gemma 3 structural accessors + + def _backbone_text_model(self): + # Different transformers versions expose Gemma 3 under slightly different + # attribute paths. Resolve once. + if hasattr(self.gemma3, "language_model"): + return self.gemma3.language_model + return self.gemma3.model.language_model + + def _backbone_layers(self): + text_model = self._backbone_text_model() + if hasattr(text_model, "layers"): + return text_model.layers + return text_model.model.layers + + def _backbone_final_norm(self): + text_model = self._backbone_text_model() + if hasattr(text_model, "norm"): + return text_model.norm + return text_model.model.norm diff --git a/src/opentau/policies/pi07/high_level_planner/__init__.py b/src/opentau/policies/pi07/high_level_planner/__init__.py new file mode 100644 index 00000000..28a71afe --- /dev/null +++ b/src/opentau/policies/pi07/high_level_planner/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PI07 high level planner Policy Module. + +This module implements the High level planner for π07 which is basically π05 (Pi05) Vision-Language-model with memory and response prediction, +designed for high level planning, memory and subtask generation. It includes the policy definition, +configuration, and model architecture. +""" diff --git a/src/opentau/policies/pi07/high_level_planner/configuration_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/configuration_pi07_high_level.py new file mode 100644 index 00000000..1a774ed6 --- /dev/null +++ b/src/opentau/policies/pi07/high_level_planner/configuration_pi07_high_level.py @@ -0,0 +1,234 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration module for the PI07 high level planner Policy. + +This module defines the `PI07HighLevelPlannerConfig` class, which handles the configuration parameters +for the PI07 high level planner. It includes settings for the model architecture, +optimization, scheduling, and data processing. +""" + +from dataclasses import dataclass, field + +from opentau.configs.policies import PreTrainedConfig +from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature +from opentau.optim.optimizers import AdamWConfig +from opentau.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, + LRSchedulerConfig, +) +from opentau.policies.pi07.gemma3_with_expert import Gemma3WithExpertConfig + + +@PreTrainedConfig.register_subclass("pi07_high_level") +@dataclass +class PI07HighLevelPlannerConfig(PreTrainedConfig): + """Configuration for the π07 high-level planner policy. + + The high-level planner takes images, a language instruction, robot state, + and past memory, then autoregressively predicts updated memory and the + next subtask string. This config controls model architecture, tokenizer + limits, initialization, and optimizer/scheduler presets. + + Args: + n_obs_steps: Number of observation steps to use. Only ``1`` is + currently supported. Defaults to 1. + normalization_mapping: Mapping from feature type names to + normalization modes. Defaults to identity for visual features + and mean-std for state. + max_state_dim: Maximum dimension for state vectors. Shorter vectors + are zero-padded. Defaults to 32. + resize_imgs_with_padding: Target ``(height, width)`` for image + resizing with aspect-ratio-preserving padding. Must match the + Gemma 3 vision tower's ``image_size`` (the projector hardcodes + ``patches_per_image = image_size // patch_size``). Defaults to + ``(448, 448)`` to match ``vlm_config.gemma3_config.vision_config.image_size``. + empty_cameras: Number of empty (zero-filled) camera inputs to add. + Defaults to 0. + prompt_max_length: Maximum token length for the composite language + prompt (task + past memory + state). Defaults to 256. + memory_max_length: Maximum token length for the updated memory + sequence. Defaults to 52. + response_max_length: Maximum token length for the subtask response + sequence. Defaults to 52. + metadata_max_length: Maximum token length for episode metadata + strings. Defaults to 52. + subtask_indicator_max_length: Number of tokenizer pieces for the fixed + ``"Subtask: "`` span (``encode(..., add_special_tokens=False)``). Used to + align CE slices with the prefix layout. MUST equal + ``len(tokenizer.encode("Subtask: ", add_special_tokens=False))`` + for whatever tokenizer the model uses; otherwise the memory CE + slice is misaligned. Defaults to 4. + memory_indicator_max_length: Number of tokenizer pieces for the fixed + ``"Updated Memory: "`` span. Used for documentation and layout + checks. MUST equal + ``len(tokenizer.encode("Updated Memory: ", add_special_tokens=False))`` + for whatever tokenizer the model uses. Defaults to 4. + dropout: Dropout rate applied in the transformer expert. + Defaults to 0.1. + attention_implementation: Attention backend — ``"eager"`` or + ``"fa2"`` (Flash Attention 2). Defaults to ``"eager"``. + freeze_vision_encoder: Whether to freeze the SigLIP vision encoder + during fine-tuning. Defaults to True. + optimizer_lr: Peak learning rate for AdamW. Defaults to 2.5e-5. + optimizer_betas: Beta parameters for AdamW. Defaults to (0.9, 0.95). + optimizer_eps: Epsilon for AdamW. Defaults to 1e-8. + optimizer_weight_decay: Weight decay for AdamW. Defaults to 1e-10. + scheduler_warmup_steps: Linear warmup steps. Defaults to 1_000. + scheduler_decay_steps: Cosine decay steps. Defaults to 30_000. + scheduler_decay_lr: Final learning rate after decay. + Defaults to 2.5e-6. + """ + + # Input / output structure. + n_obs_steps: int = 1 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 + + # Image preprocessing. Must equal the Gemma 3 vision tower's image_size: + # `Gemma3MultiModalProjector` hardcodes + # `patches_per_image = image_size // patch_size`, so feeding a different + # resolution crashes the projector's reshape. + resize_imgs_with_padding: tuple[int, int] = (448, 448) + + # Add empty images. Used by pi05_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Language Tokenizer + prompt_max_length: int = 256 + + # Memory Tokenizer + memory_max_length: int = 52 + + # Response Tokenizer + response_max_length: int = 52 + + # Metadata Tokenizer + metadata_max_length: int = 52 + + subtask_indicator_max_length: int = 4 + + memory_indicator_max_length: int = 4 + + # Dropout + dropout: float = 0.1 + + # Attention utils + attention_implementation: str = "eager" + + # Finetuning settings + freeze_vision_encoder: bool = True + + vlm_config: Gemma3WithExpertConfig = field( + default_factory=lambda: Gemma3WithExpertConfig( + freeze_vision_encoder=True, + train_expert_only=False, + attention_implementation="eager", + load_pretrained_gemma3=False, + dropout=0.1, + ) + ) + + # Training presets + optimizer_lr: float = 2.5e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + def __post_init__(self): + """Validates configuration values after dataclass initialization. + + Raises: + ValueError: If ``n_obs_steps`` is not 1. + """ + super().__post_init__() + + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + def validate_features(self) -> None: + """Adds placeholder camera features for empty camera slots. + + Dynamically inserts zero-filled camera entries into + ``self.input_features`` for each configured empty camera, so the + model receives a fixed number of image inputs regardless of which + cameras are physically present. + """ + + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + """Returns the default AdamW optimizer configuration. + + Returns: + An ``AdamWConfig`` populated from this config's ``optimizer_*`` + fields. + """ + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig: + """Returns the default cosine-decay-with-warmup scheduler configuration. + + Returns: + A ``CosineDecayWithWarmupSchedulerConfig`` populated from this + config's ``scheduler_*`` and ``optimizer_lr`` fields. + """ + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + """Returns ``None``; observation deltas are not used by this planner.""" + return None + + @property + def action_delta_indices(self) -> None: + """Returns ``None``; action deltas are not used by this planner.""" + return None + + @property + def reward_delta_indices(self) -> None: + """Returns ``None``; reward deltas are not used by this planner.""" + return None diff --git a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py new file mode 100644 index 00000000..f2ef2ec4 --- /dev/null +++ b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py @@ -0,0 +1,1434 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π07 High-Level Planner: A Vision-Language Model for Memory and Subtask Prediction. + +This module implements the high-level planner for π07, built on top of the +Gemma 3 VLM backbone (with a Gemma-v1 action expert; see +``opentau.policies.pi07.gemma3_with_expert``). Given images, language +instructions, robot state, and past memory, the planner autoregressively +predicts updated memory and a subtask string. +""" + +import builtins +import logging +import math +from pathlib import Path + +import torch +import torch.nn.functional as F # noqa: N812 +from einops import rearrange +from torch import Tensor, nn +from transformers import AutoProcessor, AutoTokenizer + +from opentau.configs.policies import PreTrainedConfig +from opentau.policies.normalize import Normalize +from opentau.policies.pi07.gemma3_with_expert import ( + Gemma3WithExpertModel, +) +from opentau.policies.pi07.high_level_planner.configuration_pi07_high_level import ( + PI07HighLevelPlannerConfig, +) +from opentau.policies.pretrained import PreTrainedPolicy, T +from opentau.utils.accelerate_utils import get_proc_accelerator + + +def _preferred_dtype() -> torch.dtype: + """Returns the preferred compute dtype for the current execution context. + + Returns: + ``torch.float32`` during ONNX export, ``torch.bfloat16`` otherwise. + """ + return torch.float32 if torch.onnx.is_in_onnx_export() else torch.bfloat16 + + +def make_att_2d_masks( + pad_masks: Tensor, + att_masks: Tensor, + n_cross_att_tokens: int | None = None, + cross_att_pad_masks: Tensor | None = None, +) -> Tensor: + """Creates a 2-D attention mask given padding and 1-D attention masks. + + Tokens can attend to valid inputs tokens which have a cumulative `att_masks` + smaller or equal to theirs. This way `att_masks` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + pad_masks: bool[B, N] true if its part of the input, false if padding. + att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + n_cross_att_tokens: Add attention mask for cross-attention tokens if + `n_cross_att_tokens` is provided. + cross_att_pad_masks: Padding masks for cross attention tokens. Required if + `n_cross_att_tokens` is provided. + + Returns: + A 2D attention mask tensor of shape (B, N + n_cross_att_tokens, N + n_cross_att_tokens) + if n_cross_att_tokens is provided, else (B, N, N). + + Raises: + ValueError: If att_masks or pad_masks are not 2D (including batch dimension). + AssertionError: If cross_att_pad_masks is missing when n_cross_att_tokens is set, + or if its shape is incorrect. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + + # If `n_cross_att_tokens` is provided, we add a mask for cross-attention tokens at the end of the sequence. + if n_cross_att_tokens is not None: + assert cross_att_pad_masks is not None, ( + "cross_att_pad_masks must be provided if n_cross_att_tokens is provided" + ) + assert cross_att_pad_masks.shape == (att_masks.size(0), n_cross_att_tokens), ( + "cross_att_pad_masks must have shape (batch_size, n_cross_att_tokens)" + ) + + cross_att_mask = torch.full( + (att_masks.size(0), att_masks.size(1), n_cross_att_tokens), + True, + dtype=torch.bool, + device=att_masks.device, + ) + + # Apply padding masks: pad_masks for rows, cross_att_pad_masks for columns + cross_att_mask = cross_att_mask & pad_masks[:, :, None] & cross_att_pad_masks[:, None, :] + # The cross_att_masks are concatenated before the att_2d_masks + att_2d_masks = torch.cat((cross_att_mask, att_2d_masks), dim=2) + + return att_2d_masks + + +def resize_with_pad(img: Tensor, width: int, height: int, pad_value: int = -1) -> Tensor: + """Resizes an image to fit within the specified dimensions while maintaining aspect ratio, + and pads the remaining area with the specified value. + + Args: + img: Input image tensor of shape (batch_size, channels, current_height, current_width). + width: Target width. + height: Target height. + pad_value: Value to use for padding. Defaults to -1. + + Returns: + The resized and padded image tensor of shape (batch_size, channels, height, width). + + Raises: + ValueError: If the input image tensor does not have 4 dimensions. + """ + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +class PI07HighLevelPlannerPolicy(PreTrainedPolicy): + """Policy wrapper for the π07 high-level planner. + + Handles input normalisation, tokenisation of language/memory/response, + and delegates to :class:`PI07HighLevelPlannerModel` for autoregressive + prediction of updated memory and subtask strings. + """ + + config_class = PI07HighLevelPlannerConfig + name = "pi07_high_level" + + def __init__( + self, + config: PI07HighLevelPlannerConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """Initializes the PI07HighLevelPlannerPolicy. + + Args: + config: Policy configuration instance. + dataset_stats: Dataset statistics for input normalization. If not + provided here, they must be supplied via ``load_state_dict`` + before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + + self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + + self.discrete_action_processor = AutoProcessor.from_pretrained( + "physical-intelligence/fast", trust_remote_code=True + ) + # Get vocab size from processor + discrete_action_vocab_size = getattr(self.discrete_action_processor, "vocab_size", None) + self.model = PI07HighLevelPlannerModel(config, discrete_action_vocab_size=discrete_action_vocab_size) + + self.reset() + + def reset(self) -> None: + """Resets any internal state. Call when the environment resets.""" + pass + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping. + + Args: + pretrained_name_or_path: Path to the pretrained model or its name on the Hub. + config: Configuration object. + force_download: Whether to force download the model weights. + resume_download: Whether to resume download. + proxies: Proxy configuration. + token: Authentication token. + cache_dir: Directory to cache downloaded files. + local_files_only: Whether to only look for files locally. + revision: Specific model revision. + strict: Whether to strictly enforce state dict matching. + **kwargs: Additional keyword arguments. + + Returns: + The loaded model instance. + + Raises: + ValueError: If pretrained_name_or_path is None. + """ + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + acc = get_proc_accelerator() + is_main_process = acc.is_main_process if acc else True + try: + # Try to load the pytorch_model.bin or model.safetensors file + if is_main_process: + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + if is_main_process: + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + if is_main_process: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model.") and "normalize" not in key: + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10 and is_main_process: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0 and is_main_process: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) + + if missing_keys and is_main_process: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 20: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:20]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 20} more") + + if unexpected_keys and is_main_process: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 20: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:20]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 20} more") + + if not missing_keys and not unexpected_keys and is_main_process: + print("All keys loaded successfully!") + + except Exception as e: + if is_main_process: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict: dict[str, Tensor], model_config: PreTrainedConfig + ) -> dict[str, Tensor]: # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture. + + Args: + state_dict: The state dictionary to fix. + model_config: The model configuration. + + Returns: + The fixed state dictionary. + """ + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"gemma3_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.gemma3_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"gemma3_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.gemma3_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi05 + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + """Returns the parameters to be optimized. + + Returns: + A generator over the model parameters. + """ + return self.parameters() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Not implemented for the high-level planner. + + Args: + batch: Batch of data containing environment observations. + + Raises: + NotImplementedError: Always, since the high-level planner predicts + memory and subtask strings, not action chunks. + """ + raise NotImplementedError("The high-level planner does not predict action chunks.") + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Not implemented for the high-level planner. + + Args: + batch: Batch of data containing environment observations. + + Raises: + NotImplementedError: Always, since the high-level planner predicts + memory and subtask strings, not action chunks. + """ + raise NotImplementedError("The high-level planner does not use select_action.") + + @torch.no_grad() + def sample_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Run inference to predict updated memory and subtask tokens. + + Normalizes inputs, prepares image and language embeddings, then + delegates to the inner model for autoregressive generation. + + Args: + batch: Batch of observations. Expected keys include images, + ``"prompt"``, ``"state"``, and ``"past_memory"``. + + Returns: + A tuple ``(memory_tokens, response_tokens)`` where each is a + ``Tensor`` of token IDs with shape ``(batch_size, seq_len)``. + """ + + batch = self.normalize_inputs(batch) + + images, img_masks = self.prepare_images(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + metadata_tokens, metadata_masks = self.prepare_metadata(batch) + + memory_tokens, response_tokens = self.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, metadata_tokens, metadata_masks + ) + + return memory_tokens, response_tokens + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Runs a full training forward pass and computes the loss. + + Tokenizes images, language (with state and past memory), target memory, + and target response, then computes cross-entropy losses for both the + memory and response token predictions. + + Args: + batch: Batch of training data. Expected keys include images, + ``"prompt"``, ``"state"``, ``"past_memory"``, ``"response"``, + and ``"next_memory"``. + + Returns: + A dict with ``"MSE"`` (always zero, kept for interface + compatibility) and ``"CE"`` (sum of memory and response + cross-entropy losses). + """ + batch = self.normalize_inputs(batch) + + images, img_masks = self.prepare_images( + batch + ) # in img_masks we have True for real images and False for padded images + lang_tokens, lang_masks = self.prepare_language( + batch + ) # in lang_masks we have True for real tokens and False for padded tokens + # response prediction is to predict the response . It will attend to image and language inputs. + + metadata_tokens, metadata_masks = self.prepare_metadata( + batch + ) # in metadata_masks we have True for real tokens and False for padded tokens + response_tokens, response_masks = self.prepare_response( + batch + ) # in response_masks we have True for real tokens and False for padded tokens + + # memory prediction is to predict the memory . It will attend to image and language inputs. + memory_tokens, memory_masks = self.prepare_next_memory( + batch + ) # in memory_masks we have True for real tokens and False for padded tokens + losses = self.model.forward( + images, + img_masks, + lang_tokens, + lang_masks, + response_tokens, + response_masks, + memory_tokens, + memory_masks, + metadata_tokens, + metadata_masks, + ) + + mse_loss = losses["MSE"] + ce_loss = losses["CE"] + + return {"MSE": mse_loss, "CE": ce_loss} + + def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]: + """Discretizes the state into bins and converts it to a string representation. + + Each dimension of the state vector is discretized into 256 bins. + The values of each dimension of the state are expected to be in the range [-1, 1]. + The discretization bins are linearly spaced between -1 and 1. + The index of the bin for each dimension is then concatenated into a space-separated string. + + Args: + batch: Batch of data containing the "state" tensor. + + Returns: + A list of strings, where each string is a space-separated list of discretized state values. + + Raises: + ValueError: If the state values are not normalized between -1 and 1. + """ + state = batch["state"] + state_cpu = state.to(device="cpu", dtype=torch.float32) + if torch.any(state_cpu < -1.0) or torch.any(state_cpu > 1.0): + logging.warning( + f"State values are not normalized between -1 and 1. Min: {state_cpu.min().item()}, Max: {state_cpu.max().item()}" + ) + state_clipped = torch.clamp(state_cpu, -1.0, 1.0) + # replicate np.digitize with torch for torch.compile compatibility + bin_indices = ((state_clipped + 1.0) * 128.0).long().clamp(0, 255) + discretized_states = bin_indices.cpu().tolist() + return [ + " ".join(map(str, row)) for row in discretized_states + ] # TODO: return a tensor instead of a list of strings? + + def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Apply preprocessing to the images. + + Resizes to 224x224 and padding to keep aspect ratio, and converts pixel range + from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. + + Args: + batch: Batch of data containing image tensors. + + Returns: + A tuple containing: + - images: A list of processed image tensors. + - img_masks: A list of image mask tensors. + + Raises: + ValueError: If no image features are present in the batch. + """ + images = [] + img_masks = [] + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + images.append(img) + img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + img = torch.ones_like(img) * -1 + mask = torch.zeros_like(mask) + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenizes the composite language prompt. + + Builds a prompt string from the task instruction, past memory, and + discretized robot state separated by ```` tokens, then tokenizes + and pads to ``prompt_max_length``. + + Args: + batch: Batch containing ``"prompt"`` (task strings), + ``"state"`` (state tensor), and ``"past_memory"`` (list of + past memory strings). + + Returns: + A tuple ``(lang_tokens, lang_masks)`` where: + - lang_tokens: Token IDs of shape ``(batch_size, prompt_max_length)``. + - lang_masks: Boolean attention mask of the same shape. + """ + device = batch["state"].device + tasks = batch["prompt"] + + # add state to the prompt + state = self.prepare_discrete_state(batch) + # using to separate each modality + past_memory = batch["past_memory"] + prompt = [ + f"Task: {task}, Past Memory: {past_mem}, State: {state}, " + for task, past_mem, state in zip(tasks, past_memory, state, strict=False) + ] + tokenized_prompt = self.language_tokenizer.__call__( + prompt, + padding="max_length", + padding_side="right", + max_length=self.config.prompt_max_length, + return_tensors="pt", + truncation=True, + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenizes the metadata for training. + + Wraps each metadata string with an ```` suffix, then tokenizes and + pads to ``metadata_max_length``. + """ + + metadata = [] + for speed, quality, mistake, speed_is_pad, quality_is_pad, mistake_is_pad in zip( + batch["speed"], + batch["quality"], + batch["mistake"], + batch["speed_is_pad"], + batch["quality_is_pad"], + batch["mistake_is_pad"], + strict=True, + ): + segments = [] + if not speed_is_pad: + segments.append(f"Speed: {str(speed.item())}, ") + + if not quality_is_pad: + segments.append(f"Quality: {str(quality.item())}, ") + + if not mistake_is_pad: + segments.append(f"Mistake: {str(mistake.item())}, ") + + metadata.append(f"Metadata: {' '.join(segments)}") + + device = batch["state"].device + tokenized_metadata = self.language_tokenizer.__call__( + metadata, + padding="max_length", + padding_side="right", + max_length=self.config.metadata_max_length, + return_tensors="pt", + truncation=True, + ) + metadata_tokens = tokenized_metadata["input_ids"].to(device=device) + metadata_masks = tokenized_metadata["attention_mask"].to(device=device, dtype=torch.bool) + + return metadata_tokens, metadata_masks + + def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenizes the target subtask response for training. + + Wraps each response string with an ``Actions:`` suffix, then + tokenizes and pads to ``response_max_length``. + + Args: + batch: Batch containing ``"response"`` (list of subtask strings) + and ``"state"`` (used only to determine the device). + + Returns: + A tuple ``(response_tokens, response_masks)`` where: + - response_tokens: Token IDs of shape + ``(batch_size, response_max_length)``. + - response_masks: Boolean attention mask of the same shape + (``True`` for real tokens, ``False`` for padding). + """ + + device = batch["state"].device + responses = batch["response"] + + # if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response. + response_prompt = [f"{response}" for response in responses] + + tokenized_response = self.language_tokenizer.__call__( + response_prompt, + padding="max_length", + padding_side="right", + max_length=self.config.response_max_length, + return_tensors="pt", + truncation=True, + ) + response_tokens = tokenized_response["input_ids"].to(device=device) + response_masks = tokenized_response["attention_mask"].to(device=device, dtype=torch.bool) + + return response_tokens, response_masks + + def prepare_next_memory(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenizes the target updated memory for training. + + Wraps each memory string with an ```` suffix, then tokenizes and + pads to ``memory_max_length``. + + Args: + batch: Batch containing ``"next_memory"`` (list of target memory + strings) and ``"state"`` (used only to determine the device). + + Returns: + A tuple ``(memory_tokens, memory_masks)`` where: + - memory_tokens: Token IDs of shape + ``(batch_size, memory_max_length)``. + - memory_masks: Boolean attention mask of the same shape + (``True`` for real tokens, ``False`` for padding). + """ + + device = batch["state"].device + next_memory = batch["next_memory"] + + # if '' is found in next_memory then it is not for loss calculation (used for robotic dataset with no subtask), so add pad token. + memory_prompt = [f"{mem}" for mem in next_memory] + + tokenized_memory = self.language_tokenizer.__call__( + memory_prompt, + padding="max_length", + padding_side="right", + max_length=self.config.memory_max_length, + return_tensors="pt", + truncation=True, + ) + memory_tokens = tokenized_memory["input_ids"].to(device=device) + memory_masks = tokenized_memory["attention_mask"].to(device=device, dtype=torch.bool) + + return memory_tokens, memory_masks + + +class PI07HighLevelPlannerModel(nn.Module): + """π07 High-Level Planner inner model. + + Uses the Gemma 3 VLM backbone to encode images and a composite language + prompt (task + past context) and optional episode metadata, with fixed tokenizer + spans ``";\\n "``, ``"Updated Memory: "``, and (in full training runs) ``"Subtask: "`` + before the predicted text, then autoregressively + predicts updated memory and subtask text: + + 1. **Updated memory** — next-token CE over ``memory_max_length`` slots after the + ``"Updated Memory: "`` span. + 2. **Subtask (response)** — next-token CE over ``response_max_length`` slots after + the ``"Subtask: "`` span (training). + + Inference mirrors training by inserting the live ``"Subtask: "`` token IDs into the + KV cache after memory decoding and before response decoding. + + Architecture (rough dataflow):: + + ┌───────────────────────────────────────────┐ + │ response content (subtask text) │ + │ ▲ │ + │ memory, ``Subtask: ``, lang, ``";\\n "``, images, … │ + │ ┌───────────────────────┐ │ + │ │ Gemma 3 │ │ + │ │ (autoregressive LM) │ │ + │ └────────────────────────┘ │ + └───────────────────────────────────────────┘ + + Args: + config: High-level planner configuration. + discrete_action_vocab_size: Vocabulary size for the discrete action + tokenizer (passed through to ``Gemma3WithExpertModel``). + """ + + def __init__(self, config: PI07HighLevelPlannerConfig, discrete_action_vocab_size: int | None = None): + """Initializes the PI07HighLevelPlannerModel. + + Args: + config: High-level planner configuration. + discrete_action_vocab_size: Vocabulary size for the discrete action + tokenizer (passed through to ``Gemma3WithExpertModel``). + """ + super().__init__() + self.config = config + + self.config.vlm_config.discrete_action_vocab_size = discrete_action_vocab_size + self.gemma3_with_expert = Gemma3WithExpertModel(self.config.vlm_config) + + self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + + def embed_prefix( + self, + images: list[Tensor], + img_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + response_tokens: Tensor | None = None, + response_masks: Tensor | None = None, + memory_tokens: Tensor | None = None, + memory_masks: Tensor | None = None, + metadata_tokens: Tensor | None = None, + metadata_masks: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Embeds and concatenates all prefix modalities for the transformer. + + Embeds images with SigLIP and language/metadata/memory/response spans with the + Gemma 3 embedding layer. **Concatenation order** (training when memory and response + are provided): + + ``[images | language | metadata | ";\\n " | "Updated Memory: " | memory_tokens | + "Subtask: " | response_tokens]`` + + When ``memory_tokens`` / ``response_tokens`` are omitted (inference), only the + fixed spans before those segments are present; memory and subtask text are filled in + via KV-cache decoding plus an explicit ``"Subtask: "`` injection before response AR. + + Attention pattern (via ``att_masks`` cumsums): + - Image + language tokens: bidirectional (``0``). + - Metadata (if present): new bidirectional block (``[1, 0, …, 0]``). + - ``";\\n "`` (same string as ``encode(";\n ", add_special_tokens=False)``): continues previous block (``0``). + - ``"Updated Memory: "``: new bidirectional block (``[1, 0, …, 0]``). + - Memory token slots: causal segment (``1`` per slot). + - ``"Subtask: "`` (training): new block then causal continuation within span. + - Response token slots: causal (``1`` per slot). + + Args: + images: List of image tensors, one per camera. + img_masks: List of boolean masks indicating real vs. padded images. + lang_tokens: Language token IDs of shape ``(B, prompt_max_length)``. + lang_masks: Boolean attention mask for language tokens. + response_tokens: Optional subtask response token IDs of shape + ``(B, response_max_length)``. Provided during training. + response_masks: Optional boolean mask for response tokens. + memory_tokens: Optional updated memory token IDs of shape + ``(B, memory_max_length)``. Provided during training. + memory_masks: Optional boolean mask for memory tokens. + metadata_tokens: Optional metadata token IDs of shape + ``(B, metadata_max_length)``. + metadata_masks: Optional boolean mask for metadata tokens. + + Returns: + A tuple ``(embs, pad_masks, att_masks)`` where: + - embs: Concatenated embeddings ``(B, total_seq_len, D)``. + - pad_masks: Boolean padding mask ``(B, total_seq_len)``. + - att_masks: 1-D attention pattern ``(B, total_seq_len)`` + used by :func:`make_att_2d_masks`. + """ + # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty + embs = [] + pad_masks = [] + att_masks = [] + + # TODO: remove for loop + for ( + img, + img_mask, + ) in zip(images, img_masks, strict=False): + img_emb = self.gemma3_with_expert.embed_image(img) + img_emb = img_emb.to(dtype=_preferred_dtype()) + + # Gemma 3's projector does not apply the `/ sqrt(text_hidden_size)` + # scaling that stock PaliGemma does, so no un-normalization is + # required here (matches `embed_image` in `gemma3_with_expert.py`). + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + lang_emb = self.gemma3_with_expert.embed_language_tokens(lang_tokens) + + # Normalize language embeddings + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + # full attention between image and language inputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + if metadata_tokens is not None: + metadata_emb = self.gemma3_with_expert.embed_language_tokens(metadata_tokens) + metadata_emb_dim = metadata_emb.shape[-1] + metadata_emb = metadata_emb * math.sqrt(metadata_emb_dim) + embs.append(metadata_emb) + pad_masks.append(metadata_masks) + att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) + + prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) + prefix_end_tokens = torch.tensor( + [prefix_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) + prefix_end_dim = prefix_end_emb.shape[-1] + prefix_end_emb = prefix_end_emb * math.sqrt(prefix_end_dim) + + num_prefix_end_embs = prefix_end_emb.shape[1] + prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) + + embs.append(prefix_end_emb) + pad_masks.append(prefix_end_mask) + att_masks += [0] * num_prefix_end_embs + + memory_start_indicator_ids = self.language_tokenizer.encode( + "Updated Memory: ", add_special_tokens=False + ) + memory_start_tokens = torch.tensor( + [memory_start_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + memory_start_emb = self.gemma3_with_expert.embed_language_tokens(memory_start_tokens) + memory_start_dim = memory_start_emb.shape[-1] + memory_start_emb = memory_start_emb * math.sqrt(memory_start_dim) + + num_memory_start_embs = memory_start_emb.shape[1] + memory_start_mask = torch.ones( + bsize, num_memory_start_embs, dtype=torch.bool, device=lang_tokens.device + ) + + embs.append(memory_start_emb) + pad_masks.append(memory_start_mask) + att_masks += [1] + [0] * (num_memory_start_embs - 1) + + if memory_tokens is not None: + memory_emb = self.gemma3_with_expert.embed_language_tokens(memory_tokens) + # Normalize memory language embeddings + memory_emb_dim = memory_emb.shape[-1] + memory_emb = memory_emb * math.sqrt(memory_emb_dim) + + embs.append(memory_emb) + pad_masks.append(memory_masks) + + # full attention between image, language and memory inputs + num_memory_embs = memory_emb.shape[1] + att_masks += [1] * num_memory_embs + + if response_tokens is not None: + response_start_indicator_ids = self.language_tokenizer.encode( + "Subtask: ", add_special_tokens=False + ) + response_start_tokens = torch.tensor( + [response_start_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + response_start_emb = self.gemma3_with_expert.embed_language_tokens(response_start_tokens) + response_start_dim = response_start_emb.shape[-1] + response_start_emb = response_start_emb * math.sqrt(response_start_dim) + + num_response_start_embs = response_start_emb.shape[1] + response_start_mask = torch.ones( + bsize, num_response_start_embs, dtype=torch.bool, device=lang_tokens.device + ) + + embs.append(response_start_emb) + pad_masks.append(response_start_mask) + att_masks += [1] + [0] * (num_response_start_embs - 1) + + response_emb = self.gemma3_with_expert.embed_language_tokens(response_tokens) + + # Normalize response language embeddings + response_emb_dim = response_emb.shape[-1] + response_emb = response_emb * math.sqrt(response_emb_dim) + + embs.append(response_emb) + pad_masks.append(response_masks) + + # full attention between image, language and response inputs + num_response_embs = response_emb.shape[1] + att_masks += [1] * num_response_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def forward( + self, + images: list[Tensor], + img_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + response_tokens: Tensor | None = None, + response_masks: Tensor | None = None, + memory_tokens: Tensor | None = None, + memory_masks: Tensor | None = None, + metadata_tokens: Tensor | None = None, + metadata_masks: Tensor | None = None, + ) -> dict[str, Tensor]: + """Training forward pass: embeds all modalities and computes CE losses. + + The prefix matches :meth:`embed_prefix` when memory and response tensors are set: + fixed separators ``";\\n "``, ``"Updated Memory: "``, and ``"Subtask: "`` appear + in addition to ``metadata``, ``memory_tokens``, and ``response_tokens``. CE slices use + negative offsets from the **sequence tail**, relying on + ``config.subtask_indicator_max_length`` so memory logits align with memory contents + even though ``"Subtask: "`` sits between memory and response text. + + Args: + images: List of image tensors, one per camera. + img_masks: List of boolean masks for real vs. padded images. + lang_tokens: Language token IDs ``(B, prompt_max_length)``. + lang_masks: Boolean attention mask for language tokens. + response_tokens: Subtask response token IDs + ``(B, response_max_length)``. + response_masks: Boolean mask for response tokens. + memory_tokens: Updated memory token IDs + ``(B, memory_max_length)``. + memory_masks: Boolean mask for memory tokens. + metadata_tokens: Optional metadata token IDs + ``(B, metadata_max_length)``. + metadata_masks: Optional boolean mask for metadata tokens. + + Returns: + A dict with ``"MSE"`` (zero tensor, for interface compatibility) + and ``"CE"`` (sum of memory and response cross-entropy losses). + """ + # Run VLM first to get key value cache + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, + img_masks, + lang_tokens, + lang_masks, + response_tokens, + response_masks, + memory_tokens, + memory_masks, + metadata_tokens, + metadata_masks, + ) + + vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # avoids using discrete action for predicting continuous flow matching action + num_cross_att_tokens = prefix_embs.shape[1] + + (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( + attention_mask=vlm_2d_attention_mask, + position_ids=vlm_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + n_cross_att_tokens=num_cross_att_tokens, + use_cache=False, + fill_kv_cache=True, + ) + + batch_size, seq_len = response_tokens.shape + response_token_start = -self.config.response_max_length + # Slice covers only response **content** slots at the tail (after ``Subtask: ``). + response_token_end = -1 + response_slice_object = slice(response_token_start, response_token_end) + response_out = prefix_out[ + :, + response_slice_object, + ] + response_logits = self.gemma3_with_expert._lm_head()(response_out) + # response slice to exclude the token from response while calculating loss. + response_slice = slice(1, None) + response_logits = response_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation + response_logits = rearrange(response_logits, "b s d -> (b s) d") + response_labels = rearrange(response_tokens[:, response_slice], "b s -> (b s)") + response_ce_loss = F.cross_entropy(response_logits, response_labels, reduction="none") + + response_ce_loss = rearrange(response_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len - 1) + + # remove pad tokens + response_is_pad = ~response_masks # convert into format where value for pad is True + # helps to control loss for response tokens in case of robotic data and VQA data + response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice] + + # compute mean + response_ce_loss = response_ce_loss.mean() + + batch_size, seq_len = memory_tokens.shape + memory_token_start = ( + -self.config.memory_max_length + - self.config.response_max_length + - self.config.subtask_indicator_max_length + ) + # Memory **content** span: immediately after ``Subtask: `` and response (from the end). + memory_token_end = -self.config.response_max_length - self.config.subtask_indicator_max_length - 1 + memory_slice_object = slice(memory_token_start, memory_token_end) + memory_out = prefix_out[ + :, + memory_slice_object, + ] + memory_logits = self.gemma3_with_expert._lm_head()(memory_out) + # memory slice to exclude the token from memory while calculating loss. + memory_slice = slice(1, None) + memory_logits = memory_logits.to(dtype=torch.float32) # upcast to float32 for loss calculation + memory_logits = rearrange(memory_logits, "b s d -> (b s) d") + memory_labels = rearrange(memory_tokens[:, memory_slice], "b s -> (b s)") + memory_ce_loss = F.cross_entropy(memory_logits, memory_labels, reduction="none") + + memory_ce_loss = rearrange(memory_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len - 1) + + # remove pad tokens + memory_is_pad = ~memory_masks # convert into format where value for pad is True + # helps to control loss for memory tokens in case of robotic data and VQA data + memory_ce_loss = memory_ce_loss * ~memory_is_pad[:, memory_slice] + + # compute mean + memory_ce_loss = memory_ce_loss.mean() + + ce_loss = response_ce_loss + memory_ce_loss + + return {"MSE": torch.zeros_like(ce_loss, requires_grad=False), "CE": ce_loss} + + def sample_actions( + self, + images: list[Tensor], + img_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + metadata_tokens: Tensor | None = None, + metadata_masks: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Inference forward: autoregressively generates memory and subtask tokens. + + Runs ``memory_max_length`` ``infer_autoregressive`` steps, then feeds the same + ``"Subtask: "`` token IDs used in training (tokenizer-dependent length + ``subtask_indicator_max_length``) through the cache, then runs ``response_max_length`` + response steps. Each step conditions on prior KV-cache entries. + + Args: + images: List of image tensors, one per camera. + img_masks: List of boolean masks for real vs. padded images. + lang_tokens: Language token IDs ``(B, prompt_max_length)``. + lang_masks: Boolean attention mask for language tokens. + metadata_tokens: Optional metadata token IDs + ``(B, metadata_max_length)``. + metadata_masks: Optional boolean mask for metadata tokens. + + Returns: + A tuple ``(memory_tokens, response_tokens)`` where each is a + ``Tensor`` of generated token IDs with shape + ``(B, memory_max_length)`` and ``(B, response_max_length)`` + respectively. + """ + bsize = lang_tokens.shape[0] + device = lang_tokens.device + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, + img_masks, + lang_tokens, + lang_masks, + metadata_tokens=metadata_tokens, + metadata_masks=metadata_masks, + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - 1 + + num_cross_att_tokens = prefix_embs.shape[1] + + # Compute image and language key value cache + (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + n_cross_att_tokens=num_cross_att_tokens, + use_cache=False, + fill_kv_cache=True, + ) + + # initialize memory tokens to empty tensor for storing memory tokens during inference + memory_tokens = torch.empty((bsize, 0), device=device, dtype=torch.long) + # if memory prediction is enabled, then predict memory tokens autoregressively + for auto_step in range(self.config.memory_max_length): + ( + prefix_out, + prefix_embs, + prefix_pad_masks, + prefix_att_masks, + prefix_offsets, + memory_tokens, + past_key_values, + ) = self.infer_autoregressive( + prefix_out=prefix_out, + prefix_embs=prefix_embs, + prefix_pad_masks=prefix_pad_masks, + prefix_att_masks=prefix_att_masks, + past_key_values=past_key_values, + prefix_offsets=prefix_offsets, + tokens=memory_tokens, + auto_step=auto_step, + bsize=bsize, + device=device, + ) + + # Match training `embed_prefix`: "Subtask: " must be in the KV cache before subtask + # autoregression (inference does not call `embed_prefix` with `response_tokens`). + response_start_indicator_ids = self.language_tokenizer.encode("Subtask: ", add_special_tokens=False) + for i, tid in enumerate(response_start_indicator_ids): + token = torch.full((bsize, 1), int(tid), device=device, dtype=torch.long) + emb = self.gemma3_with_expert.embed_language_tokens(token) + emb = emb * math.sqrt(emb.shape[-1]) + pad_row = torch.ones((bsize, 1), device=device, dtype=prefix_pad_masks.dtype) + if prefix_att_masks.dtype == torch.bool: + new_att = torch.full((bsize, 1), i == 0, device=device, dtype=torch.bool) + else: + new_att = torch.full( + (bsize, 1), + 1.0 if i == 0 else 0.0, + device=device, + dtype=prefix_att_masks.dtype, + ) + prefix_embs = torch.cat([prefix_embs, emb], dim=1) + prefix_pad_masks = torch.cat([prefix_pad_masks, pad_row], dim=1) + prefix_att_masks = torch.cat([prefix_att_masks, new_att], dim=1) + num_cross = prefix_pad_masks.shape[1] + att_2d_masks = make_att_2d_masks( + pad_row, + new_att, + n_cross_att_tokens=num_cross - 1, + cross_att_pad_masks=prefix_pad_masks[:, : num_cross - 1], + ) + prefix_offsets = prefix_offsets + pad_row.long() + (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=prefix_offsets, + past_key_values=past_key_values, + inputs_embeds=[emb, None], + n_cross_att_tokens=num_cross, + use_cache=True, + fill_kv_cache=True, + ) + + # initialize response tokens to empty tensor for storing response tokens during inference + response_tokens = torch.empty((bsize, 0), device=device, dtype=torch.long) + # if response prediction is enabled, then predict response tokens autoregressively + for auto_step in range(self.config.response_max_length): + ( + prefix_out, + prefix_embs, + prefix_pad_masks, + prefix_att_masks, + prefix_offsets, + response_tokens, + past_key_values, + ) = self.infer_autoregressive( + prefix_out=prefix_out, + prefix_embs=prefix_embs, + prefix_pad_masks=prefix_pad_masks, + prefix_att_masks=prefix_att_masks, + past_key_values=past_key_values, + prefix_offsets=prefix_offsets, + tokens=response_tokens, + auto_step=auto_step, + bsize=bsize, + device=device, + ) + + return memory_tokens, response_tokens + + def infer_autoregressive( + self, + prefix_out: Tensor, + prefix_embs: Tensor, + prefix_pad_masks: Tensor, + prefix_att_masks: Tensor, + past_key_values: list[dict[str, Tensor]], + prefix_offsets: Tensor, + tokens: Tensor, + auto_step: int, + bsize: int, + device: torch.device, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, list[dict[str, Tensor]]]: + """Performs one autoregressive generation step. + + At ``auto_step == 0`` a ```` token seeds the generation; on + subsequent steps the most-recent logits are argmax-decoded into the + next token. Once an ```` or ```` token appears in the + accumulated sequence the remaining positions are filled with padding. + + The method updates the KV-cache, prefix embeddings, and masks so that + the next call can attend to all previously generated tokens. + + Args: + prefix_out: Transformer output from the previous step + ``(B, 1, D)`` or ``(B, seq, D)`` on the first call. + prefix_embs: Running concatenation of all embeddings fed to the + transformer so far ``(B, current_seq, D)``. + prefix_pad_masks: Boolean padding mask ``(B, current_seq)``. + prefix_att_masks: 1-D attention pattern ``(B, current_seq)``. + past_key_values: KV-cache list from previous transformer calls. + prefix_offsets: Position ID offsets ``(B, 1)`` tracking the + current absolute position for each batch element. + tokens: Accumulated generated token IDs ``(B, steps_so_far)``. + auto_step: Current step index (0-based). + bsize: Batch size. + device: Torch device for tensor creation. + + Returns: + A tuple of updated state tensors for the next step: + ``(prefix_out, prefix_embs, prefix_pad_masks, prefix_att_masks, + prefix_offsets, tokens, past_key_values)``. + """ + EOS_TOKEN = self.language_tokenizer.convert_tokens_to_ids(self.language_tokenizer.eos_token) # noqa: N806 + if auto_step == 0: + # Start the autoregressive inference with token + token = torch.full( + (bsize, 1), + self.language_tokenizer.bos_token_id, + device=device, + dtype=torch.long, + ) + else: + # get the last predicted token from the prefix output which is predicted response + token = prefix_out[:, -1:] + token = self.gemma3_with_expert._lm_head()(token).argmax(dim=-1) + + PAD_TOKEN = self.language_tokenizer.pad_token_id # noqa: N806 + # Create pad masks: False if previous token was EOS or PAD + if tokens.shape[1] > 1: + prev_tokens = tokens + has_eos = (prev_tokens == EOS_TOKEN).any(dim=1, keepdim=True) + has_pad = (prev_tokens == PAD_TOKEN).any(dim=1, keepdim=True) + # check if the previous token was EOS or PAD. If so, then the current token should be padded, so its not attended by flow matching action expert. + pad_masks = ~(has_eos | has_pad) + token = torch.where( + pad_masks, + token, + torch.tensor(PAD_TOKEN, device=device, dtype=token.dtype), + ) + else: + pad_masks = torch.ones((bsize, 1), device=device, dtype=torch.bool) + + # Updating response tokens with current predicted token + tokens = torch.cat([tokens, token], dim=1) + + # Embed the current predicted token + emb = self.gemma3_with_expert.embed_language_tokens(token) + + # Normalize response language embeddings + emb_dim = emb.shape[-1] + emb = emb * math.sqrt(emb_dim) + + att_masks = torch.ones((bsize, 1), device=device, dtype=emb.dtype) + + # update the prefix embs, pad masks and att masks, so it can be used by action experts + prefix_embs = torch.cat([prefix_embs, emb], dim=1) + prefix_pad_masks = torch.cat([prefix_pad_masks, pad_masks], dim=1) + prefix_att_masks = torch.cat([prefix_att_masks, att_masks], dim=1) + + num_cross_att_tokens = prefix_pad_masks.shape[1] + # create the attention mask for the response tokens + att_2d_masks = make_att_2d_masks( + pad_masks, + att_masks, + n_cross_att_tokens=num_cross_att_tokens - 1, + cross_att_pad_masks=prefix_pad_masks[:, : num_cross_att_tokens - 1], + ) + prefix_offsets = prefix_offsets + pad_masks.long() + prefix_position_ids = prefix_offsets + + # Compute image and language key value cache + (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=past_key_values, + inputs_embeds=[emb, None], + n_cross_att_tokens=num_cross_att_tokens, + use_cache=True, + fill_kv_cache=True, + ) + + return ( + prefix_out, + prefix_embs, + prefix_pad_masks, + prefix_att_masks, + prefix_offsets, + tokens, + past_key_values, + ) diff --git a/src/opentau/policies/pi07/low_level_planner/__init__.py b/src/opentau/policies/pi07/low_level_planner/__init__.py new file mode 100644 index 00000000..6c76fe78 --- /dev/null +++ b/src/opentau/policies/pi07/low_level_planner/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +π07 Low-Level Planner Module. + +This module implements the low-level planner of the π07 hierarchical +architecture. It uses SpaceTimeSiglip as a video encoder, processes temporal state +sequences (one continuous token per timestep), and supports optional subtask +response, subgoal image, and metadata conditioning. Action generation +combines flow matching (continuous actions via an action expert) with FAST +discrete token prediction (through the VLM backbone with Knowledge +Insulation). +""" diff --git a/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py new file mode 100644 index 00000000..d4e01d6f --- /dev/null +++ b/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py @@ -0,0 +1,263 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration module for the π07 Low-Level Planner. + +This module defines the ``PI07LowLevelPlannerConfig`` class, which handles +configuration parameters for the π07 low-level planner. This planner uses +SpaceTimeSiglip (the Gemma 3 SigLIP vision tower wrapped with space-time +separable attention) as a video encoder, processes temporal state sequences +(one continuous token per timestep), and supports optional subtask response, +subgoal image, and metadata conditioning. +""" + +from dataclasses import dataclass, field + +from opentau.configs.policies import PreTrainedConfig +from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature +from opentau.optim.optimizers import AdamWConfig +from opentau.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, + LRSchedulerConfig, +) +from opentau.policies.pi07.gemma3_with_expert import Gemma3WithExpertConfig + + +@PreTrainedConfig.register_subclass("pi07_low_level") +@dataclass +class PI07LowLevelPlannerConfig(PreTrainedConfig): + """Configuration for the π07 low-level planner. + + The low-level planner generates continuous action chunks via flow matching + and discrete FAST action tokens through the VLM backbone. It uses + :class:`SpaceTimeSiglipVideoEncoder` (the Gemma 3 SigLIP vision tower + wrapped with space-time separable attention) as a video encoder and + projects temporal state sequences into per-timestep continuous tokens. + + Args: + n_obs_steps: Number of temporal video frames passed to the + SpaceTimeSiglip encoder per forward call. Must equal + ``n_obs_history`` when the latter is set. + chunk_size: Size of the action chunk (upper bound for + ``n_action_steps``). Defaults to 50. + n_action_steps: Number of action steps to predict. Defaults to 50. + normalization_mapping: Mapping of feature names to normalization modes. + max_state_dim: Maximum dimension for state vectors. Defaults to 32. + max_action_dim: Maximum dimension for action vectors. Defaults to 32. + resize_imgs_with_padding: Target ``(H, W)`` for video frame resizing + with padding. Must match the Gemma 3 vision tower's + ``image_size`` (the projector hardcodes + ``patches_per_image = image_size // patch_size``). Defaults to + ``(448, 448)`` to match + ``vlm_config.gemma3_config.vision_config.image_size``. + empty_cameras: Number of empty camera inputs to add. Defaults to 0. + prompt_max_length: Maximum token length for language prompts. + Defaults to 256. + discrete_action_max_length: Maximum length for FAST action tokens. + Defaults to 32. + metadata_max_length: Maximum token length for metadata strings. + Defaults to 52. + response_max_length: Maximum token length for high-level planner + subtask responses. Defaults to 52. + proj_width: Width of the action projection layer. Must match the + Gemma-v1 action expert's ``hidden_size`` so that suffix embeddings + are compatible with the expert's transformer layers. Defaults to + 1280 (the expert hidden size in ``Gemma3WithExpertConfig``). + dropout: Dropout rate. Defaults to 0.1. + num_steps: Number of flow-matching denoising steps. Defaults to 10. + max_delay: Maximum number of prefix action steps for real-time + inference. Defaults to 0. + attention_implementation: Attention backend (``"eager"`` or + ``"fa2"``). Defaults to ``"eager"``. + freeze_vision_encoder: Whether to freeze the SigLIP vision tower. + Defaults to True. + train_expert_only: Whether to train only the action expert. + Defaults to False. + vlm_config: Bundled :class:`Gemma3WithExpertConfig` for the Gemma 3 + VLM backbone + Gemma-v1 action expert. + spacetime_layer_stride: Wrap every Nth SigLIP encoder layer with + space-time separable attention. ``1`` wraps every layer. + gradient_checkpointing: If True, wrap each SigLIP encoder layer in + ``torch.utils.checkpoint.checkpoint`` during training. + """ + + # Input / output structure. + n_obs_steps: int = 8 + chunk_size: int = 50 + n_action_steps: int = 50 + + # Observation history for inference buffering. + # ``n_obs_history`` controls how many evenly-spaced historical frames the + # inference buffer keeps. ``history_interval`` is the stride between those + # frames. Together they determine ``obs_buffer_size = (n_obs_history-1) * + # history_interval + 1``. Typically ``n_obs_history`` should equal + # ``n_obs_steps`` so the SpaceTimeSiglip encoder sees the same number of + # frames at training and inference time. + # Populated from DatasetMixtureConfig during training if unset. + n_obs_history: int | None = None + history_interval: int | None = None + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Video frame preprocessing. Must equal the Gemma 3 vision tower's + # image_size: `Gemma3MultiModalProjector` hardcodes + # `patches_per_image = image_size // patch_size`, so feeding a different + # resolution crashes the projector's reshape. + resize_imgs_with_padding: tuple[int, int] = (448, 448) + + empty_cameras: int = 0 + + # Language Tokenizer + prompt_max_length: int = 256 + + # Maximum length of the action tokens + discrete_action_max_length: int = 32 + + metadata_max_length: int = 52 + + response_max_length: int = 52 + + # Projector + proj_width: int = 1280 + + # Dropout + dropout: float = 0.1 + + # Decoding + num_steps: int = 10 + + # Real Time Inference + max_delay: int = 0 + + # Attention utils + attention_implementation: str = "eager" + + # Finetuning settings + freeze_vision_encoder: bool = True + train_expert_only: bool = False + + vlm_config: Gemma3WithExpertConfig = field( + default_factory=lambda: Gemma3WithExpertConfig( + freeze_vision_encoder=True, + train_expert_only=False, + attention_implementation="eager", + load_pretrained_gemma3=False, + dropout=0.1, + ) + ) + + # SpaceTime settings + spacetime_layer_stride: int = 1 + gradient_checkpointing: bool = False + + # Training presets + optimizer_lr: float = 2.5e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + @property + def obs_buffer_size(self) -> int: + """Total raw frames the observation buffer must keep. + + With ``n_obs_history=T`` and ``history_interval=k``, the buffer stores + the most recent ``(T-1)*k + 1`` frames so that ``T`` evenly-spaced + frames can be selected. + """ + if self.n_obs_history is None or self.n_obs_history <= 1: + return 1 + return (self.n_obs_history - 1) * (self.history_interval or 1) + 1 + + def __post_init__(self): + """Post-initialization validation.""" + super().__post_init__() + + if self.n_obs_history is not None: + if not isinstance(self.n_obs_history, int) or self.n_obs_history < 1: + raise ValueError( + f"`n_obs_history` must be None or a positive integer, got {self.n_obs_history}." + ) + if self.history_interval is None: + self.history_interval = 1 + if self.history_interval is not None and ( + not isinstance(self.history_interval, int) or self.history_interval < 1 + ): + raise ValueError( + f"`history_interval` must be None or a positive integer, got {self.history_interval}." + ) + + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + + if self.max_delay > self.chunk_size: + raise ValueError( + f"The max delay must be less than or equal to the chunk size. Got {self.max_delay} for `max_delay` and {self.chunk_size} for `chunk_size`." + ) + + def validate_features(self) -> None: + """Validates the features and adds empty cameras if configured.""" + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + """Returns the default optimizer configuration.""" + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig: + """Returns the default scheduler configuration.""" + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py new file mode 100644 index 00000000..57ab1e70 --- /dev/null +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -0,0 +1,1740 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π07 Low-Level Planner: a Vision-Language-Action Flow Model for continuous +action generation. + +The low-level planner is one half of the π07 hierarchical architecture. +Given video observations (encoded by SpaceTimeSiglip), a language prompt, an +optional subtask response from the high-level planner, temporal +proprioceptive state, optional subgoal images, and optional metadata, it +produces continuous action chunks via flow matching while simultaneously +predicting discrete (FAST) action tokens through the VLM backbone. + +Key differences from the base π05 policy: + 1. SpaceTimeSiglip video encoder replaces SigLIP, compressing each camera's + temporal video into 256 tokens via a Perceiver cross-attention reducer. + 2. Temporal state sequences (B, T, D) are projected per-timestep into + separate continuous tokens for the Gemma backbone. + 3. Supports optional subtask response, subgoal image, and metadata + conditioning for hierarchical planning. + 4. Knowledge Insulation: action-expert gradients are detached from the + VLM backbone to preserve language understanding capabilities. +""" + +import builtins +import logging +import math +from collections import deque +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from einops import rearrange, repeat +from torch import Tensor, nn +from transformers import AutoProcessor, AutoTokenizer + +from opentau.configs.policies import PreTrainedConfig +from opentau.configs.types import NormalizationMode +from opentau.policies.normalize import Normalize, Unnormalize +from opentau.policies.pi07.gemma3_with_expert import ( + Gemma3WithExpertModel, +) +from opentau.policies.pi07.low_level_planner.configuration_pi07_low_level import ( + PI07LowLevelPlannerConfig, +) +from opentau.policies.pi07.low_level_planner.video_encoder import SpaceTimeSiglipVideoEncoder +from opentau.policies.pretrained import PreTrainedPolicy, T +from opentau.utils.accelerate_utils import get_proc_accelerator +from opentau.utils.utils import get_safe_dtype + + +def _preferred_dtype(): + return torch.float32 if torch.onnx.is_in_onnx_export() else torch.bfloat16 + + +def create_sinusoidal_pos_embedding( + time: Tensor, dimension: int, min_period: float, max_period: float, device: torch.device | str = "cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions. + + Args: + time: A 2-D tensor of shape (batch_size, action_chunk_length). + dimension: The dimension of the embedding vectors. Must be divisible by 2. + min_period: The minimum period of the sinusoidal functions. + max_period: The maximum period of the sinusoidal functions. + device: The device to create the tensors on. Defaults to "cpu". + + Returns: + A tensor of shape (batch_size, action_chunk_length, dimension). + """ + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 2: + raise ValueError("The time tensor is expected to be of shape `(batch_size, action_chunk_length)`.") + + dtype = ( + get_safe_dtype(torch.float64, device.type) + if isinstance(device, torch.device) + else get_safe_dtype(torch.float64, device) + ) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = rearrange(scaling_factor, "d -> 1 1 d") * rearrange(time, "b c -> b c 1") + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=2) + return pos_emb + + +def make_att_2d_masks( + pad_masks: Tensor, + att_masks: Tensor, + n_cross_att_tokens: int | None = None, + cross_att_pad_masks: Tensor | None = None, +) -> Tensor: + """Creates a 2-D attention mask given padding and 1-D attention masks. + + Args: + pad_masks: bool[B, N] true if its part of the input, false if padding. + att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + n_cross_att_tokens: Add attention mask for cross-attention tokens if provided. + cross_att_pad_masks: Padding masks for cross attention tokens. + + Returns: + A 2D attention mask tensor. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + + if n_cross_att_tokens is not None: + assert cross_att_pad_masks is not None, ( + "cross_att_pad_masks must be provided if n_cross_att_tokens is provided" + ) + assert cross_att_pad_masks.shape == (att_masks.size(0), n_cross_att_tokens), ( + "cross_att_pad_masks must have shape (batch_size, n_cross_att_tokens)" + ) + + cross_att_mask = torch.full( + (att_masks.size(0), att_masks.size(1), n_cross_att_tokens), + True, + dtype=torch.bool, + device=att_masks.device, + ) + + cross_att_mask = cross_att_mask & pad_masks[:, :, None] & cross_att_pad_masks[:, None, :] + att_2d_masks = torch.cat((cross_att_mask, att_2d_masks), dim=2) + + return att_2d_masks + + +def resize_with_pad(img: Tensor, width: int, height: int, pad_value: int = -1) -> Tensor: + """Resizes an image to fit within the specified dimensions while maintaining aspect ratio, + and pads the remaining area. + + Args: + img: Input image tensor of shape (batch_size, channels, current_height, current_width). + width: Target width. + height: Target height. + pad_value: Value to use for padding. Defaults to -1. + + Returns: + The resized and padded image tensor of shape (batch_size, channels, height, width). + """ + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +def pad_discrete_tokens(tokens: list[list[int]], max_length: int) -> tuple[np.ndarray, np.ndarray]: + """Pads or truncates a list of discrete action token sequences to a fixed length. + + Args: + tokens: A list of discrete action token sequences. + max_length: The target length. + + Returns: + A tuple of (discrete_action_tokens, discrete_action_masks) numpy arrays. + """ + discrete_action_tokens = [] + discrete_action_masks = [] + for token in tokens: + if len(token) > max_length: + logging.warning( + f"Discrete action token length {len(token)} is greater than max_length {max_length}, truncating" + ) + discrete_action_tokens.append(np.array(token[:max_length])) + discrete_action_masks.append(np.ones(max_length, dtype=bool)) + else: + discrete_action_masks.append( + np.concatenate( + [np.ones(len(token), dtype=bool), np.zeros(max_length - len(token), dtype=bool)] + ) + ) + discrete_action_tokens.append(np.pad(token, (0, max_length - len(token)), constant_values=0)) + return np.array(discrete_action_tokens), np.array(discrete_action_masks) + + +# Policy wrapper +class PI07LowLevelPlannerPolicy(PreTrainedPolicy): + """Policy wrapper for the π07 low-level planner. + + Handles tokenization, normalization, observation-history buffering, and + action queue management around the inner + :class:`PI07LowLevelPlannerFlowMatching` model. + """ + + config_class = PI07LowLevelPlannerConfig + name = "pi07_low_level" + + def __init__( + self, + config: PI07LowLevelPlannerConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.normalize_discrete_actions = Normalize( + config.output_features, {"ACTION": NormalizationMode.MIN_MAX}, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + + self.discrete_action_processor = AutoProcessor.from_pretrained( + "physical-intelligence/fast", trust_remote_code=True + ) + discrete_action_vocab_size = getattr(self.discrete_action_processor, "vocab_size", None) + self.model = PI07LowLevelPlannerFlowMatching( + config, discrete_action_vocab_size=discrete_action_vocab_size + ) + + self.reset() + + def reset(self) -> None: + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + # Observation history buffers for inference. + self._obs_buffers: dict[str, deque] = {} + self._state_buffer: deque | None = None + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping.""" + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download if resume_download is not None else False, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + model = cls(config, **kwargs) + + acc = get_proc_accelerator() + is_main_process = acc.is_main_process if acc else True + try: + if is_main_process: + logging.info("Loading model from: %s", pretrained_name_or_path) + try: + from transformers.utils.hub import cached_file + + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + ) + assert resolved_file is not None, "cached_file returned None" + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + if is_main_process: + logging.info("Loaded state dict from model.safetensors") + except Exception as e: + if is_main_process: + logging.warning("Could not load state dict from remote files: %s", e) + logging.info("Returning model without loading pretrained weights") + return model + + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model.") and "normalize" not in key: + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10 and is_main_process: + logging.debug("Remapped: %s -> %s", key, new_key) + else: + remapped_state_dict[key] = value + + if remap_count > 0 and is_main_process: + logging.info("Remapped %d state dict keys", remap_count) + + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False) + + if missing_keys and is_main_process: + logging.warning("Missing keys when loading state dict: %d keys", len(missing_keys)) + for key in missing_keys[:20]: + logging.warning(" - %s", key) + if len(missing_keys) > 20: + logging.warning(" ... and %d more", len(missing_keys) - 20) + + if unexpected_keys and is_main_process: + logging.warning("Unexpected keys when loading state dict: %d keys", len(unexpected_keys)) + for key in unexpected_keys[:20]: + logging.warning(" - %s", key) + if len(unexpected_keys) > 20: + logging.warning(" ... and %d more", len(unexpected_keys) - 20) + + if not missing_keys and not unexpected_keys and is_main_process: + logging.info("All keys loaded successfully!") + + except Exception as e: + if is_main_process: + logging.warning("Could not remap state dict keys: %s", e) + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict: dict[str, Tensor], model_config: PreTrainedConfig + ) -> dict[str, Tensor]: + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + if re.match( + r"gemma3_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + expert_uses_adarms = getattr( + self.model.gemma3_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"gemma3_with_expert\.gemma_expert\.model\.norm\.weight", key): + expert_uses_adarms = getattr( + self.model.gemma3_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + if "patch_embedding" in key: + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + raise NotImplementedError("Currently not implemented for PI07 low-level planner") + + def _build_history_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Buffer the current observation and construct a temporal batch. + + Appends the single-frame observation from ``batch`` to internal deque + buffers, then assembles a batch with ``n_obs_history`` evenly-spaced + frames (interval = ``history_interval``). Early in an episode, missing + history slots are zero-padded. + + Expected batch keys: + - ``"state"``: (B, D) current proprioceptive state. + - image keys matching ``config.image_features``: (B, C, H, W) camera frames. + - ``"prompt"``: list[str] language instructions (passed through unchanged). + - Any other metadata keys are forwarded unchanged. + + Returns a new dict with ``"state"`` expanded to (B, T, D) and image keys + expanded to (B, T, C, H, W), where T = ``n_obs_history``. + """ + assert self.config.n_obs_history is not None + n_hist: int = self.config.n_obs_history + interval = self.config.history_interval or 1 + buf_maxlen = self.config.obs_buffer_size + + # initialise buffers on first call after reset() + if self._state_buffer is None: + self._state_buffer = deque(maxlen=buf_maxlen) + self._obs_buffers = {} + + img_keys = [key for key in self.config.image_features if key in batch] + for key in img_keys: + if key not in self._obs_buffers: + self._obs_buffers[key] = deque(maxlen=buf_maxlen) + + # append current observation + self._state_buffer.append(batch["state"]) # (B, D) + for key in img_keys: + self._obs_buffers[key].append(batch[key]) # (B, C, H, W) + + # sample n_hist frames at the configured interval + buf_len = len(self._state_buffer) + missing = buf_maxlen - buf_len # how many slots are still empty + + # Pass through all non-image, non-state keys (e.g. "prompt" and other metadata). + temporal_batch = {key: v for key, v in batch.items() if key not in img_keys and key != "state"} + + # Build state tensor (B, T, D) + state_frames = [] + for i in range(n_hist): + idx = i * interval - missing # index into current buffer + if idx < 0: + state_frames.append(torch.zeros_like(self._state_buffer[0])) + else: + state_frames.append(self._state_buffer[idx]) + temporal_batch["state"] = torch.stack(state_frames, dim=1) # (B, T, D) + + # Build camera tensors (B, T, C, H, W) + for key in img_keys: + cam_frames = [] + for i in range(n_hist): + idx = i * interval - missing + if idx < 0: + cam_frames.append(torch.zeros_like(self._obs_buffers[key][0])) + else: + cam_frames.append(self._obs_buffers[key][idx]) + temporal_batch[key] = torch.stack(cam_frames, dim=1) # (B, T, C, H, W) + + return temporal_batch + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action from the queue, regenerating if needed. + + Builds temporal observation history when configured, runs flow-matching + inference to fill the action queue, and pops the next action. + + Args: + batch: Environment observation dict with ``"state"``, image keys, + ``"prompt"``, ``"response"``, and optional ``"metadata"``. + noise: Optional pre-sampled noise for deterministic evaluation. + + Returns: + A single action tensor of shape ``(B, action_dim)``. + """ + self.eval() + + # Build temporal observation history if configured. + if self.config.n_obs_history is not None and self.config.n_obs_history > 1: + batch = self._build_history_batch(batch) + + if len(self._action_queue) == 0 or len(self._action_queue) <= self.config.max_delay: + action_prefix = None + delay = 0 + if len(self._action_queue) > 0: + prefix_actions = list(self._action_queue) + delay = min(len(prefix_actions), self.config.max_delay) + assert delay == self.config.max_delay, f"Delay must be equal to {self.config.max_delay}" + prefix_actions = prefix_actions[-delay:] + action_prefix = torch.stack(prefix_actions, dim=1) + delay = torch.tensor(delay, dtype=torch.long, device=batch["state"].device) + actions = self.sample_actions(batch, noise=noise, action_prefix=action_prefix, delay=delay) + actions = rearrange(actions, "b c d -> c b d") + self._action_queue.extend(actions[delay:]) + assert len(self._action_queue) == self.config.n_action_steps, ( + f"Action queue must have {self.config.n_action_steps} actions" + ) + + action = self._action_queue.popleft() + return action + + @torch.no_grad() + def sample_actions( + self, + batch: dict[str, Tensor], + action_prefix: Tensor | None = None, + delay: Tensor | None = None, + noise: Tensor | None = None, + ) -> Tensor: + """Sample a full action chunk via flow-matching inference. + + Normalizes inputs, prepares all modalities (video, language, response, + state, subgoal images, metadata), and delegates to the inner model's + ``sample_actions`` for iterative denoising. + + Args: + batch: Environment observation dict. + action_prefix: Optional previously-committed actions for real-time + inference with delay. + delay: Number of prefix action steps already committed. + noise: Optional pre-sampled noise for deterministic evaluation. + + Returns: + Action chunk tensor of shape ``(B, n_action_steps, action_dim)``. + """ + if not (torch.compiler.is_compiling() or torch.onnx.is_in_onnx_export()): + assert delay is None or 0 <= delay.item() <= self.config.max_delay, ( + f"Delay must be None or between 0 and {self.config.max_delay}" + ) + + batch = self.normalize_inputs(batch) + + videos, vid_masks = self.prepare_videos(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + response_tokens, response_masks = self.prepare_response(batch) + state = self.prepare_state(batch) + + subgoal_images, subgoal_img_masks = self.prepare_subgoal_images(batch) + + metadata_tokens, metadata_masks = self.prepare_metadata(batch) + + # Shape checks: videos must be 5D (B, T, C, H, W), state must be 3D (B, T, D). + for vid in videos: + assert vid.ndim == 5, f"Expected 5D video tensor (B, T, C, H, W), got {vid.shape}" + assert state.ndim == 3, f"Expected 3D state tensor (B, T, D), got {state.shape}" + + if self.config.n_obs_history is not None and self.config.n_obs_history > 1: + t_dim = state.shape[1] + if t_dim == 1: + logging.warning( + "Temporal dimension T=1: no historical frames included. " + "This should only happen at most %d time(s) at the start of an episode.", + self.config.history_interval or 1, + ) + + if delay is None: + delay = torch.tensor(0, dtype=torch.long, device=lang_tokens.device) + + if action_prefix is None: + bsize = lang_tokens.shape[0] + actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) + action_prefix = torch.zeros(actions_shape, dtype=lang_tokens.dtype, device=lang_tokens.device) + else: + normalized = self.normalize_targets({"actions": action_prefix})["actions"] + action_prefix = F.pad( + normalized, + (0, 0, 0, self.config.chunk_size - normalized.shape[1]), + ) + + actions = self.model.sample_actions( + videos, + vid_masks, + lang_tokens, + lang_masks, + state, + action_prefix, + delay, + noise=noise, + subgoal_images=subgoal_images, + subgoal_img_masks=subgoal_img_masks, + metadata_tokens=metadata_tokens, + metadata_masks=metadata_masks, + response_tokens=response_tokens, + response_masks=response_masks, + ) + + action_feature = self.config.action_feature + assert action_feature is not None, "action_feature must be set in output_features" + original_action_dim = action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"actions": actions})["actions"] + + return actions + + def forward( + self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None + ) -> dict[str, Tensor]: + """Training forward pass: normalize, prepare modalities, and compute losses. + + Returns a dict with ``"MSE"`` (flow-matching velocity loss) and + ``"CE"`` (discrete action cross-entropy loss). + + Args: + batch: Training batch dict with observations, actions, and prompts. + noise: Optional pre-sampled noise tensor. + time: Optional pre-sampled flow-matching timesteps. + + Returns: + Dict with ``"MSE"`` and ``"CE"`` scalar loss tensors. + """ + batch = self.normalize_inputs(batch) + batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch))["actions"] + batch = self.normalize_targets(batch) + + obs_history_is_pad = batch.get("obs_history_is_pad") + if obs_history_is_pad is None: + logging.warning( + "obs_history_is_pad is missing from the training batch. " + "Padded observation-history timesteps will not be masked." + ) + videos, vid_masks = self.prepare_videos(batch, obs_history_is_pad=obs_history_is_pad) + lang_tokens, lang_masks = self.prepare_language(batch) + response_tokens, response_masks = self.prepare_response(batch) + state = self.prepare_state(batch) + subgoal_images, subgoal_masks = self.prepare_subgoal_images(batch) + metadata_tokens, metadata_masks = self.prepare_metadata(batch) + discrete_actions, discrete_action_masks = self.prepare_discrete_actions(batch) + actions = batch["actions"] + actions_is_pad = batch.get("action_is_pad") + + losses = self.model.forward( + videos, + vid_masks, + lang_tokens, + lang_masks, + state, + actions, + actions_is_pad, + noise, + time, + discrete_actions, + discrete_action_masks, + obs_history_is_pad=obs_history_is_pad, + subgoal_images=subgoal_images, + subgoal_img_masks=subgoal_masks, + metadata_tokens=metadata_tokens, + metadata_masks=metadata_masks, + response_tokens=response_tokens, + response_masks=response_masks, + ) + + mse_loss = losses["MSE"] + ce_loss = losses["CE"] + + return {"MSE": mse_loss, "CE": ce_loss} + + def prepare_state(self, batch: dict[str, Tensor]) -> Tensor: + """Prepares the temporal state tensor, padding or truncating to max_state_dim. + + Args: + batch: Batch of data containing "state" tensor of shape (B, T, D). + + Returns: + A tensor of shape (B, T, max_state_dim). + """ + state = batch["state"] # (B, T, D) or (B, D) during inference + if state.ndim == 2: + if self.config.n_obs_history is not None and self.config.n_obs_history > 1: + raise ValueError( + f"Expected 3D state tensor (B, T, D) when n_obs_history > 1, " + f"got shape {state.shape}. Ensure select_action() is being used." + ) + state = state.unsqueeze(1) # (B, D) -> (B, 1, D) + state_dim = state.shape[-1] + if state_dim > self.config.max_state_dim: + raise ValueError( + f"State dimension ({state_dim}) exceeds max_state_dim ({self.config.max_state_dim}). " + f"Increase max_state_dim in the config to accommodate the state vector." + ) + if state_dim < self.config.max_state_dim: + state = F.pad(state, (0, self.config.max_state_dim - state_dim)) + return state + + def prepare_discrete_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenize continuous actions into discrete FAST tokens and pad to fixed length. + + Args: + batch: Batch dict containing ``"discrete_actions"`` (min-max normalized). + + Returns: + A tuple ``(token_ids, token_masks)`` with shapes + ``(B, discrete_action_max_length)``. + """ + device = batch["discrete_actions"].device + discrete_actions = batch["discrete_actions"].to(device="cpu", dtype=torch.float32) + tokens = self.discrete_action_processor.__call__(discrete_actions) + discrete_action_tokens, discrete_action_masks = pad_discrete_tokens( + tokens, self.config.discrete_action_max_length + ) + return torch.from_numpy(discrete_action_tokens).to(device=device, dtype=torch.long), torch.from_numpy( + discrete_action_masks + ).to(device=device, dtype=torch.bool) + + def prepare_videos( + self, batch: dict[str, Tensor], obs_history_is_pad: Tensor | None = None + ) -> tuple[list[Tensor], list[Tensor]]: + """Apply preprocessing to the video inputs. + + Each camera key now contains a video tensor of shape (B, T, C, H, W). + Frames are resized to 224x224 with padding. ImageNet normalization is + assumed to be already applied by the dataset loader. + + Args: + batch: Batch of data containing video tensors. + obs_history_is_pad: Optional bool tensor (B, T) indicating which + temporal frames are padded. Padded frames are zeroed out before + encoding so SpaceTimeSiglip does not process clamped/repeated content. + + Returns: + A tuple of (videos, vid_masks) lists. + """ + videos: list[Tensor] = [] + vid_masks: list[Tensor] = [] + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + + last_vid: Tensor | None = None + last_mask: Tensor | None = None + + for key in present_img_keys: + vid = batch[key] # (B, T, C, H, W) or (B, C, H, W) during inference + if vid.ndim == 4: + if self.config.n_obs_history is not None and self.config.n_obs_history > 1: + raise ValueError( + f"Expected 5D video tensor (B, T, C, H, W) when n_obs_history > 1, " + f"got shape {vid.shape}. Ensure select_action() is being used." + ) + vid = vid.unsqueeze(1) # (B, C, H, W) -> (B, 1, C, H, W) + + if obs_history_is_pad is not None: + frame_mask = (~obs_history_is_pad)[:, :, None, None, None] # (B, T, 1, 1, 1) + vid = vid * frame_mask + + if self.config.resize_imgs_with_padding is not None: + b, t_frames = vid.shape[:2] + flat = rearrange(vid, "B T C H W -> (B T) C H W") + flat = resize_with_pad(flat, *self.config.resize_imgs_with_padding, pad_value=0) + vid = rearrange(flat, "(B T) C H W -> B T C H W", B=b, T=t_frames) + + bsize = vid.shape[0] + device = vid.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + videos.append(vid) + vid_masks.append(mask) + last_vid = vid + last_mask = mask + + n_empty = min(len(missing_img_keys), self.config.empty_cameras) + if n_empty > 0: + assert last_vid is not None and last_mask is not None + for _ in range(n_empty): + videos.append(torch.zeros_like(last_vid)) + vid_masks.append(torch.zeros_like(last_mask)) + + return videos, vid_masks + + def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenize the language prompt into PaliGemma token IDs. + + Wraps each prompt string as ``"Task: {task}"`` and pads/truncates + to ``prompt_max_length``. + + Args: + batch: Batch dict containing ``"prompt"`` (list of strings). + + Returns: + A tuple ``(lang_tokens, lang_masks)`` with shapes + ``(B, prompt_max_length)``. + """ + device = batch["state"].device + tasks = batch["prompt"] + + prompt = [f"Task: {task}, " for task in tasks] + + tokenized_prompt = self.language_tokenizer.__call__( + prompt, + padding="max_length", + padding_side="right", + max_length=self.config.prompt_max_length, + return_tensors="pt", + truncation=True, + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenize the high-level planner subtask response into PaliGemma token IDs. + + Wraps each response string as ``"Subtask: {response}, "`` and + pads/truncates to ``response_max_length``. Uses ``add_special_tokens=False`` so + no BOS (or other special tokens) are inserted; the prefix already encodes a + ``"Subtask: "`` span in :meth:`PI07LowLevelPlannerFlowMatching.embed_prefix`. + + Args: + batch: Batch dict containing ``"response"`` (list of strings). + + Returns: + A tuple ``(response_tokens, response_masks)`` with shapes + ``(B, response_max_length)``.""" + device = batch["state"].device + responses = batch["response"] if "response" in batch else [""] * batch["state"].shape[0] + response_prompt = [f"Subtask: {response}, " if response != "" else "" for response in responses] + tokenized_response = self.language_tokenizer.__call__( + response_prompt, + padding="max_length", + padding_side="right", + max_length=self.config.response_max_length, + return_tensors="pt", + truncation=True, + add_special_tokens=False, + ) + response_tokens = tokenized_response["input_ids"].to(device=device) + response_masks = tokenized_response["attention_mask"].to(device=device, dtype=torch.bool) + return response_tokens, response_masks + + def prepare_subgoal_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess subgoal images for SigLIP embedding. + + Derives subgoal keys from ``config.image_features``: for each + ``camera{k}`` the corresponding batch key is ``subgoal{k}`` (the + naming convention used by ``LeRobotDataset._emit_optional_keys``). + If no ``subgoal{k}`` keys are present, a warning is logged and + zero-filled images with masks all ``False`` are returned (no subgoal + signal), matching a fully padded subgoal batch. + + Resizes each subgoal image to 224×224 with aspect-ratio padding and + converts the pixel range from ``[0, 1]`` to ``[-1, 1]`` as expected + by SigLIP. Missing cameras are filled with ``-1`` padding tensors + up to ``empty_cameras``. + + When ``batch["subgoal_is_pad"]`` is ``True`` for a sample, all + subgoal slots for that sample are zeroed out and their masks set to + ``False`` so that downstream attention ignores them. + + Args: + batch: Batch dict containing subgoal image tensors keyed as + ``subgoal{k}`` for each ``camera{k}`` in + ``config.image_features``. If all are absent, see warning + + fallback above. + + Returns: + A tuple ``(subgoal_images, subgoal_img_masks)`` of lists. + """ + subgoal_images = [] + subgoal_img_masks = [] + + # Derive subgoal keys from image_features: camera{k} -> subgoal{k} + subgoal_keys = [key.replace("camera", "subgoal") for key in self.config.image_features] + present_subgoal_img_keys = [key for key in subgoal_keys if key in batch] + missing_subgoal_img_keys = [key for key in subgoal_keys if key not in batch] + + if len(present_subgoal_img_keys) == 0: + logging.getLogger(__name__).warning( + "All subgoal image features are missing from the batch; using zero tensors with " + "cleared masks (no subgoal conditioning). " + f"(batch keys: {list(batch.keys())}) (expected: {subgoal_keys})" + ) + + # Per-sample flag: True means the subgoal was dropped or absent. + subgoal_is_pad = batch.get( + "subgoal_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool) + ) # (B,) bool or None + + for key in present_subgoal_img_keys: + subgoal_img = batch[key] + + if self.config.resize_imgs_with_padding is not None: + subgoal_img = resize_with_pad(subgoal_img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expected by siglip + subgoal_img = subgoal_img * 2.0 - 1.0 + + bsize = subgoal_img.shape[0] + device = subgoal_img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + + if subgoal_is_pad is not None: + is_pad = subgoal_is_pad.to(device=device, dtype=torch.bool) + mask = mask & ~is_pad + subgoal_img = subgoal_img * (~is_pad)[:, None, None, None] + + subgoal_images.append(subgoal_img) + subgoal_img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_subgoal_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + subgoal_img = torch.ones_like(subgoal_img) * -1 + mask = torch.zeros_like(mask) + subgoal_images.append(subgoal_img) + subgoal_img_masks.append(mask) + + return subgoal_images, subgoal_img_masks + + def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: + """Tokenize episode metadata into PaliGemma token IDs. + + Wraps each metadata string as ``"Metadata: {meta}"`` and + pads/truncates to ``metadata_max_length``. + + Args: + batch: Batch dict containing ``"speed"``, ``"quality"``, + ``"mistake"`` and their corresponding ``_is_pad`` flags. + + Returns: + A tuple ``(metadata_tokens, metadata_masks)`` with shapes + ``(B, metadata_max_length)``. + """ + + metadata = [] + # safety conditioning if metadata are not passed by sample actions + for speed, quality, mistake, speed_is_pad, quality_is_pad, mistake_is_pad in zip( + batch.get("speed", torch.zeros(batch["state"].shape[0], dtype=torch.float32)), + batch.get("quality", torch.zeros(batch["state"].shape[0], dtype=torch.float32)), + batch.get("mistake", torch.zeros(batch["state"].shape[0], dtype=torch.float32)), + batch.get("speed_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool)), + batch.get("quality_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool)), + batch.get("mistake_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool)), + strict=True, + ): + segments = [] + if not speed_is_pad: + segments.append(f"Speed: {str(speed.item())}, ") + + if not quality_is_pad: + segments.append(f"Quality: {str(quality.item())}, ") + + if not mistake_is_pad: + segments.append(f"Mistake: {str(mistake.item())}, ") + + metadata.append(f"Metadata: {' '.join(segments)}" if segments else "") + + device = batch["state"].device + tokenized_metadata = self.language_tokenizer.__call__( + metadata, + padding="max_length", + padding_side="right", + max_length=self.config.metadata_max_length, + return_tensors="pt", + truncation=True, + add_special_tokens=False, + ) + metadata_tokens = tokenized_metadata["input_ids"].to(device=device) + metadata_masks = tokenized_metadata["attention_mask"].to(device=device, dtype=torch.bool) + + return metadata_tokens, metadata_masks + + +# Flow-matching model +class PI07LowLevelPlannerFlowMatching(nn.Module): + """π07 Low-Level Planner: flow-matching action generation with Knowledge Insulation. + + Architecture overview:: + + ┌───────────────────────────────────────────────────┐ + │ actions │ + │ ▲ │ + │ ┌┴─────┐ │ + │ kv cache │Gemma │ (detached) │ + │ ┌──────────►│Expert│ │ + │ │ │ │ │ + │ ┌┴─────────┐ │x 10 │ │ + │ │ │ └▲─────┘ │ + │ │PaliGemma │ │ │ + │ │ (VLM) │ noise │ + │ └▲──▲──▲──▲──▲──▲──▲──▲ │ + │ │ │ │ │ │ └── ``Action:`` + discrete (training) │ + │ │ │ │ │ └───── ``";\n "`` + metadata │ + │ │ │ │ └──────── subgoal images, ``Subgoal:`` │ + │ │ │ └────────── response, commas, state, ``State:`` │ + │ │ └──────────── language │ + │ └────────────── video (SpaceTimeSiglip) │ + └───────────────────────────────────────────────────┘ + + The VLM processes the same prefix layout as :meth:`embed_prefix` (videos, language, + ``State:``, state, commas, response, ``Subgoal:``, subgoal images, ``";\n "``, + optional ``Action:``/discrete, metadata). The action expert receives + the prefix KV-cache (detached for Knowledge Insulation) together with + noisy continuous actions and flow-matching timestep embeddings to + predict the velocity field. + """ + + def __init__(self, config: PI07LowLevelPlannerConfig, discrete_action_vocab_size: int | None = None): + super().__init__() + self.config = config + + self.config.vlm_config.discrete_action_vocab_size = discrete_action_vocab_size + self.gemma3_with_expert = Gemma3WithExpertModel(self.config.vlm_config) + + expert_hidden = self.gemma3_with_expert.config.gemma_expert_config.hidden_size + if config.proj_width != expert_hidden: + raise ValueError( + f"proj_width ({config.proj_width}) must equal the action expert's " + f"hidden_size ({expert_hidden}) so suffix embeddings are compatible " + f"with the expert transformer layers." + ) + + vlm_hidden_size = self.gemma3_with_expert.config.gemma3_config.text_config.hidden_size + + self.video_encoder = SpaceTimeSiglipVideoEncoder( + vision_tower=self.gemma3_with_expert._vision_tower(), + multi_modal_projector=self.gemma3_with_expert._multi_modal_projector(), + num_frames=config.n_obs_steps, + spacetime_layer_stride=config.spacetime_layer_stride, + gradient_checkpointing=config.gradient_checkpointing, + ) + + # Per-timestep state projection: each of the T state vectors becomes one token + self.state_proj = nn.Linear(self.config.max_state_dim, vlm_hidden_size) + + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) + self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) + + self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width) + self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) + + self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + + def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor: + return torch.normal(mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device) + + def sample_time(self, bsize: int, device: torch.device | str) -> Tensor: + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) + time = time_beta * 0.999 + 0.001 + return time + + def embed_video(self, video: Tensor) -> Tensor: + """Encode a video through SpaceTimeSiglip + Perceiver reducer + projection. + + Args: + video: (B, T, C, H, W) + + Returns: + (B, num_video_tokens, vlm_hidden_size) + """ + return self.video_encoder(video) + + def embed_prefix( + self, + videos: list[Tensor], + vid_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + state: Tensor, + response_tokens: Tensor, + response_masks: Tensor, + metadata_tokens: Tensor, + metadata_masks: Tensor, + discrete_actions: Tensor | None = None, + discrete_action_masks: Tensor | None = None, + obs_history_is_pad: Tensor | None = None, + subgoal_images: list[Tensor] | None = None, + subgoal_img_masks: list[Tensor] | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Embed all prefix modalities and build the 1-D attention pattern. + + Concatenation order: + + ``[videos | language | State: | state(T) | ", " | response | + Subgoal: | subgoal_images… | ", " | metadata | ";\\n " | + ("Action:" + discrete_actions only when training)]`` + + Attention pattern (via ``att_masks`` cumsums): + - Video + language: bidirectional (``0``). + - ``State:``, projected state timestep tokens, comma after state: bidirectional (``0``). + - Response spans: prefix-LM style block opening (``[1, 0, …]`` inside the segment). + - ``Subgoal:``: new bidirectional block (``[1, 0, …]``). + - Subgoal image patches per camera: bidirectional blocks (``[1, 0, …]``). + - Commas/metadata / ``";\\n "``: mostly continued prefix blocks (see code). + - Discrete actions (training): causal ``1`` per timestep after ``Action:``. + + Args: + videos: List of video tensors, each ``(B, T, C, H, W)``. + vid_masks: List of boolean masks, each ``(B,)``. + lang_tokens: Language token IDs ``(B, prompt_max_length)``. + lang_masks: Boolean mask for language tokens. + state: Temporal state ``(B, T, max_state_dim)``. + response_tokens: Subtask response token IDs + ``(B, response_max_length)``. + response_masks: Boolean mask for response tokens. + metadata_tokens: Metadata token IDs ``(B, metadata_max_length)``. + metadata_masks: Boolean mask for metadata tokens. + discrete_actions: Optional FAST token IDs + ``(B, discrete_action_max_length)``. Provided during training. + discrete_action_masks: Boolean mask for discrete actions. + obs_history_is_pad: Optional ``(B, T)`` bool tensor; ``True`` for + padded timesteps. Used to mask state tokens during training. + subgoal_images: List of subgoal image tensors ``(B, C, H, W)``. + subgoal_img_masks: List of boolean masks ``(B,)``. + + Returns: + A tuple ``(embs, pad_masks, att_masks)`` where: + - embs: ``(B, total_seq_len, D)`` + - pad_masks: ``(B, total_seq_len)`` + - att_masks: ``(B, total_seq_len)`` for + :func:`make_att_2d_masks`. + """ + embs = [] + pad_masks = [] + att_masks = [] + bsize = lang_tokens.shape[0] + + for vid, vid_mask in zip(videos, vid_masks, strict=True): + vid_emb = self.embed_video(vid) # (B, num_video_tokens, vlm_hidden) + vid_emb = vid_emb.to(dtype=_preferred_dtype()) + + num_vid_embs = vid_emb.shape[1] + vid_mask_expanded = vid_mask[:, None].expand(bsize, num_vid_embs) + + embs.append(vid_emb) + pad_masks.append(vid_mask_expanded) + + att_masks += [0] * num_vid_embs + + lang_emb = self.gemma3_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + state_start_indicator_ids = self.language_tokenizer.encode("State: ", add_special_tokens=False) + state_start_tokens = torch.tensor( + [state_start_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + state_start_emb = self.gemma3_with_expert.embed_language_tokens(state_start_tokens) + state_start_dim = state_start_emb.shape[-1] + state_start_emb = state_start_emb * math.sqrt(state_start_dim) + + num_state_start_embs = state_start_emb.shape[1] + state_start_mask = torch.ones( + bsize, num_state_start_embs, dtype=torch.bool, device=lang_tokens.device + ) + + embs.append(state_start_emb) + pad_masks.append(state_start_mask) + att_masks += [0] * num_state_start_embs + + # Project each timestep's state into a separate VLM token + # state: (B, T, max_state_dim) -> state_emb: (B, T, vlm_hidden_size) + state_emb = self.state_proj(state.to(dtype=_preferred_dtype())) + num_state_tokens = state_emb.shape[1] # T + if obs_history_is_pad is not None: + state_mask = ~obs_history_is_pad # True = real, False = padded + else: + state_mask = torch.ones(bsize, num_state_tokens, dtype=torch.bool, device=state.device) + + embs.append(state_emb) + pad_masks.append(state_mask) + att_masks += [0] * num_state_tokens # full attention with video and language + + state_end_indicator_ids = self.language_tokenizer.encode(", ", add_special_tokens=False) + state_end_tokens = torch.tensor( + [state_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + state_end_emb = self.gemma3_with_expert.embed_language_tokens(state_end_tokens) + state_end_dim = state_end_emb.shape[-1] + state_end_emb = state_end_emb * math.sqrt(state_end_dim) + + num_state_end_embs = state_end_emb.shape[1] + state_end_mask = torch.ones(bsize, num_state_end_embs, dtype=torch.bool, device=lang_tokens.device) + + embs.append(state_end_emb) + pad_masks.append(state_end_mask) + att_masks += [0] * num_state_end_embs + + response_emb = self.gemma3_with_expert.embed_language_tokens(response_tokens) + response_emb_dim = response_emb.shape[-1] + response_emb = response_emb * math.sqrt(response_emb_dim) + embs.append(response_emb) + pad_masks.append(response_masks) + num_response_embs = response_emb.shape[1] + att_masks += [1] + [0] * (num_response_embs - 1) + + subgoal_img_start_indicator_ids = self.language_tokenizer.encode( + "Subgoal: ", add_special_tokens=False + ) + subgoal_img_start_tokens = torch.tensor( + [subgoal_img_start_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + subgoal_img_start_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_start_tokens) + subgoal_img_start_dim = subgoal_img_start_emb.shape[-1] + subgoal_img_start_emb = subgoal_img_start_emb * math.sqrt(subgoal_img_start_dim) + + num_subgoal_img_start_embs = subgoal_img_start_emb.shape[1] + subgoal_img_start_mask = torch.ones( + bsize, num_subgoal_img_start_embs, dtype=torch.bool, device=lang_tokens.device + ) + + embs.append(subgoal_img_start_emb) + pad_masks.append(subgoal_img_start_mask) + att_masks += [1] + [0] * (num_subgoal_img_start_embs - 1) + + for ( + subgoal_img, + subgoal_img_mask, + ) in zip(subgoal_images, subgoal_img_masks, strict=True): + subgoal_img_emb = self.gemma3_with_expert.embed_image(subgoal_img) + subgoal_img_emb = subgoal_img_emb.to(dtype=_preferred_dtype()) + + # Gemma 3's projector does not apply the `/ sqrt(text_hidden_size)` + # scaling that stock PaliGemma does, so no un-normalization is + # required here (matches `embed_image` in `gemma3_with_expert.py`). + + bsize, num_subgoal_img_embs = subgoal_img_emb.shape[:2] + subgoal_img_mask = subgoal_img_mask[:, None].expand(bsize, num_subgoal_img_embs) + + embs.append(subgoal_img_emb) + pad_masks.append(subgoal_img_mask) + + # Create attention masks so that image tokens attend to each other + att_masks += [1] + [0] * (num_subgoal_img_embs - 1) + + subgoal_img_end_indicator_ids = self.language_tokenizer.encode(", ", add_special_tokens=False) + subgoal_img_end_tokens = torch.tensor( + [subgoal_img_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + subgoal_img_end_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_end_tokens) + subgoal_img_end_dim = subgoal_img_end_emb.shape[-1] + subgoal_img_end_emb = subgoal_img_end_emb * math.sqrt(subgoal_img_end_dim) + + num_subgoal_img_end_embs = subgoal_img_end_emb.shape[1] + subgoal_img_end_mask = torch.ones( + bsize, num_subgoal_img_end_embs, dtype=torch.bool, device=lang_tokens.device + ) + + embs.append(subgoal_img_end_emb) + pad_masks.append(subgoal_img_end_mask) + att_masks += [0] * num_subgoal_img_end_embs + + metadata_emb = self.gemma3_with_expert.embed_language_tokens(metadata_tokens) + metadata_emb_dim = metadata_emb.shape[-1] + metadata_emb = metadata_emb * math.sqrt(metadata_emb_dim) + embs.append(metadata_emb) + pad_masks.append(metadata_masks) + att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) + + prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) + prefix_end_tokens = torch.tensor( + [prefix_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) + prefix_end_dim = prefix_end_emb.shape[-1] + prefix_end_emb = prefix_end_emb * math.sqrt(prefix_end_dim) + + num_prefix_end_embs = prefix_end_emb.shape[1] + prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) + + embs.append(prefix_end_emb) + pad_masks.append(prefix_end_mask) + att_masks += [0] * num_prefix_end_embs + + if discrete_actions is not None: + discrete_action_start_indicator_ids = self.language_tokenizer.encode( + "Action: ", add_special_tokens=False + ) + discrete_action_start_tokens = torch.tensor( + [discrete_action_start_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + discrete_action_start_emb = self.gemma3_with_expert.embed_language_tokens( + discrete_action_start_tokens + ) + discrete_action_start_dim = discrete_action_start_emb.shape[-1] + discrete_action_start_emb = discrete_action_start_emb * math.sqrt(discrete_action_start_dim) + + num_discrete_action_start_embs = discrete_action_start_emb.shape[1] + discrete_action_start_mask = torch.ones( + bsize, num_discrete_action_start_embs, dtype=torch.bool, device=lang_tokens.device + ) + + embs.append(discrete_action_start_emb) + pad_masks.append(discrete_action_start_mask) + att_masks += [1] + [0] * (num_discrete_action_start_embs - 1) + + discrete_action_emb = self.gemma3_with_expert.embed_discrete_actions(discrete_actions) + embs.append(discrete_action_emb.to(dtype=_preferred_dtype())) + pad_masks.append(discrete_action_masks) + att_masks += [1] * discrete_action_emb.shape[1] + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions: Tensor, timestep: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Embed noisy actions and flow-matching timestep for the action expert. + + Projects actions through ``action_in_proj`` and computes an + adaRMS-style conditioning vector from sinusoidal timestep embeddings + via a two-layer MLP. The suffix forms a single bidirectional block + (``att_masks = [1, 0, …, 0]``). + + Args: + noisy_actions: ``(B, chunk_size, max_action_dim)`` noisy action + tensor at the current flow-matching timestep. + timestep: ``(B, chunk_size)`` per-step timestep values. + + Returns: + A tuple ``(embs, pad_masks, att_masks, adarms_cond)`` where + ``adarms_cond`` is the conditioning vector for adaptive RMSNorm + in the Gemma expert layers. + """ + embs = [] + pad_masks = [] + att_masks = [] + + bsize = noisy_actions.shape[0] + dtype = _preferred_dtype() + device = noisy_actions.device + + time_emb = create_sinusoidal_pos_embedding( + timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device + ) + + noisy_actions = noisy_actions.to(dtype=dtype) + action_emb = self.action_in_proj(noisy_actions) + + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = time_emb.to(dtype=dtype) + adarms_cond = time_mlp_func(time_emb) + + embs.append(action_emb) + + bsize, action_dim = action_emb.shape[:2] + action_mask = torch.ones(bsize, action_dim, dtype=torch.bool, device=device) + pad_masks.append(action_mask) + + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward( + self, + videos: list[Tensor], + vid_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + state: Tensor, + actions: Tensor, + actions_is_pad: Tensor | None = None, + noise: Tensor | None = None, + time: Tensor | None = None, + discrete_actions: Tensor | None = None, + discrete_action_masks: Tensor | None = None, + obs_history_is_pad: Tensor | None = None, + subgoal_images: list[Tensor] | None = None, + subgoal_img_masks: list[Tensor] | None = None, + metadata_tokens: Tensor | None = None, + metadata_masks: Tensor | None = None, + response_tokens: Tensor | None = None, + response_masks: Tensor | None = None, + ) -> dict[str, Tensor]: + """Training forward pass: embed all modalities and compute losses. + + Runs the VLM on the prefix (video, language, response, state, subgoal + images, metadata, discrete actions), then the action expert on the + noisy action suffix. Returns both the flow-matching MSE loss and the + discrete-action cross-entropy loss. + + The VLM's KV-cache is detached before being passed to the action + expert (Knowledge Insulation). + + Args: + videos: List of video tensors ``(B, T, C, H, W)``. + vid_masks: List of boolean masks ``(B,)``. + lang_tokens: Language token IDs ``(B, prompt_max_length)``. + lang_masks: Boolean mask for language tokens. + state: Temporal state ``(B, T, max_state_dim)``. + actions: Ground-truth continuous actions ``(B, chunk_size, max_action_dim)``. + actions_is_pad: Optional ``(B, chunk_size)`` bool mask for padded actions. + noise: Optional pre-sampled noise. + time: Optional pre-sampled flow-matching timesteps. + discrete_actions: Optional FAST token IDs. + discrete_action_masks: Optional mask for discrete actions. + obs_history_is_pad: Optional ``(B, T)`` mask for padded frames. + subgoal_images: Optional list of subgoal image tensors. + subgoal_img_masks: Optional list of masks. + metadata_tokens: Optional metadata token IDs. + metadata_masks: Optional mask for metadata tokens. + response_tokens: Optional subtask response token IDs. + response_masks: Optional mask for response tokens. + + Returns: + Dict with ``"MSE"`` (flow-matching loss) and ``"CE"`` + (discrete action loss) scalar tensors. + """ + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + videos, + vid_masks, + lang_tokens, + lang_masks, + state, + response_tokens, + response_masks, + metadata_tokens, + metadata_masks, + discrete_actions=discrete_actions, + discrete_action_masks=discrete_action_masks, + obs_history_is_pad=obs_history_is_pad, + subgoal_images=subgoal_images, + subgoal_img_masks=subgoal_img_masks, + ) + + vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + num_cross_att_tokens = prefix_embs.shape[1] - self.config.discrete_action_max_length + + (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( + attention_mask=vlm_2d_attention_mask, + position_ids=vlm_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + n_cross_att_tokens=num_cross_att_tokens, + use_cache=False, + fill_kv_cache=True, + ) + + batch_size = actions.shape[0] + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(batch_size, actions.device) + + delay = torch.randint(0, self.config.max_delay + 1, (batch_size,)) + prefix_mask = rearrange(torch.arange(self.config.chunk_size), "c -> 1 c") < rearrange( + delay, "b -> b 1" + ) + prefix_mask = prefix_mask.to(device=actions.device) + time = torch.where(prefix_mask, 0, rearrange(time, "b -> b 1")) + + time_expanded = rearrange(time, "b c -> b c 1") + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + + action_expert_2d_attention_mask = make_att_2d_masks( + suffix_pad_masks, + suffix_att_masks, + n_cross_att_tokens=num_cross_att_tokens, + cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens], + ) + prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[ + :, None + ] + action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + assert past_key_values is not None + kv_cache: dict = past_key_values + for layer_idx in kv_cache: + kv_cache[layer_idx]["key_states"] = kv_cache[layer_idx]["key_states"].detach() + kv_cache[layer_idx]["value_states"] = kv_cache[layer_idx]["value_states"].detach() + + (_, suffix_out), _ = self.gemma3_with_expert.forward( + attention_mask=action_expert_2d_attention_mask, + position_ids=action_expert_position_ids, + past_key_values=kv_cache, + inputs_embeds=[None, suffix_embs], + use_cache=True, + fill_kv_cache=False, + adarms_cond=[None, adarms_cond], + ) + + assert suffix_out is not None + suffix_out = suffix_out[:, -self.config.n_action_steps :] + v_t = self.action_out_proj(suffix_out) + v_t = v_t.to(dtype=torch.float32) + + mse_loss = F.mse_loss(u_t, v_t, reduction="none") + + postfix_mask = rearrange(torch.logical_not(prefix_mask), "b c -> b c 1") + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + in_episode_bound = rearrange(in_episode_bound, "b c -> b c 1") + postfix_mask = torch.logical_and(postfix_mask, in_episode_bound) + + mse_loss = mse_loss * postfix_mask + + mse_loss = mse_loss[:, :, : self.config.max_action_dim] + + postfix_mask_expanded = repeat(postfix_mask, "b c 1 -> b c d", d=mse_loss.shape[-1]) + mse_loss = mse_loss.sum() / (postfix_mask_expanded.sum() + 1e-8) + + assert discrete_actions is not None + assert discrete_action_masks is not None + assert prefix_out is not None + batch_size, seq_len = discrete_actions.shape + discrete_token_start = -self.config.discrete_action_max_length + discrete_action_slice_object = slice(discrete_token_start - 1, -1) + discrete_action_out = prefix_out[:, discrete_action_slice_object] + logits = self.gemma3_with_expert.da_head(discrete_action_out) + + logits = logits.to(dtype=torch.float32) + logits = rearrange(logits, "b s d -> (b s) d") + labels = rearrange(discrete_actions, "b s -> (b s)") + discrete_action_ce_loss = F.cross_entropy(logits, labels, reduction="none") + + discrete_action_ce_loss = rearrange(discrete_action_ce_loss, "(b s) -> b s", b=batch_size, s=seq_len) + + discrete_action_is_pad = ~discrete_action_masks + discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + + discrete_action_ce_loss = discrete_action_ce_loss.mean() + + return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + + def sample_actions( + self, + videos: list[Tensor], + vid_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + state: Tensor, + action_prefix: Tensor, + delay: Tensor, + noise: Tensor | None = None, + subgoal_images: list[Tensor] | None = None, + subgoal_img_masks: list[Tensor] | None = None, + metadata_tokens: Tensor | None = None, + metadata_masks: Tensor | None = None, + response_tokens: Tensor | None = None, + response_masks: Tensor | None = None, + ) -> Tensor: + """Inference: iteratively denoise to produce a continuous action chunk. + + Embeds the prefix (without discrete actions), caches the VLM KV + states, then runs ``num_steps`` denoising iterations through the + action expert. Prefix action steps (up to ``delay``) are held fixed + from ``action_prefix``. + + Args: + videos: List of video tensors ``(B, T, C, H, W)``. + vid_masks: List of boolean masks ``(B,)``. + lang_tokens: Language token IDs ``(B, prompt_max_length)``. + lang_masks: Boolean mask for language tokens. + state: Temporal state ``(B, T, max_state_dim)``. + action_prefix: Previously committed actions + ``(B, chunk_size, max_action_dim)`` (zero-padded beyond delay). + delay: Scalar tensor indicating how many prefix steps are fixed. + noise: Optional pre-sampled noise. + subgoal_images: Optional list of subgoal image tensors. + subgoal_img_masks: Optional list of masks. + metadata_tokens: Optional metadata token IDs. + metadata_masks: Optional mask for metadata tokens. + response_tokens: Optional subtask response token IDs. + response_masks: Optional mask for response tokens. + + Returns: + Denoised action chunk ``(B, chunk_size, max_action_dim)``. + """ + bsize = lang_tokens.shape[0] + device = lang_tokens.device + + if noise is None: + actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + videos, + vid_masks, + lang_tokens, + lang_masks, + state, + response_tokens, + response_masks, + metadata_tokens, + metadata_masks, + subgoal_images=subgoal_images, + subgoal_img_masks=subgoal_img_masks, + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + num_cross_att_tokens = prefix_embs.shape[1] + + (prefix_out, _), past_kv = self.gemma3_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + n_cross_att_tokens=num_cross_att_tokens, + use_cache=False, + fill_kv_cache=True, + ) + past_key_values: list[dict[str, Tensor]] = past_kv + + dt = -1.0 / self.config.num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + prefix_mask = rearrange(torch.arange(self.config.chunk_size, device=device), "c -> 1 c") < delay + while time >= -dt / 2: + x_t = torch.where(rearrange(prefix_mask, "b c -> b c 1"), action_prefix, x_t) + masked_time = torch.where(prefix_mask, 0, time) + v_t = self.denoise_step( + prefix_pad_masks, + past_key_values, + x_t, + masked_time, + ) + + x_t += dt * v_t + time += dt + + x_t = torch.where(rearrange(prefix_mask, "b c -> b c 1"), action_prefix, x_t) + return x_t + + def denoise_step( + self, + prefix_pad_masks: Tensor, + past_key_values: list[dict[str, Tensor]], + x_t: Tensor, + time: Tensor, + ) -> Tensor: + """Run one action-expert forward pass to predict the velocity field. + + Embeds the suffix (noisy actions + timestep), constructs the + cross-attention mask to the cached prefix, and runs the Gemma + expert to produce the predicted velocity ``v_t``. + + Args: + prefix_pad_masks: ``(B, prefix_len)`` padding mask from the + prefix pass (used for cross-attention masking). + past_key_values: Cached KV states from the VLM prefix pass. + x_t: Current noisy actions ``(B, chunk_size, max_action_dim)``. + time: Per-step timestep ``(B, chunk_size)``. + + Returns: + Predicted velocity ``(B, n_action_steps, max_action_dim)``. + """ + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + + num_cross_att_tokens = prefix_pad_masks.shape[1] + action_expert_2d_attention_mask = make_att_2d_masks( + suffix_pad_masks, + suffix_att_masks, + n_cross_att_tokens=num_cross_att_tokens, + cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens], + ) + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + outputs_embeds, _ = self.gemma3_with_expert.forward( + attention_mask=action_expert_2d_attention_mask, + position_ids=action_expert_position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=True, + fill_kv_cache=False, + adarms_cond=[None, adarms_cond], + ) + suffix_out = outputs_embeds[1] + assert suffix_out is not None + suffix_out = suffix_out[:, -self.config.n_action_steps :] + v_t = self.action_out_proj(suffix_out) + v_t = v_t.to(dtype=torch.float32) + return v_t diff --git a/src/opentau/policies/pi07/low_level_planner/video_encoder.py b/src/opentau/policies/pi07/low_level_planner/video_encoder.py new file mode 100644 index 00000000..16570a24 --- /dev/null +++ b/src/opentau/policies/pi07/low_level_planner/video_encoder.py @@ -0,0 +1,460 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SigLIP video encoder with space-time separable attention (MEM paper). + +Implements the low-level memory video encoder from Torne, Pertsch, Walke et al. +"MEM: Multi-Scale Embodied Memory for Vision Language Action Models" +(Section III-C + Appendix C): a standard SigLIP ViT extended with +space-time separable attention at every ``spacetime_layer_stride``-th layer +and a fixed sinusoidal temporal position encoding whose current-frame row is +zero. Past-timestep tokens are dropped after the encoder so the output shape +matches a single-image VLA. + +Key properties: + - Introduces no new learnable parameters on top of the pretrained SigLIP + weights (temporal attention re-uses each layer's own Q/K/V/O projections). + Any pi05/pi05_continuous_state checkpoint can be loaded directly — the + space-time layers just wrap the existing SiglipEncoderLayer weights. + - Single-frame invariance: with ``T=1`` the output is byte-identical to + ``PaliGemmaModel.get_image_features`` (see the single-frame invariance + tests). + - Convention: the current frame lives at the **last** time index + (``t = T-1``). This matches + ``src/opentau/datasets/factory.py:136`` (delta_timestamps) and + ``PI05MemPolicy._build_history_batch``. ``obs_history_is_pad[:, -1]`` is + always ``False`` by construction. + +The encoder does NOT own its own copy of the SigLIP weights. The caller +(``PI07LowLevelPlannerFlowMatching``) constructs a ``Gemma3WithExpertModel`` +— which already owns ``vision_tower`` and ``multi_modal_projector`` (under +``gemma3.model``) — and passes them in by reference. This avoids duplicating +~400M parameters in memory. +""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F # noqa: N812 +from einops import rearrange +from torch import Tensor, nn +from transformers.models.siglip.modeling_siglip import ( + SiglipAttention, + SiglipEncoderLayer, + SiglipVisionModel, +) + +# Import triggers the transformers patch (see opentau.utils.transformers_patch) +# which rewrites PaliGemmaModel.get_image_features to drop the +# `/ sqrt(hidden_size)` scaling that stock HuggingFace applies after the +# multi_modal_projector. Our forward must match that patched behavior for +# single-frame invariance to hold. +import opentau.utils.transformers_patch # noqa: F401 + + +def _build_temporal_sinusoidal_pe( + num_frames: int, + embed_dim: int, + *, + min_period: float = 4e-3, + max_period: float = 4.0, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> Tensor: + """Fixed sinusoidal temporal positional embedding, ``(T, embed_dim)``. + + Row ``T-1`` (the current frame) is all zeros; earlier rows encode the + temporal offset into the past via sin/cos on a geometric period schedule + (matching ``create_sinusoidal_pos_embedding`` in ``modeling_pi05.py``). + + The zero-current-row condition lets a ``T=1`` forward pass match an + un-modified SigLIP ViT exactly, which is required for single-frame + invariance against ``PaliGemmaModel.get_image_features``. + """ + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim ({embed_dim}) must be divisible by 2") + if num_frames < 1: + raise ValueError(f"num_frames ({num_frames}) must be >= 1") + + # time[i] = i - (T-1) in {-(T-1), ..., -1, 0}; row T-1 has time = 0. + time = torch.arange(num_frames, dtype=torch.float64, device=device) - (num_frames - 1) + fraction = torch.linspace(0.0, 1.0, embed_dim // 2, dtype=torch.float64, device=device) + period = min_period * (max_period / min_period) ** fraction + scaling = 1.0 / period * 2 * math.pi # (embed_dim/2,) + phase = time.unsqueeze(-1) * scaling.unsqueeze(0) # (T, embed_dim/2) + pe = torch.cat([torch.sin(phase), torch.cos(phase)], dim=-1) # (T, embed_dim) + # Shift so row T-1 is exactly zero (preserves relative sinusoidal structure, + # enforces boundary condition e(current) = 0 from MEM Appendix C). + pe = pe - pe[-1:] + return pe.to(dtype=dtype) + + +class _TemporalSelfAttention(nn.Module): + """Parameter-free causal temporal self-attention. + + Re-uses an existing ``SiglipAttention`` instance's + ``q_proj``/``k_proj``/``v_proj``/``out_proj`` linear layers, but applies + them over the ``T`` axis (for each fixed patch position) with a + standard lower-triangular causal mask (position ``i`` attends to + positions ``j <= i``; since ``t = T-1`` is the current frame, the current + frame attends to all past frames). + + The referenced ``SiglipAttention`` is held in a list to keep ``nn.Module`` + from re-registering its parameters under this module's path (which would + duplicate them in ``state_dict`` under both + ``base_layer.self_attn.*`` and ``_temporal_attn.attn.*``). + """ + + def __init__(self, attn: SiglipAttention): + super().__init__() + # Wrap in a list so nn.Module.__setattr__ does not treat ``attn`` + # as a child submodule; the base layer already owns these params. + self._attn_ref: list[SiglipAttention] = [attn] + + @property + def attn(self) -> SiglipAttention: + return self._attn_ref[0] + + def forward(self, hidden_states: Tensor) -> Tensor: + """hidden_states: (B*N, T, D) -> (B*N, T, D).""" + attn = self.attn + bn, t, d = hidden_states.shape + num_heads = attn.num_heads + head_dim = attn.head_dim + + q = attn.q_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2) + k = attn.k_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2) + v = attn.v_proj(hidden_states).view(bn, t, num_heads, head_dim).transpose(1, 2) + + # is_causal=True -> lower-triangular mask, each position attends to + # itself and earlier positions (our convention: t=T-1 is current). + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, dropout_p=0.0, scale=attn.scale + ) + out = out.transpose(1, 2).reshape(bn, t, d) + return attn.out_proj(out) + + +class SpaceTimeEncoderLayerWrapper(nn.Module): + """Replaces a ``SiglipEncoderLayer`` in-place; adds a temporal sublayer. + + The wrapper **adopts** the original layer's submodules by reference — + ``self_attn``, ``layer_norm1``, ``layer_norm2``, ``mlp`` — so its + ``state_dict`` keys are **identical** to a vanilla ``SiglipEncoderLayer``. + That means any pi05 / pi05_continuous_state checkpoint can load directly + into the wrapped layer without any key remapping. The only new state is + a non-persistent ``_temporal_pe`` buffer (excluded from state_dict) and + an internal ``_temporal_attn`` wrapper that holds a by-reference pointer + to ``self_attn`` (also excluded because it's kept in a list). + + The forward computes: + + h_pe = h + e(t) # broadcast over (B, N) + h = h + temporal_attn( LN1(h_pe) ) # new; causal over T + # then the standard SigLIP block: + h = h + spatial_attn( LN1(h) ) + h = h + MLP( LN2(h) ) + + At ``T=1`` the temporal sublayer is skipped entirely so that the block + degenerates to the unmodified SigLIP forward, satisfying the MEM paper's + single-frame invariance claim. (With ``T=1`` causal attention over a + single timestep is not an identity — it returns ``out_proj(v_proj( + LN1(x)))`` — so ``e(0)=0`` alone is insufficient; the block itself must + also be bypassed.) + + Reusing ``layer_norm1`` for both the temporal and spatial sublayers keeps + the paper's "no new learnable parameters" guarantee. It is an intentional + design choice: the two attentions operate on different axes and the + LayerNorm is applied to different input tensors each time. + """ + + # Match the SiglipEncoderLayer class attribute so transformers' + # gradient-checkpointing plumbing sees a familiar interface. + gradient_checkpointing: bool = False + + def __init__( + self, + base_layer: SiglipEncoderLayer, + num_frames: int, + num_tokens_per_frame: int, + ): + super().__init__() + # Adopt the base layer's submodules as our own (same attribute names). + # The state_dict therefore uses keys like + # ``encoder.layers.{i}.self_attn.q_proj.weight`` — identical to a + # vanilla SiglipEncoderLayer, so pi05 checkpoints load directly. + self.self_attn = base_layer.self_attn + self.layer_norm1 = base_layer.layer_norm1 + self.layer_norm2 = base_layer.layer_norm2 + self.mlp = base_layer.mlp + self.embed_dim = base_layer.embed_dim + + self.num_frames = num_frames + self.num_tokens_per_frame = num_tokens_per_frame + # The temporal attention re-uses self_attn's Q/K/V/O projections; it + # holds its reference in a list (see _TemporalSelfAttention) so the + # params don't show up twice in state_dict. + self._temporal_attn = _TemporalSelfAttention(self.self_attn) + + # Build the PE on the base layer's current device / dtype. The parent + # vision_tower is often moved to GPU BEFORE this wrapper is inserted + # (the normal load flow for pi05_mem does + # ``paligemma = ...from_pretrained(...).to('cuda')`` and then wraps); + # with no parent ``.to(device)`` happening after wrapping, a PE built + # on CPU would stay on CPU and trigger a cross-device RuntimeError at + # forward time. Pinning to the base layer's device sidesteps that. + ref_param = base_layer.self_attn.q_proj.weight + pe = _build_temporal_sinusoidal_pe( + num_frames, self.embed_dim, dtype=ref_param.dtype, device=ref_param.device + ) + # Non-persistent: not saved in state_dict but moves with .to(device). + self.register_buffer("_temporal_pe", pe, persistent=False) + + def _spatial_block_forward( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor], + output_attentions: bool, + ) -> tuple[Tensor, ...]: + """Inlined SiglipEncoderLayer.forward using the adopted submodules. + + Mirrors + ``transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward`` + exactly — any upstream change to that forward would need to be + reflected here. + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs: tuple[Tensor, ...] = (hidden_states,) + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + def forward( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> tuple[Tensor, ...]: + """hidden_states: (B*T, N, D) -> tuple starting with (B*T, N, D). + + Signature matches ``SiglipEncoderLayer.forward`` so ``SiglipEncoder`` + can dispatch unchanged. + """ + t = self.num_frames + bt, n, d = hidden_states.shape + if n != self.num_tokens_per_frame: + raise ValueError( + f"hidden_states.shape[1] ({n}) != num_tokens_per_frame ({self.num_tokens_per_frame})." + ) + + # Short-circuit at T=1: temporal self-attention over a single timestep + # collapses to ``out_proj(v_proj(LN1(x)))``, which is NOT an identity + # and would break single-frame invariance (the MEM paper's guarantee + # that a T=1 pass matches the unmodified SigLIP ViT). e(t=0)=0 alone + # is insufficient; the block must also be skipped. + # + # Also short-circuit when ``bt`` is not a multiple of ``t``: the same + # wrapped tower is shared with ``Gemma3WithExpertModel.embed_image``, + # which passes single images as ``(B, N, D)`` with ``B < t`` (e.g. + # subgoal frames in ``embed_prefix``). Those calls are spatial-only; + # there is no valid ``(b, t, ...)`` grouping for temporal attention. + if t == 1 or bt % t != 0: + return self._spatial_block_forward(hidden_states, attention_mask, output_attentions) + + b = bt // t + + # Temporal sublayer. + x = rearrange(hidden_states, "(b t) n d -> b t n d", b=b, t=t) + # Cast PE to match tensor device/dtype each call. Both are no-ops if + # already aligned (the common case — the buffer is constructed on + # the base layer's device). The cast only allocates when something + # external has moved the inputs onto a different device without + # propagating ``.to()`` through to this wrapper. + pe = self._temporal_pe.to(device=x.device, dtype=x.dtype).view(1, t, 1, d) + x_pe = x + pe + + t_in = rearrange(x_pe, "b t n d -> (b n) t d") + t_norm = self.layer_norm1(t_in) + t_out = self._temporal_attn(t_norm) + # Residual on the pre-PE hidden (not on x_pe): PE is a transient + # positional signal, not a feature perturbation to carry forward. + t_res = rearrange(x, "b t n d -> (b n) t d") + t_out + h_after_t = rearrange(t_res, "(b n) t d -> (b t) n d", n=n) + + # Spatial + MLP sublayers. + return self._spatial_block_forward(h_after_t, attention_mask, output_attentions) + + +class SpaceTimeSiglipVideoEncoder(nn.Module): + """SigLIP-based video encoder with space-time separable attention. + + Takes video tensors of shape ``(B, T, 3, H, W)`` in the ``[0, 1]`` range + and produces ``(B, num_video_tokens, vlm_hidden_size)``. Rescales pixels + to ``[-1, 1]`` internally (SigLIP's expected range). + + Past-timestep tokens are dropped after the encoder; only the current + frame's ``num_video_tokens`` tokens are returned, so the output shape is + identical to a single-frame VLA's vision-token prefix. + + The ``multi_modal_projector`` is applied to match the output space of + ``PaliGemmaModel.get_image_features``. We intentionally **do not** apply + the ``/ sqrt(text_hidden_size)`` scaling, matching + ``opentau.utils.transformers_patch.patched_paligemma_model_get_image_features`` + which removes it from stock HuggingFace. + + The caller owns ``vision_tower`` and ``multi_modal_projector``. This + module holds them by reference (via a list, so ``nn.Module`` does not + re-register their parameters under this module's path) and mutates the + vision_tower's encoder in place to wrap every ``spacetime_layer_stride``-th + layer. In practice the only caller is ``PI07LowLevelPlannerFlowMatching``, + which passes in the ``Gemma3WithExpertModel``'s SigLIP vision components + (resolved via ``_vision_tower()`` / ``_multi_modal_projector()``). + """ + + def __init__( + self, + vision_tower: SiglipVisionModel, + multi_modal_projector: nn.Module, + num_frames: int, + spacetime_layer_stride: int = 4, + gradient_checkpointing: bool = False, + ): + super().__init__() + if num_frames < 1: + raise ValueError(f"num_frames ({num_frames}) must be >= 1.") + if spacetime_layer_stride < 1: + raise ValueError(f"spacetime_layer_stride ({spacetime_layer_stride}) must be >= 1.") + + self.num_frames = num_frames + self.spacetime_layer_stride = spacetime_layer_stride + # Wrap each SigLIP encoder layer (vanilla or space-time) in + # torch.utils.checkpoint.checkpoint during training. Mirrors the + # explicit per-layer pattern used by pi05's PaliGemmaWithExpertModel + # so we do not depend on transformers' SiglipEncoder internal + # gradient-checkpointing plumbing. The strict distributed-backend + # guard in src/opentau/scripts/train.py applies (DDP, single, or + # DeepSpeed ZeRO-1/2 only). + self.gradient_checkpointing = gradient_checkpointing + + # Hold references in lists so nn.Module.__setattr__ does not + # re-register these modules under this encoder's path. They are owned + # by the caller (Gemma3WithExpertModel); double registration would + # duplicate ~400M params in state_dict. + self._vision_tower_ref: list[SiglipVisionModel] = [vision_tower] + self._multi_modal_projector_ref: list[nn.Module] = [multi_modal_projector] + + # The number of output tokens is fixed by the SigLIP patch grid + # (e.g. 224/14 = 16 -> 16*16 = 256 patches for the default config). + vision_cfg = vision_tower.config + num_patches = (vision_cfg.image_size // vision_cfg.patch_size) ** 2 + self.num_video_tokens = num_patches + self.siglip_hidden_size = vision_cfg.hidden_size + + # Wrap every stride-th layer with space-time attention. The wrapper + # holds the original SiglipEncoderLayer as ``base_layer`` so its + # pretrained weights flow through unchanged. State-dict keys for + # wrapped layers will carry a ``.base_layer.`` prefix; as long as + # reloads round-trip through this code, keys stay consistent. + layers = vision_tower.vision_model.encoder.layers + n_layers = len(layers) + for i in range(spacetime_layer_stride - 1, n_layers, spacetime_layer_stride): + layers[i] = SpaceTimeEncoderLayerWrapper( + base_layer=layers[i], + num_frames=num_frames, + num_tokens_per_frame=num_patches, + ) + + @property + def vision_tower(self) -> SiglipVisionModel: + return self._vision_tower_ref[0] + + @property + def multi_modal_projector(self) -> nn.Module: + return self._multi_modal_projector_ref[0] + + def forward(self, video: Tensor) -> Tensor: + """Encode a video clip and return the current-frame tokens. + + Args: + video: ``(B, T, C, H, W)`` pixel values in ``[0, 1]``, with + ``T == num_frames``, ``C == 3``, and spatial size matching + the SigLIP config (224x224 by default). + + Returns: + ``(B, num_video_tokens, vlm_hidden_size)`` current-frame tokens, + ready to concatenate into the VLA prefix. + """ + if video.ndim != 5: + raise ValueError(f"Expected 5D input (B, T, C, H, W); got {tuple(video.shape)}.") + b, t, c, h, w = video.shape + if t != self.num_frames: + raise ValueError( + f"Expected T={self.num_frames} frames; got {t}. " + "Reinstantiate the encoder with a matching num_frames." + ) + + # SigLIP expects pixel values in [-1, 1]. The dataset loader yields + # [0, 1]; rescale here (keeps prepare_videos producer-agnostic). + video = video * 2.0 - 1.0 + + # Flatten time into batch for the SigLIP pipeline. + flat = rearrange(video, "b t c h w -> (b t) c h w") + + # Patch embedding + learned spatial position embedding. + hidden = self.vision_tower.vision_model.embeddings(flat) + + # Encoder stack: standard spatial layers + wrapped every-Nth layer + # with temporal attention. SpaceTimeEncoderLayerWrapper matches the + # SiglipEncoderLayer signature, so we drive the loop manually here + # (instead of calling SiglipEncoder.forward) so we can wrap each + # layer in torch.utils.checkpoint.checkpoint when the flag is set — + # the same explicit pattern PaliGemmaWithExpertModel uses. + use_ckpt = self.gradient_checkpointing and self.training + for layer in self.vision_tower.vision_model.encoder.layers: + if use_ckpt: + layer_outputs = torch.utils.checkpoint.checkpoint( + layer, hidden, None, False, use_reentrant=False + ) + else: + layer_outputs = layer(hidden, None, False) + hidden = layer_outputs[0] + + hidden = self.vision_tower.vision_model.post_layernorm(hidden) + + # Drop past-timestep tokens: keep only the current frame (t = T-1). + # This matches the MEM paper's "we only pass the representation + # computed for the current timestep onwards" and makes the encoder + # a drop-in replacement for a single-frame vision tower. + hidden = rearrange(hidden, "(b t) n d -> b t n d", b=b, t=t) + current = hidden[:, -1] + + # multi_modal_projector: SigLIP hidden (1152) -> VLA hidden (2048). + # We deliberately omit the `/ sqrt(hidden_size)` division to match + # the patched ``PaliGemmaModel.get_image_features`` (see + # ``opentau.utils.transformers_patch``). + return self.multi_modal_projector(current) diff --git a/tests/policies/test_pi07_high_level_planner.py b/tests/policies/test_pi07_high_level_planner.py new file mode 100644 index 00000000..37d23790 --- /dev/null +++ b/tests/policies/test_pi07_high_level_planner.py @@ -0,0 +1,333 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature +from opentau.policies.pi07.high_level_planner.configuration_pi07_high_level import ( + PI07HighLevelPlannerConfig, +) +from opentau.policies.pi07.high_level_planner.modeling_pi07_high_level import ( + PI07HighLevelPlannerPolicy, + make_att_2d_masks, +) + +# Config defaults used across the test. +NUM_CAMERAS = 2 +# Gemma 3's `Gemma3MultiModalProjector` reduces every image to +# `mm_tokens_per_image` tokens regardless of the underlying patch grid, so each +# camera contributes 256 tokens to the prefix even though the SigLIP backbone +# produces a 32x32=1024 patch grid at 448x448. +SIGLIP_TOKENS_PER_CAMERA = 256 +IMAGE_SIZE = 448 +PROMPT_MAX_LENGTH = 256 +METADATA_MAX_LENGTH = 52 +MEMORY_MAX_LENGTH = 52 +RESPONSE_MAX_LENGTH = 52 +MAX_STATE_DIM = 32 + +# Token offsets (fixed by config and image tokenization). +IMAGE_TOKENS = NUM_CAMERAS * SIGLIP_TOKENS_PER_CAMERA # 512 +LANG_START = IMAGE_TOKENS # 512 +METADATA_START = LANG_START + PROMPT_MAX_LENGTH # 768 +METADATA_END = METADATA_START + METADATA_MAX_LENGTH # 820 + + +class TestPI07HighLevelPlannerIntegration: + """Integration tests for the PI07 high-level planner pipeline.""" + + @staticmethod + def _make_config() -> PI07HighLevelPlannerConfig: + config = PI07HighLevelPlannerConfig( + n_obs_steps=1, + max_state_dim=MAX_STATE_DIM, + prompt_max_length=PROMPT_MAX_LENGTH, + metadata_max_length=METADATA_MAX_LENGTH, + memory_max_length=MEMORY_MAX_LENGTH, + response_max_length=RESPONSE_MAX_LENGTH, + normalization_mapping={ + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + }, + ) + config.input_features = { + "camera0": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + "camera1": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + "state": PolicyFeature(type=FeatureType.STATE, shape=(MAX_STATE_DIM,)), + } + config.output_features = {} + return config + + @staticmethod + def _indicator_lens(tokenizer): + """Lengths of fixed text spans inserted by ``embed_prefix`` / inference. + + The Gemma 3 tokenizer (``google/gemma-3-4b-pt``) may tokenize these + spans differently than PaliGemma did. We resolve against the live + tokenizer so the prefix-layout assertions track the real token count. + """ + return { + "prefix_end": len(tokenizer.encode(";\n ", add_special_tokens=False)), + "memory_lead": len(tokenizer.encode("Updated Memory: ", add_special_tokens=False)), + "subtask_lead": len(tokenizer.encode("Subtask: ", add_special_tokens=False)), + } + + @classmethod + def _train_prefix_total(cls, tokenizer) -> int: + meta = cls._indicator_lens(tokenizer) + mem_tokens_start = METADATA_END + meta["prefix_end"] + meta["memory_lead"] + return mem_tokens_start + MEMORY_MAX_LENGTH + meta["subtask_lead"] + RESPONSE_MAX_LENGTH + + @classmethod + def _infer_embed_prefix_total(cls, tokenizer) -> int: + meta = cls._indicator_lens(tokenizer) + return METADATA_END + meta["prefix_end"] + meta["memory_lead"] + + @staticmethod + def _check_ones_before_zeros(mask_slice): + """Check that in a 1-D boolean mask all Trues precede all Falses.""" + mask = mask_slice.cpu().numpy() + first_zero = None + for idx, val in enumerate(mask): + if val == 0: + first_zero = idx + break + if first_zero is not None: + assert all(v == 0 for v in mask[first_zero:]), f"Zeros not contiguous: {mask}" + assert all(v == 1 for v in mask[:first_zero]), f"Ones not contiguous: {mask}" + else: + assert all(v == 1 for v in mask), f"Expected all ones: {mask}" + + # ------------------------------------------------------------------ + # Verification helpers + # ------------------------------------------------------------------ + + def _verify_pad_masks(self, prefix_pad_masks, tokenizer): + meta = self._indicator_lens(tokenizer) + total = self._train_prefix_total(tokenizer) + assert prefix_pad_masks.shape == (1, total) + assert prefix_pad_masks.dtype == torch.bool + + mem_tokens_start = METADATA_END + meta["prefix_end"] + meta["memory_lead"] + resp_tokens_start = mem_tokens_start + MEMORY_MAX_LENGTH + meta["subtask_lead"] + + for i in range(1): + assert torch.all(prefix_pad_masks[i, :IMAGE_TOKENS] == 1) + self._check_ones_before_zeros(prefix_pad_masks[i, LANG_START:METADATA_START]) + self._check_ones_before_zeros(prefix_pad_masks[i, METADATA_START:METADATA_END]) + self._check_ones_before_zeros( + prefix_pad_masks[i, mem_tokens_start : mem_tokens_start + MEMORY_MAX_LENGTH] + ) + self._check_ones_before_zeros( + prefix_pad_masks[i, resp_tokens_start : resp_tokens_start + RESPONSE_MAX_LENGTH] + ) + + def _verify_position_ids(self, prefix_position_ids, prefix_pad_masks): + expected = torch.cumsum(prefix_pad_masks, dim=1) - 1 + assert torch.equal(prefix_position_ids, expected) + + def _verify_vlm_attention_mask(self, vlm_attention_mask, prefix_pad_masks, prefix_att_masks): + expected = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + assert torch.equal(vlm_attention_mask, expected), ( + f"VLM attention mask mismatch vs make_att_2d_masks.\n" + f"Diff indices: {(vlm_attention_mask != expected).nonzero(as_tuple=False)[:20]}" + ) + + # ------------------------------------------------------------------ + # Main integration test + # ------------------------------------------------------------------ + + @pytest.mark.gpu + @pytest.mark.slow + def test_complete_pi07_high_level_planner_pipeline(self, lerobot_dataset_metadata): + """Test the PI07 high-level planner: forward (training) and sample_actions (inference).""" + + config = self._make_config() + policy = PI07HighLevelPlannerPolicy(config, dataset_stats=lerobot_dataset_metadata.stats) + tokenizer = policy.model.language_tokenizer + + batch_size = 1 + batch = { + "camera0": torch.randn(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), + "camera1": torch.randn(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), + "state": torch.randn(batch_size, MAX_STATE_DIM), + "prompt": ["Pick up the red block"], + "past_memory": ["Robot is near the table"], + "speed": torch.tensor([500]), + "quality": torch.tensor([3]), + "mistake": torch.tensor([0]), + "speed_is_pad": torch.tensor([False]), + "quality_is_pad": torch.tensor([False]), + "mistake_is_pad": torch.tensor([False]), + "next_memory": ["Robot is grasping the red block"], + "response": ["Grasp the red block"], + } + + policy.to(dtype=torch.bfloat16, device="cuda") + batch_cuda = { + key: value.to("cuda", non_blocking=True, dtype=torch.bfloat16) + if isinstance(value, torch.Tensor) + else value + for key, value in batch.items() + } + + # ── Monkey-patch to capture intermediate tensors ────────────── + captured = {} + original_gemma3_forward = policy.model.gemma3_with_expert.forward + original_embed_prefix = policy.model.embed_prefix + + def capture_forward(*args, **kwargs): + if kwargs.get("past_key_values") is None: + captured["vlm_2d_attention_mask"] = kwargs["attention_mask"].clone() + captured["vlm_position_ids"] = kwargs["position_ids"].clone() + return original_gemma3_forward(*args, **kwargs) + + def capture_embed_prefix(*args, **kwargs): + result = original_embed_prefix(*args, **kwargs) + captured["prefix_pad_masks"] = result[1].clone() + captured["prefix_att_masks"] = result[2].clone() + return result + + policy.model.gemma3_with_expert.forward = capture_forward + policy.model.embed_prefix = capture_embed_prefix + + # ── Training forward pass ──────────────────────────────────── + loss = policy.forward(batch_cuda) + + # Restore originals. + policy.model.gemma3_with_expert.forward = original_gemma3_forward + policy.model.embed_prefix = original_embed_prefix + + # Verify captures. + for var in ["prefix_pad_masks", "prefix_att_masks", "vlm_2d_attention_mask", "vlm_position_ids"]: + assert var in captured, f"{var} was not captured" + + assert captured["vlm_2d_attention_mask"].dtype == torch.bool + assert captured["prefix_pad_masks"].dtype == torch.bool + + self._verify_pad_masks(captured["prefix_pad_masks"], tokenizer) + self._verify_position_ids(captured["vlm_position_ids"], captured["prefix_pad_masks"]) + self._verify_vlm_attention_mask( + captured["vlm_2d_attention_mask"], + captured["prefix_pad_masks"], + captured["prefix_att_masks"], + ) + + assert isinstance(loss, dict) + assert "MSE" in loss + assert "CE" in loss + assert loss["MSE"].item() == 0.0 + assert loss["CE"].isfinite() + + # Optimizer params are non-empty. + assert len(list(policy.get_optim_params())) > 0 + + # ── Inference via sample_actions ────────────────────────────── + captured_infer = {} + original_infer_autoregressive = policy.model.infer_autoregressive + step_counter = [0] + + def capture_infer_autoregressive(*args, **kwargs): + prev_prefix_pad_masks = kwargs.get("prefix_pad_masks") + result = original_infer_autoregressive(*args, **kwargs) + ( + _prefix_out, + prefix_embs, + prefix_pad_masks, + prefix_att_masks, + prefix_offsets, + tokens, + _past_kv, + ) = result + + # Verify prefix grows by exactly one token each step. + assert prefix_pad_masks.shape[1] == prev_prefix_pad_masks.shape[1] + 1, ( + f"Step {step_counter[0]}: prefix should grow by 1" + ) + assert prefix_embs.shape[1] == prefix_pad_masks.shape[1] + + assert prefix_att_masks.shape[1] == prefix_pad_masks.shape[1] + + captured_infer["last_prefix_pad_masks"] = prefix_pad_masks.clone() + captured_infer["last_prefix_offsets"] = prefix_offsets.clone() + step_counter[0] += 1 + return result + + def capture_forward_infer(*args, **kwargs): + if kwargs.get("past_key_values") is None: + captured_infer["vlm_2d_attention_mask"] = kwargs["attention_mask"].clone() + captured_infer["vlm_position_ids"] = kwargs["position_ids"].clone() + return original_gemma3_forward(*args, **kwargs) + + def capture_embed_prefix_infer(*args, **kwargs): + result = original_embed_prefix(*args, **kwargs) + captured_infer["prefix_pad_masks"] = result[1].clone() + captured_infer["prefix_att_masks"] = result[2].clone() + return result + + policy.model.gemma3_with_expert.forward = capture_forward_infer + policy.model.embed_prefix = capture_embed_prefix_infer + policy.model.infer_autoregressive = capture_infer_autoregressive + + infer_batch = { + "camera0": batch_cuda["camera0"], + "camera1": batch_cuda["camera1"], + "state": batch_cuda["state"], + "prompt": ["Pick up the red block"], + "past_memory": ["Robot is near the table"], + "speed": torch.tensor([500], device="cuda"), + "quality": torch.tensor([3], device="cuda"), + "mistake": torch.tensor([0], device="cuda"), + "speed_is_pad": torch.tensor([False], device="cuda"), + "quality_is_pad": torch.tensor([False], device="cuda"), + "mistake_is_pad": torch.tensor([False], device="cuda"), + } + memory_tokens, response_tokens = policy.sample_actions(infer_batch) + + # Restore originals. + policy.model.gemma3_with_expert.forward = original_gemma3_forward + policy.model.embed_prefix = original_embed_prefix + policy.model.infer_autoregressive = original_infer_autoregressive + + # Verify total number of autoregressive steps (memory + response only). + assert step_counter[0] == MEMORY_MAX_LENGTH + RESPONSE_MAX_LENGTH + + # Verify inference captures. + assert "vlm_2d_attention_mask" in captured_infer + assert "prefix_pad_masks" in captured_infer + assert "prefix_att_masks" in captured_infer + assert "last_prefix_pad_masks" in captured_infer + assert "last_prefix_offsets" in captured_infer + + infer_base = self._infer_embed_prefix_total(tokenizer) + assert captured_infer["vlm_2d_attention_mask"].shape == (1, infer_base, infer_base) + assert captured_infer["vlm_2d_attention_mask"].dtype == torch.bool + + init_expected = make_att_2d_masks( + captured_infer["prefix_pad_masks"], + captured_infer["prefix_att_masks"], + ) + assert torch.equal(captured_infer["vlm_2d_attention_mask"], init_expected), ( + f"Inference VLM 2D mask vs make_att_2d_masks.\n" + f"Diff: {(captured_infer['vlm_2d_attention_mask'] != init_expected).nonzero(as_tuple=False)[:20]}" + ) + + subtask_lead = self._indicator_lens(tokenizer)["subtask_lead"] + final_prefix_len = infer_base + MEMORY_MAX_LENGTH + subtask_lead + RESPONSE_MAX_LENGTH + assert captured_infer["last_prefix_pad_masks"].shape == (1, final_prefix_len) + + # Output shapes. + assert memory_tokens.shape == (1, MEMORY_MAX_LENGTH) + assert response_tokens.shape == (1, RESPONSE_MAX_LENGTH) diff --git a/tests/policies/test_pi07_low_level_planner.py b/tests/policies/test_pi07_low_level_planner.py new file mode 100644 index 00000000..1b233c13 --- /dev/null +++ b/tests/policies/test_pi07_low_level_planner.py @@ -0,0 +1,673 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature +from opentau.policies.pi07.gemma3_with_expert import Gemma3WithExpertConfig +from opentau.policies.pi07.low_level_planner.configuration_pi07_low_level import ( + PI07LowLevelPlannerConfig, +) +from opentau.policies.pi07.low_level_planner.modeling_pi07_low_level import ( + PI07LowLevelPlannerFlowMatching, + PI07LowLevelPlannerPolicy, + make_att_2d_masks, +) + +# Tiny VLM config so that the full forward pass fits within 24 GB GPU memory. +# Backbone: 2 text layers × 512 hidden, 2 KV heads × 128 head_dim (=256 total). +# Expert: 2 layers × 256 hidden, matching head_dim/KV heads. +# Vision: 2 SigLIP layers, 448 image_size (matching production resolution). +_TINY_TEXT_HIDDEN = 512 +_TINY_EXPERT_HIDDEN = 256 +_TINY_HEAD_DIM = 128 +_TINY_NUM_LAYERS = 2 + +_TINY_VLM_CONFIG = Gemma3WithExpertConfig( + gemma3_config={ + "model_type": "gemma3", + "text_config": { + "model_type": "gemma3_text", + "hidden_size": _TINY_TEXT_HIDDEN, + "intermediate_size": _TINY_TEXT_HIDDEN * 4, + "num_hidden_layers": _TINY_NUM_LAYERS, + "num_attention_heads": _TINY_TEXT_HIDDEN // _TINY_HEAD_DIM, + "num_key_value_heads": 2, + "head_dim": _TINY_HEAD_DIM, + "query_pre_attn_scalar": _TINY_HEAD_DIM, + "sliding_window": 1024, + "rope_theta": 1_000_000.0, + "rope_local_base_freq": 10_000.0, + "rms_norm_eps": 1e-6, + "vocab_size": 262_208, + "max_position_embeddings": 8192, + "attention_bias": False, + "attention_dropout": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + "sliding_window_pattern": 6, + "torch_dtype": "float32", + }, + "vision_config": { + "model_type": "siglip_vision_model", + "hidden_size": 256, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_hidden_layers": 2, + "patch_size": 14, + "image_size": 448, + "projection_dim": _TINY_TEXT_HIDDEN, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + "torch_dtype": "float32", + "layer_norm_eps": 1e-6, + }, + "image_token_index": 262144, + "mm_tokens_per_image": 256, + "boi_token_index": 255999, + "eoi_token_index": 256000, + "initializer_range": 0.02, + }, + gemma_expert_config={ + "model_type": "gemma", + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "head_dim": _TINY_HEAD_DIM, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": _TINY_EXPERT_HIDDEN, + "initializer_range": 0.02, + "intermediate_size": _TINY_EXPERT_HIDDEN * 4, + "max_position_embeddings": 8192, + "num_attention_heads": _TINY_TEXT_HIDDEN // _TINY_HEAD_DIM, + "num_hidden_layers": _TINY_NUM_LAYERS, + "num_key_value_heads": 2, + "pad_token_id": 0, + "rms_norm_eps": 1e-6, + "rope_theta": 10_000.0, + "torch_dtype": "float32", + "use_adarms": True, + "adarms_cond_dim": _TINY_EXPERT_HIDDEN, + "use_cache": True, + "vocab_size": 262_208, + }, + freeze_vision_encoder=True, + train_expert_only=False, + attention_implementation="eager", + load_pretrained_gemma3=False, + dropout=0.1, +) + +# Config defaults used across the test. +NUM_CAMERAS = 2 +# SpaceTimeSiglip pipes per-frame patches through the Gemma 3 multimodal +# projector, which always reduces to ``mm_tokens_per_image=256`` tokens per +# camera regardless of the underlying 32x32 patch grid (1024 patches at 448x448). +SPACETIME_SIGLIP_TOKENS_PER_CAMERA = 256 +NUM_SUBGOAL_CAMERAS = 1 +SIGLIP_TOKENS_PER_SUBGOAL = 256 +IMAGE_SIZE = 448 +PROMPT_MAX_LENGTH = 256 +RESPONSE_MAX_LENGTH = 52 +METADATA_MAX_LENGTH = 52 +DISCRETE_ACTION_MAX_LENGTH = 32 +# chunk_size=50 must match the ``lerobot_dataset_metadata`` fixture's action +# stats shape (50, 32); the normalizer buffers are built from those stats. +# n_obs_steps is reduced from the production default (8) to 2 so that the +# full-resolution 448×448 video tensors fit within 24 GB GPU memory. T=2 is +# the minimum that exercises SpaceTime temporal attention (T=1 short-circuits +# to spatial-only). +CHUNK_SIZE = 50 +MAX_STATE_DIM = 32 +MAX_ACTION_DIM = 32 + +# For training the state is provided as (B, n_obs_steps, D) so T = n_obs_steps. +N_OBS_STEPS = 2 + +VIDEO_TOKENS = NUM_CAMERAS * SPACETIME_SIGLIP_TOKENS_PER_CAMERA # 512 +LANG_START = VIDEO_TOKENS # 512 +SUBGOAL_TOKENS = NUM_SUBGOAL_CAMERAS * SIGLIP_TOKENS_PER_SUBGOAL # 256 + +# For inference: no discrete actions. SpaceTimeSiglip requires the full +# temporal window, so state keeps all T == N_OBS_STEPS timesteps. +INFER_STATE_TOKENS = N_OBS_STEPS + + +class TestPI07LowLevelPlannerIntegration: + """Integration tests for the PI07 low-level planner pipeline.""" + + @staticmethod + def _make_config() -> PI07LowLevelPlannerConfig: + config = PI07LowLevelPlannerConfig( + n_obs_steps=N_OBS_STEPS, + chunk_size=CHUNK_SIZE, + n_action_steps=CHUNK_SIZE, + max_state_dim=MAX_STATE_DIM, + max_action_dim=MAX_ACTION_DIM, + prompt_max_length=PROMPT_MAX_LENGTH, + response_max_length=RESPONSE_MAX_LENGTH, + metadata_max_length=METADATA_MAX_LENGTH, + discrete_action_max_length=DISCRETE_ACTION_MAX_LENGTH, + proj_width=_TINY_EXPERT_HIDDEN, + vlm_config=_TINY_VLM_CONFIG, + normalization_mapping={ + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MEAN_STD, + }, + ) + config.input_features = { + "camera0": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + "camera1": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + "state": PolicyFeature(type=FeatureType.STATE, shape=(MAX_STATE_DIM,)), + } + config.output_features = { + "actions": PolicyFeature(type=FeatureType.ACTION, shape=(CHUNK_SIZE, MAX_ACTION_DIM)), + } + return config + + @staticmethod + def _indicator_lens(tokenizer): + """Fixed strings inserted by ``embed_prefix`` (matches modeling layout). + + Resolved against the live Gemma 3 tokenizer, which may segment these + spans differently than PaliGemma's tokenizer. + """ + return { + "state_lead": len(tokenizer.encode("State: ", add_special_tokens=False)), + "comma": len(tokenizer.encode(", ", add_special_tokens=False)), + "subgoal_lead": len(tokenizer.encode("Subgoal: ", add_special_tokens=False)), + "prefix_end": len(tokenizer.encode(";\n ", add_special_tokens=False)), + "action_lead": len(tokenizer.encode("Action: ", add_special_tokens=False)), + } + + @classmethod + def _train_prefix_total(cls, tokenizer) -> int: + m = cls._indicator_lens(tokenizer) + p = 0 + p += VIDEO_TOKENS + p += PROMPT_MAX_LENGTH + p += m["state_lead"] + N_OBS_STEPS + m["comma"] + p += RESPONSE_MAX_LENGTH + p += m["subgoal_lead"] + SUBGOAL_TOKENS + m["comma"] + p += METADATA_MAX_LENGTH + p += m["prefix_end"] + p += m["action_lead"] + DISCRETE_ACTION_MAX_LENGTH + return p + + @classmethod + def _infer_prefix_total(cls, tokenizer) -> int: + m = cls._indicator_lens(tokenizer) + p = 0 + p += VIDEO_TOKENS + p += PROMPT_MAX_LENGTH + p += m["state_lead"] + INFER_STATE_TOKENS + m["comma"] + p += RESPONSE_MAX_LENGTH + p += m["subgoal_lead"] + SUBGOAL_TOKENS + m["comma"] + p += METADATA_MAX_LENGTH + p += m["prefix_end"] + return p + + @staticmethod + def _check_ones_before_zeros(mask_slice): + """Check that in a 1-D boolean mask all Trues precede all Falses.""" + mask = mask_slice.cpu().numpy() + first_zero = None + for idx, val in enumerate(mask): + if val == 0: + first_zero = idx + break + if first_zero is not None: + assert all(v == 0 for v in mask[first_zero:]), f"Zeros not contiguous: {mask}" + assert all(v == 1 for v in mask[:first_zero]), f"Ones not contiguous: {mask}" + else: + assert all(v == 1 for v in mask), f"Expected all ones: {mask}" + + # ------------------------------------------------------------------ + # Verification helpers + # ------------------------------------------------------------------ + + def _verify_pad_masks(self, prefix_pad_masks, suffix_pad_masks, tokenizer, inference_mode=False): + assert prefix_pad_masks.shape[0] == 1 + total = self._infer_prefix_total(tokenizer) if inference_mode else self._train_prefix_total(tokenizer) + assert prefix_pad_masks.shape[1] == total + assert prefix_pad_masks.dtype == torch.bool + assert suffix_pad_masks.shape == (1, CHUNK_SIZE) + assert suffix_pad_masks.dtype == torch.bool + + m = self._indicator_lens(tokenizer) + + lang_slice = slice(LANG_START, LANG_START + PROMPT_MAX_LENGTH) + state_t = INFER_STATE_TOKENS if inference_mode else N_OBS_STEPS + + resp_lo = LANG_START + PROMPT_MAX_LENGTH + m["state_lead"] + state_t + m["comma"] + resp_slice = slice(resp_lo, resp_lo + RESPONSE_MAX_LENGTH) + + sg_lo = resp_lo + RESPONSE_MAX_LENGTH + m["subgoal_lead"] + sg_slice = slice(sg_lo, sg_lo + SUBGOAL_TOKENS) + + meta_lo = sg_lo + SUBGOAL_TOKENS + m["comma"] + meta_slice = slice(meta_lo, meta_lo + METADATA_MAX_LENGTH) + + for i in range(prefix_pad_masks.shape[0]): + assert torch.all(prefix_pad_masks[i, :VIDEO_TOKENS] == 1) + self._check_ones_before_zeros(prefix_pad_masks[i, lang_slice]) + self._check_ones_before_zeros(prefix_pad_masks[i, resp_slice]) + assert torch.all(prefix_pad_masks[i, sg_slice] == 1) + self._check_ones_before_zeros(prefix_pad_masks[i, meta_slice]) + + if not inference_mode: + da_lo = meta_lo + METADATA_MAX_LENGTH + m["prefix_end"] + m["action_lead"] + da_slice = slice(da_lo, da_lo + DISCRETE_ACTION_MAX_LENGTH) + self._check_ones_before_zeros(prefix_pad_masks[i, da_slice]) + + self._check_ones_before_zeros(suffix_pad_masks[i]) + + def _verify_position_ids( + self, + prefix_position_ids, + suffix_position_ids, + prefix_pad_masks, + suffix_pad_masks, + tokenizer, + inference_mode=False, + ): + expected_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1 + assert torch.equal(prefix_position_ids, expected_prefix) + + if inference_mode: + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + else: + prefix_offsets = torch.sum(prefix_pad_masks[:, :-DISCRETE_ACTION_MAX_LENGTH], dim=-1)[:, None] + + expected_suffix = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + assert torch.equal(suffix_position_ids, expected_suffix) + + def _verify_vlm_attention_mask( + self, vlm_attention_mask, prefix_pad_masks, prefix_att_masks, inference_mode=False + ): + del inference_mode # same rule as training + expected = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + assert torch.equal(vlm_attention_mask, expected), ( + f"VLM attention mask mismatch vs make_att_2d_masks.\n" + f"Diff indices: {(vlm_attention_mask != expected).nonzero(as_tuple=False)[:20]}" + ) + + def _verify_action_expert_attention_mask( + self, + action_expert_attention_mask, + prefix_pad_masks, + suffix_pad_masks, + suffix_att_masks, + inference_mode=False, + ): + if inference_mode: + num_cross = prefix_pad_masks.shape[1] + else: + num_cross = prefix_pad_masks.shape[1] - DISCRETE_ACTION_MAX_LENGTH + + expected = make_att_2d_masks( + suffix_pad_masks, + suffix_att_masks, + n_cross_att_tokens=num_cross, + cross_att_pad_masks=prefix_pad_masks[:, :num_cross], + ) + assert torch.equal(action_expert_attention_mask, expected), ( + f"Action expert attention mask mismatch vs make_att_2d_masks.\n" + f"Diff indices: {(action_expert_attention_mask != expected).nonzero(as_tuple=False)[:20]}" + ) + + # ------------------------------------------------------------------ + # Main integration test + # ------------------------------------------------------------------ + + @pytest.mark.gpu + @pytest.mark.slow + def test_complete_pi07_low_level_pipeline(self, lerobot_dataset_metadata): + """Test the PI07 low-level planner pipeline: forward (training) and select_action (inference).""" + + config = self._make_config() + policy = PI07LowLevelPlannerPolicy(config, dataset_stats=lerobot_dataset_metadata.stats) + tokenizer = policy.model.language_tokenizer + + batch_size = 1 + batch = { + "camera0": torch.randn(batch_size, N_OBS_STEPS, 3, IMAGE_SIZE, IMAGE_SIZE), + "camera1": torch.randn(batch_size, N_OBS_STEPS, 3, IMAGE_SIZE, IMAGE_SIZE), + "subgoal0": torch.randn(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), + "state": torch.randn(batch_size, N_OBS_STEPS, MAX_STATE_DIM), + "actions": torch.randn(batch_size, CHUNK_SIZE, MAX_ACTION_DIM), + "prompt": ["Pick up the red block"], + "response": ["Grasp the red block"], + "speed": torch.tensor([500]), + "quality": torch.tensor([3]), + "mistake": torch.tensor([0]), + "speed_is_pad": torch.tensor([False]), + "quality_is_pad": torch.tensor([False]), + "mistake_is_pad": torch.tensor([False]), + "subgoal_is_pad": torch.tensor([False]), + "action_is_pad": torch.cat( + [ + torch.zeros(batch_size, CHUNK_SIZE // 2, dtype=torch.bool), + torch.ones(batch_size, CHUNK_SIZE - CHUNK_SIZE // 2, dtype=torch.bool), + ], + dim=1, + ), + } + + policy.to(dtype=torch.bfloat16, device="cuda") + batch_cuda = { + key: value.to("cuda", non_blocking=True, dtype=torch.bfloat16) + if isinstance(value, torch.Tensor) + else value + for key, value in batch.items() + } + batch_cuda["action_is_pad"] = batch_cuda["action_is_pad"].to(dtype=torch.bool) + + # ── Monkey-patch to capture intermediate tensors ────────────── + captured = {} + original_gemma3_forward = policy.model.gemma3_with_expert.forward + original_embed_prefix = policy.model.embed_prefix + original_embed_suffix = policy.model.embed_suffix + + def capture_forward(*args, **kwargs): + if kwargs["inputs_embeds"][0] is not None: + captured["vlm_2d_attention_mask"] = kwargs["attention_mask"].clone() + captured["vlm_position_ids"] = kwargs["position_ids"].clone() + else: + captured["action_expert_2d_attention_mask"] = kwargs["attention_mask"].clone() + captured["action_expert_position_ids"] = kwargs["position_ids"].clone() + return original_gemma3_forward(*args, **kwargs) + + def capture_embed_prefix(*args, **kwargs): + # Truncate discrete actions to inject padding (same workaround as PI05 test). + half = DISCRETE_ACTION_MAX_LENGTH // 2 + discrete_actions = kwargs["discrete_actions"] + discrete_action_masks = kwargs["discrete_action_masks"] + kwargs["discrete_actions"] = torch.cat( + ( + discrete_actions[:, :half], + torch.zeros((1, half), dtype=discrete_actions.dtype, device=discrete_actions.device), + ), + dim=-1, + ) + kwargs["discrete_action_masks"] = torch.cat( + ( + discrete_action_masks[:, :half], + torch.zeros((1, half), dtype=torch.bool, device=discrete_action_masks.device), + ), + dim=-1, + ) + result = original_embed_prefix(*args, **kwargs) + captured["prefix_pad_masks"] = result[1].clone() + captured["prefix_att_masks"] = result[2].clone() + return result + + def capture_embed_suffix(*args, **kwargs): + result = original_embed_suffix(*args, **kwargs) + captured["suffix_pad_masks"] = result[1].clone() + captured["suffix_att_masks"] = result[2].clone() + return result + + policy.model.gemma3_with_expert.forward = capture_forward + policy.model.embed_prefix = capture_embed_prefix + policy.model.embed_suffix = capture_embed_suffix + + # ── Training forward pass ──────────────────────────────────── + loss = policy.forward(batch_cuda) + + # Restore originals. + policy.model.gemma3_with_expert.forward = original_gemma3_forward + policy.model.embed_prefix = original_embed_prefix + policy.model.embed_suffix = original_embed_suffix + + # Verify normalize / unnormalize round-trip. + action_output = {"actions": batch["actions"].to("cuda")} + assert torch.allclose( + action_output["actions"], + policy.unnormalize_outputs(policy.normalize_targets(action_output))["actions"], + atol=1e-6, + ) + + # Verify all expected captures are present. + for var in [ + "prefix_pad_masks", + "prefix_att_masks", + "suffix_pad_masks", + "suffix_att_masks", + "vlm_2d_attention_mask", + "vlm_position_ids", + "action_expert_2d_attention_mask", + "action_expert_position_ids", + ]: + assert var in captured, f"{var} was not captured" + + assert captured["vlm_2d_attention_mask"].dtype == torch.bool + assert captured["action_expert_2d_attention_mask"].dtype == torch.bool + assert captured["prefix_pad_masks"].dtype == torch.bool + assert captured["suffix_pad_masks"].dtype == torch.bool + + self._verify_pad_masks(captured["prefix_pad_masks"], captured["suffix_pad_masks"], tokenizer) + self._verify_position_ids( + captured["vlm_position_ids"], + captured["action_expert_position_ids"], + captured["prefix_pad_masks"], + captured["suffix_pad_masks"], + tokenizer, + ) + self._verify_vlm_attention_mask( + captured["vlm_2d_attention_mask"], + captured["prefix_pad_masks"], + captured["prefix_att_masks"], + ) + self._verify_action_expert_attention_mask( + captured["action_expert_2d_attention_mask"], + captured["prefix_pad_masks"], + captured["suffix_pad_masks"], + captured["suffix_att_masks"], + ) + + assert isinstance(loss, dict) + assert "MSE" in loss + assert "CE" in loss + assert all(v.isfinite() for v in loss.values()) + + # Reset and check queue cleared. + policy.reset() + assert len(policy._action_queue) == 0 + + # Optimizer params are non-empty. + assert len(list(policy.get_optim_params())) > 0 + + # ── Inference via select_action ────────────────────────────── + captured_infer = {} + + def capture_forward_infer(*args, **kwargs): + if kwargs["inputs_embeds"][0] is not None and kwargs.get("past_key_values") is None: + captured_infer["vlm_2d_attention_mask"] = kwargs["attention_mask"].clone() + captured_infer["vlm_position_ids"] = kwargs["position_ids"].clone() + else: + captured_infer["action_expert_2d_attention_mask"] = kwargs["attention_mask"].clone() + captured_infer["action_expert_position_ids"] = kwargs["position_ids"].clone() + return original_gemma3_forward(*args, **kwargs) + + def capture_embed_prefix_infer(*args, **kwargs): + result = original_embed_prefix(*args, **kwargs) + captured_infer["prefix_pad_masks"] = result[1].clone() + captured_infer["prefix_att_masks"] = result[2].clone() + return result + + def capture_embed_suffix_infer(*args, **kwargs): + result = original_embed_suffix(*args, **kwargs) + captured_infer["suffix_pad_masks"] = result[1].clone() + captured_infer["suffix_att_masks"] = result[2].clone() + return result + + policy.model.gemma3_with_expert.forward = capture_forward_infer + policy.model.embed_prefix = capture_embed_prefix_infer + policy.model.embed_suffix = capture_embed_suffix_infer + + # Inference batch: SpaceTimeSiglip requires T == n_obs_steps frames, so + # videos stay 5-D (B, T, C, H, W) and state stays 3-D (B, T, D). + infer_batch = { + "camera0": batch_cuda["camera0"], # (B, T, C, H, W) + "camera1": batch_cuda["camera1"], + "subgoal0": batch_cuda["subgoal0"], # already (B, C, H, W) + "state": batch_cuda["state"], # (B, T, D) + "prompt": ["Pick up the red block"], + "response": ["Grasp the red block"], + "speed": torch.tensor([500], device="cuda"), + "quality": torch.tensor([3], device="cuda"), + "mistake": torch.tensor([0], device="cuda"), + "speed_is_pad": torch.tensor([False], device="cuda"), + "quality_is_pad": torch.tensor([False], device="cuda"), + "mistake_is_pad": torch.tensor([False], device="cuda"), + "subgoal_is_pad": torch.tensor([False], device="cuda"), + } + action = policy.select_action(infer_batch) + + # Restore originals. + policy.model.gemma3_with_expert.forward = original_gemma3_forward + policy.model.embed_prefix = original_embed_prefix + policy.model.embed_suffix = original_embed_suffix + + for var in [ + "prefix_pad_masks", + "prefix_att_masks", + "suffix_pad_masks", + "suffix_att_masks", + "vlm_2d_attention_mask", + "vlm_position_ids", + "action_expert_2d_attention_mask", + "action_expert_position_ids", + ]: + assert var in captured_infer, f"{var} was not captured for select_action" + + assert captured_infer["vlm_2d_attention_mask"].dtype == torch.bool + assert captured_infer["action_expert_2d_attention_mask"].dtype == torch.bool + assert captured_infer["prefix_pad_masks"].dtype == torch.bool + assert captured_infer["suffix_pad_masks"].dtype == torch.bool + + self._verify_pad_masks( + captured_infer["prefix_pad_masks"], + captured_infer["suffix_pad_masks"], + tokenizer, + inference_mode=True, + ) + self._verify_position_ids( + captured_infer["vlm_position_ids"], + captured_infer["action_expert_position_ids"], + captured_infer["prefix_pad_masks"], + captured_infer["suffix_pad_masks"], + tokenizer, + inference_mode=True, + ) + self._verify_vlm_attention_mask( + captured_infer["vlm_2d_attention_mask"], + captured_infer["prefix_pad_masks"], + captured_infer["prefix_att_masks"], + inference_mode=True, + ) + self._verify_action_expert_attention_mask( + captured_infer["action_expert_2d_attention_mask"], + captured_infer["prefix_pad_masks"], + captured_infer["suffix_pad_masks"], + captured_infer["suffix_att_masks"], + inference_mode=True, + ) + + assert action.shape == (1, MAX_ACTION_DIM) + + +class TestPI07LowLevelPlannerRegression: + """GPU regression tests pinning the low-level planner signature/dtype fixes. + + Covers the changes made to ``embed_prefix``, ``embed_suffix``, + ``prepare_metadata``, and the metadata-zip ``strict=True`` switch. + """ + + @staticmethod + def _make_policy(lerobot_dataset_metadata) -> PI07LowLevelPlannerPolicy: + config = TestPI07LowLevelPlannerIntegration._make_config() + policy = PI07LowLevelPlannerPolicy(config, dataset_stats=lerobot_dataset_metadata.stats) + policy.to(dtype=torch.bfloat16, device="cuda") + return policy + + @staticmethod + def _make_metadata_batch(batch_size: int) -> dict[str, torch.Tensor]: + return { + "state": torch.randn(batch_size, N_OBS_STEPS, MAX_STATE_DIM, device="cuda", dtype=torch.bfloat16), + "speed": torch.tensor([500] * batch_size, device="cuda"), + "quality": torch.tensor([3] * batch_size, device="cuda"), + "mistake": torch.tensor([0] * batch_size, device="cuda"), + "speed_is_pad": torch.tensor([False] * batch_size, device="cuda"), + "quality_is_pad": torch.tensor([False] * batch_size, device="cuda"), + "mistake_is_pad": torch.tensor([False] * batch_size, device="cuda"), + } + + @pytest.mark.gpu + @pytest.mark.slow + def test_prepare_metadata_always_returns_tensors(self, lerobot_dataset_metadata): + """prepare_metadata returns (Tensor, Tensor) — never (None, None) — with the documented shapes.""" + policy = self._make_policy(lerobot_dataset_metadata) + batch = self._make_metadata_batch(batch_size=2) + + tokens, masks = policy.prepare_metadata(batch) + + assert isinstance(tokens, torch.Tensor) + assert isinstance(masks, torch.Tensor) + assert tokens.shape == (2, METADATA_MAX_LENGTH) + assert masks.shape == (2, METADATA_MAX_LENGTH) + assert masks.dtype == torch.bool + + @pytest.mark.gpu + @pytest.mark.slow + def test_prepare_metadata_zip_strict_catches_mismatch(self, lerobot_dataset_metadata): + """The ``strict=True`` zip in prepare_metadata raises on length mismatch.""" + policy = self._make_policy(lerobot_dataset_metadata) + batch = self._make_metadata_batch(batch_size=2) + # Truncate quality to length 1 to break the zip. + batch["quality"] = batch["quality"][:1] + batch["quality_is_pad"] = batch["quality_is_pad"][:1] + + with pytest.raises(ValueError): + policy.prepare_metadata(batch) + + @pytest.mark.gpu + @pytest.mark.slow + def test_embed_suffix_returns_bool_att_masks(self, lerobot_dataset_metadata): + """The suffix att_masks must be bool, not embs.dtype (was a copy-paste bug).""" + policy = self._make_policy(lerobot_dataset_metadata) + + bsize = 1 + noisy_actions = torch.randn(bsize, CHUNK_SIZE, MAX_ACTION_DIM, device="cuda", dtype=torch.bfloat16) + timestep = torch.zeros(bsize, CHUNK_SIZE, device="cuda", dtype=torch.bfloat16) + + _, _, att_masks, _ = policy.model.embed_suffix(noisy_actions, timestep) + + assert att_masks.dtype == torch.bool, f"Expected torch.bool, got {att_masks.dtype}" + + @pytest.mark.gpu + @pytest.mark.slow + def test_embed_prefix_metadata_response_are_required(self): + """response_tokens / response_masks / metadata_tokens / metadata_masks are positional, no defaults.""" + import inspect + + params = inspect.signature(PI07LowLevelPlannerFlowMatching.embed_prefix).parameters + for name in ("response_tokens", "response_masks", "metadata_tokens", "metadata_masks"): + assert params[name].default is inspect.Parameter.empty, ( + f"{name} should be a required parameter (no default), got default={params[name].default}" + ) From e135cb64509b9d2c8effaedce59c3742fab9fa0a Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Wed, 29 Apr 2026 11:45:24 -0700 Subject: [PATCH 2/8] fix(pi07): align with #178/#171 invariants & restore CPU tests (#198) --- src/opentau/__init__.py | 9 +- .../policies/pi07/gemma3_with_expert.py | 28 +- .../modeling_pi07_high_level.py | 42 +- .../configuration_pi07_low_level.py | 10 +- .../modeling_pi07_low_level.py | 31 +- .../pi07/low_level_planner/video_encoder.py | 60 ++- tests/policies/test_pi07_cpu.py | 364 ++++++++++++++++++ tests/policies/test_pi07_video_encoder_cpu.py | 355 +++++++++++++++++ tests/test_available.py | 8 + 9 files changed, 844 insertions(+), 63 deletions(-) create mode 100644 tests/policies/test_pi07_cpu.py create mode 100644 tests/policies/test_pi07_video_encoder_cpu.py diff --git a/src/opentau/__init__.py b/src/opentau/__init__.py index 25e86281..88725cec 100644 --- a/src/opentau/__init__.py +++ b/src/opentau/__init__.py @@ -149,7 +149,14 @@ ) # lists all available policies from `src/opentau/policies` -available_policies = ["pi0", "pi05", "pi05_mem", "value"] +available_policies = [ + "pi0", + "pi05", + "pi05_mem", + "pi07_high_level", + "pi07_low_level", + "value", +] # keys and values refer to yaml files available_policies_per_env = {} diff --git a/src/opentau/policies/pi07/gemma3_with_expert.py b/src/opentau/policies/pi07/gemma3_with_expert.py index ea296e12..6a9670a6 100644 --- a/src/opentau/policies/pi07/gemma3_with_expert.py +++ b/src/opentau/policies/pi07/gemma3_with_expert.py @@ -372,19 +372,33 @@ def _multi_modal_projector(self): def embed_image(self, image: torch.Tensor) -> torch.Tensor: """Runs the SigLIP tower + multimodal projector to obtain image tokens. - Gemma 3's `get_image_features` returns a `BaseModelOutputWithPooling` - whose `pooler_output` holds the projected image tokens (shape - `(B, mm_tokens_per_image, text_hidden_size)`). We extract that tensor - so callers can treat the return as a plain `(B, N, D)` tensor — matching - the patched PaliGemma behavior in `transformers_patch.py`. + Mirrors ``Gemma3ForConditionalGeneration.get_image_features``: feed + ``pixel_values`` through the vision tower and run + ``multi_modal_projector`` on the resulting ``last_hidden_state``, + returning a plain ``(B, mm_tokens_per_image, text_hidden_size)`` + tensor. + + When the vision tower has been wrapped with space-time attention by + :class:`SpaceTimeSiglipVideoEncoder` (low-level planner), suppress the + temporal sublayer here — single-image inputs have no time axis to + attend over. The context manager is a no-op when no wrappers are + present. """ + # Local import keeps ``gemma3_with_expert`` importable from the + # high-level planner (which never constructs a video encoder) without + # forcing a transitive import of einops/F at module load time. + from opentau.policies.pi07.low_level_planner.video_encoder import ( + suppress_spacetime_temporal, + ) + vision_tower = self._vision_tower() projector = self._multi_modal_projector() if vision_tower is None or projector is None: raise RuntimeError( "Gemma3 vision tower / multi_modal_projector could not be located on `self.gemma3`." ) - last_hidden_state = vision_tower(pixel_values=image).last_hidden_state + with suppress_spacetime_temporal(vision_tower): + last_hidden_state = vision_tower(pixel_values=image).last_hidden_state return projector(last_hidden_state) def _lm_head(self) -> nn.Module: @@ -415,7 +429,7 @@ def embed_discrete_actions(self, actions: torch.Tensor) -> torch.Tensor: def get_attention_interface(self): if self.config.attention_implementation == "fa2": raise NotImplementedError( - "fa2 attention is not supported for pi06 yet because of the interleaved " + "fa2 attention is not supported for pi07 yet because of the interleaved " "local/global mask pattern — use 'eager' instead." ) return self.eager_attention_forward diff --git a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py index f2ef2ec4..61b22508 100644 --- a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py +++ b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py @@ -26,7 +26,6 @@ import builtins import logging -import math from pathlib import Path import torch @@ -371,6 +370,12 @@ def _fix_pytorch_state_dict_keys( fixed_state_dict = {} for key, value in state_dict.items(): + # Accept legacy `paligemma_with_expert.*` prefixes from + # `pi07_paligemma` checkpoints as a warm-start path. The rest of + # the rewrite logic applies uniformly to both prefixes. + if key.startswith("paligemma_with_expert."): + key = key.replace("paligemma_with_expert.", "gemma3_with_expert.", 1) + new_key = key # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias @@ -914,12 +919,14 @@ def embed_prefix( # Create attention masks so that image tokens attend to each other att_masks += [0] * num_img_embs + # Gemma 3's `embed_tokens` is a `Gemma3TextScaledWordEmbedding` that + # already multiplies by sqrt(hidden_size) internally — do NOT scale + # again here (unlike pi05's PaliGemma path, whose Gemma-v1 embedding + # is a plain nn.Embedding with the normalizer applied later in the + # stock forward that we bypass). Applies to every + # `embed_language_tokens` call in this method. lang_emb = self.gemma3_with_expert.embed_language_tokens(lang_tokens) - # Normalize language embeddings - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) - embs.append(lang_emb) pad_masks.append(lang_masks) @@ -929,8 +936,6 @@ def embed_prefix( if metadata_tokens is not None: metadata_emb = self.gemma3_with_expert.embed_language_tokens(metadata_tokens) - metadata_emb_dim = metadata_emb.shape[-1] - metadata_emb = metadata_emb * math.sqrt(metadata_emb_dim) embs.append(metadata_emb) pad_masks.append(metadata_masks) att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) @@ -942,8 +947,6 @@ def embed_prefix( dtype=torch.long, ) prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) - prefix_end_dim = prefix_end_emb.shape[-1] - prefix_end_emb = prefix_end_emb * math.sqrt(prefix_end_dim) num_prefix_end_embs = prefix_end_emb.shape[1] prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) @@ -961,8 +964,6 @@ def embed_prefix( dtype=torch.long, ) memory_start_emb = self.gemma3_with_expert.embed_language_tokens(memory_start_tokens) - memory_start_dim = memory_start_emb.shape[-1] - memory_start_emb = memory_start_emb * math.sqrt(memory_start_dim) num_memory_start_embs = memory_start_emb.shape[1] memory_start_mask = torch.ones( @@ -975,9 +976,6 @@ def embed_prefix( if memory_tokens is not None: memory_emb = self.gemma3_with_expert.embed_language_tokens(memory_tokens) - # Normalize memory language embeddings - memory_emb_dim = memory_emb.shape[-1] - memory_emb = memory_emb * math.sqrt(memory_emb_dim) embs.append(memory_emb) pad_masks.append(memory_masks) @@ -996,8 +994,6 @@ def embed_prefix( dtype=torch.long, ) response_start_emb = self.gemma3_with_expert.embed_language_tokens(response_start_tokens) - response_start_dim = response_start_emb.shape[-1] - response_start_emb = response_start_emb * math.sqrt(response_start_dim) num_response_start_embs = response_start_emb.shape[1] response_start_mask = torch.ones( @@ -1010,10 +1006,6 @@ def embed_prefix( response_emb = self.gemma3_with_expert.embed_language_tokens(response_tokens) - # Normalize response language embeddings - response_emb_dim = response_emb.shape[-1] - response_emb = response_emb * math.sqrt(response_emb_dim) - embs.append(response_emb) pad_masks.append(response_masks) @@ -1251,8 +1243,9 @@ def sample_actions( response_start_indicator_ids = self.language_tokenizer.encode("Subtask: ", add_special_tokens=False) for i, tid in enumerate(response_start_indicator_ids): token = torch.full((bsize, 1), int(tid), device=device, dtype=torch.long) + # Gemma 3's `embed_tokens` already scales by sqrt(hidden_size); see + # the note in `embed_prefix`. emb = self.gemma3_with_expert.embed_language_tokens(token) - emb = emb * math.sqrt(emb.shape[-1]) pad_row = torch.ones((bsize, 1), device=device, dtype=prefix_pad_masks.dtype) if prefix_att_masks.dtype == torch.bool: new_att = torch.full((bsize, 1), i == 0, device=device, dtype=torch.bool) @@ -1387,13 +1380,10 @@ def infer_autoregressive( # Updating response tokens with current predicted token tokens = torch.cat([tokens, token], dim=1) - # Embed the current predicted token + # Embed the current predicted token. Gemma 3's `embed_tokens` already + # scales by sqrt(hidden_size); see the note in `embed_prefix`. emb = self.gemma3_with_expert.embed_language_tokens(token) - # Normalize response language embeddings - emb_dim = emb.shape[-1] - emb = emb * math.sqrt(emb_dim) - att_masks = torch.ones((bsize, 1), device=device, dtype=emb.dtype) # update the prefix embs, pad masks and att masks, so it can be used by action experts diff --git a/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py index d4e01d6f..9aaa452f 100644 --- a/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/configuration_pi07_low_level.py @@ -88,7 +88,9 @@ class PI07LowLevelPlannerConfig(PreTrainedConfig): vlm_config: Bundled :class:`Gemma3WithExpertConfig` for the Gemma 3 VLM backbone + Gemma-v1 action expert. spacetime_layer_stride: Wrap every Nth SigLIP encoder layer with - space-time separable attention. ``1`` wraps every layer. + space-time separable attention. Defaults to ``4`` (every 4th + of the 27 SigLIP layers, indices ``[3, 7, 11, 15, 19, 23]``), + matching the MEM paper / pi05_mem (#171). gradient_checkpointing: If True, wrap each SigLIP encoder layer in ``torch.utils.checkpoint.checkpoint`` during training. """ @@ -167,8 +169,10 @@ class PI07LowLevelPlannerConfig(PreTrainedConfig): ) ) - # SpaceTime settings - spacetime_layer_stride: int = 1 + # SpaceTime settings. Stride 4 matches the MEM paper and pi05_mem + # (#171): every 4th of the 27 SigLIP layers gets wrapped, indices + # [3, 7, 11, 15, 19, 23]. + spacetime_layer_stride: int = 4 gradient_checkpointing: bool = False # Training presets diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index 57ab1e70..9aacd5e3 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -381,6 +381,13 @@ def _fix_pytorch_state_dict_keys( fixed_state_dict = {} for key, value in state_dict.items(): + # Accept legacy `paligemma_with_expert.*` prefixes from + # `pi07_paligemma` checkpoints as a warm-start path; the rest of + # the AdaRMS / time-mlp / patch_embedding handling below applies + # uniformly to both prefixes. + if key.startswith("paligemma_with_expert."): + key = key.replace("paligemma_with_expert.", "gemma3_with_expert.", 1) + new_key = key if re.match( @@ -1164,9 +1171,13 @@ def embed_prefix( att_masks += [0] * num_vid_embs + # Gemma 3's `embed_tokens` is a `Gemma3TextScaledWordEmbedding` that + # already multiplies by sqrt(hidden_size) internally — do NOT scale + # again here (unlike pi05, whose PaliGemma Gemma-v1 embedding is a + # plain nn.Embedding with the normalizer applied later in the stock + # forward that we bypass). Applies to every `embed_language_tokens` + # call in this method. lang_emb = self.gemma3_with_expert.embed_language_tokens(lang_tokens) - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) embs.append(lang_emb) pad_masks.append(lang_masks) @@ -1181,8 +1192,6 @@ def embed_prefix( dtype=torch.long, ) state_start_emb = self.gemma3_with_expert.embed_language_tokens(state_start_tokens) - state_start_dim = state_start_emb.shape[-1] - state_start_emb = state_start_emb * math.sqrt(state_start_dim) num_state_start_embs = state_start_emb.shape[1] state_start_mask = torch.ones( @@ -1213,8 +1222,6 @@ def embed_prefix( dtype=torch.long, ) state_end_emb = self.gemma3_with_expert.embed_language_tokens(state_end_tokens) - state_end_dim = state_end_emb.shape[-1] - state_end_emb = state_end_emb * math.sqrt(state_end_dim) num_state_end_embs = state_end_emb.shape[1] state_end_mask = torch.ones(bsize, num_state_end_embs, dtype=torch.bool, device=lang_tokens.device) @@ -1224,8 +1231,6 @@ def embed_prefix( att_masks += [0] * num_state_end_embs response_emb = self.gemma3_with_expert.embed_language_tokens(response_tokens) - response_emb_dim = response_emb.shape[-1] - response_emb = response_emb * math.sqrt(response_emb_dim) embs.append(response_emb) pad_masks.append(response_masks) num_response_embs = response_emb.shape[1] @@ -1240,8 +1245,6 @@ def embed_prefix( dtype=torch.long, ) subgoal_img_start_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_start_tokens) - subgoal_img_start_dim = subgoal_img_start_emb.shape[-1] - subgoal_img_start_emb = subgoal_img_start_emb * math.sqrt(subgoal_img_start_dim) num_subgoal_img_start_embs = subgoal_img_start_emb.shape[1] subgoal_img_start_mask = torch.ones( @@ -1279,8 +1282,6 @@ def embed_prefix( dtype=torch.long, ) subgoal_img_end_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_end_tokens) - subgoal_img_end_dim = subgoal_img_end_emb.shape[-1] - subgoal_img_end_emb = subgoal_img_end_emb * math.sqrt(subgoal_img_end_dim) num_subgoal_img_end_embs = subgoal_img_end_emb.shape[1] subgoal_img_end_mask = torch.ones( @@ -1292,8 +1293,6 @@ def embed_prefix( att_masks += [0] * num_subgoal_img_end_embs metadata_emb = self.gemma3_with_expert.embed_language_tokens(metadata_tokens) - metadata_emb_dim = metadata_emb.shape[-1] - metadata_emb = metadata_emb * math.sqrt(metadata_emb_dim) embs.append(metadata_emb) pad_masks.append(metadata_masks) att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) @@ -1305,8 +1304,6 @@ def embed_prefix( dtype=torch.long, ) prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) - prefix_end_dim = prefix_end_emb.shape[-1] - prefix_end_emb = prefix_end_emb * math.sqrt(prefix_end_dim) num_prefix_end_embs = prefix_end_emb.shape[1] prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) @@ -1327,8 +1324,6 @@ def embed_prefix( discrete_action_start_emb = self.gemma3_with_expert.embed_language_tokens( discrete_action_start_tokens ) - discrete_action_start_dim = discrete_action_start_emb.shape[-1] - discrete_action_start_emb = discrete_action_start_emb * math.sqrt(discrete_action_start_dim) num_discrete_action_start_embs = discrete_action_start_emb.shape[1] discrete_action_start_mask = torch.ones( diff --git a/src/opentau/policies/pi07/low_level_planner/video_encoder.py b/src/opentau/policies/pi07/low_level_planner/video_encoder.py index 16570a24..8a6abd50 100644 --- a/src/opentau/policies/pi07/low_level_planner/video_encoder.py +++ b/src/opentau/policies/pi07/low_level_planner/video_encoder.py @@ -44,7 +44,8 @@ """ import math -from typing import Optional +from contextlib import contextmanager +from typing import Iterator, Optional import torch import torch.nn.functional as F # noqa: N812 @@ -208,6 +209,16 @@ def __init__( # params don't show up twice in state_dict. self._temporal_attn = _TemporalSelfAttention(self.self_attn) + # Caller-driven flag: when False, ``forward`` short-circuits to the + # vanilla spatial-only block (`_spatial_block_forward`) regardless of + # the input shape. This is used by ``Gemma3WithExpertModel.embed_image`` + # via the ``suppress_spacetime_temporal`` context manager so the same + # wrapped vision_tower can be reused for non-video inputs (e.g. single + # subgoal images) without firing temporal attention over data that has + # no time axis. Flag lives on the wrapper rather than on a kwarg + # because ``SiglipEncoder.forward`` does not accept extra kwargs. + self._temporal_active: bool = True + # Build the PE on the base layer's current device / dtype. The parent # vision_tower is often moved to GPU BEFORE this wrapper is inserted # (the normal load flow for pi05_mem does @@ -272,20 +283,28 @@ def forward( f"hidden_states.shape[1] ({n}) != num_tokens_per_frame ({self.num_tokens_per_frame})." ) + # Short-circuit when the caller has suppressed temporal attention. + # ``Gemma3WithExpertModel.embed_image`` shares the same wrapped vision + # tower for non-video inputs (e.g. subgoal frames) and toggles this + # flag via the ``suppress_spacetime_temporal`` context manager — those + # calls are spatial-only; there is no time axis to attend over. + if not self._temporal_active: + return self._spatial_block_forward(hidden_states, attention_mask, output_attentions) + # Short-circuit at T=1: temporal self-attention over a single timestep # collapses to ``out_proj(v_proj(LN1(x)))``, which is NOT an identity # and would break single-frame invariance (the MEM paper's guarantee # that a T=1 pass matches the unmodified SigLIP ViT). e(t=0)=0 alone # is insufficient; the block must also be skipped. - # - # Also short-circuit when ``bt`` is not a multiple of ``t``: the same - # wrapped tower is shared with ``Gemma3WithExpertModel.embed_image``, - # which passes single images as ``(B, N, D)`` with ``B < t`` (e.g. - # subgoal frames in ``embed_prefix``). Those calls are spatial-only; - # there is no valid ``(b, t, ...)`` grouping for temporal attention. - if t == 1 or bt % t != 0: + if t == 1: return self._spatial_block_forward(hidden_states, attention_mask, output_attentions) + if bt % t != 0: + raise ValueError( + f"hidden_states.shape[0] ({bt}) must be divisible by num_frames ({t}); " + "video encoder expects inputs flattened as (B*T, N, D). " + "Use `suppress_spacetime_temporal(...)` for non-video forwards." + ) b = bt // t # Temporal sublayer. @@ -310,6 +329,31 @@ def forward( return self._spatial_block_forward(h_after_t, attention_mask, output_attentions) +@contextmanager +def suppress_spacetime_temporal(module: nn.Module) -> Iterator[None]: + """Context manager that flips ``_temporal_active=False`` on every + :class:`SpaceTimeEncoderLayerWrapper` in ``module``'s subtree, and + restores the previous value on exit. + + Used by ``Gemma3WithExpertModel.embed_image`` so that single-image + forwards through a vision_tower that has been wrapped with space-time + attention skip the temporal sublayer (which has no time axis to attend + over for non-video inputs). When ``module`` contains no wrappers (e.g. + no video encoder has been constructed yet), this is a no-op. + """ + wrappers: list[SpaceTimeEncoderLayerWrapper] = [ + m for m in module.modules() if isinstance(m, SpaceTimeEncoderLayerWrapper) + ] + previous = [w._temporal_active for w in wrappers] + for w in wrappers: + w._temporal_active = False + try: + yield + finally: + for w, prev in zip(wrappers, previous, strict=True): + w._temporal_active = prev + + class SpaceTimeSiglipVideoEncoder(nn.Module): """SigLIP-based video encoder with space-time separable attention. diff --git a/tests/policies/test_pi07_cpu.py b/tests/policies/test_pi07_cpu.py new file mode 100644 index 00000000..f6179843 --- /dev/null +++ b/tests/policies/test_pi07_cpu.py @@ -0,0 +1,364 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU-only regression tests for the ``pi07`` policy backbone. + +Covers the architectural invariants ported from PR #178 (`pi06`): + + * Gemma 3's ``Gemma3TextScaledWordEmbedding`` already scales by + ``sqrt(hidden_size)`` — the planner code must NOT apply a second manual + ``* math.sqrt(hidden_size)`` to text embeddings. Otherwise text tokens + are scaled to ~51× the image-token magnitude. + * The vision-config ``image_size`` MUST match the planner's input + resolution; ``Gemma3MultiModalProjector`` hardcodes + ``patches_per_image = image_size // patch_size``. + * Per-layer RoPE ``θ`` is applied uniformly to backbone *and* expert at the + same layer — the shared-attention dot product would otherwise mix RoPE + bases. + * Gemma 3's 1024-token sliding-window pattern is **not** enforced; every + layer receives the unmodified block-causal prefix-LM mask. +""" + +from __future__ import annotations + +import pytest +import torch + +from opentau.policies.pi07 import gemma3_with_expert as g3we +from opentau.policies.pi07.gemma3_with_expert import ( + Gemma3WithExpertConfig, + Gemma3WithExpertModel, +) + +# Shared tiny config helper + + +def _make_tiny_g3we_cfg() -> Gemma3WithExpertConfig: + """Construct a minimally-sized ``Gemma3WithExpertConfig`` for fast tests. + + The text config is sized for two layers — one ``sliding_attention`` and + one ``full_attention`` — so the per-layer RoPE-θ and no-sliding-window + invariants can be exercised in a single forward pass. Vision config + matches the production 448 / 14 / 32-patch grid so the projector reshape + runs without crashing. + """ + return Gemma3WithExpertConfig( + gemma3_config={ + "text_config": { + "model_type": "gemma3_text", + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 16, + "sliding_window": 2, + "rope_theta": 1_000_000.0, + "rope_local_base_freq": 10_000.0, + "query_pre_attn_scalar": 16, + "rms_norm_eps": 1e-6, + "vocab_size": 128, + "max_position_embeddings": 512, + "attention_bias": False, + "attention_dropout": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + "sliding_window_pattern": 2, + "torch_dtype": "float32", + "layer_types": ["sliding_attention", "full_attention"], + }, + "vision_config": { + "model_type": "siglip_vision_model", + "hidden_size": 16, + "intermediate_size": 32, + "num_attention_heads": 2, + "num_hidden_layers": 2, + "patch_size": 14, + "image_size": 448, + "projection_dim": 32, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + "torch_dtype": "float32", + "layer_norm_eps": 1e-6, + }, + "image_token_index": 127, + "mm_tokens_per_image": 4, + "boi_token_index": 125, + "eoi_token_index": 126, + }, + gemma_expert_config={ + "attention_bias": False, + "attention_dropout": 0.0, + "head_dim": 16, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 16, + "intermediate_size": 32, + "max_position_embeddings": 512, + "num_attention_heads": 2, + "num_hidden_layers": 2, + "num_key_value_heads": 1, + "rms_norm_eps": 1e-6, + # Intentionally different from the backbone θ so the test can + # confirm this value is IGNORED during shared attention. + "rope_theta": 10_000.0, + "use_adarms": True, + "adarms_cond_dim": 16, + "vocab_size": 128, + }, + discrete_action_vocab_size=32, + freeze_vision_encoder=False, + train_expert_only=False, + ) + + +# Vision config invariants + + +class TestVisionConfig: + def test_vision_image_size_matches_input_resolution(self): + """Regression: ``Gemma3MultiModalProjector`` hardcodes + ``patches_per_image = image_size // patch_size``, so the default + config's ``vision_config.image_size`` MUST equal what the planners + actually feed (448). Otherwise the projector reshape crashes on + the first forward.""" + cfg = Gemma3WithExpertConfig(discrete_action_vocab_size=2048) + vc = cfg.gemma3_config.vision_config + assert vc.image_size == 448 + assert vc.patch_size == 14 + # 448 / 14 = 32 patches/side → 1024 vision tokens → avg-pool to 256 + # multimodal tokens. + assert (vc.image_size // vc.patch_size) ** 2 == 1024 + + def test_projector_accepts_448_inputs(self): + from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector + + cfg = Gemma3WithExpertConfig(discrete_action_vocab_size=2048) + proj = Gemma3MultiModalProjector(cfg.gemma3_config) + vision_hidden = cfg.gemma3_config.vision_config.hidden_size + # SigLIP at 448 produces 32×32 = 1024 patch tokens per image. + vision_out = torch.randn(1, 1024, vision_hidden) + out = proj(vision_out) + assert out.shape == (1, 256, cfg.gemma3_config.text_config.hidden_size) + + +# Per-layer RoPE-θ symmetry across backbone / expert + + +class TestRopeThetaSymmetryDuringForward: + def test_expert_uses_backbone_per_layer_theta(self, monkeypatch): + """Both streams' Q/K must rotate with the backbone's per-layer θ so + the shared-attention dot product stays in a consistent RoPE basis. + + Sliding (local) layers use ``rope_local_base_freq=10_000``; global + layers use ``rope_theta=1_000_000``. Each layer calls ``apply_rope`` + once for Q and once for K — a 2-layer config therefore generates 4 + captures, two per θ. + """ + captured: list[float] = [] + real_apply_rope = g3we.apply_rope + + def spy_apply_rope(x, positions, max_wavelength=10_000.0): + captured.append(float(max_wavelength)) + return real_apply_rope(x, positions, max_wavelength=max_wavelength) + + monkeypatch.setattr(g3we, "apply_rope", spy_apply_rope) + # Skip the unconditional bf16 cast in __init__ so a plain float32 + # forward through tiny linears doesn't complain about mixed dtypes. + monkeypatch.setattr(g3we, "_preferred_dtype", lambda: torch.float32) + + cfg = _make_tiny_g3we_cfg() + model = Gemma3WithExpertModel(cfg).to(dtype=torch.float32) + + batch, seq_len = 1, 3 + hidden_backbone = torch.randn(batch, seq_len, cfg.gemma3_config.text_config.hidden_size) + position_ids = torch.arange(seq_len)[None, :] + attention_mask = torch.ones(batch, seq_len, seq_len, dtype=torch.bool) + + model( + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=[hidden_backbone, None], + n_cross_att_tokens=seq_len, + use_cache=False, + fill_kv_cache=True, + ) + + # 2 layers × 2 tensors (Q, K) per layer = 4 captures. Sliding layer + # (idx 0) uses 10_000; global layer (idx 1) uses 1_000_000. + assert 10_000.0 in captured + assert 1_000_000.0 in captured + assert captured.count(10_000.0) == 2, captured + assert captured.count(1_000_000.0) == 2, captured + + +# Sliding-window mask is intentionally NOT enforced + + +class TestNoSlidingWindowEnforcement: + def test_per_layer_mask_equals_input_mask_on_both_layer_types(self, monkeypatch): + """π0.7 keeps the prefix block-causal mask global at every layer. + + Gemma 3 pretraining uses a 1024-token sliding window on local layers, + but π0.7 needs bidirectional attention across **all** image tokens. + If sliding-window enforcement crept back in, the captured mask on the + sliding (local) layer would differ from the one on the global layer. + """ + monkeypatch.setattr(g3we, "_preferred_dtype", lambda: torch.float32) + + cfg = _make_tiny_g3we_cfg() + assert cfg.gemma3_config.text_config.layer_types == [ + "sliding_attention", + "full_attention", + ] + model = Gemma3WithExpertModel(cfg).to(dtype=torch.float32) + + captured_masks: list[torch.Tensor] = [] + real_eager = model.eager_attention_forward + + def spy_eager(attention_mask, *args, **kwargs): + captured_masks.append(attention_mask.clone()) + return real_eager(attention_mask, *args, **kwargs) + + monkeypatch.setattr(model, "eager_attention_forward", spy_eager) + + batch, seq_len = 1, 8 + hidden_backbone = torch.randn(batch, seq_len, cfg.gemma3_config.text_config.hidden_size) + position_ids = torch.arange(seq_len)[None, :] + attention_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))[None] + + model( + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=[hidden_backbone, None], + n_cross_att_tokens=seq_len, + use_cache=False, + fill_kv_cache=True, + ) + + assert len(captured_masks) == 2, "expected one capture per layer" + # Sliding-window cap (= 2) WOULD have zeroed everything more than 2 + # positions away on the sliding layer; the global layer would have + # been left alone. Confirm both layers received the identical input + # mask. + assert torch.equal(captured_masks[0], attention_mask), ( + "sliding (local) layer mask differs from input — sliding window " + "enforcement appears to have been re-enabled" + ) + assert torch.equal(captured_masks[1], attention_mask), ( + "global layer mask differs from input — something is mutating it" + ) + + +# Embedding-magnitude invariant — guards against the double-scale bug + + +class TestEmbeddingMagnitudeInvariant: + """Gemma 3's ``embed_tokens`` is a ``Gemma3TextScaledWordEmbedding`` that + already multiplies by ``sqrt(hidden_size)`` internally. Apply a manual + ``* math.sqrt(hidden_size)`` on top of ``embed_language_tokens`` and + text tokens end up at ~51× image-token magnitude, corrupting the + bidirectional prefix attention and the FAST/response cross-entropy heads. + + These tests pin the invariant at the embedding level (not inside the + planner) so a regression is flagged regardless of which planner gets the + scaling wrong. + """ + + def test_embed_language_tokens_already_scaled(self, monkeypatch): + """``embed_language_tokens`` returns the scaled embedding directly.""" + monkeypatch.setattr(g3we, "_preferred_dtype", lambda: torch.float32) + + cfg = _make_tiny_g3we_cfg() + model = Gemma3WithExpertModel(cfg).to(dtype=torch.float32) + + tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + embedded = model.embed_language_tokens(tokens) + # The embedding table's underlying weights have stdev ≈ initializer + # range; after the internal sqrt(hidden_size) scaling the embedded + # tokens have stdev ≈ initializer_range × sqrt(hidden). Assert the + # scaling was applied (embedded > raw weights * 1.0) by sampling the + # raw weight matrix and comparing magnitudes. + lm = getattr(model.gemma3, "language_model", model.gemma3.model.language_model) + raw_weights = lm.embed_tokens.weight + embedded_std = embedded.detach().std().item() + raw_std = raw_weights.detach().std().item() + # Scaled embedding should be markedly larger than the raw row. + assert embedded_std > raw_std, ( + f"Gemma3TextScaledWordEmbedding does not appear to have applied its " + f"sqrt(hidden) scaling: embedded_std={embedded_std} <= raw_std={raw_std}" + ) + + def test_embedded_token_stdev_matches_internal_scaling(self, monkeypatch): + """``embed_language_tokens`` output stdev must be ≈ ``raw_weight_stdev + × sqrt(hidden)`` — and not an additional factor of ``sqrt(hidden)`` + on top. An untrained Gemma3WithExpertModel does not give us a well- + calibrated image-token magnitude to compare against (random SigLIP + + random projector → tiny stdev), so we verify the Gemma 3 internal + scaling alone fires, and rely on the source-level lint + (``TestNoManualSqrtScalingInPlannerSource``) plus the GPU integration + tests to catch a manual second scale further up the call stack. + """ + monkeypatch.setattr(g3we, "_preferred_dtype", lambda: torch.float32) + + cfg = _make_tiny_g3we_cfg() + model = Gemma3WithExpertModel(cfg).to(dtype=torch.float32) + + tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + embedded = model.embed_language_tokens(tokens) + + lm = getattr(model.gemma3, "language_model", model.gemma3.model.language_model) + raw_weight_std = lm.embed_tokens.weight.detach().std().item() + embedded_std = embedded.detach().std().item() + hidden = cfg.gemma3_config.text_config.hidden_size + expected_factor = hidden**0.5 + + # Allow ±0.5× slack because the sample stdev over 5 token rows is noisy. + actual_factor = embedded_std / max(raw_weight_std, 1e-8) + assert 0.5 * expected_factor < actual_factor < 1.5 * expected_factor, ( + f"embedded_std/raw_std = {actual_factor:.2f} but expected ≈ " + f"sqrt(hidden) = {expected_factor:.2f}. A second manual " + f"`* sqrt(hidden)` would push this ratio to ~{expected_factor**2:.0f}." + ) + + +# Spec sanity for the embed_prefix call sites — guards against accidental +# reintroduction of the manual scaling. + + +class TestNoManualSqrtScalingInPlannerSource: + """Source-level invariant: planner files must not call + ``math.sqrt(*_dim)`` on a tensor whose immediate origin is + ``embed_language_tokens``. This is a fast file-level lint so a + regression is caught even if no test exercises that prefix slot. + """ + + @pytest.mark.parametrize( + "module_path", + [ + "src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py", + "src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py", + ], + ) + def test_no_sqrt_emb_dim_left(self, module_path): + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[2] + text = (repo_root / module_path).read_text() + # ``math.sqrt(*_dim)`` is only ever applied to a language-embedding + # tensor in this codebase; if it appears in a pi07 planner file, the + # double-scale fix from #178 has been undone. + assert "math.sqrt(" not in text, ( + f"{module_path} contains a residual math.sqrt(...) call; " + "Gemma 3's embed_tokens already scales by sqrt(hidden_size). " + "Reintroducing the manual scaling double-scales text embeddings." + ) diff --git a/tests/policies/test_pi07_video_encoder_cpu.py b/tests/policies/test_pi07_video_encoder_cpu.py new file mode 100644 index 00000000..d326835f --- /dev/null +++ b/tests/policies/test_pi07_video_encoder_cpu.py @@ -0,0 +1,355 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU-only tests for ``SpaceTimeSiglipVideoEncoder`` (pi07 low-level). + +Mirrors PR #171's pi05_mem video-encoder tests, ported to pi07's namespace +and using the Gemma 3 SigLIP config. Builds randomly-initialised SigLIP + +``Gemma3MultiModalProjector`` pairs from scratch so no HF download is needed. + +The tests pin three guarantees the planner relies on: + + * Wrapping every ``stride``-th SigLIP layer with space-time attention + leaves the wrapped vision_tower's state_dict keys identical to the + vanilla SigLIP — a vanilla pi07 / pi07_paligemma checkpoint loads in + untouched. + * At ``T=1`` the wrapper short-circuits to the vanilla spatial block, so + a stride-999 (no-wrapping) encoder and a stride-4 (6 layers wrapped) + encoder built on the same weights produce byte-identical outputs. + * The new ``suppress_spacetime_temporal`` context manager (replaces the + silent ``bt % t != 0`` numerology in earlier drafts) suppresses temporal + attention for non-video forwards while leaving the spatial block + untouched. +""" + +from __future__ import annotations + +import copy + +import pytest +import torch + +from opentau.policies.pi07.low_level_planner.video_encoder import ( + SpaceTimeEncoderLayerWrapper, + SpaceTimeSiglipVideoEncoder, + _build_temporal_sinusoidal_pe, + suppress_spacetime_temporal, +) + +# Helpers + + +def _build_siglip_and_projector(): + """Construct a fresh, randomly-initialised SigLIP + Gemma 3 projector.""" + from transformers import SiglipVisionConfig, SiglipVisionModel + from transformers.models.gemma3.configuration_gemma3 import Gemma3Config + from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector + + vision_cfg_dict = { + "hidden_size": 1152, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "image_size": 224, + "projection_dim": 2560, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + } + vision_tower = SiglipVisionModel(SiglipVisionConfig(**vision_cfg_dict)) + g3_cfg = Gemma3Config( + vision_config=vision_cfg_dict, + text_config={ + "model_type": "gemma3_text", + "hidden_size": 2560, + "vocab_size": 262_208, + "num_hidden_layers": 1, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "head_dim": 256, + }, + mm_tokens_per_image=256, + boi_token_index=255_999, + eoi_token_index=256_000, + image_token_index=262_144, + ) + projector = Gemma3MultiModalProjector(g3_cfg) + return vision_tower, projector + + +def _make_encoder(num_frames: int = 2, spacetime_stride: int = 4): + vision_tower, projector = _build_siglip_and_projector() + return SpaceTimeSiglipVideoEncoder( + vision_tower=vision_tower, + multi_modal_projector=projector, + num_frames=num_frames, + spacetime_layer_stride=spacetime_stride, + ) + + +# Temporal sinusoidal PE + + +class TestTemporalSinusoidalPE: + def test_current_frame_row_is_zero(self): + pe = _build_temporal_sinusoidal_pe(num_frames=8, embed_dim=64) + assert pe.shape == (8, 64) + # Current frame lives at t=T-1; row must be all zeros so a T=1 pass + # collapses to the unmodified SigLIP forward. + assert torch.all(pe[-1] == 0) + + def test_earlier_rows_are_nonzero(self): + pe = _build_temporal_sinusoidal_pe(num_frames=8, embed_dim=64) + assert torch.any(pe[0] != 0) + + def test_single_frame_produces_zero_row(self): + pe = _build_temporal_sinusoidal_pe(num_frames=1, embed_dim=64) + assert pe.shape == (1, 64) + assert torch.all(pe == 0) + + def test_odd_embed_dim_raises(self): + with pytest.raises(ValueError, match="divisible by 2"): + _build_temporal_sinusoidal_pe(num_frames=4, embed_dim=63) + + def test_zero_num_frames_raises(self): + with pytest.raises(ValueError, match="num_frames"): + _build_temporal_sinusoidal_pe(num_frames=0, embed_dim=64) + + +# Wrapper structure / state_dict invariance + + +class TestSpaceTimeWrapperStructure: + def test_wrapped_layer_count(self): + """Verify every ``stride``-th layer (1-indexed from the Nth) is wrapped.""" + encoder = _make_encoder(num_frames=2, spacetime_stride=4) + layers = encoder.vision_tower.vision_model.encoder.layers + wrapped_indices = [ + i for i, layer in enumerate(layers) if isinstance(layer, SpaceTimeEncoderLayerWrapper) + ] + # 27 SigLIP layers; every 4th starting at index 3 → [3, 7, 11, 15, 19, 23] + assert wrapped_indices == [3, 7, 11, 15, 19, 23] + + def test_state_dict_keys_match_vanilla_siglip(self): + """The wrapped vision_tower's state_dict must have identical keys to + an unwrapped SigLIP — guarantees that a pi07_paligemma checkpoint + (or any vanilla SigLIP) loads without key remapping.""" + from transformers import SiglipVisionModel + + vision_tower, projector = _build_siglip_and_projector() + reference_keys = set(SiglipVisionModel(vision_tower.config).state_dict().keys()) + + SpaceTimeSiglipVideoEncoder( + vision_tower=vision_tower, + multi_modal_projector=projector, + num_frames=4, + spacetime_layer_stride=4, + ) + wrapped_keys = set(vision_tower.state_dict().keys()) + assert wrapped_keys == reference_keys, ( + f"state_dict keys diverged; " + f"extra in wrapped: {wrapped_keys - reference_keys}, " + f"missing from wrapped: {reference_keys - wrapped_keys}" + ) + + def test_temporal_pe_on_base_layer_device(self): + """The temporal PE buffer must be constructed on the base layer's + device/dtype, not default CPU. + + Regression for the pi05_mem GPU bug where the parent vision_tower had + already been moved to a non-default dtype/device BEFORE wrapping — + leaving the wrapper's PE on CPU/float32. + """ + vision_tower, projector = _build_siglip_and_projector() + vision_tower = vision_tower.to(dtype=torch.bfloat16) + projector = copy.deepcopy(projector).to(dtype=torch.bfloat16) + + SpaceTimeSiglipVideoEncoder( + vision_tower=vision_tower, + multi_modal_projector=projector, + num_frames=4, + spacetime_layer_stride=4, + ) + + ref_param = next(vision_tower.parameters()) + for layer in vision_tower.vision_model.encoder.layers: + if isinstance(layer, SpaceTimeEncoderLayerWrapper): + assert layer._temporal_pe.device == ref_param.device + assert layer._temporal_pe.dtype == ref_param.dtype + + +# Forward shape / dtype / arg validation + + +class TestSpaceTimeForward: + def test_forward_shape(self): + encoder = _make_encoder(num_frames=2) + encoder.eval() + with torch.no_grad(): + video = torch.rand(1, 2, 3, 224, 224) + out = encoder(video) + # 16 patches/side at 224/14 → 256 tokens; Gemma 3 projector outputs + # 256 mm tokens; per-token width is the Gemma 3 text hidden size. + assert out.shape == (1, 256, 2560) + assert out.dtype == torch.float32 + + def test_wrong_num_frames_raises(self): + encoder = _make_encoder(num_frames=4) + encoder.eval() + with torch.no_grad(), pytest.raises(ValueError, match="frames"): + encoder(torch.rand(1, 2, 3, 224, 224)) + + def test_wrong_ndim_raises(self): + encoder = _make_encoder(num_frames=2) + encoder.eval() + with torch.no_grad(), pytest.raises(ValueError, match="5D"): + encoder(torch.rand(2, 3, 224, 224)) + + def test_forward_after_external_dtype_cast(self): + """End-to-end forward must succeed when vision_tower was moved to a + different dtype BEFORE wrapping (mirrors the GPU load flow of + ``gemma3.to(cuda/bf16)`` then construct encoder). + """ + vision_tower, projector = _build_siglip_and_projector() + vision_tower = vision_tower.to(dtype=torch.bfloat16) + projector = copy.deepcopy(projector).to(dtype=torch.bfloat16) + + encoder = SpaceTimeSiglipVideoEncoder( + vision_tower=vision_tower, + multi_modal_projector=projector, + num_frames=4, + spacetime_layer_stride=4, + ).eval() + with torch.no_grad(): + out = encoder(torch.rand(1, 4, 3, 224, 224, dtype=torch.bfloat16)) + assert out.shape == (1, 256, 2560) + assert out.dtype == torch.bfloat16 + + +# Single-frame invariance + + +class TestSingleFrameInvariance: + def test_single_frame_invariance_structural(self): + """At T=1 the wrapper must short-circuit: byte-identical output to a + stride-999 (no-wrapping) encoder sharing the exact same underlying + weights. + """ + torch.manual_seed(0) + vision_tower, projector = _build_siglip_and_projector() + vt_no_st = copy.deepcopy(vision_tower) + proj_no_st = copy.deepcopy(projector) + + enc_no_st = SpaceTimeSiglipVideoEncoder( + vision_tower=vt_no_st, + multi_modal_projector=proj_no_st, + num_frames=1, + spacetime_layer_stride=999, # > 27 → no layers wrapped + ).eval() + enc_st = SpaceTimeSiglipVideoEncoder( + vision_tower=vision_tower, + multi_modal_projector=projector, + num_frames=1, + spacetime_layer_stride=4, # 6 layers wrapped + ).eval() + + # Adopt-submodules wrapping doesn't introduce ``.base_layer.`` keys, + # so the two state_dicts have identical keys; weights are still + # identical because we deep-copied. + assert set(vt_no_st.state_dict().keys()) == set(vision_tower.state_dict().keys()) + + video = torch.rand(1, 1, 3, 224, 224) + with torch.no_grad(): + out_no_st = enc_no_st(video) + out_st = enc_st(video) + + torch.testing.assert_close(out_st, out_no_st, rtol=1e-5, atol=1e-5) + + +# `suppress_spacetime_temporal` context manager (replaces the silent +# `bt % t != 0` numerology bypass). + + +class TestSuppressSpacetimeTemporal: + def test_flag_toggles_and_restores(self): + """Entering the context manager flips every wrapper to inactive; + exiting restores the prior value (idempotent — works nested).""" + encoder = _make_encoder(num_frames=4, spacetime_stride=4) + wrappers = [ + layer + for layer in encoder.vision_tower.vision_model.encoder.layers + if isinstance(layer, SpaceTimeEncoderLayerWrapper) + ] + assert wrappers, "expected the encoder to wrap at least one layer" + assert all(w._temporal_active for w in wrappers) + + with suppress_spacetime_temporal(encoder.vision_tower): + assert all(not w._temporal_active for w in wrappers) + with suppress_spacetime_temporal(encoder.vision_tower): + assert all(not w._temporal_active for w in wrappers) + # Inner context restored to outer (False), not to True. + assert all(not w._temporal_active for w in wrappers) + assert all(w._temporal_active for w in wrappers) + + def test_no_op_when_no_wrappers_present(self): + """When the module subtree contains no wrappers (e.g. the high-level + planner only ever calls ``embed_image`` on a vanilla SigLIP), the + context manager must be a silent no-op.""" + from transformers import SiglipVisionConfig, SiglipVisionModel + + plain = SiglipVisionModel( + SiglipVisionConfig( + hidden_size=64, + intermediate_size=128, + num_attention_heads=4, + num_hidden_layers=2, + patch_size=14, + image_size=224, + ) + ) + # Should not raise. + with suppress_spacetime_temporal(plain): + pass + + def test_non_divisible_batch_raises_when_active(self): + """When ``temporal_active=True`` (the default), the wrapper now + explicitly rejects ``bt % t != 0`` — replacing the previous silent + short-circuit which could have wrongly fired temporal attention over + spatial-only inputs that happened to be divisible by ``num_frames``. + """ + encoder = _make_encoder(num_frames=4, spacetime_stride=4) + wrapper = next( + layer + for layer in encoder.vision_tower.vision_model.encoder.layers + if isinstance(layer, SpaceTimeEncoderLayerWrapper) + ) + # 256 = (224/14)**2 patches per frame (the expected per-frame token count). + bad_input = torch.rand(3, 256, wrapper.embed_dim) + with pytest.raises(ValueError, match="divisible by num_frames"): + wrapper.forward(bad_input) + + def test_non_divisible_batch_succeeds_when_suppressed(self): + """Inside ``suppress_spacetime_temporal`` the wrapper accepts any + batch size and dispatches to the vanilla spatial block.""" + encoder = _make_encoder(num_frames=4, spacetime_stride=4) + wrapper = next( + layer + for layer in encoder.vision_tower.vision_model.encoder.layers + if isinstance(layer, SpaceTimeEncoderLayerWrapper) + ) + bad_input = torch.rand(3, 256, wrapper.embed_dim) + with suppress_spacetime_temporal(encoder.vision_tower): + out = wrapper.forward(bad_input)[0] + assert out.shape == bad_input.shape diff --git a/tests/test_available.py b/tests/test_available.py index b923fb45..73f8a6c3 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -20,6 +20,12 @@ from opentau.policies.pi0.modeling_pi0 import PI0Policy from opentau.policies.pi05.modeling_pi05 import PI05Policy from opentau.policies.pi05_mem.modeling_pi05 import PI05MemPolicy +from opentau.policies.pi07.high_level_planner.modeling_pi07_high_level import ( + PI07HighLevelPlannerPolicy, +) +from opentau.policies.pi07.low_level_planner.modeling_pi07_low_level import ( + PI07LowLevelPlannerPolicy, +) from opentau.policies.value.modeling_value import ValueFunction @@ -33,6 +39,8 @@ def test_available_policies(): ValueFunction, PI05Policy, PI05MemPolicy, + PI07HighLevelPlannerPolicy, + PI07LowLevelPlannerPolicy, ] policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(opentau.available_policies), policies From 3c5ee905abdc1edb222d371e3252942995d4aeaf Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Wed, 29 Apr 2026 14:36:26 -0700 Subject: [PATCH 3/8] Adding control_type and robot_type to metadata in policy code --- .../modeling_pi07_high_level.py | 32 +++++++++++++++---- .../modeling_pi07_low_level.py | 32 +++++++++++++++---- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py index 61b22508..cf676eec 100644 --- a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py +++ b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py @@ -675,13 +675,25 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: """ metadata = [] - for speed, quality, mistake, speed_is_pad, quality_is_pad, mistake_is_pad in zip( - batch["speed"], - batch["quality"], - batch["mistake"], - batch["speed_is_pad"], - batch["quality_is_pad"], - batch["mistake_is_pad"], + batch_size = batch["state"].shape[0] + for ( + speed, + quality, + mistake, + speed_is_pad, + quality_is_pad, + mistake_is_pad, + robot_type, + control_mode, + ) in zip( + batch.get("speed", torch.zeros(batch_size, dtype=torch.float32)), + batch.get("quality", torch.zeros(batch_size, dtype=torch.float32)), + batch.get("mistake", torch.zeros(batch_size, dtype=torch.float32)), + batch.get("speed_is_pad", torch.ones(batch_size, dtype=torch.bool)), + batch.get("quality_is_pad", torch.ones(batch_size, dtype=torch.bool)), + batch.get("mistake_is_pad", torch.ones(batch_size, dtype=torch.bool)), + batch.get("robot_type", [""] * batch_size), + batch.get("control_mode", [""] * batch_size), strict=True, ): segments = [] @@ -694,6 +706,12 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: if not mistake_is_pad: segments.append(f"Mistake: {str(mistake.item())}, ") + if robot_type: + segments.append(f"Robot: {robot_type}, ") + + if control_mode: + segments.append(f"Control: {control_mode}, ") + metadata.append(f"Metadata: {' '.join(segments)}") device = batch["state"].device diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index 9aacd5e3..c2513151 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -966,14 +966,26 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: """ metadata = [] + batch_size = batch["state"].shape[0] # safety conditioning if metadata are not passed by sample actions - for speed, quality, mistake, speed_is_pad, quality_is_pad, mistake_is_pad in zip( - batch.get("speed", torch.zeros(batch["state"].shape[0], dtype=torch.float32)), - batch.get("quality", torch.zeros(batch["state"].shape[0], dtype=torch.float32)), - batch.get("mistake", torch.zeros(batch["state"].shape[0], dtype=torch.float32)), - batch.get("speed_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool)), - batch.get("quality_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool)), - batch.get("mistake_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool)), + for ( + speed, + quality, + mistake, + speed_is_pad, + quality_is_pad, + mistake_is_pad, + robot_type, + control_mode, + ) in zip( + batch.get("speed", torch.zeros(batch_size, dtype=torch.float32)), + batch.get("quality", torch.zeros(batch_size, dtype=torch.float32)), + batch.get("mistake", torch.zeros(batch_size, dtype=torch.float32)), + batch.get("speed_is_pad", torch.zeros(batch_size, dtype=torch.bool)), + batch.get("quality_is_pad", torch.zeros(batch_size, dtype=torch.bool)), + batch.get("mistake_is_pad", torch.zeros(batch_size, dtype=torch.bool)), + batch.get("robot_type", [""] * batch_size), + batch.get("control_mode", [""] * batch_size), strict=True, ): segments = [] @@ -986,6 +998,12 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: if not mistake_is_pad: segments.append(f"Mistake: {str(mistake.item())}, ") + if robot_type: + segments.append(f"Robot: {robot_type}, ") + + if control_mode: + segments.append(f"Control: {control_mode}, ") + metadata.append(f"Metadata: {' '.join(segments)}" if segments else "") device = batch["state"].device From 4485f55bcd0690815b70fa0b24a07a2e1952e504 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:03:59 +0000 Subject: [PATCH 4/8] [claude-fix] address review feedback on #197 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - addresses @claude[bot] (low-level is_pad defaults, low_level_planner modeling_pi07_low_level.py:984-986): switched the *_is_pad fallbacks to torch.ones to match the high-level planner — missing speed/quality/mistake no longer fabricate "Speed: 0.0" entries in the prompt. - addresses @claude[bot] (prepare_metadata docstrings): updated both planners' docstrings to reference Gemma 3 (not PaliGemma) and to enumerate robot_type / control_mode (string-valued, empty-string-as-pad). - addresses @claude[bot] (high-level all-empty guard, high_level_planner/modeling_pi07_high_level.py:715): mirrored the low-level "if segments else ''" guard so an all-padded sample emits "" instead of the literal "Metadata: ". Both planners now agree. - addresses @claude[bot] (CPU coverage for prepare_metadata): added TestPrepareMetadataSegments to tests/policies/test_pi07_cpu.py covering (a) robot+control populated, (b) both absent, (c) one populated and one empty, (d) per-sample all-empty emits "" (regression for both planners), and (e) low-level missing speed/quality/mistake never produces a fabricated "Speed: 0.0" segment. tests: passed — pytest -m "not gpu" tests/policies/test_pi07_cpu.py Co-Authored-By: Claude Opus 4.7 (1M context) --- .../modeling_pi07_high_level.py | 23 ++- .../modeling_pi07_low_level.py | 23 ++- tests/policies/test_pi07_cpu.py | 178 ++++++++++++++++++ 3 files changed, 212 insertions(+), 12 deletions(-) diff --git a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py index cf676eec..2b8f372f 100644 --- a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py +++ b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py @@ -668,10 +668,25 @@ def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: return lang_tokens, lang_masks def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: - """Tokenizes the metadata for training. + """Tokenize episode metadata into Gemma 3 token IDs. - Wraps each metadata string with an ```` suffix, then tokenizes and - pads to ``metadata_max_length``. + Wraps non-empty per-sample metadata segments into a single + ``"Metadata: {seg1} {seg2} ..."`` string, then pads/truncates to + ``metadata_max_length``. Samples with no active segments emit an + empty string. + + Args: + batch: Batch dict that may contain any of: + ``"speed"``, ``"quality"``, ``"mistake"`` (float tensors with + a corresponding ``_is_pad`` bool tensor — entries marked as + pad are dropped), and ``"robot_type"``, ``"control_mode"`` + (lists of strings — empty string is the pad signal, no + separate ``_is_pad`` flag). Missing keys are treated as + fully padded. + + Returns: + A tuple ``(metadata_tokens, metadata_masks)`` with shapes + ``(B, metadata_max_length)``. """ metadata = [] @@ -712,7 +727,7 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: if control_mode: segments.append(f"Control: {control_mode}, ") - metadata.append(f"Metadata: {' '.join(segments)}") + metadata.append(f"Metadata: {' '.join(segments)}" if segments else "") device = batch["state"].device tokenized_metadata = self.language_tokenizer.__call__( diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index c2513151..127c9f61 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -951,14 +951,21 @@ def prepare_subgoal_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor] return subgoal_images, subgoal_img_masks def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: - """Tokenize episode metadata into PaliGemma token IDs. + """Tokenize episode metadata into Gemma 3 token IDs. - Wraps each metadata string as ``"Metadata: {meta}"`` and - pads/truncates to ``metadata_max_length``. + Wraps non-empty per-sample metadata segments into a single + ``"Metadata: {seg1} {seg2} ..."`` string, then pads/truncates to + ``metadata_max_length``. Samples with no active segments emit an + empty string. Args: - batch: Batch dict containing ``"speed"``, ``"quality"``, - ``"mistake"`` and their corresponding ``_is_pad`` flags. + batch: Batch dict that may contain any of: + ``"speed"``, ``"quality"``, ``"mistake"`` (float tensors with + a corresponding ``_is_pad`` bool tensor — entries marked as + pad are dropped), and ``"robot_type"``, ``"control_mode"`` + (lists of strings — empty string is the pad signal, no + separate ``_is_pad`` flag). Missing keys are treated as + fully padded. Returns: A tuple ``(metadata_tokens, metadata_masks)`` with shapes @@ -981,9 +988,9 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]: batch.get("speed", torch.zeros(batch_size, dtype=torch.float32)), batch.get("quality", torch.zeros(batch_size, dtype=torch.float32)), batch.get("mistake", torch.zeros(batch_size, dtype=torch.float32)), - batch.get("speed_is_pad", torch.zeros(batch_size, dtype=torch.bool)), - batch.get("quality_is_pad", torch.zeros(batch_size, dtype=torch.bool)), - batch.get("mistake_is_pad", torch.zeros(batch_size, dtype=torch.bool)), + batch.get("speed_is_pad", torch.ones(batch_size, dtype=torch.bool)), + batch.get("quality_is_pad", torch.ones(batch_size, dtype=torch.bool)), + batch.get("mistake_is_pad", torch.ones(batch_size, dtype=torch.bool)), batch.get("robot_type", [""] * batch_size), batch.get("control_mode", [""] * batch_size), strict=True, diff --git a/tests/policies/test_pi07_cpu.py b/tests/policies/test_pi07_cpu.py index f6179843..525c651f 100644 --- a/tests/policies/test_pi07_cpu.py +++ b/tests/policies/test_pi07_cpu.py @@ -362,3 +362,181 @@ def test_no_sqrt_emb_dim_left(self, module_path): "Gemma 3's embed_tokens already scales by sqrt(hidden_size). " "Reintroducing the manual scaling double-scales text embeddings." ) + + +# `prepare_metadata` segment-string construction +# +# The two planners share an identical contract for assembling the metadata +# prompt: each per-sample segment is appended only when the field is *not* +# padded, and a sample with zero active segments emits an empty string (not +# the literal "Metadata: "). These tests pin that contract on the CPU side +# without instantiating the full Gemma 3 backbone — we bind the unbound +# method to a SimpleNamespace and stub the tokenizer to capture the exact +# strings the loop produced. + + +def _make_tokenizer_capture(): + """Return a (tokenizer_stub, captured_list) pair. + + Use the captured list to assert on the metadata strings the planner + constructed, independent of the real Gemma 3 tokenizer. + """ + import types + + captured: list[str] = [] + + def stub_call(metadata, **kwargs): + captured.extend(metadata) + batch_size = len(metadata) + max_length = kwargs.get("max_length", 4) + return { + "input_ids": torch.zeros((batch_size, max_length), dtype=torch.long), + "attention_mask": torch.zeros((batch_size, max_length), dtype=torch.long), + } + + tokenizer = types.SimpleNamespace() + tokenizer.__call__ = stub_call + return tokenizer, captured + + +def _make_fake_planner(metadata_max_length: int = 4): + """Construct a minimal ``self`` stand-in for the planners' ``prepare_metadata``. + + The unbound method only reads ``self.language_tokenizer`` and + ``self.config.metadata_max_length``, so a plain SimpleNamespace is + sufficient — no Gemma 3 backbone, no HuggingFace download. + """ + import types + + tokenizer, captured = _make_tokenizer_capture() + fake = types.SimpleNamespace( + language_tokenizer=tokenizer, + config=types.SimpleNamespace(metadata_max_length=metadata_max_length), + ) + return fake, captured + + +class TestPrepareMetadataSegments: + """Verify the per-sample metadata string construction in both planners. + + Covers: + * ``robot_type`` / ``control_mode`` (PR-introduced) emit ``"Robot: ..."`` + and ``"Control: ..."`` segments only when non-empty. + * Missing batch keys default to "fully padded" → empty string. + * The high- and low-level planners agree on the all-empty case. + """ + + @staticmethod + def _planner_methods(): + from opentau.policies.pi07.high_level_planner.modeling_pi07_high_level import ( + PI07HighLevelPlannerPolicy, + ) + from opentau.policies.pi07.low_level_planner.modeling_pi07_low_level import ( + PI07LowLevelPlannerPolicy, + ) + + return { + "low": PI07LowLevelPlannerPolicy.prepare_metadata, + "high": PI07HighLevelPlannerPolicy.prepare_metadata, + } + + @pytest.mark.parametrize("planner", ["low", "high"]) + def test_robot_and_control_present(self, planner): + """Both robot_type and control_mode populated → both segments emitted.""" + method = self._planner_methods()[planner] + fake, captured = _make_fake_planner() + + batch_size = 2 + batch = { + "state": torch.zeros(batch_size, 1), + "robot_type": ["franka", "ur5"], + "control_mode": ["joint", "ee"], + } + + method(fake, batch) + + assert len(captured) == batch_size + for line, robot, ctrl in zip(captured, ["franka", "ur5"], ["joint", "ee"], strict=True): + assert line.startswith("Metadata: ") + assert f"Robot: {robot}, " in line + assert f"Control: {ctrl}, " in line + + @pytest.mark.parametrize("planner", ["low", "high"]) + def test_robot_and_control_absent_emits_empty_string(self, planner): + """No metadata keys present → all samples emit ``""`` (no fabricated values).""" + method = self._planner_methods()[planner] + fake, captured = _make_fake_planner() + + batch_size = 3 + batch = {"state": torch.zeros(batch_size, 1)} + + method(fake, batch) + + assert captured == ["", "", ""] + + @pytest.mark.parametrize("planner", ["low", "high"]) + def test_one_present_one_empty(self, planner): + """Mixed: only ``robot_type`` populated → only ``Robot:`` segment emitted.""" + method = self._planner_methods()[planner] + fake, captured = _make_fake_planner() + + batch_size = 2 + batch = { + "state": torch.zeros(batch_size, 1), + "robot_type": ["franka", "ur5"], + "control_mode": ["", ""], + } + + method(fake, batch) + + assert len(captured) == batch_size + for line, robot in zip(captured, ["franka", "ur5"], strict=True): + assert line.startswith("Metadata: ") + assert f"Robot: {robot}, " in line + assert "Control:" not in line + + @pytest.mark.parametrize("planner", ["low", "high"]) + def test_per_sample_empty_emits_empty_string(self, planner): + """Within a batch, samples whose every field is empty/padded emit ``""``, + even when sibling samples have populated fields. Pins the + ``if segments else ""`` guard at both planner sites so the two + planners agree on the all-empty contract. + """ + method = self._planner_methods()[planner] + fake, captured = _make_fake_planner() + + batch_size = 2 + batch = { + "state": torch.zeros(batch_size, 1), + "robot_type": ["franka", ""], + "control_mode": ["joint", ""], + } + + method(fake, batch) + + assert len(captured) == batch_size + assert captured[0].startswith("Metadata: ") + assert captured[1] == "" + + def test_low_level_missing_speed_pad_does_not_fabricate_value(self): + """Regression for the zeros→ones default fix: with ``speed`` and + ``speed_is_pad`` both absent from the batch, the low-level planner + must NOT emit ``"Speed: 0.0"`` — that would surface a value the + caller never provided. + """ + from opentau.policies.pi07.low_level_planner.modeling_pi07_low_level import ( + PI07LowLevelPlannerPolicy, + ) + + method = PI07LowLevelPlannerPolicy.prepare_metadata + fake, captured = _make_fake_planner() + + batch_size = 2 + batch = {"state": torch.zeros(batch_size, 1)} + + method(fake, batch) + + for line in captured: + assert "Speed:" not in line + assert "Quality:" not in line + assert "Mistake:" not in line From 2949e738bf3859b23e636cffa340c3877b16bd61 Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Fri, 1 May 2026 14:24:16 -0700 Subject: [PATCH 5/8] Applying pi07_paligemma fixes to pi07 --- src/opentau/datasets/lerobot_dataset.py | 44 +++-- .../modeling_pi07_low_level.py | 164 ++++++++++-------- 2 files changed, 126 insertions(+), 82 deletions(-) diff --git a/src/opentau/datasets/lerobot_dataset.py b/src/opentau/datasets/lerobot_dataset.py index 9b56bb67..65f38fa5 100644 --- a/src/opentau/datasets/lerobot_dataset.py +++ b/src/opentau/datasets/lerobot_dataset.py @@ -1766,15 +1766,23 @@ def _sample_subgoal_frame(self, ep_idx: int, frame_in_ep: int, *, at_end_of_segm current segment (clipped to the episode's last frame). Otherwise samples a timestamp uniformly in ``[t, t + 4s]`` (wall-clock) and converts it to a frame index, clipping to the current segment end and the episode end. + + Episodes that have no ``segments`` annotation in ``episodes.jsonl`` + skip segment-aware clipping entirely and fall back to a fixed + ~4-seconds-ahead subgoal frame (clipped to the episode end). This + keeps subgoal supervision available on legacy datasets that never + wrote per-episode segment boundaries. """ ep_length = self.episode_lengths[ep_idx] + window_frames = int(round(4.0 * self.fps)) + if "segments" not in self.meta.episodes[ep_idx]: + return min(frame_in_ep + window_frames, ep_length - 1) seg_idx = self._lookup_segment_index(ep_idx, frame_in_ep) seg_end_excl = self._segment_end_in_ep(ep_idx, seg_idx) upper = min(seg_end_excl, ep_length) - 1 # inclusive upper bound. upper = max(upper, frame_in_ep) if at_end_of_segment: return upper - window_frames = int(round(4.0 * self.fps)) top = min(frame_in_ep + window_frames, upper) if top <= frame_in_ep: return frame_in_ep @@ -1794,21 +1802,21 @@ def _load_subgoal_frames(self, ep_idx: int, frame_in_ep: int) -> dict[str, torch When the key IS present: - The at-end-of-segment vs uniform sampling roll happens ONCE per ``__getitem__`` call (shared across all camera slots); each slot - decodes the frame from its own video. + fetches the frame from its own source — video file for ``video`` + dtype features, parquet row for ``image`` dtype features (the + latter share the same within-episode frame index). - Drop-roll is short-circuited here so a dropped subgoal skips the - per-camera ``_query_videos`` decode. When + per-camera decode/lookup. When ``self.enable_optional_key_dropout`` is False (e.g. the validation subset), drop is never rolled — the frame-selection randomness stays live because it's about which future frame to read, not masking. - - Episodes with no ``segments`` entry in ``episodes.jsonl`` still - skip sampling (no segment boundaries → nothing to clip against). + - Episodes with no ``segments`` entry in ``episodes.jsonl`` fall + back to a fixed ~4 s lookahead inside ``_sample_subgoal_frame`` + rather than skipping subgoal loading, so legacy datasets without + segment annotations still get supervision. """ - if self.num_cams <= 0 or len(self.meta.video_keys) == 0: - return {} - if "subgoals" not in self.meta.info: - return {} - if "segments" not in self.meta.episodes[ep_idx]: + if self.num_cams <= 0 or len(self.meta.camera_keys) == 0: return {} # Roll drop before any video decoding — at `subgoal_drop_prob=0.75` the # old ordering threw away 75% of decodes. @@ -1818,13 +1826,21 @@ def _load_subgoal_frames(self, ep_idx: int, frame_in_ep: int) -> dict[str, torch at_end = bool(torch.rand(()) < self.subgoal_end_of_segment_prob) subgoal_frame = self._sample_subgoal_frame(ep_idx, frame_in_ep, at_end_of_segment=at_end) ts = subgoal_frame / self.fps + ep_start = int(self.episode_data_index["from"][self.epi2idx[ep_idx]].item()) out: dict[str, torch.Tensor] = {} for k in range(self.num_cams): - vid_key = name_map.get(f"camera{k}") - if vid_key is None or vid_key not in self.meta.video_keys: + cam_key = name_map.get(f"camera{k}") + if cam_key is None: continue - frames = self._query_videos({vid_key: np.array([ts])}, ep_idx) - out[f"subgoal{k}_raw"] = frames[vid_key] + if cam_key in self.meta.video_keys: + frames = self._query_videos({cam_key: np.array([ts])}, ep_idx) + out[f"subgoal{k}_raw"] = frames[cam_key] + elif cam_key in self.meta.image_keys: + # Image-dtype cameras are stored per-frame in the parquet + # rows of ``hf_dataset``. The within-episode index returned + # by ``_sample_subgoal_frame`` maps directly to the absolute + # row ``ep_start + subgoal_frame``. + out[f"subgoal{k}_raw"] = self.hf_dataset[ep_start + subgoal_frame][cam_key] return out def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index 127c9f61..ffb59ef4 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -914,9 +914,12 @@ def prepare_subgoal_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor] # Per-sample flag: True means the subgoal was dropped or absent. subgoal_is_pad = batch.get( - "subgoal_is_pad", torch.zeros(batch["state"].shape[0], dtype=torch.bool) + "subgoal_is_pad", torch.ones(batch["state"].shape[0], dtype=torch.bool) ) # (B,) bool or None + last_subgoal_img: Tensor | None = None + last_mask: Tensor | None = None + for key in present_subgoal_img_keys: subgoal_img = batch[key] @@ -937,16 +940,16 @@ def prepare_subgoal_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor] subgoal_images.append(subgoal_img) subgoal_img_masks.append(mask) + last_subgoal_img = subgoal_img + last_mask = mask - # Create image features not present in the batch - # as fully 0 padded images. - for num_empty_cameras in range(len(missing_subgoal_img_keys)): - if num_empty_cameras >= self.config.empty_cameras: - break - subgoal_img = torch.ones_like(subgoal_img) * -1 - mask = torch.zeros_like(mask) - subgoal_images.append(subgoal_img) - subgoal_img_masks.append(mask) + # Create image features not present in the batch as fully 0 padded images. + if last_subgoal_img is not None and last_mask is not None: + for num_empty_cameras in range(len(missing_subgoal_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + subgoal_images.append(torch.ones_like(last_subgoal_img) * -1) + subgoal_img_masks.append(torch.zeros_like(last_mask)) return subgoal_images, subgoal_img_masks @@ -1255,72 +1258,90 @@ def embed_prefix( pad_masks.append(state_end_mask) att_masks += [0] * num_state_end_embs - response_emb = self.gemma3_with_expert.embed_language_tokens(response_tokens) - embs.append(response_emb) - pad_masks.append(response_masks) - num_response_embs = response_emb.shape[1] - att_masks += [1] + [0] * (num_response_embs - 1) + # Only embed response when at least one sample in the batch has real response tokens. + # With response_drop_prob=1.0 all masks are False; skipping avoids a spurious + # causal-block boundary (att_masks=[1,...]) that would corrupt cumsum for every + # subsequent token. + if response_tokens is not None and response_masks is not None and response_masks.any(): + response_emb = self.gemma3_with_expert.embed_language_tokens(response_tokens) + embs.append(response_emb) + pad_masks.append(response_masks) + num_response_embs = response_emb.shape[1] + att_masks += [1] + [0] * (num_response_embs - 1) + + # Only embed the subgoal block (header + images + footer) when subgoal images are + # actually present. Unconditionally adding "Subgoal: " injects real (non-padded) + # spurious tokens into every prefix even with subgoal_drop_prob=1.0. + if subgoal_images and any(mask.any() for mask in subgoal_img_masks): + # Per-sample availability: True iff at least one camera slot + # has a real subgoal image for that sample. In a mixed batch + # (some samples have subgoals, others don't), the "Subgoal: " + # header and trailing ", " footer must follow this same mask — + # otherwise pad-only samples would receive real (unmasked) + # indicator tokens with no image content behind them, which + # the prefix-LM block then attends to as if it were grounded. + sample_has_subgoal = torch.stack([m.to(dtype=torch.bool) for m in subgoal_img_masks], dim=0).any( + dim=0 + ) - subgoal_img_start_indicator_ids = self.language_tokenizer.encode( - "Subgoal: ", add_special_tokens=False - ) - subgoal_img_start_tokens = torch.tensor( - [subgoal_img_start_indicator_ids] * bsize, - device=lang_tokens.device, - dtype=torch.long, - ) - subgoal_img_start_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_start_tokens) + subgoal_img_start_indicator_ids = self.language_tokenizer.encode( + "Subgoal: ", add_special_tokens=False + ) + subgoal_img_start_tokens = torch.tensor( + [subgoal_img_start_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + subgoal_img_start_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_start_tokens) - num_subgoal_img_start_embs = subgoal_img_start_emb.shape[1] - subgoal_img_start_mask = torch.ones( - bsize, num_subgoal_img_start_embs, dtype=torch.bool, device=lang_tokens.device - ) + num_subgoal_img_start_embs = subgoal_img_start_emb.shape[1] + subgoal_img_start_mask = sample_has_subgoal[:, None].expand(bsize, num_subgoal_img_start_embs) - embs.append(subgoal_img_start_emb) - pad_masks.append(subgoal_img_start_mask) - att_masks += [1] + [0] * (num_subgoal_img_start_embs - 1) + embs.append(subgoal_img_start_emb) + pad_masks.append(subgoal_img_start_mask) + att_masks += [1] + [0] * (num_subgoal_img_start_embs - 1) - for ( - subgoal_img, - subgoal_img_mask, - ) in zip(subgoal_images, subgoal_img_masks, strict=True): - subgoal_img_emb = self.gemma3_with_expert.embed_image(subgoal_img) - subgoal_img_emb = subgoal_img_emb.to(dtype=_preferred_dtype()) + for ( + subgoal_img, + subgoal_img_mask, + ) in zip(subgoal_images, subgoal_img_masks, strict=True): + subgoal_img_emb = self.gemma3_with_expert.embed_image(subgoal_img) + subgoal_img_emb = subgoal_img_emb.to(dtype=_preferred_dtype()) - # Gemma 3's projector does not apply the `/ sqrt(text_hidden_size)` - # scaling that stock PaliGemma does, so no un-normalization is - # required here (matches `embed_image` in `gemma3_with_expert.py`). + # Gemma 3's projector does not apply the `/ sqrt(text_hidden_size)` + # scaling that stock PaliGemma does, so no un-normalization is + # required here (matches `embed_image` in `gemma3_with_expert.py`). - bsize, num_subgoal_img_embs = subgoal_img_emb.shape[:2] - subgoal_img_mask = subgoal_img_mask[:, None].expand(bsize, num_subgoal_img_embs) + bsize, num_subgoal_img_embs = subgoal_img_emb.shape[:2] + subgoal_img_mask = subgoal_img_mask[:, None].expand(bsize, num_subgoal_img_embs) - embs.append(subgoal_img_emb) - pad_masks.append(subgoal_img_mask) + embs.append(subgoal_img_emb) + pad_masks.append(subgoal_img_mask) - # Create attention masks so that image tokens attend to each other - att_masks += [1] + [0] * (num_subgoal_img_embs - 1) + # Create attention masks so that image tokens attend to each other + att_masks += [1] + [0] * (num_subgoal_img_embs - 1) - subgoal_img_end_indicator_ids = self.language_tokenizer.encode(", ", add_special_tokens=False) - subgoal_img_end_tokens = torch.tensor( - [subgoal_img_end_indicator_ids] * bsize, - device=lang_tokens.device, - dtype=torch.long, - ) - subgoal_img_end_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_end_tokens) + subgoal_img_end_indicator_ids = self.language_tokenizer.encode(", ", add_special_tokens=False) + subgoal_img_end_tokens = torch.tensor( + [subgoal_img_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + subgoal_img_end_emb = self.gemma3_with_expert.embed_language_tokens(subgoal_img_end_tokens) - num_subgoal_img_end_embs = subgoal_img_end_emb.shape[1] - subgoal_img_end_mask = torch.ones( - bsize, num_subgoal_img_end_embs, dtype=torch.bool, device=lang_tokens.device - ) + num_subgoal_img_end_embs = subgoal_img_end_emb.shape[1] + subgoal_img_end_mask = sample_has_subgoal[:, None].expand(bsize, num_subgoal_img_end_embs) - embs.append(subgoal_img_end_emb) - pad_masks.append(subgoal_img_end_mask) - att_masks += [0] * num_subgoal_img_end_embs + embs.append(subgoal_img_end_emb) + pad_masks.append(subgoal_img_end_mask) + att_masks += [0] * num_subgoal_img_end_embs - metadata_emb = self.gemma3_with_expert.embed_language_tokens(metadata_tokens) - embs.append(metadata_emb) - pad_masks.append(metadata_masks) - att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) + # Only embed metadata when at least one sample has real metadata tokens. + if metadata_tokens is not None and metadata_masks is not None and metadata_masks.any(): + metadata_emb = self.gemma3_with_expert.embed_language_tokens(metadata_tokens) + embs.append(metadata_emb) + pad_masks.append(metadata_masks) + att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) prefix_end_tokens = torch.tensor( @@ -1503,7 +1524,13 @@ def forward( vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - num_cross_att_tokens = prefix_embs.shape[1] - self.config.discrete_action_max_length + # Exclude both the "Action: " indicator tokens and the discrete action tokens from + # cross-attention so the action expert sees the same prefix length at train and inference + # (at inference, neither is in the prefix). Matches pi05's discrete_action_indicator_max_length logic. + _action_indicator_len = len(self.language_tokenizer.encode("Action: ", add_special_tokens=False)) + num_cross_att_tokens = ( + prefix_embs.shape[1] - _action_indicator_len - self.config.discrete_action_max_length + ) (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( attention_mask=vlm_2d_attention_mask, @@ -1541,9 +1568,10 @@ def forward( n_cross_att_tokens=num_cross_att_tokens, cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens], ) - prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[ - :, None - ] + prefix_offsets = torch.sum( + prefix_pad_masks[:, : -(_action_indicator_len + self.config.discrete_action_max_length)], + dim=-1, + )[:, None] action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 assert past_key_values is not None From 2921cabeb3d9f6c2f7bec78eae1187533b6de186 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sat, 2 May 2026 03:58:32 +0000 Subject: [PATCH 6/8] [claude-fix] address review feedback on #197 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - addresses @claude[bot] (subgoal opt-in docstring): updated _load_subgoal_frames docstring to match the always-on behaviour; replaced the stale test_missing_subgoals_key_in_info_returns_empty test (which encoded the old opt-in gate) with two pinned cases — no-cameras → {}, and "no info.subgoals key still loads subgoals". Added camera_keys / image_keys / episode_data_index attrs to the SimpleNamespace meta in the two existing video tests so they match the new attribute reads. - addresses @claude[bot] (image-dtype fallback row index): new test_image_dtype_fallback_uses_absolute_row_index in tests/datasets/ test_optional_keys.py stubs hf_dataset.__getitem__ and pins that the parquet-row lookup uses ep_start + subgoal_frame, never the within-episode index. - addresses @claude[bot] ("fully 0 padded" comment): rewrote the comment at modeling_pi07_low_level.py:946 to call out -1 ([-1, 1] SigLIP range) and the False-mask role of the placeholder. - addresses @claude[bot] (_action_indicator_len recompute): cached the Action-indicator length at PI07LowLevelPlannerFlowMatching.__init__; both forward sites now read self._action_indicator_len. - addresses @claude[bot] (embed_prefix CPU coverage): added TestEmbedPrefixConditionalGuards in tests/policies/test_pi07_cpu.py with a fake Gemma3WithExpert + tokenizer + state_proj + embed_video so the three guards (response_masks.any(), subgoal availability, metadata_masks.any()) are exercised without GPU. Cases: all-False optional masks → no spurious causal boundary; mixed-availability subgoal batch → header/footer pad mask zeroes the pad-only sample; response.any() → exactly one boundary. - addresses @claude[bot] (.base_layer. stale comment in video_encoder.py): rewrote the wrap comment at video_encoder.py:422 to explain that the wrapper adopts submodules by reference, so wrapped-layer state-dict keys are byte-for-byte identical to a vanilla SiglipEncoderLayer (no .base_layer. prefix). tests: passed — pytest tests/policies tests/datasets -m "not gpu" -n auto (452 passed, 12 skipped; the 1 collection error in tests/policies/test_pi07_paligemma_low_level_planner.py is pre-existing and unrelated — pi07_paligemma still imports VJEPA2VideoEncoder from pi05_mem, which was removed in #171). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/opentau/datasets/lerobot_dataset.py | 17 +- .../modeling_pi07_low_level.py | 16 +- .../pi07/low_level_planner/video_encoder.py | 9 +- tests/datasets/test_optional_keys.py | 159 +++++++++++++- tests/policies/test_pi07_cpu.py | 204 ++++++++++++++++++ 5 files changed, 381 insertions(+), 24 deletions(-) diff --git a/src/opentau/datasets/lerobot_dataset.py b/src/opentau/datasets/lerobot_dataset.py index 65f38fa5..ed7fdbba 100644 --- a/src/opentau/datasets/lerobot_dataset.py +++ b/src/opentau/datasets/lerobot_dataset.py @@ -1792,14 +1792,15 @@ def _sample_subgoal_frame(self, ep_idx: int, frame_in_ep: int, *, at_end_of_segm def _load_subgoal_frames(self, ep_idx: int, frame_in_ep: int) -> dict[str, torch.Tensor]: """Decode subgoal frames — one per camera slot — for this sample. - Subgoal image paths must be declared in ``meta/info.json`` under the - ``subgoals`` key. When the key is missing (the state of every - LeRobot dataset today), we assume no subgoal images exist and return - ``{}``; :meth:`BaseDataset._emit_optional_keys` then emits - ``subgoal_is_pad=True`` for every slot. Datasets opt in by adding - the key to info.json. - - When the key IS present: + Subgoal supervision is always-on for any dataset that exposes camera + keys; the dedicated ``subgoals`` info.json declaration that the older + pi07_paligemma path required is no longer consulted. Datasets without + any cameras (``self.num_cams == 0`` or empty + ``self.meta.camera_keys``) still return ``{}``, which lets + :meth:`BaseDataset._emit_optional_keys` emit ``subgoal_is_pad=True`` + for every slot. + + Behavior for camera-bearing datasets: - The at-end-of-segment vs uniform sampling roll happens ONCE per ``__getitem__`` call (shared across all camera slots); each slot fetches the frame from its own source — video file for ``video`` diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index ffb59ef4..01dfeb41 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -943,7 +943,9 @@ def prepare_subgoal_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor] last_subgoal_img = subgoal_img last_mask = mask - # Create image features not present in the batch as fully 0 padded images. + # Materialize missing-camera placeholders. Filled with -1 (not 0) so + # they sit at the "black" end of SigLIP's [-1, 1] input range; the + # accompanying False mask still flags these slots as padded. if last_subgoal_img is not None and last_mask is not None: for num_empty_cameras in range(len(missing_subgoal_img_keys)): if num_empty_cameras >= self.config.empty_cameras: @@ -1102,6 +1104,13 @@ def __init__(self, config: PI07LowLevelPlannerConfig, discrete_action_vocab_size self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + # Length of the "Action: " indicator tokens — constant for a given + # tokenizer, so cache once instead of re-tokenizing on every forward + # call. Used together with discrete_action_max_length to compute the + # cross-attention prefix slice in forward()/sample_actions(), so the + # action expert sees the same prefix length at train and inference. + self._action_indicator_len = len(self.language_tokenizer.encode("Action: ", add_special_tokens=False)) + def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor: return torch.normal(mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device) @@ -1527,9 +1536,8 @@ def forward( # Exclude both the "Action: " indicator tokens and the discrete action tokens from # cross-attention so the action expert sees the same prefix length at train and inference # (at inference, neither is in the prefix). Matches pi05's discrete_action_indicator_max_length logic. - _action_indicator_len = len(self.language_tokenizer.encode("Action: ", add_special_tokens=False)) num_cross_att_tokens = ( - prefix_embs.shape[1] - _action_indicator_len - self.config.discrete_action_max_length + prefix_embs.shape[1] - self._action_indicator_len - self.config.discrete_action_max_length ) (prefix_out, _), past_key_values = self.gemma3_with_expert.forward( @@ -1569,7 +1577,7 @@ def forward( cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens], ) prefix_offsets = torch.sum( - prefix_pad_masks[:, : -(_action_indicator_len + self.config.discrete_action_max_length)], + prefix_pad_masks[:, : -(self._action_indicator_len + self.config.discrete_action_max_length)], dim=-1, )[:, None] action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 diff --git a/src/opentau/policies/pi07/low_level_planner/video_encoder.py b/src/opentau/policies/pi07/low_level_planner/video_encoder.py index 8a6abd50..e2601e8f 100644 --- a/src/opentau/policies/pi07/low_level_planner/video_encoder.py +++ b/src/opentau/policies/pi07/low_level_planner/video_encoder.py @@ -420,10 +420,11 @@ def __init__( self.siglip_hidden_size = vision_cfg.hidden_size # Wrap every stride-th layer with space-time attention. The wrapper - # holds the original SiglipEncoderLayer as ``base_layer`` so its - # pretrained weights flow through unchanged. State-dict keys for - # wrapped layers will carry a ``.base_layer.`` prefix; as long as - # reloads round-trip through this code, keys stay consistent. + # adopts the base layer's submodules (self_attn / layer_norm{1,2} / + # mlp) by reference, so wrapped-layer state-dict keys are byte-for-byte + # identical to a vanilla ``SiglipEncoderLayer`` — no ``.base_layer.`` + # prefix appears. Pinned by ``test_pi07_video_encoder_cpu.py:: + # test_state_dict_keys_unchanged_after_wrapping``. layers = vision_tower.vision_model.encoder.layers n_layers = len(layers) for i in range(spacetime_layer_stride - 1, n_layers, spacetime_layer_stride): diff --git a/tests/datasets/test_optional_keys.py b/tests/datasets/test_optional_keys.py index 9d1d3ccc..b0a8708e 100644 --- a/tests/datasets/test_optional_keys.py +++ b/tests/datasets/test_optional_keys.py @@ -320,11 +320,18 @@ def test_subgoal_end_of_segment_roll_stays_active(self, monkeypatch): ds.resolution = (8, 8) ds.episode_lengths = {0: 100} ds.segment_starts_by_episode = {0: _np.array([0])} + ds.episode_data_index = { + "from": torch.tensor([0], dtype=torch.long), + "to": torch.tensor([100], dtype=torch.long), + } + ds.epi2idx = {0: 0} ds.meta = SimpleNamespace( video_keys=["camera0"], + image_keys=[], + camera_keys=["camera0"], episodes={0: {"segments": [0]}}, fps=30, - info={"subgoals": {"subgoal0": "camera0"}}, # opt in to subgoal loading + info={}, ) monkeypatch.setattr(type(ds), "_get_feature_mapping_key", lambda self: mapping_key) @@ -369,9 +376,11 @@ def test_subgoal_drop_skips_video_decode_in_train_mode(self, monkeypatch): ds.segment_starts_by_episode = {0: _np.array([0])} ds.meta = SimpleNamespace( video_keys=["camera0"], + image_keys=[], + camera_keys=["camera0"], episodes={0: {"segments": [0]}}, fps=30, - info={"subgoals": {"subgoal0": "camera0"}}, # opt in so the drop path is exercised + info={}, ) monkeypatch.setattr(type(ds), "_get_feature_mapping_key", lambda self: mapping_key) @@ -383,40 +392,174 @@ def _fake_query_videos(self, query_ts, ep_idx): out = ds._load_subgoal_frames(0, 0) assert out == {}, "drop should return an empty dict and skip decoding" - def test_missing_subgoals_key_in_info_returns_empty(self, monkeypatch): - """Default state: info.json has no ``subgoals`` key, so no subgoals load.""" + def test_no_cameras_returns_empty(self, monkeypatch): + """Datasets with no camera keys still short-circuit to ``{}`` so + :meth:`BaseDataset._emit_optional_keys` can mark every subgoal slot + padded. + """ from types import SimpleNamespace import numpy as _np import opentau.datasets.lerobot_dataset as _ld - mapping_key = "_tests/subgoal_missing_info_key" + mapping_key = "_tests/subgoal_no_cameras" _ld.DATA_FEATURES_NAME_MAPPING[mapping_key] = {"camera0": "camera0"} ds = _ld.LeRobotDataset.__new__(_ld.LeRobotDataset) ds.enable_optional_key_dropout = False ds.subgoal_end_of_segment_prob = 1.0 ds.subgoal_drop_prob = 0.0 + ds.num_cams = 0 + ds.resolution = (8, 8) + ds.episode_lengths = {0: 100} + ds.segment_starts_by_episode = {0: _np.array([0])} + ds.meta = SimpleNamespace( + video_keys=[], + image_keys=[], + camera_keys=[], + episodes={0: {"segments": [0]}}, + fps=30, + info={}, + ) + monkeypatch.setattr(type(ds), "_get_feature_mapping_key", lambda self: mapping_key) + + def _fail_query_videos(self, query_ts, ep_idx): + raise AssertionError("_query_videos must not be called when no cameras are present") + + monkeypatch.setattr(type(ds), "_query_videos", _fail_query_videos) + + assert ds._load_subgoal_frames(0, 0) == {} + + def test_subgoals_load_without_info_subgoals_key(self, monkeypatch): + """Always-on behavior: info.json with no ``subgoals`` key still + triggers subgoal loading from the camera streams. + + Pins the deliberate removal of the prior opt-in gate so a future + refactor doesn't silently restore it. + """ + from types import SimpleNamespace + + import numpy as _np + + import opentau.datasets.lerobot_dataset as _ld + + mapping_key = "_tests/subgoal_no_info_key" + _ld.DATA_FEATURES_NAME_MAPPING[mapping_key] = {"camera0": "camera0"} + + ds = _ld.LeRobotDataset.__new__(_ld.LeRobotDataset) + ds.enable_optional_key_dropout = False + ds.subgoal_end_of_segment_prob = 0.0 + ds.subgoal_drop_prob = 0.0 ds.num_cams = 1 ds.resolution = (8, 8) ds.episode_lengths = {0: 100} ds.segment_starts_by_episode = {0: _np.array([0])} - # No "subgoals" key in info — mirrors every production LeRobot dataset today. + ds.episode_data_index = { + "from": torch.tensor([0], dtype=torch.long), + "to": torch.tensor([100], dtype=torch.long), + } + ds.epi2idx = {0: 0} + # info.json deliberately omits "subgoals" — mirrors every legacy + # LeRobot dataset that pre-dates this PR. ds.meta = SimpleNamespace( video_keys=["camera0"], + image_keys=[], + camera_keys=["camera0"], episodes={0: {"segments": [0]}}, fps=30, info={}, ) monkeypatch.setattr(type(ds), "_get_feature_mapping_key", lambda self: mapping_key) + called: list = [] + + def _fake_query_videos(self, query_ts, ep_idx): + called.append(dict(query_ts)) + return {"camera0": torch.zeros((3, *self.resolution))} + + monkeypatch.setattr(type(ds), "_query_videos", _fake_query_videos) + + out = ds._load_subgoal_frames(0, 0) + assert "subgoal0_raw" in out + assert len(called) == 1, "video decode fired exactly once for the single camera" + + def test_image_dtype_fallback_uses_absolute_row_index(self, monkeypatch): + """``image``-dtype cameras read from the parquet rows of + ``hf_dataset`` instead of decoding a video. The lookup index must be + ``ep_start + subgoal_frame`` — the absolute row in the table — not + the within-episode index. + """ + from types import SimpleNamespace + + import numpy as _np + + import opentau.datasets.lerobot_dataset as _ld + + mapping_key = "_tests/subgoal_image_dtype" + _ld.DATA_FEATURES_NAME_MAPPING[mapping_key] = {"camera0": "camera0"} + + ds = _ld.LeRobotDataset.__new__(_ld.LeRobotDataset) + ds.enable_optional_key_dropout = False + ds.subgoal_end_of_segment_prob = 0.0 + ds.subgoal_drop_prob = 0.0 + ds.num_cams = 1 + ds.resolution = (8, 8) + # Episode 1 starts at absolute row 100, length 50. + ep_start_abs = 100 + ep_len = 50 + ds.episode_lengths = {1: ep_len} + ds.segment_starts_by_episode = {1: _np.array([0])} + ds.episode_data_index = { + "from": torch.tensor([0, ep_start_abs], dtype=torch.long), + "to": torch.tensor([ep_start_abs, ep_start_abs + ep_len], dtype=torch.long), + } + ds.epi2idx = {0: 0, 1: 1} + ds.meta = SimpleNamespace( + video_keys=[], + image_keys=["camera0"], + camera_keys=["camera0"], + episodes={1: {"segments": [0]}}, + fps=10, + info={}, + ) + monkeypatch.setattr(type(ds), "_get_feature_mapping_key", lambda self: mapping_key) + + # Pin the within-episode subgoal index so the test asserts a known + # absolute row. + within_ep_subgoal = 7 + + def _fake_sample(self, ep_idx, frame_in_ep, *, at_end_of_segment): + assert ep_idx == 1 and frame_in_ep == 0 + return within_ep_subgoal + + monkeypatch.setattr(type(ds), "_sample_subgoal_frame", _fake_sample) + + # Stub hf_dataset to capture the row index requested and return a + # sentinel tensor identifying that row. + sentinel_payload = torch.full((3, 8, 8), 0.42, dtype=torch.float32) + hf_calls: list[int] = [] + + class _HFStub: + def __getitem__(self, idx): + hf_calls.append(idx) + return {"camera0": sentinel_payload} + + ds.hf_dataset = _HFStub() + def _fail_query_videos(self, query_ts, ep_idx): - raise AssertionError("_query_videos must not be called when info has no 'subgoals'") + raise AssertionError("_query_videos must not be called for image-dtype cameras") monkeypatch.setattr(type(ds), "_query_videos", _fail_query_videos) - assert ds._load_subgoal_frames(0, 0) == {} + out = ds._load_subgoal_frames(1, 0) + + assert hf_calls == [ep_start_abs + within_ep_subgoal], ( + f"image-dtype subgoal lookup used the wrong row: got {hf_calls}, " + f"expected [{ep_start_abs + within_ep_subgoal}]" + ) + assert "subgoal0_raw" in out + assert torch.equal(out["subgoal0_raw"], sentinel_payload) # Default collate tolerates a batch with mixed _is_pad flags. diff --git a/tests/policies/test_pi07_cpu.py b/tests/policies/test_pi07_cpu.py index 525c651f..85dea0eb 100644 --- a/tests/policies/test_pi07_cpu.py +++ b/tests/policies/test_pi07_cpu.py @@ -540,3 +540,207 @@ def test_low_level_missing_speed_pad_does_not_fabricate_value(self): assert "Speed:" not in line assert "Quality:" not in line assert "Mistake:" not in line + + +# `embed_prefix` conditional-block guards +# +# The low-level planner skips entire prefix blocks when their availability +# masks are all-False (response_masks, subgoal_img_masks, metadata_masks). +# Pinning that on the CPU side means a regression that re-introduces a +# spurious causal boundary (att_masks=[1,...]) on a fully-padded slot is +# caught without firing up the Gemma 3 backbone. + + +def _make_fake_flow_matching(*, hidden: int = 4, n_video_tokens: int = 3): + """Construct a minimal stand-in for ``PI07LowLevelPlannerFlowMatching`` + so ``embed_prefix`` runs end-to-end without instantiating the real + Gemma 3 backbone. + + All language tokens, image tokens, and video tokens project to deterministic + zero tensors with the correct shape — the test only asserts on the + structure of ``embs`` / ``pad_masks`` / ``att_masks`` (lengths, sample-mask + rows), not on numeric values. + """ + import types + + class _FakeGemma3WithExpert: + def embed_language_tokens(self, tokens): + return torch.zeros((*tokens.shape, hidden), dtype=torch.float32) + + def embed_image(self, image): + # Each image becomes 2 tokens — small enough to keep prefix lengths readable. + return torch.zeros(image.shape[0], 2, hidden, dtype=torch.float32) + + def embed_discrete_actions(self, da): + return torch.zeros((*da.shape, hidden), dtype=torch.float32) + + class _FakeTokenizer: + # Every indicator phrase encodes to exactly 2 tokens. embed_prefix + # uses tokenizer.encode() to get the literal token ids, then later + # passes them back through embed_language_tokens for the embedding. + def encode(self, text, add_special_tokens=False): + return [1, 2] + + def _state_proj(state): + # state: (B, T, D) → (B, T, hidden) + return torch.zeros(state.shape[0], state.shape[1], hidden, dtype=torch.float32) + + def _embed_video(video): + # video: (B, T, C, H, W) → (B, n_video_tokens, hidden) + return torch.zeros(video.shape[0], n_video_tokens, hidden, dtype=torch.float32) + + fake = types.SimpleNamespace( + gemma3_with_expert=_FakeGemma3WithExpert(), + language_tokenizer=_FakeTokenizer(), + state_proj=_state_proj, + embed_video=_embed_video, + config=types.SimpleNamespace(discrete_action_max_length=2), + ) + return fake + + +def _embed_prefix_method(): + from opentau.policies.pi07.low_level_planner.modeling_pi07_low_level import ( + PI07LowLevelPlannerFlowMatching, + ) + + return PI07LowLevelPlannerFlowMatching.embed_prefix + + +def _build_default_inputs(*, batch_size: int = 2, prompt_len: int = 3, t_state: int = 1): + """Build the minimal kwargs every embed_prefix invocation needs.""" + return { + "videos": [torch.zeros(batch_size, t_state, 3, 4, 4)], + "vid_masks": [torch.ones(batch_size, dtype=torch.bool)], + "lang_tokens": torch.zeros(batch_size, prompt_len, dtype=torch.long), + "lang_masks": torch.ones(batch_size, prompt_len, dtype=torch.bool), + "state": torch.zeros(batch_size, t_state, 7), + "response_tokens": None, + "response_masks": None, + "metadata_tokens": None, + "metadata_masks": None, + } + + +class TestEmbedPrefixConditionalGuards: + """Pin per-block toggles in ``embed_prefix``. + + The three conditional emit-or-skip toggles (response, subgoal, metadata) + are normally exercised only by the GPU integration tests. These CPU tests + fake the Gemma 3 backbone so a regression that emits a spurious header + with all-False masks (or fails to emit a real header on a mixed batch) + is caught directly. + """ + + def test_all_optional_blocks_absent_skips_emission(self): + """All-False response_masks + no subgoal_images + all-False metadata_masks → + the prefix collapses to ``videos + lang + State: + state + ", " + ";\\n "``. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + kwargs = _build_default_inputs(batch_size=bsize) + + # Empty response/metadata still passed in: tokens with all-False masks. + kwargs["response_tokens"] = torch.zeros(bsize, 5, dtype=torch.long) + kwargs["response_masks"] = torch.zeros(bsize, 5, dtype=torch.bool) + kwargs["metadata_tokens"] = torch.zeros(bsize, 4, dtype=torch.long) + kwargs["metadata_masks"] = torch.zeros(bsize, 4, dtype=torch.bool) + # No subgoals at all. + kwargs["subgoal_images"] = [] + kwargs["subgoal_img_masks"] = [] + + embs, pad_masks, att_masks = method(fake, **kwargs) + + # Expected layout (each video produces 3 fake video tokens, each + # indicator encodes to 2 tokens, hidden=4): + # videos: 3 + # lang: 3 (prompt_len) + # "State: ": 2 + # state(T=1): 1 + # ", ": 2 + # ";\n ": 2 + # Total = 13 + assert embs.shape == (bsize, 13, 4) + assert pad_masks.shape == (bsize, 13) + assert att_masks.shape == (bsize, 13) + # No causal boundaries (1's) anywhere in the att_masks — every block + # should remain bidirectional with all-skip optional blocks. + assert int(att_masks[0].sum().item()) == 0, ( + "att_masks contained a causal boundary — a guarded block was emitted " + "anyway: this would corrupt the cumsum for every subsequent token." + ) + + def test_mixed_subgoal_availability_zeros_pad_only_samples(self): + """In a mixed batch (some samples have subgoals, others don't), the + ``Subgoal: `` header and trailing comma must follow the per-sample + availability mask — pad-only samples get False on those header tokens. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + kwargs = _build_default_inputs(batch_size=bsize) + + # Sample 0 has a subgoal, sample 1 does not. + per_sample_mask = torch.tensor([True, False]) + subgoal_img = torch.zeros(bsize, 3, 4, 4) + kwargs["subgoal_images"] = [subgoal_img] + kwargs["subgoal_img_masks"] = [per_sample_mask] + kwargs["metadata_tokens"] = torch.zeros(bsize, 4, dtype=torch.long) + kwargs["metadata_masks"] = torch.zeros(bsize, 4, dtype=torch.bool) + + _, pad_masks, _ = method(fake, **kwargs) + + # The Subgoal: header is 2 tokens (fake tokenizer); the image is 2 tokens + # (fake embed_image); the trailing ", " is 2 tokens. The mask values for + # those 6 tokens must mirror per_sample_mask on every emitted slot. + # Easiest assertion: at least one row has pad=False everywhere in the + # subgoal block (sample 1), and at least one row has pad=True (sample 0). + # We don't pin exact slice boundaries here — instead we check that the + # union of per_sample_mask=False rows are all-zero on the subgoal slice. + # Total length sanity: + # videos(3) + lang(3) + "State: "(2) + state(1) + ", "(2) + # + "Subgoal: "(2) + subgoal_img(2) + ", "(2) + # + ";\n "(2) = 19 + assert pad_masks.shape == (bsize, 19) + # The subgoal block is the 6 tokens at indices 11..16 (after the + # prefix-LM transition); the sample with no subgoal must be False + # across all 6, and the sample with a subgoal must be True across all 6. + subgoal_block = pad_masks[:, 11:17] + assert subgoal_block[0].all().item() is True, ( + "sample 0 (has subgoal) should have all-True pad mask on the subgoal block" + ) + assert (~subgoal_block[1]).all().item() is True, ( + "sample 1 (no subgoal) should have all-False pad mask on the subgoal block — " + "otherwise the prefix-LM block opening on a pad-only sample injects a " + "spurious unmasked indicator token" + ) + + def test_response_mask_any_true_emits_block(self): + """When at least one sample has a real response, the response block + IS emitted and contributes a causal boundary token. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + kwargs = _build_default_inputs(batch_size=bsize) + kwargs["response_tokens"] = torch.zeros(bsize, 5, dtype=torch.long) + # Sample 0 has real response; sample 1 is all padded. + kwargs["response_masks"] = torch.tensor( + [[True, True, True, False, False], [False, False, False, False, False]] + ) + kwargs["subgoal_images"] = [] + kwargs["subgoal_img_masks"] = [] + kwargs["metadata_tokens"] = torch.zeros(bsize, 4, dtype=torch.long) + kwargs["metadata_masks"] = torch.zeros(bsize, 4, dtype=torch.bool) + + _, _, att_masks = method(fake, **kwargs) + + # Without the response block, att_masks would be all-zeros (see + # test_all_optional_blocks_absent_skips_emission). With the response + # block, att_masks gets exactly one ``1`` (the prefix-LM block opening + # at the start of the 5-token response span). + per_sample_sum = int(att_masks[0].sum().item()) + assert per_sample_sum == 1, ( + f"expected exactly one causal boundary from the response block opening, got {per_sample_sum}" + ) From 62a5677b4005b87092bab72565164a7fabf96f73 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Fri, 1 May 2026 23:38:58 -0700 Subject: [PATCH 7/8] fix(pi07): gate optional prefix tokens in low- and high-level planners (#229) --- .github/workflows/claude-implement-fixes.yml | 22 +++++- .github/workflows/claude-pr-review.yml | 2 +- .github/workflows/cpu_test.yml | 3 +- .github/workflows/extract-claude-lessons.yml | 7 +- .github/workflows/gpu_test.yml | 3 +- .../modeling_pi07_high_level.py | 40 +++++++---- .../modeling_pi07_low_level.py | 70 +++++++++++++----- tests/policies/test_pi06.py | 20 +++++- tests/policies/test_pi07_cpu.py | 72 +++++++++++++++++-- tests/policies/test_pi07_low_level_planner.py | 17 ++++- 10 files changed, 209 insertions(+), 47 deletions(-) diff --git a/.github/workflows/claude-implement-fixes.yml b/.github/workflows/claude-implement-fixes.yml index f88d9231..e3e97bc9 100644 --- a/.github/workflows/claude-implement-fixes.yml +++ b/.github/workflows/claude-implement-fixes.yml @@ -100,6 +100,7 @@ jobs: with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} claude_args: | + --model ${{ vars.CLAUDE_MODEL || 'claude-opus-4-7[1m]' }} --permission-mode bypassPermissions prompt: | A reviewer asked you to address review feedback on this PR. @@ -111,16 +112,33 @@ jobs: 2. For each actionable comment containing `@claude fix`, implement the fix on the PR's branch. 3. Skip comments that are questions, taste preferences, or already addressed. - 4. Run the test command from CLAUDE.md before pushing. + 4. Decide whether to run tests: + - If the diff is purely documentation, comments, formatting, + string-literal text, or otherwise CANNOT change runtime behavior, + you MAY skip tests. Be honest about confidence — when in doubt, run. + If you skip, the commit body MUST contain a line of the form: + tests: skipped — + - Otherwise (changes touching imports, function bodies, control + flow, types, configs read at runtime, dependencies, or build + manifests): run `pytest -m "not gpu" -n auto`. Scope to the + changed subtree where possible (e.g. `pytest tests/policies/test_pi05.py` + for a pi05 change) to keep the run fast. The commit body MUST contain: + tests: passed — + If tests fail: do NOT push; reply on the relevant PR comment + explaining the failure and stop. 5. Make ONE commit at the end of the session that addresses every comment you decided to act on — do NOT push one commit per comment. Subject line (must be < 80 chars per CLAUDE.md): [claude-fix] address review feedback on #${{ github.event.issue.number || github.event.pull_request.number }} - Commit body: a bulleted list, one bullet per addressed comment: + Commit body: a bulleted list, one bullet per addressed comment, + followed by the `tests:` line from step 4: - addresses @ (): + ... + tests: passed — pytest -m "not gpu" tests/policies/test_pi05.py + (or: tests: skipped — comment-only change, no runtime impact) Then push the single commit to the PR branch. 6. Reply individually to each addressed comment on the PR with diff --git a/.github/workflows/claude-pr-review.yml b/.github/workflows/claude-pr-review.yml index def19ebe..68ef60c8 100644 --- a/.github/workflows/claude-pr-review.yml +++ b/.github/workflows/claude-pr-review.yml @@ -50,7 +50,7 @@ jobs: with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} claude_args: | - --model claude-opus-4-7 + --model ${{ vars.CLAUDE_MODEL || 'claude-opus-4-7[1m]' }} --permission-mode bypassPermissions prompt: | You are reviewing PR #${{ github.event.pull_request.number }} in diff --git a/.github/workflows/cpu_test.yml b/.github/workflows/cpu_test.yml index 2ed9b5f5..bc6f8be9 100644 --- a/.github/workflows/cpu_test.yml +++ b/.github/workflows/cpu_test.yml @@ -102,7 +102,8 @@ jobs: python3 -c "import sys; print(sys.path)" python3 -c "import libero.libero" && echo "LIBERO config set successfully." echo "Running cpu based pytest and generating coverage report..." - pytest -m "not gpu" -n auto -v --cov=lerobot/ --cov-report=xml:cpu_test/cpu_test.xml --ignore=tests/planner/test_planner.py --ignore tests/utils/test_libero_utils.py --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_async_vector_env --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_sync_vector_env tests/ + # TODO(#210): drop --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py once pi07 migrates to SpaceTimeSiglipVideoEncoder (#192). + pytest -m "not gpu" -n auto -v --cov=lerobot/ --cov-report=xml:cpu_test/cpu_test.xml --ignore=tests/planner/test_planner.py --ignore tests/utils/test_libero_utils.py --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_async_vector_env --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_sync_vector_env tests/ echo "Pytest execution and coverage report generation completed." - name: Upload coverage reports diff --git a/.github/workflows/extract-claude-lessons.yml b/.github/workflows/extract-claude-lessons.yml index 5826321e..e0a7cbf0 100644 --- a/.github/workflows/extract-claude-lessons.yml +++ b/.github/workflows/extract-claude-lessons.yml @@ -25,9 +25,13 @@ permissions: jobs: extract-lessons: + # Gate on the head branch name, not user.login: in this repo Claude Code + # pushes to a `claude/*` branch and a human opens the PR, so the PR's + # `user.login` never contains 'claude'. The branch prefix is the reliable + # signal that Claude touched the PR. if: >- github.event.pull_request.merged == true - && contains(github.event.pull_request.user.login, 'claude') + && startsWith(github.event.pull_request.head.ref, 'claude/') && !startsWith(github.event.pull_request.title, 'chore(claude): learn from') runs-on: ubuntu-latest timeout-minutes: 20 @@ -39,6 +43,7 @@ jobs: with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} claude_args: | + --model ${{ vars.CLAUDE_MODEL || 'claude-opus-4-7[1m]' }} --permission-mode bypassPermissions prompt: | Review the comments on this merged PR. If any reviewer feedback diff --git a/.github/workflows/gpu_test.yml b/.github/workflows/gpu_test.yml index 453db4b5..9c52a9f5 100644 --- a/.github/workflows/gpu_test.yml +++ b/.github/workflows/gpu_test.yml @@ -91,7 +91,8 @@ jobs: source .venv/bin/activate mkdir -p /tmp/libero-assets/libero/libero export LIBERO_CONFIG_PATH="$(pwd)/.github/assets/libero" - pytest -m "gpu" -n 0 -v tests/ + # TODO(#210): drop --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py once pi07 migrates to SpaceTimeSiglipVideoEncoder (#192). + pytest -m "gpu" -n 0 -v --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py tests/ stop-runner: name: Stop GPU Runner diff --git a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py index 2b8f372f..2af13448 100644 --- a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py +++ b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py @@ -888,9 +888,16 @@ def embed_prefix( Gemma 3 embedding layer. **Concatenation order** (training when memory and response are provided): - ``[images | language | metadata | ";\\n " | "Updated Memory: " | memory_tokens | + ``[images | language | metadata? | ";\\n "? | "Updated Memory: " | memory_tokens | "Subtask: " | response_tokens]`` + ``";\\n "`` is gated on ``metadata_tokens`` — it serves as the metadata → + ``"Updated Memory:"`` separator, so when no metadata is provided there is + nothing to terminate and emitting it would dangle spurious tokens. The + ``"Updated Memory: "`` anchor itself is unconditional because inference + relies on it as the autoregressive starting point for memory decoding + (memory_tokens is None at inference by design). + When ``memory_tokens`` / ``response_tokens`` are omitted (inference), only the fixed spans before those segments are present; memory and subtask text are filled in via KV-cache decoding plus an explicit ``"Subtask: "`` injection before response AR. @@ -898,7 +905,7 @@ def embed_prefix( Attention pattern (via ``att_masks`` cumsums): - Image + language tokens: bidirectional (``0``). - Metadata (if present): new bidirectional block (``[1, 0, …, 0]``). - - ``";\\n "`` (same string as ``encode(";\n ", add_special_tokens=False)``): continues previous block (``0``). + - ``";\\n "`` (only when metadata present): continues previous block (``0``). - ``"Updated Memory: "``: new bidirectional block (``[1, 0, …, 0]``). - Memory token slots: causal segment (``1`` per slot). - ``"Subtask: "`` (training): new block then causal continuation within span. @@ -973,20 +980,25 @@ def embed_prefix( pad_masks.append(metadata_masks) att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) - prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) - prefix_end_tokens = torch.tensor( - [prefix_end_indicator_ids] * bsize, - device=lang_tokens.device, - dtype=torch.long, - ) - prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) + # ";\n " is the metadata -> "Updated Memory:" separator. With no metadata, + # there is nothing to terminate, so omit it; "Updated Memory:" still anchors + # AR memory decoding either way. + prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) + prefix_end_tokens = torch.tensor( + [prefix_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) - num_prefix_end_embs = prefix_end_emb.shape[1] - prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) + num_prefix_end_embs = prefix_end_emb.shape[1] + prefix_end_mask = torch.ones( + bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device + ) - embs.append(prefix_end_emb) - pad_masks.append(prefix_end_mask) - att_masks += [0] * num_prefix_end_embs + embs.append(prefix_end_emb) + pad_masks.append(prefix_end_mask) + att_masks += [0] * num_prefix_end_embs memory_start_indicator_ids = self.language_tokenizer.encode( "Updated Memory: ", add_special_tokens=False diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index 01dfeb41..ae18a54e 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -1152,17 +1152,24 @@ def embed_prefix( Concatenation order: - ``[videos | language | State: | state(T) | ", " | response | - Subgoal: | subgoal_images… | ", " | metadata | ";\\n " | + ``[videos | language | State: | state(T) | | response? | + Subgoal: | subgoal_images… | ", " | metadata? | ";\\n "? | ("Action:" + discrete_actions only when training)]`` + ```` is ``", "`` when at least one optional middle block + (response / subgoal / metadata) contributes real tokens, else ``":\\n"``. + The trailing ``";\\n "`` prefix-end is only emitted in the former case; + without optional content the state-end already serves as the separator + before ``"Action: "``, so appending another would dangle spurious tokens. + Attention pattern (via ``att_masks`` cumsums): - Video + language: bidirectional (``0``). - - ``State:``, projected state timestep tokens, comma after state: bidirectional (``0``). + - ``State:``, projected state timestep tokens, state-end separator: bidirectional (``0``). - Response spans: prefix-LM style block opening (``[1, 0, …]`` inside the segment). - ``Subgoal:``: new bidirectional block (``[1, 0, …]``). - Subgoal image patches per camera: bidirectional blocks (``[1, 0, …]``). - Commas/metadata / ``";\\n "``: mostly continued prefix blocks (see code). + - ``Action:`` indicator: each token is its own causal block (``[1, 1, …]``). - Discrete actions (training): causal ``1`` per timestep after ``Action:``. Args: @@ -1196,6 +1203,22 @@ def embed_prefix( att_masks = [] bsize = lang_tokens.shape[0] + # Whether any optional middle block (response / subgoal / metadata) will + # actually contribute real tokens to the prefix. When all are dropped the + # state-end separator collapses to ":\n" and the trailing prefix-end is + # omitted, eliminating spurious dangling tokens that would otherwise break + # the cumsum at the indicator -> first-discrete boundary. + has_response = ( + response_tokens is not None and response_masks is not None and bool(response_masks.any()) + ) + has_subgoal = ( + bool(subgoal_images) and bool(subgoal_img_masks) and any(bool(m.any()) for m in subgoal_img_masks) + ) + has_metadata = ( + metadata_tokens is not None and metadata_masks is not None and bool(metadata_masks.any()) + ) + has_any_optional = bool(has_response or has_subgoal or has_metadata) + for vid, vid_mask in zip(videos, vid_masks, strict=True): vid_emb = self.embed_video(vid) # (B, num_video_tokens, vlm_hidden) vid_emb = vid_emb.to(dtype=_preferred_dtype()) @@ -1252,7 +1275,10 @@ def embed_prefix( pad_masks.append(state_mask) att_masks += [0] * num_state_tokens # full attention with video and language - state_end_indicator_ids = self.language_tokenizer.encode(", ", add_special_tokens=False) + # When optional middle blocks follow, use ", " as a state -> optional separator; + # otherwise collapse to ":\n" so the trailing prefix-end can be omitted. + state_end_str = ", " if has_any_optional else ":\n" + state_end_indicator_ids = self.language_tokenizer.encode(state_end_str, add_special_tokens=False) state_end_tokens = torch.tensor( [state_end_indicator_ids] * bsize, device=lang_tokens.device, @@ -1352,20 +1378,26 @@ def embed_prefix( pad_masks.append(metadata_masks) att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) - prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) - prefix_end_tokens = torch.tensor( - [prefix_end_indicator_ids] * bsize, - device=lang_tokens.device, - dtype=torch.long, - ) - prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) + # Only emit the ";\n " prefix-end when at least one optional middle block was added + # above. With no optional content, the state-end already collapsed to ":\n" and acts + # as the separator before "Action: " — appending another would dangle spurious tokens. + if has_any_optional: + prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) + prefix_end_tokens = torch.tensor( + [prefix_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) - num_prefix_end_embs = prefix_end_emb.shape[1] - prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) + num_prefix_end_embs = prefix_end_emb.shape[1] + prefix_end_mask = torch.ones( + bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device + ) - embs.append(prefix_end_emb) - pad_masks.append(prefix_end_mask) - att_masks += [0] * num_prefix_end_embs + embs.append(prefix_end_emb) + pad_masks.append(prefix_end_mask) + att_masks += [0] * num_prefix_end_embs if discrete_actions is not None: discrete_action_start_indicator_ids = self.language_tokenizer.encode( @@ -1387,7 +1419,11 @@ def embed_prefix( embs.append(discrete_action_start_emb) pad_masks.append(discrete_action_start_mask) - att_masks += [1] + [0] * (num_discrete_action_start_embs - 1) + # Each "Action: " indicator token is its own causal block. Using + # [1] + [0]*(N-1) collapses them into a single bidirectional block, which + # shifts the cumsum at the indicator -> first-discrete boundary by N-1 + # and breaks the discrete-action CE loss. + att_masks += [1] * num_discrete_action_start_embs discrete_action_emb = self.gemma3_with_expert.embed_discrete_actions(discrete_actions) embs.append(discrete_action_emb.to(dtype=_preferred_dtype())) diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index 5bf32a57..f6b74538 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -513,7 +513,25 @@ def test_complete_pi06_pipeline_integration_smoke(lerobot_dataset_metadata): config.output_features = {k: ft for k, ft in features.items() if ft.type is FeatureType.ACTION} config.input_features = {k: ft for k, ft in features.items() if k not in config.output_features} - policy = PI06Policy(config, dataset_stats=lerobot_dataset_metadata.stats) + # The shared lerobot_dataset_metadata fixture carries actions stats shaped + # (50, 32) — matching the default PI06Config(chunk_size=50). This test + # uses chunk_size=10 to keep the model small, so override the actions + # stats to (10, 32) before constructing Normalize buffers; otherwise + # `(actions - min) / (max - min + EPS)` mismatches at dim=1 (actions is + # (B, 10, 32) but the buffer is (50, 32)). + import copy + + import numpy as np + + dataset_stats = copy.deepcopy(lerobot_dataset_metadata.stats) + for k in ("max", "mean", "min", "std"): + dataset_stats["actions"][k] = np.full( + (config.chunk_size, 32), + float(dataset_stats["actions"][k].flatten()[0]), + dtype=np.float32, + ) + + policy = PI06Policy(config, dataset_stats=dataset_stats) policy.to(dtype=torch.bfloat16, device="cuda") batch = { diff --git a/tests/policies/test_pi07_cpu.py b/tests/policies/test_pi07_cpu.py index 85dea0eb..35b49df9 100644 --- a/tests/policies/test_pi07_cpu.py +++ b/tests/policies/test_pi07_cpu.py @@ -634,7 +634,12 @@ class TestEmbedPrefixConditionalGuards: def test_all_optional_blocks_absent_skips_emission(self): """All-False response_masks + no subgoal_images + all-False metadata_masks → - the prefix collapses to ``videos + lang + State: + state + ", " + ";\\n "``. + the prefix collapses to ``videos + lang + State: + state + ":\\n"``. + + With ``has_any_optional == False`` the state-end separator collapses + from ``", "`` to ``":\\n"`` (same fake-tokenizer length: 2 tokens) and + the trailing ``";\\n "`` prefix-end is omitted entirely so it cannot + dangle as a spurious separator before ``"Action: "``. """ method = _embed_prefix_method() fake = _make_fake_flow_matching() @@ -658,12 +663,11 @@ def test_all_optional_blocks_absent_skips_emission(self): # lang: 3 (prompt_len) # "State: ": 2 # state(T=1): 1 - # ", ": 2 - # ";\n ": 2 - # Total = 13 - assert embs.shape == (bsize, 13, 4) - assert pad_masks.shape == (bsize, 13) - assert att_masks.shape == (bsize, 13) + # ":\n": 2 (state-end; collapsed from ", " because no optionals) + # Total = 11 (";\n " prefix-end is omitted with no optional content) + assert embs.shape == (bsize, 11, 4) + assert pad_masks.shape == (bsize, 11) + assert att_masks.shape == (bsize, 11) # No causal boundaries (1's) anywhere in the att_masks — every block # should remain bidirectional with all-skip optional blocks. assert int(att_masks[0].sum().item()) == 0, ( @@ -744,3 +748,57 @@ def test_response_mask_any_true_emits_block(self): assert per_sample_sum == 1, ( f"expected exactly one causal boundary from the response block opening, got {per_sample_sum}" ) + + def test_discrete_actions_indicator_uses_per_token_causal_blocks(self): + """The ``"Action: "`` indicator must use ``[1]*N`` (one causal block per + token), not ``[1] + [0]*(N-1)`` (single bidirectional block). + + The buggy ``[1] + [0]*(N-1)`` pattern collapses the indicator into one + bidirectional block, shifting the cumsum at the indicator -> first-discrete + boundary by N-1 and breaking the discrete-action CE loss. This test pins + the tail of ``att_masks`` to all 1's after the indicator, so a regression + to the old pattern fails immediately. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + kwargs = _build_default_inputs(batch_size=bsize) + # No optional middle blocks — keeps the prefix layout deterministic so + # the indicator + discrete-action span sits exactly at the tail. + kwargs["subgoal_images"] = [] + kwargs["subgoal_img_masks"] = [] + + num_action_tokens = 3 + kwargs["discrete_actions"] = torch.zeros(bsize, num_action_tokens, dtype=torch.long) + kwargs["discrete_action_masks"] = torch.ones(bsize, num_action_tokens, dtype=torch.bool) + + _, pad_masks, att_masks = method(fake, **kwargs) + + # Expected layout (no optional blocks; fake tokenizer encodes every + # indicator phrase to 2 tokens; ``discrete_action_emb`` matches its + # input shape): + # videos(3) + lang(3) + "State: "(2) + state(1) + ":\n"(2) + # + "Action: "(2) + discrete_actions(3) = 16 + num_indicator_tokens = 2 + base_prefix_len = 11 + total_len = base_prefix_len + num_indicator_tokens + num_action_tokens + assert att_masks.shape == (bsize, total_len) + assert pad_masks.shape == (bsize, total_len) + + # The first ``base_prefix_len`` positions are bidirectional (no causal + # boundaries) — same invariant as test_all_optional_blocks_absent. + assert int(att_masks[0, :base_prefix_len].sum().item()) == 0 + + # Tail invariant: every position from the indicator onward must be 1 + # (per-token causal blocks for the indicator + one causal step per + # discrete action). A regression to ``[1] + [0]*(N-1)`` on the + # indicator would put zeros at indices ``base_prefix_len + 1 .. + # base_prefix_len + num_indicator_tokens - 1`` and the assertion below + # would fail. + tail = att_masks[0, base_prefix_len:] + assert int(tail.sum().item()) == num_indicator_tokens + num_action_tokens, ( + f"expected all-ones tail of length {num_indicator_tokens + num_action_tokens} " + f"(indicator + discrete actions), got {tail.tolist()} — a zero in the indicator " + "span signals a regression to the old [1]+[0]*(N-1) pattern, which shifts the " + "cumsum at the indicator -> first-discrete boundary and breaks the CE loss." + ) diff --git a/tests/policies/test_pi07_low_level_planner.py b/tests/policies/test_pi07_low_level_planner.py index 1b233c13..52ef1dc3 100644 --- a/tests/policies/test_pi07_low_level_planner.py +++ b/tests/policies/test_pi07_low_level_planner.py @@ -291,7 +291,14 @@ def _verify_position_ids( if inference_mode: prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] else: - prefix_offsets = torch.sum(prefix_pad_masks[:, :-DISCRETE_ACTION_MAX_LENGTH], dim=-1)[:, None] + # Training: model's prefix_offsets exclude both the "Action: " indicator + # and the discrete-action span from cross-attention (matches pi05's + # discrete_action_indicator_max_length logic) so the action expert sees + # the same prefix length at train and inference. + action_lead_len = self._indicator_lens(tokenizer)["action_lead"] + prefix_offsets = torch.sum( + prefix_pad_masks[:, : -(action_lead_len + DISCRETE_ACTION_MAX_LENGTH)], dim=-1 + )[:, None] expected_suffix = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 assert torch.equal(suffix_position_ids, expected_suffix) @@ -312,12 +319,16 @@ def _verify_action_expert_attention_mask( prefix_pad_masks, suffix_pad_masks, suffix_att_masks, + tokenizer, inference_mode=False, ): if inference_mode: num_cross = prefix_pad_masks.shape[1] else: - num_cross = prefix_pad_masks.shape[1] - DISCRETE_ACTION_MAX_LENGTH + # Training: cross-attention excludes both the "Action: " indicator and + # the discrete-action span (mirrors the prefix_offsets logic above). + action_lead_len = self._indicator_lens(tokenizer)["action_lead"] + num_cross = prefix_pad_masks.shape[1] - action_lead_len - DISCRETE_ACTION_MAX_LENGTH expected = make_att_2d_masks( suffix_pad_masks, @@ -478,6 +489,7 @@ def capture_embed_suffix(*args, **kwargs): captured["prefix_pad_masks"], captured["suffix_pad_masks"], captured["suffix_att_masks"], + tokenizer, ) assert isinstance(loss, dict) @@ -586,6 +598,7 @@ def capture_embed_suffix_infer(*args, **kwargs): captured_infer["prefix_pad_masks"], captured_infer["suffix_pad_masks"], captured_infer["suffix_att_masks"], + tokenizer, inference_mode=True, ) From 7ed41af73993e34c6ce3fa1575d588f3bcf299fa Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sat, 2 May 2026 06:41:43 +0000 Subject: [PATCH 8/8] chore(claude): learn from #229 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add hard rule 4: pin training-path layout fixes with a CPU unit test. PR #229 review surfaced that the `[1]+[0]*(N-1)` -> `[1]*N` att_masks fix — the most material correctness change in the PR, since it shifts the cumsum at the indicator -> first-discrete boundary and changes the discrete-action CE loss — initially shipped without a CPU test pinning the pattern. It would have only been caught by deferred GPU and nightly regression tests. Author added a CPU assertion in dda4446 after review. Rule 3 (determinism) doesn't cover this case: two seeded runs of the new code agree, but determinism cannot tell you the layout is correct. A pinned CPU test does. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 3551b31c..168d373b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,6 +20,8 @@ These override defaults — read them before running anything. 3. **Verify determinism on any change to the training loop or model.** ML bugs hide in stochasticity: a bad change can still produce loss curves that *look* plausible. After touching anything in `scripts/train.py`, `policies/*/modeling_*.py`, `optim/`, or `datasets/sampler.py`, run a smoke config twice with the same `seed` and confirm the per-step loss series is bit-identical (not just "close"). Seeding utilities live in `src/opentau/utils/random_utils.py` (`set_seed`, `serialize_python_rng_state`, etc.). If two seeded runs diverge, that's a bug — investigate before claiming the change works. +4. **Pin training-path layout fixes with a CPU unit test.** When you fix a detail in `embed_prefix` / `embed_suffix` / position-id construction that shifts a cumsum boundary or alters which tokens fall inside vs. outside a causal block — `att_masks` patterns (e.g. `[1]+[0]*(N-1)` vs `[1]*N` per-token causal blocks), `position_ids` slices, prefix/suffix layout — add a CPU unit test in the policy's `test_*_cpu.py` that asserts the exact pattern (token count, the att_masks tail, indicator boundaries). Determinism (rule 3) only proves two runs agree; it does *not* prove the layout is correct. GPU integration and nightly regression tests run on a delayed schedule, so a layout regression can merge silently if only those gate it. A pinned CPU assertion fails on the same PR. + ## Project overview OpenTau is Tensor's open-source PyTorch training toolchain for vision-language-action (VLA) models — a fork of LeRobot with extra capabilities (heterogeneous-dataset co-training, discrete actions for π₀.₅, knowledge insulation, dropout in PaliGemma, π*₀.₆-style RL, validation splits, profilers). Any LeRobot-compliant policy and dataset works directly. Pinned to **Python 3.10**.