Skip to content
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
"src/opentau/scripts/grpc/server.py" = ["N802"]
# Uppercase names for rotation matrices follow standard math convention
"src/opentau/scripts/recordhuman_to_lerobot.py" = ["N803", "N806"]
# Math-convention names (B, S, Hq, Dh, O, LSE, P, ...) for ring attention.
"src/opentau/policies/pi07/ring_attention.py" = ["N803", "N806", "E741"]
"src/opentau/scripts/ringattn_experiments/*.py" = ["N803", "N806", "E741"]

[tool.bandit]
exclude_dirs = [
Expand Down
41 changes: 41 additions & 0 deletions src/opentau/configs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ class TrainPipelineConfig(HubMixin):
dataloader_batch_size: int | None = None
# Prefetch factor for the dataloader.
prefetch_factor: int | None = None
# Ring-attention sub-group size. When set (and the policy enables
# ``attention_implementation="ring"``), splits the WORLD process group
# into ``world_size / ring_group_size`` ring sub-groups of this size.
# Sequence parallelism happens within each ring; DP-replication happens
# across rings (ZeRO over WORLD scales the gradient correctly via a
# ``loss *= ring_group_size`` pre-backward, see train.py). None keeps the
# legacy single-ring behaviour (ring spans WORLD). Must divide world_size.
ring_group_size: int | None = None
steps: int = 100_000
log_freq: int = 200
save_checkpoint: bool = True
Expand Down Expand Up @@ -287,6 +295,39 @@ def validate(self):
"to match policy.n_obs_steps."
)

# Ring sub-group validation. ``validate()`` runs before the Accelerator
# initialises distributed, so probe ``WORLD_SIZE`` from the env (set
# by torchrun / accelerate) rather than ``dist.get_world_size()``.
if self.ring_group_size is not None:
import os as _os

if self.ring_group_size < 1:
raise ValueError(f"ring_group_size must be >= 1; got {self.ring_group_size}.")
world_size_env = _os.environ.get("WORLD_SIZE")
if world_size_env is not None:
world_size = int(world_size_env)
if world_size % self.ring_group_size != 0:
raise ValueError(
f"world_size ({world_size}) must be divisible by "
f"ring_group_size ({self.ring_group_size})."
)
# Ring sub-grouping is only meaningful when the policy actually
# consumes ring attention. pi07 exposes the flag directly on the
# policy config (it's plumbed into ``vlm_config`` in
# ``__post_init__``); other policies (pi05, etc.) don't have the
# field at all — for those we treat ring_group_size as a no-op
# rather than erroring, so a shared launcher script can leave
# the flag set without breaking non-pi07 runs.
if self.policy is not None:
impl = getattr(self.policy, "attention_implementation", None)
if impl is not None and impl != "ring":
raise ValueError(
"ring_group_size is set but "
f"policy.attention_implementation = {impl!r}. "
"Set attention_implementation='ring' or unset "
"ring_group_size."
)

@classmethod
def __get_path_fields__(cls) -> list[str]:
"""Get list of field names that support path-based loading.
Expand Down
14 changes: 13 additions & 1 deletion src/opentau/datasets/dataset_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,23 @@ def _get_worker_name_mapping_overrides(self) -> dict[str, dict[str, str]]:
overrides[dataset_cfg.repo_id] = dataset_cfg.data_features_name_mapping
return overrides

def get_dataloader(self) -> DataLoader:
def get_dataloader(self, sampler_seed: int | None = None) -> DataLoader:
"""Create and return a PyTorch DataLoader with weighted sampling.

Uses HierarchicalSampler to first sample a dataset according to weights,
then uniformly sample within that dataset.

Args:
sampler_seed: Optional seed for the HierarchicalSampler's RNG.
Identical seeds across ranks produce identical sample streams,
which is what 2D ring + DP parallelism needs: ranks in the same
ring sub-group should see the same batch (so sequence
parallelism is computing one consistent batch), while ranks in
different DP sub-groups should see different batches. Callers
that want this should pass ``base_seed + dp_rank * large_prime``.
None leaves the sampler unseeded (every rank's independent
``torch.Generator()`` state), preserving the pre-ring behaviour.

Returns:
DataLoader configured for weighted hierarchical sampling.

Expand Down Expand Up @@ -535,6 +546,7 @@ def get_dataloader(self) -> DataLoader:
dataset_lengths=ds_lengths,
dataset_probs=self.dataset_weights,
num_samples=num_samples_per_epoch,
seed=sampler_seed,
)

dataloader = DataLoader(
Expand Down
172 changes: 161 additions & 11 deletions src/opentau/policies/pi07/gemma3_with_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma

from opentau.policies.pi07.ring_attention import (
gather_seq,
ring_attention_forward,
ring_world_size,
)

# Ensure the Gemma-v1 AdaRMS / gated-residual patches are live before we
# construct an action expert. Import for side effects only.
from opentau.utils import transformers_patch # noqa: F401
Expand Down Expand Up @@ -129,13 +135,24 @@ def __init__(
Defaults to a ~860M-parameter Gemma with AdaRMS enabled.
freeze_vision_encoder: Freeze the SigLIP tower during training.
train_expert_only: Only update the expert and its heads.
attention_implementation: "eager", "sdpa", or "fa2". "fa2" is not
implemented and falls back to eager with a warning. "sdpa"
dispatches to ``torch.nn.functional.scaled_dot_product_attention``;
see the per-layer note about Gemma 3's interleaved local/global
attention_implementation: "eager", "sdpa", "ring", or "fa2".
"fa2" is not implemented and falls back to eager with a
warning. "sdpa" dispatches to
``torch.nn.functional.scaled_dot_product_attention``; see
the per-layer note about Gemma 3's interleaved local/global
pattern in ``forward()`` — π0.7 deliberately keeps the same
block-causal mask at every layer, so the SDPA call sees a
regular bool mask and takes the standard fused path.
regular bool mask and takes the standard fused path. "ring"
shards the sequence axis across the active distributed
process group and computes attention with paper-style ring
rotation + online softmax (see
``opentau.policies.pi07.ring_attention``). The ring path
also subsumes per-layer activation rematerialisation — full
``(S, S)`` attention scores are never stored, so the legacy
``gradient_checkpointing`` flag is ignored under "ring". Use
"ring" only on the prefix forward pass (single-stream,
backbone-only); the action-expert suffix is short enough
that the path always falls back to SDPA there.
load_pretrained_gemma3: Whether to pull pretrained Gemma 3 weights
from the Hub (only recommended when He-initializing the expert).
discrete_action_vocab_size: FAST tokenizer vocab size.
Expand Down Expand Up @@ -284,10 +301,10 @@ def __init__(
"`train_expert_only=False` (the high-level-planner default) when "
"`disable_action_expert=True`."
)
if self.attention_implementation not in ["eager", "sdpa", "fa2"]:
if self.attention_implementation not in ["eager", "sdpa", "fa2", "ring"]:
raise ValueError(
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). "
"Expected 'eager', 'sdpa', or 'fa2'."
"Expected 'eager', 'sdpa', 'ring', or 'fa2'."
)
if self.attention_implementation == "fa2":
# fa2 has been considered but never implemented for pi07 because of
Expand Down Expand Up @@ -474,9 +491,18 @@ def forward(
if fill_kv_cache:
if n_cross_att_tokens is None:
raise ValueError("n_cross_att_tokens must be provided when fill_kv_cache is True")
# Under ring attention, k_concat / v_concat hold this rank's shard
# along the sequence axis. The cache feeds the (unsharded) suffix
# forward, so gather full K/V before storing. ``gather_seq`` is a
# no-op when the ring group has world size 1.
if ring_world_size() > 1:
k_full = gather_seq(k_concat, seq_dim=1)
v_full = gather_seq(v_concat, seq_dim=1)
else:
k_full, v_full = k_concat, v_concat
past_key_values[layer_idx] = {
"key_states": k_concat[:, :n_cross_att_tokens, :, :],
"value_states": v_concat[:, :n_cross_att_tokens, :, :],
"key_states": k_full[:, :n_cross_att_tokens, :, :],
"value_states": v_full[:, :n_cross_att_tokens, :, :],
}

# π0.7 keeps the prefix block-causal mask at every layer — the
Expand Down Expand Up @@ -843,16 +869,72 @@ def get_attention_interface(self):
layer (sliding window deliberately not enforced — see the comment
near ``apply_rope``), so SDPA sees a regular bool mask and does
not need a per-layer mask shape branch.
- ``"ring"``: paper-style ring attention — see
``opentau.policies.pi07.ring_attention``. Each rank holds 1/W
of the sequence; K/V rotate around the ring while an online
softmax accumulates the output per rank. Falls back to SDPA
transparently when world size is 1.
- ``"fa2"``: accepted for backward compatibility; falls back to
eager with a warning emitted at config validation time.
"""
impl = self.config.attention_implementation
if impl == "sdpa":
return self.sdpa_attention_forward
if impl == "ring":
# The model only enters the ring code path during the prefix
# forward (single-stream, fill_kv_cache=True). The suffix
# forward — short action-expert stream that cross-attends to
# the cached prefix K/V — still benefits from SDPA on full
# unsharded data. ``_ring_active`` is True only inside the
# sharded prefix forward (see Gemma3WithExpertModel.forward).
if getattr(self, "_ring_active", False):
return self.ring_attention_forward
return self.sdpa_attention_forward
# "eager" and legacy "fa2" both land here; "fa2" already warned during
# config construction.
return self.eager_attention_forward

def ring_attention_forward(
self,
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling: float | None = None,
) -> torch.Tensor:
"""Ring-attention interface adapter.

Bridges the existing ``(attention_mask, batch_size, head_dim, q, k, v,
scaling)`` per-layer call signature to the standalone
:func:`ring_attention_forward` in
``opentau.policies.pi07.ring_attention``. The standalone function
needs to know the GQA head counts, which live on the text config
but aren't passed through the per-layer signature; this method
supplies them.

Q / K / V tensors are this rank's local shards when the top-level
:meth:`Gemma3WithExpertModel.forward` has sharded the inputs along
the sequence axis (the only configuration in which ``"ring"`` does
anything useful). When the ring group has world size 1 (single GPU
/ eager tests) the underlying function transparently falls back to
SDPA on the full sequence.
"""
num_attention_heads = self._text_config.num_attention_heads
num_key_value_heads = self._text_config.num_key_value_heads
return ring_attention_forward(
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
num_query_heads=num_attention_heads,
num_kv_heads=num_key_value_heads,
scaling=scaling,
)

def eager_attention_forward(
self,
attention_mask: torch.Tensor,
Expand Down Expand Up @@ -1060,7 +1142,58 @@ def forward(
if fill_kv_cache and past_key_values is None:
past_key_values = {}

use_ckpt = self.config.gradient_checkpointing and self.training
# Ring attention sharding. Active when the user opted in via the config
# and we are running in a multi-rank group AND the call is a
# single-stream prefix forward — the expert-suffix forward is short
# enough that ring's per-rank fixed costs outweigh its memory savings,
# so it transparently routes through SDPA via the get_attention_interface
# branch on ``self._ring_active`` below.
ring_active = (
self.config.attention_implementation == "ring"
and ring_world_size() > 1
and inputs_embeds[1] is None
)
# The single-stream prefix forward holds the (long) backbone shard
# only; ``inputs_embeds[0]`` is therefore the tensor to shard.
ring_pad_len = 0
ring_original_seq_len = 0
if ring_active:
ws = ring_world_size()
backbone_embs = inputs_embeds[0]
ring_original_seq_len = backbone_embs.shape[1]
ring_pad_len = (-ring_original_seq_len) % ws
if ring_pad_len > 0:
# Pad inputs_embeds, attention_mask, position_ids to a multiple
# of world_size. attention_mask: False entries along both
# padded Q and K rows so padding contributes nothing to the
# softmax. position_ids: replicate the last valid id so RoPE
# on padded slots produces a finite (ignored) rotation.
pad_emb = nn.functional.pad(backbone_embs, (0, 0, 0, ring_pad_len))
inputs_embeds = [pad_emb, None]
attention_mask = nn.functional.pad(
attention_mask, (0, ring_pad_len, 0, ring_pad_len), value=False
)
last_pos = position_ids[:, -1:].expand(-1, ring_pad_len)
position_ids = torch.cat([position_ids, last_pos], dim=1)
# Shard inputs along the sequence axis. The attention mask stays
# replicated on every rank (it's small relative to activations)
# and ring_attention_forward slices it internally per ring step.
from opentau.policies.pi07.ring_attention import split_seq

shard = split_seq(inputs_embeds[0], seq_dim=1)
inputs_embeds = [shard, None]
position_ids = split_seq(position_ids, seq_dim=1)

# Tell the per-layer attention dispatch which branch to take. Layers
# capture ``get_attention_interface`` lazily on every call, so this
# flag is honoured by every layer in the loop below.
self._ring_active = ring_active

# Per-layer ``torch.utils.checkpoint`` is redundant under ring: the
# ring attention kernel already does blockwise rematerialisation
# (Section 3.2.2 of the paper) — full (Q, K) score matrices are
# never stored across the forward → backward boundary on any rank.
use_ckpt = self.config.gradient_checkpointing and self.training and not ring_active
for layer_idx, layer in enumerate(self.interleaved_layers):
if use_ckpt:
# use_reentrant=False is the modern, DDP-safe path; it
Expand Down Expand Up @@ -1098,7 +1231,9 @@ def forward(
layer_idx,
)

# Final norms.
# Final norms. RMSNorm is per-token so it commutes with sequence
# sharding — we apply it on the per-rank shards and gather *after*
# to minimise communicated bytes.
final_outputs: list[torch.Tensor | None] = []
for stream_idx, hidden_states in enumerate(inputs_embeds):
if hidden_states is None:
Expand All @@ -1110,6 +1245,21 @@ def forward(
out, _ = expert_norm(hidden_states, cond=adarms_cond[stream_idx])
final_outputs.append(out)

# Gather the backbone stream back to the full sequence length when
# we sharded at the top, and strip any padding we added to make the
# length divisible by world size. Callers see the same shape /
# semantics as the non-ring path.
if ring_active:
assert final_outputs[0] is not None
gathered = gather_seq(final_outputs[0], seq_dim=1)
if ring_pad_len > 0:
gathered = gathered[:, :ring_original_seq_len]
final_outputs[0] = gathered
# Clear the flag so subsequent (unsharded) forwards — e.g. the suffix
# forward in low-level training — see ``_ring_active=False`` and
# take the SDPA branch in ``get_attention_interface``.
self._ring_active = False

return final_outputs, past_key_values

# Gemma 3 structural accessors
Expand Down
Loading
Loading