Skip to content

Add conversion script for Qwen3 Next and Readme#2672

Merged
copybara-service[bot] merged 1 commit intomainfrom
rbierneni-qwen3-next-ckpt-conversion
Dec 16, 2025
Merged

Add conversion script for Qwen3 Next and Readme#2672
copybara-service[bot] merged 1 commit intomainfrom
rbierneni-qwen3-next-ckpt-conversion

Conversation

@Rohan-Bierneni
Copy link
Copy Markdown
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Nov 12, 2025

Description

This pr is a follow up pr to the initial prs for qwen3-next and will enable the model to be fully supported in maxtext for pre-training functionality. The pr will include conversion scripts from huggingface to orbax format, a train_compile test for the model, and verification of forward pass logits between the hf and maxtext model.

The model currently doesn't support decode/inference and will be in a follow up pr to add caching to the gated delta net.

Bugs found and fixed:

1) Simplified qkvz split in GatedDeltaNet

We wrote tests to compare logits between the hf implementation vs our jax implementation. However, in the testcase for the GatedDeltaNet block, we initially wrote simplified tensor splitting logic for pytorch code.

Then, when implementing in jax, our jax code aligned with the simplified pytorch logic, and not what was actually written in reference implementation. This caused our testcases to pass, but forward pass logit checker to fail. We have updated the GatedDeltaNet with the correct qkvz splitting logic and verified forward pass logits match.


2) Wrong RMSNorm after decoder layers

After the 48 decoder layers we were normalizing with default RMSNorm instead of model-specific Qwen3NextRMSNorm which was causing the forward pass to fail.


3) Using MaxText.decode was failing

i) Previously we were hardcoding the query and key-value tensor shapes via cfg.per_device_batch_size:

inputs_q_shape = (cfg.per_device_batch_size, cfg.max_target_length, cfg.emb_dim)
inputs_kv_shape = (cfg.per_device_batch_size, cfg.max_target_length, cfg.emb_dim)

However, good practice is to use max_utils.get_batch_seq_len_for_mode() to account for different model_modes. The code is now changed to

batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)

ii) Moved gating operation
Previously, we were doing gating on the attention_output before sharding logic which was causing decode to fail. Once it was moved to after the sharding logic, decode was working as intended in GatedFullAttention(). The code now looks like this:

    if model_mode == MODEL_MODE_PREFILL:
      out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
    elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
      out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
    elif model_mode == MODEL_MODE_TRAIN:
      out = self._maybe_shard_with_logical(out, self.out_axis_names)
    else:
      out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
    if self.is_qwen3_next:
      out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
      out = out * jax.nn.sigmoid(gate)
    out = self.out_projection(out, out_sharding=out_sharding)
    out = checkpoint_name(out, "out_proj")
    return out, kv_cache

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456

Tests

(Unscanned) Forward pass logit checker: https://paste.googleplex.com/6146326802857984

(Scanned) Forward pass logit checker: https://paste.googleplex.com/6195553369194496

Added train_compile test is passing.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Comment thread src/MaxText/convert_qwen3_next.py Outdated
Comment thread src/MaxText/convert_qwen3_next.py Outdated
Comment thread src/MaxText/convert_qwen3_next.py Outdated
@github-actions
Copy link
Copy Markdown

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces comprehensive support for the Qwen3-Next model, including both scanned and unscanned checkpoint conversion scripts. The integration of heterogeneous layers and the new configuration validation are positive additions, demonstrating a thoughtful approach to supporting this new model.

🔍 General Feedback

  • The overall structure for Qwen3-Next integration appears well-designed, particularly the handling of alternating Gated Delta Net and Gated Attention layers.
  • The addition of configuration validation for gdn_num_value_heads is a good practice.
  • There are a few areas identified for potential improvement in terms of code clarity, naming conventions, and a critical logic change in the attention mechanism that warrants further review and verification.

Comment thread src/MaxText/layers/attentions.py
Comment thread src/MaxText/layers/decoders.py Outdated
Comment thread src/MaxText/layers/decoders.py Outdated
Comment thread src/MaxText/layers/decoders.py
Comment thread src/MaxText/layers/decoders.py Outdated
Comment thread src/MaxText/layers/decoders.py Outdated
Comment thread src/MaxText/layers/qwen3.py
Comment thread src/MaxText/pyconfig.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py
@github-actions
Copy link
Copy Markdown

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Rohan-Bierneni, but I was unable to process your request. Please see the logs for more details.

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-ckpt-conversion branch from 3f0799a to 531058d Compare November 15, 2025 00:20
@github-actions
Copy link
Copy Markdown

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request adds support for the Qwen3-Next model, including checkpoint conversion scripts, model configurations, and necessary modifications to the attention and decoder layers. The changes are well-structured and the new conversion scripts are comprehensive.

🔍 General Feedback

  • The addition of both scanned and unscanned conversion scripts is a great feature, providing flexibility for users.
  • The refactoring in qwen3.py to make the QKV splitting logic more explicit is a significant improvement in readability and maintainability.
  • The new documentation and test scripts are clear and helpful for getting started with the new model.
  • A few minor improvements were suggested regarding documentation, comments, and default values in test scripts.

Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py Outdated
Comment thread src/MaxText/layers/attentions.py
Comment thread src/MaxText/layers/qwen3.py
Comment thread src/MaxText/pyconfig.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

On the high-level:

  • Could you explain why Qwen3NextGatedDeltaNet is updated? Might be good to include in PR description
  • For the readme, have you tested these example commands? might be good to have example that can be run with smaller cluster other than v5p-512.
  • Followup: if you are to onboard the test to xlml, it might be better to generate golden logits and compare against it (otherwise we would need download 80B and run hf every time).

Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py
Comment thread end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh Outdated
Comment thread end_to_end/tpu/qwen/next/run_qwen3_next.md Outdated
Comment thread end_to_end/tpu/qwen/next/run_qwen3_next.md
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py
Comment thread src/MaxText/layers/qwen3.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

Shall we also test pre-training functionality? Or at least we have a compile test in unit tests. Otherwise, we won't know when the Gate DeltaNet is broken.

Comment thread src/MaxText/layers/qwen3.py
Comment thread src/MaxText/pyconfig.py Outdated
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py
Comment thread src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py Outdated
Comment thread end_to_end/tpu/qwen/next/run_qwen3_next.md
Comment thread end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh Outdated
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-ckpt-conversion branch 2 times, most recently from 531058d to 00ae854 Compare November 17, 2025 04:21
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-ckpt-conversion branch from b982f96 to 8426337 Compare November 26, 2025 23:26
Comment thread end_to_end/tpu/qwen/next/run_qwen3_next.md Outdated
Comment thread src/MaxText/layers/qwen3.py Outdated
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-ckpt-conversion branch 4 times, most recently from b3e501c to caf5021 Compare December 8, 2025 19:41
Copy link
Copy Markdown
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed description!

Comment thread src/MaxText/pyconfig_deprecated.py Outdated
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a few comments.

Comment thread end_to_end/tpu/qwen/next/run_qwen3_next.md Outdated
Comment thread end_to_end/tpu/qwen/next/run_qwen3_next.md
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-ckpt-conversion branch from b27537e to d46c4e6 Compare December 15, 2025 15:26
add debug statements

Conversion script ran without failing

test verify orbax hf tensors

Add unscanned conversion script for qwen3 next

Move gating op to after sharding optimizations

added zero centered rmsnorm

Add layer by layer comparision script

Remove debug files

Remove zero centered rms norm logic

Remove changes from forward pass logit checker

Remove sow debug line

Fix qkvz split in gated delta net and fix normalization after decoder layers

Run linter and modify ckpt conversion config

remove scanned script since it is not working yet

move qwen3 next unscanned conversion script to utils folder

Remove rms norm after decoder block for qwen3 next

Add scanned conversion script for qwen3 next

Added qwen3 next conversion test script

Resolved gemini review comments

Ran pyink for indentation errors

Added readme for qwen3 next

typo in qwen3 next readme

Reformatted unscanned script

Formatted scripts again

Undo changes in decoders.py

Formatted function with long line length

fix linter issues

Revise gemini-review comment

Add back change to pyconfig after rebase

Resolved pr comments

Added moe strategies section to qwen3 next readme

resolved comments in scripts

Dynamically get batch_size and seq_len

Add logic to decouple touple when using scanned

Resolve pr comments

Add train compile test for qwen3-next

Update train_compile test for qwen3-next

Moved checks to types.py from pyconfig_deprecated.py

Resolved comment for qwen3 next readme

Ran pyink formatter

Remove sparse_matmul test
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-ckpt-conversion branch from 7081435 to b6c32b6 Compare December 15, 2025 16:01
@Rohan-Bierneni
Copy link
Copy Markdown
Collaborator Author

Adding pull_ready as all tests passed and lgtmed.

@copybara-service copybara-service Bot merged commit 9edcdba into main Dec 16, 2025
52 of 64 checks passed
@copybara-service copybara-service Bot deleted the rbierneni-qwen3-next-ckpt-conversion branch December 16, 2025 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants