Skip to content

feat(pi07): SDPA attention + gradient checkpointing options#233

Merged
shuheng-liu merged 1 commit into
feat/pi07from
claude/sdpa-ckpt-pi07
May 2, 2026
Merged

feat(pi07): SDPA attention + gradient checkpointing options#233
shuheng-liu merged 1 commit into
feat/pi07from
claude/sdpa-ckpt-pi07

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu commented May 2, 2026

What this does

Backports the pi05 PR #182 / pi06 PR #214 attention + checkpointing knobs to the pi07 engine (src/opentau/policies/pi07/gemma3_with_expert.py) and both planner policies (high-level + low-level). pi07_paligemma is intentionally untouched.

Enginepi07/gemma3_with_expert.py:

  • attention_implementation now accepts "sdpa"; "fa2" warns and falls back to eager (was: NotImplementedError).
  • New sdpa_attention_forward dispatches via F.scaled_dot_product_attention with optional scale= for Gemma 3's query_pre_attn_scalar.
  • The interleaved per-layer body is extracted into _run_layer (bit-identical to the original inlined loop) so it can be the unit of torch.utils.checkpoint.checkpoint(use_reentrant=False) when config.gradient_checkpointing=True.
  • Hoists the lazy past_key_values={} initialization out of the loop so checkpoint recompute is idempotent (saved-tensor hooks would otherwise see a different argument identity on the second pass).

High-level + low-level configs:

  • High-level (PI07HighLevelPlannerConfig): new gradient_checkpointing field.
  • Low-level (PI07LowLevelPlannerConfig): the existing gradient_checkpointing field now applies to both the SpaceTimeSiglip video encoder and the engine (combined semantics, mirrors PR feat(pi0, pi06): add SDPA attention + gradient checkpointing options #214's "one flag per policy" pattern). Behavior change for existing low-level users with the flag set: they additionally get engine ckpt — strictly safer numerically (extra recompute, identical training math), saves more memory than they currently get. Documented in the docstring.
  • Both: __post_init__ plumbs policy-level attention_implementation and gradient_checkpointing into vlm_config so a single --policy.attention_implementation / --policy.gradient_checkpointing CLI override reaches the engine. The previously-dead policy-level attention_implementation field is now wired up.
  • Direct --policy.vlm_config.* overrides still work when the policy-level field is at its default (the plumbing is gated on the policy-level field being non-default).

train.py:255 ZeRO-3/FSDP guard keys off cfg.policy.gradient_checkpointing — policy-agnostic, so pi07 inherits it automatically once the field exists at the policy level.

Defaults stay attention_implementation="eager" and gradient_checkpointing=False — matches PRs #182/#214's opt-in rollout.

How it was tested

CPU equivalence (added in tests/policies/test_pi07_cpu.py)

11 new tests across 5 classes:

  • TestGemma3WithExpertConfig (3): bad/sdpa/fa2 validation.
  • TestPi07AttentionDispatcher (3): dispatcher routing for each implementation; verifies fa2 falls back to eager (no NotImplementedError).
  • TestPi07SdpaEquivalence (1): eager vs sdpa within atol/rtol 1e-4 in fp32. Drives both streams (inputs_embeds=[hidden_backbone, hidden_expert] with adarms_cond=[None, cond]) so the q_concat/k_concat/v_concat cross-stream concat path is the actual codepath under test — pi06 PR feat(pi0, pi06): add SDPA attention + gradient checkpointing options #214 only covered backbone-only.
  • TestPi07GradCkptEquivalence (1): grad-ckpt forward bit-identical (atol/rtol 1e-6) in train mode, dropout=0, both streams. Pins that _run_layer extraction is bit-identical to the inlined loop body.
  • TestPi07ConfigPlumbing (3): __post_init__ propagation invariants — gradient_checkpointing=True and attention_implementation="sdpa" on the policy reach vlm_config, and direct vlm_config overrides survive when policy-level fields are at default.
uv run pytest -m "not gpu" tests/policies/test_pi07_cpu.py -v
# 60 passed (49 existing pi07 CPU tests + 11 new)

uv run pytest -m "not gpu" -n auto tests/policies tests/configs --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py
# 194 passed, 2 skipped — no regressions in adjacent policies/configs.

GPU verification (mlbox, RTX 5090 32 GB)

Pi07 GPU pytest subset on the new branch:

uv run pytest -m "gpu" -n 0 tests/policies/test_pi07_cpu.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_low_level_planner.py tests/policies/test_pi07_video_encoder_cpu.py -v
# 7 passed, 49 deselected — all 7 pi07 GPU integration tests still pass with the refactored engine.

Ad-hoc GPU smoke (4-layer tiny config so it fits with budget for a forward + backward) verifies the actual SDPA + grad-ckpt code paths end-to-end on real CUDA + bf16:

  • eager vs sdpa in eval mode: max-abs-diff backbone=1.80e-01 expert=3.91e-03 (within bf16 reassociation noise; the unit tests assert tighter tolerances in fp32).
  • sdpa + grad-ckpt in train mode: forward + backward complete without OOM, gradients flow on 99/143 trainable parameters (the 44 without grads are da_head/discrete_action_embedding/vision tower paths that the synthetic batch does not exercise — expected).

How to checkout & try? (for the reviewer)

CPU equivalence:

uv run pytest -m "not gpu" tests/policies/test_pi07_cpu.py -v

GPU pytest subset:

uv run pytest -m "gpu" -n 0 tests/policies/test_pi07_cpu.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_low_level_planner.py

GPU smoke training (single rank, ~10 steps; pick whichever lerobot dataset you have cached):

uv run python -m opentau.scripts.train \
  --policy.type=pi07_low_level \
  --policy.gradient_checkpointing=true \
  --policy.attention_implementation=sdpa \
  --batch_size=1 \
  --steps=10 \
  --dataset.repo_id=lerobot/<cached-dataset> \
  --output_dir=/tmp/pi07_low_level_smoke \
  --wandb.enable=false

uv run python -m opentau.scripts.train \
  --policy.type=pi07_high_level \
  --policy.gradient_checkpointing=true \
  --policy.attention_implementation=sdpa \
  --batch_size=1 \
  --steps=10 \
  --dataset.repo_id=lerobot/<cached-dataset> \
  --output_dir=/tmp/pi07_high_level_smoke \
  --wandb.enable=false

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Backports the pi05 PR #182 / pi06 PR #214 attention + checkpointing
knobs to the pi07 engine and both planner policies (high-level and
low-level). pi07_paligemma is intentionally untouched.

- src/opentau/policies/pi07/gemma3_with_expert.py:
  - attention_implementation now accepts "sdpa"; "fa2" warns and falls
    back to eager (was: NotImplementedError).
  - sdpa_attention_forward dispatches via F.scaled_dot_product_attention
    with optional scale= for Gemma 3's query_pre_attn_scalar.
  - Per-layer interleaved body extracted into _run_layer (bit-identical
    to the inlined loop) so it can be the unit of
    torch.utils.checkpoint.checkpoint(use_reentrant=False) when
    config.gradient_checkpointing=True.
  - Hoists past_key_values={} init out of the loop for checkpoint
    recompute idempotency.

- pi07/high_level_planner/configuration_pi07_high_level.py and
  pi07/low_level_planner/configuration_pi07_low_level.py:
  - High-level: new gradient_checkpointing field. Low-level: existing
    flag now applies to BOTH the SpaceTimeSiglip video encoder AND the
    engine (combined semantics, mirrors PR #214). __post_init__ plumbs
    policy-level attention_implementation + gradient_checkpointing into
    vlm_config so a single --policy.* CLI override reaches the engine.
  - Direct --policy.vlm_config.* overrides still work when the
    policy-level field is at its default.

- tests/policies/test_pi07_cpu.py: adds 11 tests across 5 classes:
  - TestGemma3WithExpertConfig (3): bad/sdpa/fa2 validation.
  - TestPi07AttentionDispatcher (3): dispatch routing for each impl.
  - TestPi07SdpaEquivalence (1): eager vs sdpa within atol/rtol 1e-4 in
    fp32, drives BOTH streams to cover the q_concat/k_concat/v_concat
    path (pi06 PR #214 only covered backbone-only).
  - TestPi07GradCkptEquivalence (1): grad-ckpt forward bit-identical
    (atol/rtol 1e-6) in train mode dropout=0, both streams.
  - TestPi07ConfigPlumbing (3): __post_init__ propagation invariants.

The train.py:255 ZeRO-3/FSDP guard keys off cfg.policy.gradient_checkpointing
— policy-agnostic, so pi07 inherits it automatically once the field
exists at the policy level.

Defaults stay attention_implementation="eager" and
gradient_checkpointing=False — matches the opt-in rollout pattern from
PRs #182/#214.
@shuheng-liu shuheng-liu added the feature New feature or request label May 2, 2026
@shuheng-liu shuheng-liu self-assigned this May 2, 2026
@shuheng-liu shuheng-liu added the feature New feature or request label May 2, 2026
@shuheng-liu shuheng-liu marked this pull request as ready for review May 2, 2026 08:10
@shuheng-liu shuheng-liu merged commit 6f256ec into feat/pi07 May 2, 2026
5 of 7 checks passed
@shuheng-liu shuheng-liu deleted the claude/sdpa-ckpt-pi07 branch May 2, 2026 08:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant