Skip to content

Fix: Gemma 3 & 4 Base Model and Flax Linen Decoding #4066

Open
RexBearIU wants to merge 2 commits into
mainfrom
jackyf/gemma3_4-base-decoding-fixes
Open

Fix: Gemma 3 & 4 Base Model and Flax Linen Decoding #4066
RexBearIU wants to merge 2 commits into
mainfrom
jackyf/gemma3_4-base-decoding-fixes

Conversation

@RexBearIU

@RexBearIU RexBearIU commented Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR introduces architectural configurations for Gemma3 and Gemma4, and implements critical decoding layer KV-cache propagation fixes for JAX/NNX
(nnx_decoders.py) and Flax Linen (decoders.py) pipelines.

Before this change, Flax Linen scanned blocks (_apply_gemma3_scanned_blocks and _apply_gemma4_scanned_blocks) lacked mechanisms to propagate
intermediate KV-caches across scanned layers under nn.scan, causing end-to-end decodes to fail.

Key Changes:
- Flax Linen Decoder (src/maxtext/layers/decoders.py): Added kv_caches and attention_metadata propagation across Gemma3/4 scanned blocks
via a stack/unstack mapping over nn.scan.
- JAX/NNX Decoder (src/maxtext/layers/nnx_decoders.py): Unified dynamic KV-cache carry updates within scanned block execution.
- Model Scaffolding (gemma3.py, gemma4.py): Added kv_cache routing through scannable blocks and layer inputs.

Tests

  • Validated via end-to-end decoding with Gemma3-4B on an 8-device TPU mesh.
  • Run unit tests:
python3 -m pytest tests/post_training/unit/lora_utils_test.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 4, 2026

Copy link
Copy Markdown

@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown

🤖 Hi @RexBearIU, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch from 3a983c6 to c4c93a7 Compare June 5, 2026 10:41
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions Bot Jun 5, 2026
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 3 times, most recently from e755816 to 5c2a98d Compare June 8, 2026 12:18
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 3 times, most recently from 355c904 to 59ca028 Compare June 12, 2026 08:15
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 3 times, most recently from fca7f7c to 338f2fc Compare June 18, 2026 09:09
@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch 6 times, most recently from 8bf05e7 to d96b8a6 Compare June 30, 2026 08:11
for i in range(scan_length):
start_idx = i * block_len
for offset, updated_item in enumerate(unstacked[i]):
kv_caches[start_idx + offset] = updated_item

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is kv_caches a python list, not jax array? jax array is not mutable. If so, use if isinstance(kv_caches) List to make sure kv_caches is a list

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch. kv_caches is indeed a Python list containing JAX PyTrees. Since JAX arrays/PyTrees are immutable but we mutate the outer list structure in update_kv_caches_after_scan, I've added explicit isinstance(kv_caches, list) checks and validation tests to ensure it behaves correctly. I just pushed the fixes.

@RexBearIU RexBearIU force-pushed the jackyf/gemma3_4-base-decoding-fixes branch from d96b8a6 to dbcc192 Compare July 1, 2026 08:10
@RexBearIU RexBearIU requested a review from xibinliu as a code owner July 1, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants