Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/opentau/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
from opentau.policies.pi05.configuration_pi05 import PI05Config
from opentau.policies.pi05_continuous_state.configuration_pi05 import PI05ContinuousStateConfig
from opentau.policies.pi05_mem.configuration_pi05 import PI05MemConfig
from opentau.policies.pi07_paligemma.high_level_planner.configuration_pi07_high_level import (
PI07HighLevelPlannerConfig,
)
from opentau.policies.pi07_paligemma.low_level_planner.configuration_pi07_low_level import (
PI07lowlevelPlannerConfig,
)
from opentau.policies.pretrained import PreTrainedPolicy
from opentau.policies.value.configuration_value import ValueConfig

Expand Down Expand Up @@ -69,6 +75,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from opentau.policies.pi05_mem.modeling_pi05 import PI05MemPolicy

return PI05MemPolicy
elif name == "pi07_paligemma_high_level_planner":
from opentau.policies.pi07_paligemma.high_level_planner.modeling_pi07_high_level import (
PI07HighLevelPlannerPolicy,
)

return PI07HighLevelPlannerPolicy
elif name == "pi07_paligemma_low_level_planner":
from opentau.policies.pi07_paligemma.low_level_planner.modeling_pi07_low_level import (
PI07LowLevelPlannerPolicy,
)

return PI07LowLevelPlannerPolicy
elif name == "value":
from opentau.policies.value.modeling_value import ValueFunction

Expand Down Expand Up @@ -98,6 +116,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI05ContinuousStateConfig(**kwargs)
elif policy_type == "pi05_mem":
return PI05MemConfig(**kwargs)
elif policy_type == "pi07_paligemma_high_level_planner":
return PI07HighLevelPlannerConfig(**kwargs)
elif policy_type == "pi07_paligemma_low_level_planner":
return PI07lowlevelPlannerConfig(**kwargs)
elif policy_type == "value":
return ValueConfig(**kwargs)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
optimization, scheduling, and data processing.
"""

import logging
from dataclasses import dataclass, field
from typing import Literal

from opentau.configs.policies import PreTrainedConfig
from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
Expand Down Expand Up @@ -66,9 +64,6 @@ class PI07HighLevelPlannerConfig(PreTrainedConfig):
strings. Defaults to 52.
dropout: Dropout rate applied in the transformer expert.
Defaults to 0.1.
init_strategy: Weight initialization strategy. One of ``"no_init"``,
``"full_he_init"``, ``"expert_only_he_init"``. Defaults to
``"full_he_init"``.
attention_implementation: Attention backend — ``"eager"`` or
``"fa2"`` (Flash Attention 2). Defaults to ``"eager"``.
freeze_vision_encoder: Whether to freeze the SigLIP vision encoder
Expand Down Expand Up @@ -118,9 +113,6 @@ class PI07HighLevelPlannerConfig(PreTrainedConfig):
# Dropout
dropout: float = 0.1

# Initialization strategy
init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"

# Attention utils
attention_implementation: str = "eager"

Expand All @@ -142,8 +134,6 @@ def __post_init__(self):

Raises:
ValueError: If ``n_obs_steps`` is not 1.
AssertionError: If ``init_strategy`` is not one of the allowed
values.
"""
super().__post_init__()

Expand All @@ -152,19 +142,6 @@ def __post_init__(self):
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)

assert self.init_strategy in ["no_init", "full_he_init", "expert_only_he_init"], (
f"Invalid init strategy: {self.init_strategy} must be one of ['no_init', 'full_he_init', 'expert_only_he_init']"
)

if self.init_strategy == "expert_only_he_init" and self.pretrained_path == "lerobot/pi05":
raise ValueError(
"You cannot load pretrained PI0 model when init_strategy is 'expert_only_he_init' due to differences in PaliGemma tokenizer vocab sizes."
)

if self.pretrained_path is not None and self.pretrained_path != "lerobot/pi05":
logging.info("Setting init_strategy to 'no_init' because we are resuming from a checkpoint.")
self.init_strategy = "no_init"

def validate_features(self) -> None:
"""Adds placeholder camera features for empty camera slots.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
Args:
batch: Batch of training data. Expected keys include images,
``"prompt"``, ``"state"``, ``"past_memory"``, ``"response"``,
and ``"memory"``.
and ``"next_memory"``.

Returns:
A dict with ``"MSE"`` (always zero, kept for interface
Expand All @@ -512,7 +512,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
) # in response_masks we have True for real tokens and False for padded tokens

# memory prediction is to predict the memory . It will attend to image and language inputs.
memory_tokens, memory_masks = self.prepare_memory(
memory_tokens, memory_masks = self.prepare_next_memory(
batch
) # in memory_masks we have True for real tokens and False for padded tokens
losses = self.model.forward(
Expand Down Expand Up @@ -676,19 +676,19 @@ def prepare_metadata(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
batch["speed_is_pad"],
batch["quality_is_pad"],
batch["mistake_is_pad"],
strict=False,
strict=True,
):
meta = ""
segments = []
if not speed_is_pad:
meta += f"Speed: {str(speed.item())} "
segments.append(f"Speed: {str(speed.item())}")

if not quality_is_pad:
meta += f"Quality: {str(quality.item())} "
segments.append(f"Quality: {str(quality.item())}")

if not mistake_is_pad:
meta += f"Mistake: {str(mistake.item())}"
segments.append(f"Mistake: {str(mistake.item())}")

metadata.append(f"Metadata: {meta}<eos>")
metadata.append(f"Metadata: {' '.join(segments)}<eos>")

device = batch["state"].device
tokenized_metadata = self.language_tokenizer.__call__(
Expand Down Expand Up @@ -741,15 +741,15 @@ def prepare_response(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:

return response_tokens, response_masks

def prepare_memory(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
def prepare_next_memory(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
"""Tokenizes the target updated memory for training.

Wraps each memory string with an ``<eos>`` suffix, then tokenizes and
pads to ``memory_max_length``.

Args:
batch: Batch containing ``"memory"`` (list of memory strings) and
``"state"`` (used only to determine the device).
batch: Batch containing ``"next_memory"`` (list of target memory
strings) and ``"state"`` (used only to determine the device).

Returns:
A tuple ``(memory_tokens, memory_masks)`` where:
Expand All @@ -760,10 +760,10 @@ def prepare_memory(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
"""

device = batch["state"].device
memory = batch["memory"]
next_memory = batch["next_memory"]

# if '' is found in response then response is not for loss calculation (used for robotic dataset with no subtask), so add pad token to the response.
memory_prompt = [f"{mem}<eos>" for mem in memory]
# if '' is found in next_memory then it is not for loss calculation (used for robotic dataset with no subtask), so add pad token.
memory_prompt = [f"{mem}<eos>" for mem in next_memory]

tokenized_memory = self.language_tokenizer.__call__(
memory_prompt,
Expand Down Expand Up @@ -824,50 +824,18 @@ def __init__(self, config: PI07HighLevelPlannerConfig, discrete_action_vocab_siz
super().__init__()
self.config = config

load_pretrained_paligemma = (
self.config.init_strategy == "expert_only_he_init"
) # only load pretrained paligemma if we are He-initializing the expert only
paligemma_with_expert_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=False,
attention_implementation=self.config.attention_implementation,
load_pretrained_paligemma=load_pretrained_paligemma,
load_pretrained_paligemma=False,
discrete_action_vocab_size=discrete_action_vocab_size,
dropout=self.config.dropout,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_expert_config)

self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")

self._init_model()

def _init_weights(self, module: nn.Module) -> None:
"""Initialize weights using He (Kaiming) initialization.

Args:
module: The module to initialize.
"""
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)

def _init_model(self) -> None:
"""Initialize the model weights based on the configuration."""
if self.config.init_strategy == "no_init":
return
elif self.config.init_strategy == "full_he_init":
for m in self.modules():
self._init_weights(m)
elif self.config.init_strategy == "expert_only_he_init":
for m in self.paligemma_with_expert.gemma_expert.modules():
self._init_weights(m)
else:
raise ValueError(f"Invalid init strategy: {self.config.init_strategy}")

def embed_prefix(
self,
images: list[Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
subgoal image, and metadata conditioning.
"""

import logging
from dataclasses import dataclass, field
from typing import Literal

from opentau.configs.policies import PreTrainedConfig
from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
Expand Down Expand Up @@ -70,8 +68,6 @@ class PI07lowlevelPlannerConfig(PreTrainedConfig):
num_steps: Number of flow-matching denoising steps. Defaults to 10.
max_delay: Maximum number of prefix action steps for real-time
inference. Defaults to 0.
init_strategy: Weight initialization strategy. Defaults to
``"full_he_init"``.
attention_implementation: Attention backend (``"eager"`` or
``"fa2"``). Defaults to ``"eager"``.
freeze_vision_encoder: Whether to freeze V-JEPA2. Defaults to True.
Expand Down Expand Up @@ -142,9 +138,6 @@ class PI07lowlevelPlannerConfig(PreTrainedConfig):
# Real Time Inference
max_delay: int = 0

# Initialization strategy
init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"

# Attention utils
attention_implementation: str = "eager"

Expand Down Expand Up @@ -205,19 +198,6 @@ def __post_init__(self):
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)

assert self.init_strategy in ["no_init", "full_he_init", "expert_only_he_init"], (
f"Invalid init strategy: {self.init_strategy} must be one of ['no_init', 'full_he_init', 'expert_only_he_init']"
)

if self.init_strategy == "expert_only_he_init" and self.pretrained_path == "lerobot/pi05":
raise ValueError(
"You cannot load pretrained PI05 model when init_strategy is 'expert_only_he_init' due to differences in PaliGemma tokenizer vocab sizes."
)

if self.pretrained_path is not None and self.pretrained_path != "lerobot/pi05":
logging.info("Setting init_strategy to 'no_init' because we are resuming from a checkpoint.")
self.init_strategy = "no_init"

if self.max_delay > self.chunk_size:
raise ValueError(
f"The max delay must be less than or equal to the chunk size. Got {self.max_delay} for `max_delay` and {self.chunk_size} for `chunk_size`."
Expand Down
Loading
Loading