From 545b677b1df47550ffe04449d9ae6cab7ba17803 Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Thu, 21 May 2026 05:58:23 +0000 Subject: [PATCH] Add Qwen3.5 vision layers --- pytest.ini | 1 + .../configs/models/qwen3.5-397b-a17b.yml | 20 +++ src/maxtext/configs/types.py | 1 + src/maxtext/layers/attentions.py | 6 +- src/maxtext/layers/decoders.py | 1 + src/maxtext/layers/encoders.py | 10 ++ src/maxtext/models/qwen3_5_vision.py | 38 ++++ src/maxtext/multimodal/processor.py | 16 +- tests/unit/qwen3_5_layers_test.py | 167 ++++++++++++++++++ 9 files changed, 250 insertions(+), 10 deletions(-) create mode 100644 src/maxtext/models/qwen3_5_vision.py create mode 100644 tests/unit/qwen3_5_layers_test.py diff --git a/pytest.ini b/pytest.ini index a38a9883a0..840b4fa7e5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -24,6 +24,7 @@ addopts = --ignore=tests/unit/profiler_test.py --ignore=tests/unit/qwen3_omni_layers_test.py --ignore=tests/unit/qwen3_next_vs_reference_test.py + --ignore=tests/unit/qwen3_5_layers_test.py --ignore=tests/unit/deepseek32_vs_reference_test.py --ignore=tests/unit/engram_vs_reference_test.py markers = diff --git a/src/maxtext/configs/models/qwen3.5-397b-a17b.yml b/src/maxtext/configs/models/qwen3.5-397b-a17b.yml index 1613286b4b..6809804780 100644 --- a/src/maxtext/configs/models/qwen3.5-397b-a17b.yml +++ b/src/maxtext/configs/models/qwen3.5-397b-a17b.yml @@ -49,3 +49,23 @@ partial_rotary_factor: 0.25 # General Model Settings enable_dropout: False + +# Vision Encoder Configuration (need to set use_multimodal=true) +# Based on Qwen3.5 MoE Vision Model Config +image_size_for_vit: 768 +hidden_size_for_vit: 1152 +intermediate_size_for_vit: 4304 +num_attention_heads_for_vit: 16 +num_hidden_layers_for_vit: 27 +num_channels_for_vit: 3 +patch_size_for_vit: 16 +temporal_patch_size_for_vit: 2 +spatial_merge_size_for_vit: 2 +out_hidden_size_for_vit: 4096 # Projects to decoder emb_dim (4096) +num_position_embeddings_for_vit: 2304 +deepstack_visual_indexes_for_vit: [] # No deepstack for Qwen3.5 VL +rope_theta_for_vit: 10000 + +# MRoPE Settings (Multi-dimensional RoPE for multimodal) +use_mrope: true +mrope_section: [11, 11, 10] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index f6a92bbb8a..77484575dd 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2927,6 +2927,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3.5-397b-a17b", ) if self.model_name not in valid_mm_models and self.model_name != "default": raise ValueError(f"Multimodal is only supported for {valid_mm_models}, not {self.model_name}") diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 3e347fd0ab..ff44a95b5d 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -429,7 +429,9 @@ def __init__( self.share_kv_layer = share_kv_layer self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2 - self.is_qwen3_hybrid = self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5) + self.is_qwen3_hybrid = ( + self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5) and not self.is_vision + ) # Module attribute names must match names previously passed to Linen for checkpointing self.KVCache_0 = ( @@ -804,7 +806,7 @@ def init_rotary_embedding(self): rope_type = self.rope_type rope_use_scale = self.config.rope_use_scale if self.is_vision: - if self.config.model_name.startswith("qwen3-omni"): + if self.config.model_name.startswith("qwen3-omni") or self.config.model_name.startswith("qwen3.5"): rotary_embedding = Qwen3OmniMoeVisionRotaryEmbedding( hidden_size=self.config.hidden_size_for_vit, num_attention_heads=self.config.num_attention_heads_for_vit, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 1003aeb249..eff0cdfdf8 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -661,6 +661,7 @@ def _apply_embedding( "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3.5-397b-a17b", ]: y = mm_utils.merge_mm_embeddings( text_embeddings=y, diff --git a/src/maxtext/layers/encoders.py b/src/maxtext/layers/encoders.py index 16f14c2d20..4fdd524c9d 100644 --- a/src/maxtext/layers/encoders.py +++ b/src/maxtext/layers/encoders.py @@ -70,6 +70,16 @@ def _setup_vision_encoder_layers(self): self, projector_name, gemma4_vision.Gemma4VisionProjector(config=self.config, mesh=self.mesh, rngs=self.rngs) ) return encoder_name, projector_name + elif self.config.model_name in ["qwen3.5-397b-a17b"]: + from maxtext.models import qwen3_5_vision # pylint: disable=import-outside-toplevel + + encoder_name = "Qwen3_5MoeVisionEncoder_0" + projector_name = "Qwen3_5MoeVisionProjector_0" + setattr( + self, encoder_name, qwen3_5_vision.Qwen3_5MoeVisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs) + ) + setattr(self, projector_name, qwen3_5_vision.Qwen3_5MoeVisionProjector(config=self.config, rngs=self.rngs)) + return encoder_name, projector_name else: raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet") diff --git a/src/maxtext/models/qwen3_5_vision.py b/src/maxtext/models/qwen3_5_vision.py new file mode 100644 index 0000000000..90feec4c10 --- /dev/null +++ b/src/maxtext/models/qwen3_5_vision.py @@ -0,0 +1,38 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3.5 Vision model tower NNX subclasses. + +These classes subclass the Qwen3-Omni vision tower components to provide +clean class type names (Qwen3_5MoeVision...) in the Flax NNX metadata, +ensuring that the JAX parameter keys stored in checkpoints do not contain +the word 'Omni'. +""" + +from maxtext.models.qwen3 import Qwen3OmniMoeVisionEncoder, Qwen3OmniMoeVisionProjector + + +class Qwen3_5MoeVisionEncoder(Qwen3OmniMoeVisionEncoder): + """Subclass of Qwen3OmniMoeVisionEncoder for Qwen3.5 VL models. + + Inherits all core vision tower layers (patch embedding, position embedding, + rotary embeddings, attention, and transformer blocks) without modification. + """ + + +class Qwen3_5MoeVisionProjector(Qwen3OmniMoeVisionProjector): + """Subclass of Qwen3OmniMoeVisionProjector for Qwen3.5 VL models. + + Inherits the final projection/merger layers without modification. + """ diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index ca880107aa..3ed44ce1b3 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -44,7 +44,7 @@ def preprocess_mm_data(config): images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")] processor_outputs = preprocess_mm_data_llama4(images) - elif config.model_name in ["qwen3-omni-30b-a3b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel processor_outputs = preprocess_mm_data_qwen3_omni(config) @@ -68,7 +68,7 @@ def preprocess_image_for_training(image, model_name): from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel return preprocess_mm_data_llama4(image) - elif model_name in ["qwen3-omni-30b-a3b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel return preprocess_mm_data_qwen3_omni_for_training(image) @@ -90,7 +90,7 @@ def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | No from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel return get_image_offsets_llama4(processor_output) - elif config.model_name in ["qwen3-omni-30b-a3b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel return get_mm_offsets_qwen3_omni(config, processor_output) @@ -112,7 +112,7 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_pla from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel return reformat_prompt_llama4(prompt, image_placeholder, num_images) - elif model_name in ["qwen3-omni-30b-a3b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel return reformat_prompt_qwen3_omni( @@ -137,7 +137,7 @@ def reformat_response(response, model_name): elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]: formatted_response = f"{response}" return formatted_response - elif model_name in ["qwen3-omni-30b-a3b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: formatted_response = f"{response}<|im_end|>" return formatted_response else: @@ -158,7 +158,7 @@ def prepare_text_for_image_fusion(tokens, config, processor_output=None): from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel return add_extra_tokens_for_images_llama4(tokens, processor_output) - elif config.model_name in ["qwen3-omni-30b-a3b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output) @@ -181,7 +181,7 @@ def get_dummy_image_shape_for_init(model_name, batch_size=1, num_image_per_seque from maxtext.multimodal.processor_llama4 import get_dummy_image_shape_for_init_llama4 # pylint: disable=import-outside-toplevel image_shape = get_dummy_image_shape_for_init_llama4(batch_size, num_image_per_sequence) - elif model_name.startswith("qwen3-omni-30b-a3b"): + elif model_name.startswith("qwen3-omni-30b-a3b") or model_name.startswith("qwen3.5-397b-a17b"): from maxtext.multimodal.processor_qwen3_omni import get_dummy_image_shape_for_init_qwen3_omni # pylint: disable=import-outside-toplevel image_shape = get_dummy_image_shape_for_init_qwen3_omni(batch_size) @@ -222,7 +222,7 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens): from maxtext.multimodal.processor_llama4 import LLAMA4_PATCH_TOKEN # pylint: disable=import-outside-toplevel bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN - elif config.model_name in ["qwen3-omni-30b-a3b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import QWEN3_OMNI_IMAGE_TOKEN, QWEN3_OMNI_VIDEO_TOKEN # pylint: disable=import-outside-toplevel # Create bidirectional_mask for vision/video token merging diff --git a/tests/unit/qwen3_5_layers_test.py b/tests/unit/qwen3_5_layers_test.py new file mode 100644 index 0000000000..62e513f03f --- /dev/null +++ b/tests/unit/qwen3_5_layers_test.py @@ -0,0 +1,167 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Qwen3.5 Moe Vision subclasses comparing JAX implementations against PyTorch references.""" + +import os +import unittest +import numpy as np +import torch +from flax import nnx +import jax +import jax.numpy as jnp +from jax.sharding import Mesh + +from maxtext.configs import pyconfig +from maxtext.utils.globals import MAXTEXT_REPO_ROOT + +# Explicit, differentiated imports matching Gemma4 layer-wise testing style +from maxtext.models.qwen3_5_vision import ( + Qwen3_5MoeVisionEncoder as JaxQwen3_5MoeVisionEncoder, + Qwen3_5MoeVisionProjector as JaxQwen3_5MoeVisionProjector, +) + +from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig +from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeVisionModel as TorchQwen3_5MoeVisionModel, +) + +from tests.utils.multimodal_test_utils import ( + assert_all_close_jax_torch, + copy_patch_embed_weights, + copy_layernorm_weights, + copy_attention_weights_to_maxtext, + copy_linear_weights, + create_random_jax_torch, +) + +# Initialize JAX config cleanly using Qwen3.5-397B VL model registered config +base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml") +jax_config = pyconfig.initialize( + ["", base_config_path], + model_name="qwen3.5-397b-a17b", + attention="dot_product", + attention_type="full", + matmul_precision="highest", + dropout_rate=0.0, + dtype="float32", + dtype_mm="float32", + weight_dtype="float32", + float32_logits=True, + float32_qk_product=True, +) + +# PyTorch Vision Config: all fields match Qwen3_5MoeVisionConfig defaults for the 397B-A17B vision +# tower except out_hidden_size, which must match the text decoder's base_emb_dim (4096 vs default 2048 of 35B-A3B). +torch_vision_config = Qwen3_5MoeVisionConfig( + out_hidden_size=jax_config.out_hidden_size_for_vit, + attn_implementation="eager", +) + +torch.set_grad_enabled(False) + + +def create_torch_vision_encoder(): + """Create and configure PyTorch Qwen3.5 vision model.""" + encoder = TorchQwen3_5MoeVisionModel(torch_vision_config) + encoder.eval() + return encoder + + +def copy_qwen3_5_vision_encoder_weights(torch_encoder, jax_encoder): + """Copy weights from PyTorch Qwen3.5 vision encoder to JAX subclassed encoder.""" + # Copy patch embedding + copy_patch_embed_weights(torch_encoder.patch_embed, jax_encoder.patch_embed) + + # Copy positional embedding weights + torch_pos_embed = torch_encoder.pos_embed.weight.detach().cpu().numpy() + jax_encoder.pos_embed_interpolate.pos_embed.value = jnp.array(torch_pos_embed) + + # Copy encoder blocks + for i, torch_block in enumerate(torch_encoder.blocks): + jax_block = getattr(jax_encoder, f"blocks_{i}") + copy_layernorm_weights(torch_block.norm1, jax_block.ln1) + copy_layernorm_weights(torch_block.norm2, jax_block.ln2) + copy_attention_weights_to_maxtext(torch_block.attn, jax_block.attn.attn, fused_qkv=True) + copy_linear_weights(torch_block.mlp.linear_fc1, jax_block.mlp) + copy_linear_weights(torch_block.mlp.linear_fc2, jax_block.mlp_out) + + +def copy_qwen3_5_patch_merger_weights(torch_merger, jax_merger): + """Copy patch merger weights from PyTorch Qwen3.5 to JAX subclassed merger.""" + copy_layernorm_weights(torch_merger.norm, jax_merger.ln_q) + copy_linear_weights(torch_merger.linear_fc1, jax_merger.mlp_0) + copy_linear_weights(torch_merger.linear_fc2, jax_merger.mlp_2) + + +class TestQwen3_5MoeVisionEncoderEndToEnd(unittest.TestCase): + """End-to-end equivalence test for Qwen3.5 Moe Vision Encoder + Projector JAX subclasses.""" + + def setUp(self): + np.random.seed(42) + torch.manual_seed(42) + devices = jax.devices() + self.mesh = Mesh(np.array(devices[:1]), axis_names=("data",)) + + def test_vision_encoder_subclasses_match_torch(self): + """Test full JAX vision subclassed tower matches PyTorch Qwen3.5 vision tower.""" + torch_encoder = create_torch_vision_encoder() + + jax_encoder = JaxQwen3_5MoeVisionEncoder(config=jax_config, mesh=self.mesh, rngs=nnx.Rngs(42)) + jax_projector = JaxQwen3_5MoeVisionProjector(config=jax_config, rngs=nnx.Rngs(43)) + + # Copy weights + copy_qwen3_5_vision_encoder_weights(torch_encoder, jax_encoder) + copy_qwen3_5_patch_merger_weights(torch_encoder.merger, jax_projector.merger) + + patch_size = jax_config.patch_size_for_vit + temporal_patch_size = jax_config.temporal_patch_size_for_vit + in_channels = jax_config.num_channels_for_vit + h, w = 8, 8 # 8x8 patches + + n_patches = h * w + total_elements = n_patches * in_channels * temporal_patch_size * patch_size * patch_size + flat_data, _ = create_random_jax_torch(total_elements) + + # Reshape inputs + jax_hidden_states = flat_data.reshape(1, in_channels, temporal_patch_size, h * patch_size, w * patch_size) + torch_hidden_states = torch.from_numpy( + np.array(flat_data).reshape((n_patches, in_channels, temporal_patch_size, patch_size, patch_size)) + ) + + grid_thw = np.array([[1, h, w]], dtype=np.int64) + grid_thw_torch = torch.from_numpy(grid_thw) + + # PyTorch forward + torch_out = torch_encoder(torch_hidden_states, grid_thw_torch) + torch_output = torch_out.pooler_output # after merger + + # JAX forward + jax_encoder_output, _ = jax_encoder(jax_hidden_states) + jax_output = jax_projector(jax_encoder_output) + jax_output = jax_output[0] + + # Compare final projected outputs. Use 2e-2 (vs 1e-2 for omni) because out_hidden_size=4096 + # (vs 2048 for omni) doubles the projector MLP output, accumulating slightly more float32 error. + assert_all_close_jax_torch( + jax_output, + torch_output, + rtol=1e-2, + atol=2e-2, + error_msg="Qwen3.5 JAX subclassed vision tower final output differs from PyTorch reference", + ) + + +if __name__ == "__main__": + unittest.main()