Skip to content

feat(policies): add authentic pi06 policy with Gemma 3 4B backbone#178

Merged
shuheng-liu merged 7 commits into
mainfrom
claude/compare-pi-models-NgCLN
Apr 29, 2026
Merged

feat(policies): add authentic pi06 policy with Gemma 3 4B backbone#178
shuheng-liu merged 7 commits into
mainfrom
claude/compare-pi-models-NgCLN

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu commented Apr 23, 2026

What this does

Adds a new pi06 policy that ports Physical Intelligence's π0.6 architecture (Model Card, 2025-11-17; arXiv:2511.14759) into OpenTau, alongside supporting config, docs, and tests. Label: 🗃️ Feature.

What changed vs pi05

Aspect pi05 pi06
VLM backbone PaliGemma 3B (Gemma 2B text, 18 L) Gemma 3 4B (34 layers, 5:1 sliding/global ratio in pretraining)
Text hidden / heads / KV 2048 / 8 / 1 (MQA) 2560 / 8 / 4 (GQA), head_dim=256
Per-layer RoPE θ single θ=10 000 per-layer: θ=10 000 on local layers, θ=1 000 000 on global layers, applied uniformly to backbone and expert at the same layer (necessary for shared-attention dot products to stay in a consistent RoPE basis)
Per-layer attention mask block-causal prefix-LM same block-causal prefix-LM at every layer — Gemma 3's 1024-token sliding-window pattern is deliberately not enforced (see "Architectural choices" below)
Q/K layernorm none Gemma 3's q_norm / k_norm per attention block
Action expert 18 L Gemma, hidden 1024, ~300M 34 L Gemma-v1, hidden 1280, GQA-4, AdaRMS, ~860M
Image resolution 224×224 padded 448×448 padded, vision tower configured with image_size=448 so Gemma3MultiModalProjector reshapes correctly (32 patches/side → 256 mm tokens/view)
Flow-matching default num_steps=10 num_steps=5 (~63 ms / chunk on H100)
FAST + KI + Beta time unchanged unchanged

Training recipe (FAST discrete actions co-trained with flow matching, Knowledge Insulation gradient stop, Beta(1.5, 1.0) time sampler, block-causal prefix + bidirectional action suffix) is identical to pi05.

Architectural choices that aren't obvious from the model card

These were each surfaced and resolved during careful re-review while writing this PR — flagging them explicitly so future reviewers understand the reasoning:

  • Text-embedding scale. Gemma 3's embed_tokens is a Gemma3TextScaledWordEmbedding that already multiplies by √hidden_size internally. The pi05-style manual * math.sqrt(lang_emb_dim) would double-scale text tokens to ~51× the image-token magnitude — corrupting both the bidirectional prefix attention and the FAST/response cross-entropy heads. Removed (was #179, merged into this branch).
  • Sliding-window attention deliberately disabled. The model card says "bidirectional attention among all of the image tokens" — incompatible with Gemma 3's 1024-token window once you have 4 cameras × 256 image tokens. Every layer in Gemma3WithExpertModel.forward now receives the unmodified block-causal mask; local layers still rotate at θ=10 000 (preserving the pretrained RoPE basis), but their attention pattern matches the global layers'.
  • Per-layer RoPE θ shared across both streams. The shared-attention q · R(Δp) · k invariant only holds when both rotations use the same θ. Backbone Q/K rotates at the layer's θ; expert Q/K is forced to use the same value (its config's own rope_theta is documented as ignored at runtime).
  • Vision tower image_size matches input resolution. Gemma3MultiModalProjector hardcodes patches_per_image = image_size // patch_size, so a default image_size=896 would crash the projector reshape on the first 448×448 forward pass.

Implementation notes

  • gemma3_with_expert.py — runs a per-layer interleaved attention loop that concatenates backbone and expert Q/K/V along the sequence axis, honouring Gemma 3's q_norm / k_norm, per-layer RoPE θ, and the four-RMSNorm Gemma 3 block (input_layernorm, post_attention_layernorm, pre_feedforward_layernorm, post_feedforward_layernorm). The Gemma-v1 AdaRMS / _gated_residual / patched GemmaRMSNorm monkey-patches in opentau.utils.transformers_patch cover the expert path; the Gemma 3 backbone runs stock, so no new patches are introduced.
  • modeling_pi06.py — mirrors modeling_pi05.py structure with the Gemma 3 tokenizer (google/gemma-3-4b-pt) and a _fix_pytorch_state_dict_keys that also accepts legacy paligemma_with_expert.* prefixes as a warm-start path for users converting pi05 checkpoints.
  • Factory + READMEpi06 registered in opentau.policies.factory and opentau.available_policies; README updated with a pi06 bullet, comparison-table row, checkpoints-coming-soon placeholder, and a pointer at configs/examples/pi06_training_config.json.

How it was tested

tests/policies/test_pi06.py ships 26 CPU-only unit tests (run on each push):

  • Block-causal attention-mask semantics (pure bidirectional / pure causal / prefix-LM / padding rows+cols / cross-attention prepend).
  • apply_rope shape / dtype preservation and θ sensitivity (zero-position identity).
  • PI06Config defaults and validators.
  • Gemma3WithExpertConfig topology — Gemma 3 4B hidden/layer/head counts, 5:1 layer-type pattern, ~860M expert with AdaRMS on, GQA matched.
  • test_vision_image_size_matches_input_resolution + test_projector_accepts_448_inputs — config invariant + end-to-end Gemma3MultiModalProjector forward at 448×448.
  • TestRopeThetaSymmetryDuringForward.test_expert_uses_backbone_per_layer_theta — monkey-patches apply_rope to record per-call θ and asserts each layer uses the backbone's θ for both streams.
  • TestNoSlidingWindowEnforcement.test_per_layer_mask_equals_input_mask_on_both_layer_types — monkey-patches eager_attention_forward to verify both local and global layers receive the unmodified input mask.
  • resize_with_pad (448×448 default, aspect-ratio preservation) and FAST discrete-action padding/truncation.

Other:

  • tests/test_available.py updated so the available_policies invariant test picks up the new PI06Policy.
  • End-to-end GPU smoke test (test_complete_pi06_pipeline_integration_smoke) — marked @pytest.mark.slow + @pytest.mark.gpu, runs on the nightly GPU CI job.
  • Full local tests/policies/ CPU suite: 75 passed / 2 skipped / 6 deselected. Pre-commit: every hook green (ruff, ruff-format, pyupgrade, typos, gitleaks, bandit, license headers).

How to checkout & try? (for the reviewer)

git fetch origin claude/compare-pi-models-NgCLN
git checkout claude/compare-pi-models-NgCLN
pytest -sx tests/policies/test_pi06.py
# Point at a dataset and start a real training run:
opentau-train --config_path=configs/examples/pi06_training_config.json

Out of scope

  • No new monkey-patches in transformers_patch.py. The Gemma 3 backbone runs stock; the expert keeps using the already-patched Gemma v1.
  • No Flash-Attention-2 path — fa2 still raises NotImplementedError until we can validate numerics on the four-RMSNorm Gemma 3 block.
  • No π*0.6 RECAP / RL loop — that's a separate PR on pistar06.
  • No checkpoint conversion script from the (not-yet-released) official π0.6 weights; the _fix_pytorch_state_dict_keys hook is ready when they drop.

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: this PR adds a new policy, but the GPU integration test is deliberately guarded behind @pytest.mark.gpu and will be validated by the nightly gpu_test.yml workflow. Leaving the policy checkbox unchecked here so the CI check-checklist job does not require me to attest to GPU tests I haven't run myself.

https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi

claude added 2 commits April 23, 2026 07:47
Adds a new `pi06` policy that ports the π0.6 architecture from Physical
Intelligence (Model Card, Nov 17 2025; arXiv:2511.14759) into OpenTau.

Relative to `pi05` (PaliGemma-3B, 224×224, 10-step flow matching), `pi06`:
  * swaps the backbone to Gemma 3 4B (34 interleaved sliding-window/global
    layers, SigLIP-400m/14, head_dim 256, GQA with 4 KV heads);
  * enlarges the action expert to ~860M params so it matches the backbone
    depth (34 Gemma-v1 layers, hidden 1280, intermediate 5120, AdaRMS);
  * raises the default image resolution to 448×448;
  * halves the default flow-matching schedule to 5 denoising steps.

Training recipe (FAST discrete action co-training, flow matching, Knowledge
Insulation gradient-stop, Beta(1.5, 1.0) time sampler, block-causal prefix
with bidirectional action suffix) is unchanged from `pi05`.

Implementation notes:
  * `gemma3_with_expert.py` runs a per-layer interleaved attention loop that
    concatenates backbone and expert Q/K/V along the sequence axis, honouring
    Gemma 3's q_norm/k_norm, per-layer local vs global RoPE theta, and
    sliding-window masks. The existing Gemma-v1 AdaRMS monkey-patches in
    `transformers_patch.py` cover the expert; the Gemma 3 backbone runs
    stock (no new patches needed).
  * `modeling_pi06.py` mirrors `modeling_pi05.py` structure — new tokenizer
    (`google/gemma-3-4b-pt`), new `_fix_pytorch_state_dict_keys` that also
    accepts legacy `paligemma_with_expert.*` prefixes as a warm-start path.

Other changes:
  * Register `pi06` in `policies/factory.py` and `opentau.available_policies`.
  * Add `configs/examples/pi06_training_config.json` to bootstrap runs.
  * Update `README.md` with a pi06 bullet, comparison-table row, checkpoints
    placeholder, and a link to the new config.
  * Add `tests/policies/test_pi06.py` covering attention-mask block semantics,
    sliding-window masks, RoPE theta selection, padding-mask contiguity,
    image resizing, and discrete-action padding.

https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi
…ant test

`tests/test_available.py::test_available_policies` hardcodes the list of
known policy classes and asserts its set of `name`s matches
`opentau.available_policies`. The previous commit added `pi06` to the
latter without updating the former, breaking the CPU CI run. Picking up
`PI06Policy` keeps the invariant satisfied.

https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi
@shuheng-liu shuheng-liu self-assigned this Apr 23, 2026
@shuheng-liu shuheng-liu added bug Something isn't working feature New feature or request and removed bug Something isn't working labels Apr 23, 2026
shuheng-liu and others added 5 commits April 23, 2026 11:17
Per reviewer feedback. Drops the long `# ---` lines that bracketed
section headings (now just the heading as a single-line comment),
and collapses the `# --- text ---` inline variants to `# text.`.
No functional change — ruff-format, ruff-check, and the full
pre-commit hook battery (pyupgrade, typos, bandit, gitleaks, etc.)
all pass on the touched files.

https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi
Fixes three distinct correctness issues in the Gemma 3 4B backbone +
Gemma-v1 action expert wiring that a careful second pass uncovered:

1. Vision projector crashes at 448×448. `Gemma3MultiModalProjector`
   hard-codes `patches_per_image = image_size // patch_size`, so the
   default config's `vision_config.image_size = 896` made the
   projector reshape `(B, 1024, 64, 64)` when SigLIP actually emits
   `B × 1024` patches at 448 input — a runtime crash on the first
   forward pass. Set `image_size = 448` to match π0.6's stated
   resolution (→ 32 patches/side, 256 mm tokens/image, which also
   matches the model card).

2. RoPE θ asymmetry between backbone and expert for global layers.
   Gemma 3 interleaves local layers (θ = 10 000) with global layers
   (θ = 1 000 000). The expert's own `rope_theta = 10 000` was being
   applied to its Q/K even on global layers, so the shared
   cross-attention ran backbone-Q (rotated at 1M) against expert-K
   (rotated at 10k) — the `q·R(Δpos)·k` invariant breaks when the two
   rotations live in different RoPE bases. Fix: both streams now use
   the backbone's per-layer θ; the expert's fallback θ is ignored at
   runtime and documented as such.

3. Sliding-window mask used dense indices instead of absolute
   positions. During expert cross-attention the Q tokens sit at
   `prefix_offsets + chunk_idx`, but `_build_sliding_window_mask`
   built its mask from `torch.arange(seq_len)` and the call site
   sliced `[:, :T_suffix, :]`. Result: every prefix key farther than
   `window` from the dense suffix row index was dropped — i.e. the
   sliding layers saw essentially no prefix during cross-attention.
   Fix: take `(query_positions, key_positions)` and compute
   `|pos_q - pos_k| < window` over absolute positions; stash the
   prefix positions in the KV cache so the expert can reconstruct
   them on each denoising step.

Tests:
  * `TestSlidingWindowMask.test_cross_attention_uses_absolute_positions`
    — regression guard for #3.
  * `TestGemma3WithExpertConfig.test_vision_image_size_matches_input_resolution`
    and `.test_projector_accepts_448_inputs` — regression guards
    for #1 (config invariant + end-to-end projector forward).
  * `TestRopeThetaSymmetryDuringForward.test_expert_uses_backbone_per_layer_theta`
    — regression guard for #2, uses monkeypatched `apply_rope` to
    observe exactly which θ each layer asks for.

Local pytest: 29 passed / 1 deselected (up from 25). Full policies/
CPU suite: 75 passed / 2 skipped / 6 deselected. Pre-commit: all
hooks green (ruff, ruff-format, pyupgrade, typos, bandit, gitleaks).

https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi
…gnment)

The π0.6 model card specifies "bidirectional attention among ALL of the
image tokens" and a "block-wise causal" prefix attention pattern, with
no mention of sliding-window or local attention. With 4 cams × 256 image
tokens = 1024 image tokens, the wording "all" is incompatible with
Gemma 3's 1024-token sliding window. The most defensible reading is
that π0.6 deliberately runs every backbone layer with the global
block-causal mask; the local layers' pretrained weights are presumably
adapted via training rather than constrained at inference.

This commit:

  * Removes `_build_sliding_window_mask` (and the `is_sliding`-conditional
    AND in `forward`) — every layer now receives the unmodified prefix
    block-causal `attention_mask`.
  * Reverts the KV-cache `key_positions` field added for the sliding
    mask; cross-attention no longer needs absolute key positions because
    no per-layer mask is constructed from them.
  * Keeps the per-layer RoPE θ split — local layers still rotate at
    θ=10 000 and global layers at θ=1 000 000, because that's baked into
    the pretrained Gemma 3 4B weights and we have to honour it.
  * Replaces the `TestSlidingWindowMask` class with a regression test
    `TestNoSlidingWindowEnforcement` that monkey-patches
    `eager_attention_forward` to verify the per-layer mask equals the
    input mask on BOTH layer types — guarding against the window
    silently sneaking back in.

Local tests: 26 passed / 1 deselected. Pre-commit: all hooks green.

https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi
…els-NgCLN

# Conflicts:
#	README.md
#	src/opentau/policies/factory.py
@shuheng-liu shuheng-liu marked this pull request as ready for review April 28, 2026 22:28
@akshay18iitg akshay18iitg self-requested a review April 29, 2026 00:25
Comment thread src/opentau/policies/pi06/gemma3_with_expert.py
@shuheng-liu shuheng-liu merged commit 9aa75f8 into main Apr 29, 2026
5 checks passed
@shuheng-liu shuheng-liu deleted the claude/compare-pi-models-NgCLN branch April 29, 2026 17:57
@claude claude Bot mentioned this pull request Apr 29, 2026
3 tasks
shuheng-liu added a commit that referenced this pull request Apr 30, 2026
`test_complete_pi06_pipeline_integration_smoke` was added in #178 with
chunk_size=10 but the shared `lerobot_dataset_metadata` fixture provides
actions stats shaped (50, 32) — matching the default PI06Config. The
Normalize buffer was therefore (50, 32) while the test batch's actions
were (B, 10, 32), and MIN_MAX normalization in normalize.py:232 raised
``RuntimeError: The size of tensor a (10) must match the size of tensor
b (50) at non-singleton dimension 1``.

Pre-existing bug — never caught in CI because the test is gated by
@pytest.mark.gpu and skipped in CPU runs. Surfaced now while validating
this PR's SDPA + grad-ckpt port on a real GPU.

Fix by deep-copying the fixture stats and reshaping the actions
max/mean/min/std arrays from (50, 32) to (chunk_size, 32) before
constructing the policy. Same numeric values, just the right shape.
shuheng-liu added a commit that referenced this pull request May 2, 2026
`test_complete_pi06_pipeline_integration_smoke` was added in #178 with
chunk_size=10 but the shared `lerobot_dataset_metadata` fixture provides
actions stats shaped (50, 32) — matching the default PI06Config. The
Normalize buffer was therefore (50, 32) while the test batch's actions
were (B, 10, 32), and MIN_MAX normalization in normalize.py:232 raised
``RuntimeError: The size of tensor a (10) must match the size of tensor
b (50) at non-singleton dimension 1``.

Pre-existing bug — never caught in CI because the test is gated by
@pytest.mark.gpu and skipped in CPU runs. Surfaced now while validating
this PR's SDPA + grad-ckpt port on a real GPU.

Fix by deep-copying the fixture stats and reshaping the actions
max/mean/min/std arrays from (50, 32) to (chunk_size, 32) before
constructing the policy. Same numeric values, just the right shape.

(cherry picked from commit 6425cb4)
shuheng-liu added a commit that referenced this pull request May 2, 2026
`test_complete_pi06_pipeline_integration_smoke` was added in #178 with
chunk_size=10 but the shared `lerobot_dataset_metadata` fixture provides
actions stats shaped (50, 32) — matching the default PI06Config. The
Normalize buffer was therefore (50, 32) while the test batch's actions
were (B, 10, 32), and MIN_MAX normalization in normalize.py:232 raised
``RuntimeError: The size of tensor a (10) must match the size of tensor
b (50) at non-singleton dimension 1``.

Pre-existing bug — never caught in CI because the test is gated by
@pytest.mark.gpu and skipped in CPU runs. Surfaced now while validating
this PR's SDPA + grad-ckpt port on a real GPU.

Fix by deep-copying the fixture stats and reshaping the actions
max/mean/min/std arrays from (50, 32) to (chunk_size, 32) before
constructing the policy. Same numeric values, just the right shape.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request model new model or model request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants