feat(pi07): SDPA attention + gradient checkpointing options#233
Merged
Conversation
Backports the pi05 PR #182 / pi06 PR #214 attention + checkpointing knobs to the pi07 engine and both planner policies (high-level and low-level). pi07_paligemma is intentionally untouched. - src/opentau/policies/pi07/gemma3_with_expert.py: - attention_implementation now accepts "sdpa"; "fa2" warns and falls back to eager (was: NotImplementedError). - sdpa_attention_forward dispatches via F.scaled_dot_product_attention with optional scale= for Gemma 3's query_pre_attn_scalar. - Per-layer interleaved body extracted into _run_layer (bit-identical to the inlined loop) so it can be the unit of torch.utils.checkpoint.checkpoint(use_reentrant=False) when config.gradient_checkpointing=True. - Hoists past_key_values={} init out of the loop for checkpoint recompute idempotency. - pi07/high_level_planner/configuration_pi07_high_level.py and pi07/low_level_planner/configuration_pi07_low_level.py: - High-level: new gradient_checkpointing field. Low-level: existing flag now applies to BOTH the SpaceTimeSiglip video encoder AND the engine (combined semantics, mirrors PR #214). __post_init__ plumbs policy-level attention_implementation + gradient_checkpointing into vlm_config so a single --policy.* CLI override reaches the engine. - Direct --policy.vlm_config.* overrides still work when the policy-level field is at its default. - tests/policies/test_pi07_cpu.py: adds 11 tests across 5 classes: - TestGemma3WithExpertConfig (3): bad/sdpa/fa2 validation. - TestPi07AttentionDispatcher (3): dispatch routing for each impl. - TestPi07SdpaEquivalence (1): eager vs sdpa within atol/rtol 1e-4 in fp32, drives BOTH streams to cover the q_concat/k_concat/v_concat path (pi06 PR #214 only covered backbone-only). - TestPi07GradCkptEquivalence (1): grad-ckpt forward bit-identical (atol/rtol 1e-6) in train mode dropout=0, both streams. - TestPi07ConfigPlumbing (3): __post_init__ propagation invariants. The train.py:255 ZeRO-3/FSDP guard keys off cfg.policy.gradient_checkpointing — policy-agnostic, so pi07 inherits it automatically once the field exists at the policy level. Defaults stay attention_implementation="eager" and gradient_checkpointing=False — matches the opt-in rollout pattern from PRs #182/#214.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Backports the pi05 PR #182 / pi06 PR #214 attention + checkpointing knobs to the pi07 engine (
src/opentau/policies/pi07/gemma3_with_expert.py) and both planner policies (high-level + low-level).pi07_paligemmais intentionally untouched.Engine —
pi07/gemma3_with_expert.py:attention_implementationnow accepts"sdpa";"fa2"warns and falls back to eager (was:NotImplementedError).sdpa_attention_forwarddispatches viaF.scaled_dot_product_attentionwith optionalscale=for Gemma 3'squery_pre_attn_scalar._run_layer(bit-identical to the original inlined loop) so it can be the unit oftorch.utils.checkpoint.checkpoint(use_reentrant=False)whenconfig.gradient_checkpointing=True.past_key_values={}initialization out of the loop so checkpoint recompute is idempotent (saved-tensor hooks would otherwise see a different argument identity on the second pass).High-level + low-level configs:
PI07HighLevelPlannerConfig): newgradient_checkpointingfield.PI07LowLevelPlannerConfig): the existinggradient_checkpointingfield now applies to both the SpaceTimeSiglip video encoder and the engine (combined semantics, mirrors PR feat(pi0, pi06): add SDPA attention + gradient checkpointing options #214's "one flag per policy" pattern). Behavior change for existing low-level users with the flag set: they additionally get engine ckpt — strictly safer numerically (extra recompute, identical training math), saves more memory than they currently get. Documented in the docstring.__post_init__plumbs policy-levelattention_implementationandgradient_checkpointingintovlm_configso a single--policy.attention_implementation/--policy.gradient_checkpointingCLI override reaches the engine. The previously-dead policy-levelattention_implementationfield is now wired up.--policy.vlm_config.*overrides still work when the policy-level field is at its default (the plumbing is gated on the policy-level field being non-default).train.py:255ZeRO-3/FSDP guard keys offcfg.policy.gradient_checkpointing— policy-agnostic, so pi07 inherits it automatically once the field exists at the policy level.Defaults stay
attention_implementation="eager"andgradient_checkpointing=False— matches PRs #182/#214's opt-in rollout.How it was tested
CPU equivalence (added in
tests/policies/test_pi07_cpu.py)11 new tests across 5 classes:
TestGemma3WithExpertConfig(3): bad/sdpa/fa2 validation.TestPi07AttentionDispatcher(3): dispatcher routing for each implementation; verifies fa2 falls back to eager (noNotImplementedError).TestPi07SdpaEquivalence(1): eager vs sdpa within atol/rtol 1e-4 in fp32. Drives both streams (inputs_embeds=[hidden_backbone, hidden_expert]withadarms_cond=[None, cond]) so theq_concat/k_concat/v_concatcross-stream concat path is the actual codepath under test — pi06 PR feat(pi0, pi06): add SDPA attention + gradient checkpointing options #214 only covered backbone-only.TestPi07GradCkptEquivalence(1): grad-ckpt forward bit-identical (atol/rtol 1e-6) in train mode, dropout=0, both streams. Pins that_run_layerextraction is bit-identical to the inlined loop body.TestPi07ConfigPlumbing(3):__post_init__propagation invariants —gradient_checkpointing=Trueandattention_implementation="sdpa"on the policy reachvlm_config, and directvlm_configoverrides survive when policy-level fields are at default.GPU verification (mlbox, RTX 5090 32 GB)
Pi07 GPU pytest subset on the new branch:
Ad-hoc GPU smoke (4-layer tiny config so it fits with budget for a forward + backward) verifies the actual SDPA + grad-ckpt code paths end-to-end on real CUDA + bf16:
max-abs-diff backbone=1.80e-01 expert=3.91e-03(within bf16 reassociation noise; the unit tests assert tighter tolerances in fp32).How to checkout & try? (for the reviewer)
CPU equivalence:
uv run pytest -m "not gpu" tests/policies/test_pi07_cpu.py -vGPU pytest subset:
uv run pytest -m "gpu" -n 0 tests/policies/test_pi07_cpu.py tests/policies/test_pi07_high_level_planner.py tests/policies/test_pi07_low_level_planner.pyGPU smoke training (single rank, ~10 steps; pick whichever lerobot dataset you have cached):
Checklist