Skip to content

cp: fix(gemma4_moe): vision-aware mask when use_bidirectional_attention==vision (1905) into r0.4.0#1907

Merged
akoumpa merged 1 commit intor0.4.0from
cherry-pick-1905-r0.4.0
Apr 19, 2026
Merged

cp: fix(gemma4_moe): vision-aware mask when use_bidirectional_attention==vision (1905) into r0.4.0#1907
akoumpa merged 1 commit intor0.4.0from
cherry-pick-1905-r0.4.0

Conversation

@svcnvidia-nemo-ci
Copy link
Copy Markdown
Contributor

beep boop [🤖]: Hi @jQizhang 👋,

we've cherry picked #1905 into  for you! 🚀

Please review and approve this cherry pick by your convenience!

…vision (#1905)

fix(gemma4_moe): use vision-aware attention mask when use_bidirectional_attention="vision"

Gemma4 multimodal variants with `use_bidirectional_attention="vision"` in
their text config (e.g. gemma-4-26B-A4B-it, gemma-4-31B-it) require a
vision-aware attention mask that makes tokens inside the same vision group
visible to each other bidirectionally. HF's `Gemma4Model.forward` builds
this mask via `create_causal_mask_mapping`.

The MoE backend `Gemma4MoETextModelBackend.forward` was always building
plain `create_causal_mask` + `create_sliding_window_causal_mask` regardless
of the config flag. For `gemma-4-26B-A4B-it` this makes the MoE forward
numerically diverge from HF on multimodal inputs (vision token logprobs
can differ by 20+ in log-space), and increases `train/gen_kl_error` by
roughly an order of magnitude during GRPO training (~0.01 vs ~0.001 on
text-only).

Fix:
- Accept `mm_token_type_ids` and `pixel_values` in
  `Gemma4MoETextModelBackend.forward`.
- When `config.use_bidirectional_attention == "vision"`, call HF's
  `create_causal_mask_mapping` (matches `Gemma4Model.forward`). Otherwise
  keep the existing plain causal-mask path.
- Plumb `mm_token_type_ids` / `pixel_values` from
  `Gemma4ForConditionalGeneration.forward` down to the text backend.

Measured impact on gemma-4-26B-A4B-it multimodal forward (HF as ground
truth, single synthetic image, 341-token sequence):

  metric                     before      after (expected)
  HF vs Automodel gen_kl     0.034       ~0.01 (FSDP noise floor)
  HF vs vLLM gen_kl          0.092       0.092 (unchanged)

Signed-off-by: jQizhang <larkz@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@svcnvidia-nemo-ci
Copy link
Copy Markdown
Contributor Author

/ok to test 813b0d3

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa akoumpa merged commit 7011ed4 into r0.4.0 Apr 19, 2026
51 of 54 checks passed
@akoumpa akoumpa deleted the cherry-pick-1905-r0.4.0 branch April 19, 2026 01:55
Copy link
Copy Markdown
Contributor

@akoumpa akoumpa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick Run CICD Trigger Testing CICD

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants