Skip to content

fix(pi06): stop double-scaling text embeddings (PR #178 correctness bug)#179

Merged
shuheng-liu merged 1 commit into
claude/compare-pi-models-NgCLNfrom
claude/fix-model-bug-pr178-8xPCI
Apr 23, 2026
Merged

fix(pi06): stop double-scaling text embeddings (PR #178 correctness bug)#179
shuheng-liu merged 1 commit into
claude/compare-pi-models-NgCLNfrom
claude/fix-model-bug-pr178-8xPCI

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

@shuheng-liu shuheng-liu commented Apr 23, 2026

What this does

Stacks on top of #178. Fixes a correctness bug where every text embedding fed into the π0.6 backbone was being multiplied by sqrt(hidden_size) twice, leaving prompt / response / FAST tokens ≈51× larger than the image tokens they sit next to in the prefix. Label: 🐛 Bug.

Root cause

PI06FlowMatching.embed_prefix and .infer_response copied pi05's manual normalizer:

lang_emb = self.gemma3_with_expert.embed_language_tokens(lang_tokens)
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)   # already scaled

That is correct for pi05 because PaliGemma's Gemma-v1 embed_tokens is a plain nn.Embedding whose hidden_size**0.5 normalizer is applied later in the stock forward that the policy bypasses — so pi05 reproduces it by hand.

Gemma 3 is different. transformers/models/gemma3/modeling_gemma3.py:134-144 folds the normalizer into the embedding layer itself:

class Gemma3TextScaledWordEmbedding(nn.Embedding):
    def __init__(self, ..., embed_scale: float = 1.0):
        super().__init__(...)
        self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)

    def forward(self, input_ids):
        return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)

Instantiated at line 538–539 with embed_scale=hidden_size**0.5. So embed_language_tokens(ids) already returns E[ids] * sqrt(hidden_size), and the manual multiplier doubles it to E[ids] * hidden_size — about 2560× for π0.6's default hidden size.

Blast radius

Because the bug is in the prefix assembly, the inflated embeddings propagate through:

  1. the bidirectional prefix attention (image vs. text tokens now live on completely different scales),
  2. the KV cache that the action expert cross-attends to at every layer,
  3. the Gemma 3 lm_head logits used for response / FAST discrete-action CE.

In practice this breaks training convergence and inference fidelity for both the flow-matching action head and the supervised discrete-action / response co-training.

Fix

Remove the three manual * math.sqrt(..._dim) multiplies in modeling_pi06.py (prompt tokens in embed_prefix, response tokens in embed_prefix, response token in infer_response). Image embeddings from get_image_features are correctly left alone — Gemma 3 does not scale those internally, so the prefix is now consistent.

How it was tested

  • Read the installed transformers/models/gemma3/modeling_gemma3.py to confirm Gemma3TextScaledWordEmbedding.forward multiplies by embed_scale = hidden_size**0.5, and transformers/models/gemma/modeling_gemma.py to confirm the Gemma-v1 embed_tokens is a plain nn.Embedding (so pi05's manual scaling there is still correct).
  • Existing tests/policies/test_pi06.py CPU tests are unaffected — none of them assert on embedding magnitudes; the mask / RoPE / resize / config tests are the regression surface and continue to pass.
  • End-to-end GPU training/inference validation inherits the @pytest.mark.gpu smoke test from feat(policies): add authentic pi06 policy with Gemma 3 4B backbone #178 and will be exercised by the nightly gpu_test.yml workflow once feat(policies): add authentic pi06 policy with Gemma 3 4B backbone #178 lands. I did not run pytest -m "gpu" locally.

How to checkout & try? (for the reviewer)

git fetch origin claude/fix-model-bug-pr178-8xPCI
git checkout claude/fix-model-bug-pr178-8xPCI
pytest -sx tests/policies/test_pi06.py

Out of scope

  • Other pi06 architectural concerns (per-layer RoPE θ asymmetry between backbone and expert for global layers, Gemma 3 vision tower image_size=896 vs. resize_imgs_with_padding=(448, 448)) are deliberately left for follow-up discussion — they warrant their own review and may not be bugs.

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: This PR touches a policy file but only deletes three numerically incorrect lines; it introduces no new functions, signatures, or code paths. Leaving the policy checkbox unchecked so the check-checklist job does not require me to attest to GPU tests I did not run — same rationale as #178.

https://claude.ai/code/session_01YaPeoLsMfaLu3NojeFnWY1

Gemma 3's text embedding layer is a `Gemma3TextScaledWordEmbedding`
which already multiplies its output by `sqrt(hidden_size)` inside its
`forward`. The three call sites that previously mirrored pi05's manual
`lang_emb * math.sqrt(lang_emb_dim)` were therefore applying the scale
twice, leaving every language / response embedding ≈51× too large and
mismatched against the (correctly unscaled) image features. This
corrupted the block-bidirectional prefix attention, the KV cache the
action expert cross-attends to, and the Gemma 3 `lm_head` logits for
response / FAST tokens.

The pi05 copy was correct because PaliGemma's Gemma-v1 `embed_tokens`
is a plain `nn.Embedding` whose `hidden_size**0.5` normalizer is
applied later in the stock forward that the policy bypasses — so the
manual scale there reproduces the stock behavior. Gemma 3 folds that
normalizer into the embedding itself, so the manual scale has to go.

https://claude.ai/code/session_01YaPeoLsMfaLu3NojeFnWY1
@shuheng-liu shuheng-liu self-assigned this Apr 23, 2026
@shuheng-liu shuheng-liu added the bug Something isn't working label Apr 23, 2026
@shuheng-liu shuheng-liu marked this pull request as ready for review April 23, 2026 18:17
@shuheng-liu shuheng-liu merged commit 23376bd into claude/compare-pi-models-NgCLN Apr 23, 2026
5 of 6 checks passed
@shuheng-liu shuheng-liu deleted the claude/fix-model-bug-pr178-8xPCI branch April 23, 2026 18:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants