fix(pi06): stop double-scaling text embeddings (PR #178 correctness bug)#179
Merged
shuheng-liu merged 1 commit intoApr 23, 2026
Conversation
Gemma 3's text embedding layer is a `Gemma3TextScaledWordEmbedding` which already multiplies its output by `sqrt(hidden_size)` inside its `forward`. The three call sites that previously mirrored pi05's manual `lang_emb * math.sqrt(lang_emb_dim)` were therefore applying the scale twice, leaving every language / response embedding ≈51× too large and mismatched against the (correctly unscaled) image features. This corrupted the block-bidirectional prefix attention, the KV cache the action expert cross-attends to, and the Gemma 3 `lm_head` logits for response / FAST tokens. The pi05 copy was correct because PaliGemma's Gemma-v1 `embed_tokens` is a plain `nn.Embedding` whose `hidden_size**0.5` normalizer is applied later in the stock forward that the policy bypasses — so the manual scale there reproduces the stock behavior. Gemma 3 folds that normalizer into the embedding itself, so the manual scale has to go. https://claude.ai/code/session_01YaPeoLsMfaLu3NojeFnWY1
23376bd
into
claude/compare-pi-models-NgCLN
5 of 6 checks passed
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Stacks on top of #178. Fixes a correctness bug where every text embedding fed into the π0.6 backbone was being multiplied by
sqrt(hidden_size)twice, leaving prompt / response / FAST tokens ≈51× larger than the image tokens they sit next to in the prefix. Label: 🐛 Bug.Root cause
PI06FlowMatching.embed_prefixand.infer_responsecopied pi05's manual normalizer:That is correct for pi05 because PaliGemma's Gemma-v1
embed_tokensis a plainnn.Embeddingwhosehidden_size**0.5normalizer is applied later in the stockforwardthat the policy bypasses — so pi05 reproduces it by hand.Gemma 3 is different.
transformers/models/gemma3/modeling_gemma3.py:134-144folds the normalizer into the embedding layer itself:Instantiated at line 538–539 with
embed_scale=hidden_size**0.5. Soembed_language_tokens(ids)already returnsE[ids] * sqrt(hidden_size), and the manual multiplier doubles it toE[ids] * hidden_size— about 2560× for π0.6's default hidden size.Blast radius
Because the bug is in the prefix assembly, the inflated embeddings propagate through:
lm_headlogits used for response / FAST discrete-action CE.In practice this breaks training convergence and inference fidelity for both the flow-matching action head and the supervised discrete-action / response co-training.
Fix
Remove the three manual
* math.sqrt(..._dim)multiplies inmodeling_pi06.py(prompt tokens inembed_prefix, response tokens inembed_prefix, response token ininfer_response). Image embeddings fromget_image_featuresare correctly left alone — Gemma 3 does not scale those internally, so the prefix is now consistent.How it was tested
transformers/models/gemma3/modeling_gemma3.pyto confirmGemma3TextScaledWordEmbedding.forwardmultiplies byembed_scale = hidden_size**0.5, andtransformers/models/gemma/modeling_gemma.pyto confirm the Gemma-v1embed_tokensis a plainnn.Embedding(so pi05's manual scaling there is still correct).tests/policies/test_pi06.pyCPU tests are unaffected — none of them assert on embedding magnitudes; the mask / RoPE / resize / config tests are the regression surface and continue to pass.@pytest.mark.gpusmoke test from feat(policies): add authentic pi06 policy with Gemma 3 4B backbone #178 and will be exercised by the nightlygpu_test.ymlworkflow once feat(policies): add authentic pi06 policy with Gemma 3 4B backbone #178 lands. I did not runpytest -m "gpu"locally.How to checkout & try? (for the reviewer)
Out of scope
image_size=896vs.resize_imgs_with_padding=(448, 448)) are deliberately left for follow-up discussion — they warrant their own review and may not be bugs.Checklist
https://claude.ai/code/session_01YaPeoLsMfaLu3NojeFnWY1