Fix: Gemma 3 & 4 Base Model and Flax Linen Decoding #4066
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
c54cd4c to
39955c9
Compare
39955c9 to
2eb6078
Compare
|
🤖 Hi @RexBearIU, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
3a983c6 to
c4c93a7
Compare
e755816 to
5c2a98d
Compare
355c904 to
59ca028
Compare
fca7f7c to
338f2fc
Compare
8bf05e7 to
d96b8a6
Compare
| for i in range(scan_length): | ||
| start_idx = i * block_len | ||
| for offset, updated_item in enumerate(unstacked[i]): | ||
| kv_caches[start_idx + offset] = updated_item |
There was a problem hiding this comment.
is kv_caches a python list, not jax array? jax array is not mutable. If so, use if isinstance(kv_caches) List to make sure kv_caches is a list
There was a problem hiding this comment.
Yes, good catch. kv_caches is indeed a Python list containing JAX PyTrees. Since JAX arrays/PyTrees are immutable but we mutate the outer list structure in update_kv_caches_after_scan, I've added explicit isinstance(kv_caches, list) checks and validation tests to ensure it behaves correctly. I just pushed the fixes.
…NNX decoding layer KV-cache propagation
d96b8a6 to
dbcc192
Compare
Description
This PR introduces architectural configurations for Gemma3 and Gemma4, and implements critical decoding layer KV-cache propagation fixes for JAX/NNX
(
nnx_decoders.py) and Flax Linen (decoders.py) pipelines.Before this change, Flax Linen scanned blocks (
_apply_gemma3_scanned_blocksand_apply_gemma4_scanned_blocks) lacked mechanisms to propagateintermediate KV-caches across scanned layers under
nn.scan, causing end-to-end decodes to fail.Key Changes:
- Flax Linen Decoder (
src/maxtext/layers/decoders.py): Addedkv_cachesandattention_metadatapropagation across Gemma3/4 scanned blocksvia a stack/unstack mapping over
nn.scan.- JAX/NNX Decoder (
src/maxtext/layers/nnx_decoders.py): Unified dynamic KV-cache carry updates within scanned block execution.- Model Scaffolding (
gemma3.py,gemma4.py): Addedkv_cacherouting through scannable blocks and layer inputs.Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.