Add conversion script for Qwen3 Next and Readme#2672
Add conversion script for Qwen3 Next and Readme#2672copybara-service[bot] merged 1 commit intomainfrom
Conversation
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This Pull Request introduces comprehensive support for the Qwen3-Next model, including both scanned and unscanned checkpoint conversion scripts. The integration of heterogeneous layers and the new configuration validation are positive additions, demonstrating a thoughtful approach to supporting this new model.
🔍 General Feedback
- The overall structure for Qwen3-Next integration appears well-designed, particularly the handling of alternating Gated Delta Net and Gated Attention layers.
- The addition of configuration validation for
gdn_num_value_headsis a good practice. - There are a few areas identified for potential improvement in terms of code clarity, naming conventions, and a critical logic change in the attention mechanism that warrants further review and verification.
dad763d to
1d0d01c
Compare
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Rohan-Bierneni, but I was unable to process your request. Please see the logs for more details. |
3f0799a to
531058d
Compare
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request adds support for the Qwen3-Next model, including checkpoint conversion scripts, model configurations, and necessary modifications to the attention and decoder layers. The changes are well-structured and the new conversion scripts are comprehensive.
🔍 General Feedback
- The addition of both scanned and unscanned conversion scripts is a great feature, providing flexibility for users.
- The refactoring in
qwen3.pyto make the QKV splitting logic more explicit is a significant improvement in readability and maintainability. - The new documentation and test scripts are clear and helpful for getting started with the new model.
- A few minor improvements were suggested regarding documentation, comments, and default values in test scripts.
There was a problem hiding this comment.
Thanks for the great work!
On the high-level:
- Could you explain why
Qwen3NextGatedDeltaNetis updated? Might be good to include in PR description - For the readme, have you tested these example commands? might be good to have example that can be run with smaller cluster other than v5p-512.
- Followup: if you are to onboard the test to xlml, it might be better to generate golden logits and compare against it (otherwise we would need download 80B and run hf every time).
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change!
Shall we also test pre-training functionality? Or at least we have a compile test in unit tests. Otherwise, we won't know when the Gate DeltaNet is broken.
531058d to
00ae854
Compare
b982f96 to
8426337
Compare
b3e501c to
caf5021
Compare
aireenmei
left a comment
There was a problem hiding this comment.
Thanks for the detailed description!
RissyRan
left a comment
There was a problem hiding this comment.
LGTM, just a few comments.
b27537e to
d46c4e6
Compare
add debug statements Conversion script ran without failing test verify orbax hf tensors Add unscanned conversion script for qwen3 next Move gating op to after sharding optimizations added zero centered rmsnorm Add layer by layer comparision script Remove debug files Remove zero centered rms norm logic Remove changes from forward pass logit checker Remove sow debug line Fix qkvz split in gated delta net and fix normalization after decoder layers Run linter and modify ckpt conversion config remove scanned script since it is not working yet move qwen3 next unscanned conversion script to utils folder Remove rms norm after decoder block for qwen3 next Add scanned conversion script for qwen3 next Added qwen3 next conversion test script Resolved gemini review comments Ran pyink for indentation errors Added readme for qwen3 next typo in qwen3 next readme Reformatted unscanned script Formatted scripts again Undo changes in decoders.py Formatted function with long line length fix linter issues Revise gemini-review comment Add back change to pyconfig after rebase Resolved pr comments Added moe strategies section to qwen3 next readme resolved comments in scripts Dynamically get batch_size and seq_len Add logic to decouple touple when using scanned Resolve pr comments Add train compile test for qwen3-next Update train_compile test for qwen3-next Moved checks to types.py from pyconfig_deprecated.py Resolved comment for qwen3 next readme Ran pyink formatter Remove sparse_matmul test
7081435 to
b6c32b6
Compare
|
Adding pull_ready as all tests passed and lgtmed. |
Description
This pr is a follow up pr to the initial prs for qwen3-next and will enable the model to be fully supported in maxtext for pre-training functionality. The pr will include conversion scripts from huggingface to orbax format, a train_compile test for the model, and verification of forward pass logits between the hf and maxtext model.
The model currently doesn't support decode/inference and will be in a follow up pr to add caching to the gated delta net.
Bugs found and fixed:
1) Simplified qkvz split in GatedDeltaNet
We wrote tests to compare logits between the hf implementation vs our jax implementation. However, in the testcase for the GatedDeltaNet block, we initially wrote simplified tensor splitting logic for pytorch code.
Then, when implementing in jax, our jax code aligned with the simplified pytorch logic, and not what was actually written in reference implementation. This caused our testcases to pass, but forward pass logit checker to fail. We have updated the GatedDeltaNet with the correct qkvz splitting logic and verified forward pass logits match.
2) Wrong RMSNorm after decoder layers
After the 48 decoder layers we were normalizing with default RMSNorm instead of model-specific Qwen3NextRMSNorm which was causing the forward pass to fail.
3) Using MaxText.decode was failing
i) Previously we were hardcoding the query and key-value tensor shapes via cfg.per_device_batch_size:
However, good practice is to use max_utils.get_batch_seq_len_for_mode() to account for different model_modes. The code is now changed to
ii) Moved gating operation
Previously, we were doing gating on the attention_output before sharding logic which was causing decode to fail. Once it was moved to after the sharding logic, decode was working as intended in GatedFullAttention(). The code now looks like this:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
Tests
(Unscanned) Forward pass logit checker: https://paste.googleplex.com/6146326802857984
(Scanned) Forward pass logit checker: https://paste.googleplex.com/6195553369194496
Added train_compile test is passing.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.