Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ addopts =
--ignore=tests/unit/profiler_test.py
--ignore=tests/unit/qwen3_omni_layers_test.py
--ignore=tests/unit/qwen3_next_vs_reference_test.py
--ignore=tests/unit/qwen3_5_layers_test.py
--ignore=tests/unit/deepseek32_vs_reference_test.py
--ignore=tests/unit/engram_vs_reference_test.py
markers =
Expand Down
20 changes: 20 additions & 0 deletions src/maxtext/configs/models/qwen3.5-397b-a17b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,23 @@ partial_rotary_factor: 0.25

# General Model Settings
enable_dropout: False

# Vision Encoder Configuration (need to set use_multimodal=true)
# Based on Qwen3.5 MoE Vision Model Config
image_size_for_vit: 768
hidden_size_for_vit: 1152
intermediate_size_for_vit: 4304
num_attention_heads_for_vit: 16
num_hidden_layers_for_vit: 27
num_channels_for_vit: 3
patch_size_for_vit: 16
temporal_patch_size_for_vit: 2
spatial_merge_size_for_vit: 2
out_hidden_size_for_vit: 4096 # Projects to decoder emb_dim (4096)
num_position_embeddings_for_vit: 2304
deepstack_visual_indexes_for_vit: [] # No deepstack for Qwen3.5 VL
rope_theta_for_vit: 10000

# MRoPE Settings (Multi-dimensional RoPE for multimodal)
use_mrope: true
mrope_section: [11, 11, 10]
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2927,6 +2927,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
"qwen3.5-397b-a17b",
)
if self.model_name not in valid_mm_models and self.model_name != "default":
raise ValueError(f"Multimodal is only supported for {valid_mm_models}, not {self.model_name}")
Expand Down
6 changes: 4 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ def __init__(
self.share_kv_layer = share_kv_layer

self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2
self.is_qwen3_hybrid = self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5)
self.is_qwen3_hybrid = (
self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5) and not self.is_vision
)

# Module attribute names must match names previously passed to Linen for checkpointing
self.KVCache_0 = (
Expand Down Expand Up @@ -804,7 +806,7 @@ def init_rotary_embedding(self):
rope_type = self.rope_type
rope_use_scale = self.config.rope_use_scale
if self.is_vision:
if self.config.model_name.startswith("qwen3-omni"):
if self.config.model_name.startswith("qwen3-omni") or self.config.model_name.startswith("qwen3.5"):
rotary_embedding = Qwen3OmniMoeVisionRotaryEmbedding(
hidden_size=self.config.hidden_size_for_vit,
num_attention_heads=self.config.num_attention_heads_for_vit,
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ def _apply_embedding(
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
"qwen3.5-397b-a17b",
]:
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/layers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def _setup_vision_encoder_layers(self):
self, projector_name, gemma4_vision.Gemma4VisionProjector(config=self.config, mesh=self.mesh, rngs=self.rngs)
)
return encoder_name, projector_name
elif self.config.model_name in ["qwen3.5-397b-a17b"]:
from maxtext.models import qwen3_5_vision # pylint: disable=import-outside-toplevel

encoder_name = "Qwen3_5MoeVisionEncoder_0"
projector_name = "Qwen3_5MoeVisionProjector_0"
setattr(
self, encoder_name, qwen3_5_vision.Qwen3_5MoeVisionEncoder(config=self.config, mesh=self.mesh, rngs=self.rngs)
)
setattr(self, projector_name, qwen3_5_vision.Qwen3_5MoeVisionProjector(config=self.config, rngs=self.rngs))
return encoder_name, projector_name
else:
raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet")

Expand Down
38 changes: 38 additions & 0 deletions src/maxtext/models/qwen3_5_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Qwen3.5 Vision model tower NNX subclasses.

These classes subclass the Qwen3-Omni vision tower components to provide
clean class type names (Qwen3_5MoeVision...) in the Flax NNX metadata,
ensuring that the JAX parameter keys stored in checkpoints do not contain
the word 'Omni'.
"""

from maxtext.models.qwen3 import Qwen3OmniMoeVisionEncoder, Qwen3OmniMoeVisionProjector


class Qwen3_5MoeVisionEncoder(Qwen3OmniMoeVisionEncoder):
"""Subclass of Qwen3OmniMoeVisionEncoder for Qwen3.5 VL models.

Inherits all core vision tower layers (patch embedding, position embedding,
rotary embeddings, attention, and transformer blocks) without modification.
"""


class Qwen3_5MoeVisionProjector(Qwen3OmniMoeVisionProjector):
"""Subclass of Qwen3OmniMoeVisionProjector for Qwen3.5 VL models.

Inherits the final projection/merger layers without modification.
"""
16 changes: 8 additions & 8 deletions src/maxtext/multimodal/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def preprocess_mm_data(config):

images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = preprocess_mm_data_llama4(images)
elif config.model_name in ["qwen3-omni-30b-a3b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel

processor_outputs = preprocess_mm_data_qwen3_omni(config)
Expand All @@ -68,7 +68,7 @@ def preprocess_image_for_training(image, model_name):
from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel

return preprocess_mm_data_llama4(image)
elif model_name in ["qwen3-omni-30b-a3b"]:
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel

return preprocess_mm_data_qwen3_omni_for_training(image)
Expand All @@ -90,7 +90,7 @@ def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | No
from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel

return get_image_offsets_llama4(processor_output)
elif config.model_name in ["qwen3-omni-30b-a3b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel

return get_mm_offsets_qwen3_omni(config, processor_output)
Expand All @@ -112,7 +112,7 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_pla
from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel

return reformat_prompt_llama4(prompt, image_placeholder, num_images)
elif model_name in ["qwen3-omni-30b-a3b"]:
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel

return reformat_prompt_qwen3_omni(
Expand All @@ -137,7 +137,7 @@ def reformat_response(response, model_name):
elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]:
formatted_response = f"{response}<end_of_turn>"
return formatted_response
elif model_name in ["qwen3-omni-30b-a3b"]:
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
formatted_response = f"{response}<|im_end|>"
return formatted_response
else:
Expand All @@ -158,7 +158,7 @@ def prepare_text_for_image_fusion(tokens, config, processor_output=None):
from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel

return add_extra_tokens_for_images_llama4(tokens, processor_output)
elif config.model_name in ["qwen3-omni-30b-a3b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel

return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output)
Expand All @@ -181,7 +181,7 @@ def get_dummy_image_shape_for_init(model_name, batch_size=1, num_image_per_seque
from maxtext.multimodal.processor_llama4 import get_dummy_image_shape_for_init_llama4 # pylint: disable=import-outside-toplevel

image_shape = get_dummy_image_shape_for_init_llama4(batch_size, num_image_per_sequence)
elif model_name.startswith("qwen3-omni-30b-a3b"):
elif model_name.startswith("qwen3-omni-30b-a3b") or model_name.startswith("qwen3.5-397b-a17b"):
from maxtext.multimodal.processor_qwen3_omni import get_dummy_image_shape_for_init_qwen3_omni # pylint: disable=import-outside-toplevel

image_shape = get_dummy_image_shape_for_init_qwen3_omni(batch_size)
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens):
from maxtext.multimodal.processor_llama4 import LLAMA4_PATCH_TOKEN # pylint: disable=import-outside-toplevel

bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN
elif config.model_name in ["qwen3-omni-30b-a3b"]:
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import QWEN3_OMNI_IMAGE_TOKEN, QWEN3_OMNI_VIDEO_TOKEN # pylint: disable=import-outside-toplevel

# Create bidirectional_mask for vision/video token merging
Expand Down
167 changes: 167 additions & 0 deletions tests/unit/qwen3_5_layers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Qwen3.5 Moe Vision subclasses comparing JAX implementations against PyTorch references."""

import os
import unittest
import numpy as np
import torch
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh

from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_REPO_ROOT

# Explicit, differentiated imports matching Gemma4 layer-wise testing style
from maxtext.models.qwen3_5_vision import (
Qwen3_5MoeVisionEncoder as JaxQwen3_5MoeVisionEncoder,
Qwen3_5MoeVisionProjector as JaxQwen3_5MoeVisionProjector,
)

from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeVisionModel as TorchQwen3_5MoeVisionModel,
)

from tests.utils.multimodal_test_utils import (
assert_all_close_jax_torch,
copy_patch_embed_weights,
copy_layernorm_weights,
copy_attention_weights_to_maxtext,
copy_linear_weights,
create_random_jax_torch,
)

# Initialize JAX config cleanly using Qwen3.5-397B VL model registered config
base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml")
jax_config = pyconfig.initialize(
["", base_config_path],
model_name="qwen3.5-397b-a17b",
attention="dot_product",
attention_type="full",
matmul_precision="highest",
dropout_rate=0.0,
dtype="float32",
dtype_mm="float32",
weight_dtype="float32",
float32_logits=True,
float32_qk_product=True,
)

# PyTorch Vision Config: all fields match Qwen3_5MoeVisionConfig defaults for the 397B-A17B vision
# tower except out_hidden_size, which must match the text decoder's base_emb_dim (4096 vs default 2048 of 35B-A3B).
torch_vision_config = Qwen3_5MoeVisionConfig(
out_hidden_size=jax_config.out_hidden_size_for_vit,
attn_implementation="eager",
)

torch.set_grad_enabled(False)


def create_torch_vision_encoder():
"""Create and configure PyTorch Qwen3.5 vision model."""
encoder = TorchQwen3_5MoeVisionModel(torch_vision_config)
encoder.eval()
return encoder


def copy_qwen3_5_vision_encoder_weights(torch_encoder, jax_encoder):
"""Copy weights from PyTorch Qwen3.5 vision encoder to JAX subclassed encoder."""
# Copy patch embedding
copy_patch_embed_weights(torch_encoder.patch_embed, jax_encoder.patch_embed)

# Copy positional embedding weights
torch_pos_embed = torch_encoder.pos_embed.weight.detach().cpu().numpy()
jax_encoder.pos_embed_interpolate.pos_embed.value = jnp.array(torch_pos_embed)

# Copy encoder blocks
for i, torch_block in enumerate(torch_encoder.blocks):
jax_block = getattr(jax_encoder, f"blocks_{i}")
copy_layernorm_weights(torch_block.norm1, jax_block.ln1)
copy_layernorm_weights(torch_block.norm2, jax_block.ln2)
copy_attention_weights_to_maxtext(torch_block.attn, jax_block.attn.attn, fused_qkv=True)
copy_linear_weights(torch_block.mlp.linear_fc1, jax_block.mlp)
copy_linear_weights(torch_block.mlp.linear_fc2, jax_block.mlp_out)


def copy_qwen3_5_patch_merger_weights(torch_merger, jax_merger):
"""Copy patch merger weights from PyTorch Qwen3.5 to JAX subclassed merger."""
copy_layernorm_weights(torch_merger.norm, jax_merger.ln_q)
copy_linear_weights(torch_merger.linear_fc1, jax_merger.mlp_0)
copy_linear_weights(torch_merger.linear_fc2, jax_merger.mlp_2)


class TestQwen3_5MoeVisionEncoderEndToEnd(unittest.TestCase):
"""End-to-end equivalence test for Qwen3.5 Moe Vision Encoder + Projector JAX subclasses."""

def setUp(self):
np.random.seed(42)
torch.manual_seed(42)
devices = jax.devices()
self.mesh = Mesh(np.array(devices[:1]), axis_names=("data",))

def test_vision_encoder_subclasses_match_torch(self):
"""Test full JAX vision subclassed tower matches PyTorch Qwen3.5 vision tower."""
torch_encoder = create_torch_vision_encoder()

jax_encoder = JaxQwen3_5MoeVisionEncoder(config=jax_config, mesh=self.mesh, rngs=nnx.Rngs(42))
jax_projector = JaxQwen3_5MoeVisionProjector(config=jax_config, rngs=nnx.Rngs(43))

# Copy weights
copy_qwen3_5_vision_encoder_weights(torch_encoder, jax_encoder)
copy_qwen3_5_patch_merger_weights(torch_encoder.merger, jax_projector.merger)

patch_size = jax_config.patch_size_for_vit
temporal_patch_size = jax_config.temporal_patch_size_for_vit
in_channels = jax_config.num_channels_for_vit
h, w = 8, 8 # 8x8 patches

n_patches = h * w
total_elements = n_patches * in_channels * temporal_patch_size * patch_size * patch_size
flat_data, _ = create_random_jax_torch(total_elements)

# Reshape inputs
jax_hidden_states = flat_data.reshape(1, in_channels, temporal_patch_size, h * patch_size, w * patch_size)
torch_hidden_states = torch.from_numpy(
np.array(flat_data).reshape((n_patches, in_channels, temporal_patch_size, patch_size, patch_size))
)

grid_thw = np.array([[1, h, w]], dtype=np.int64)
grid_thw_torch = torch.from_numpy(grid_thw)

# PyTorch forward
torch_out = torch_encoder(torch_hidden_states, grid_thw_torch)
torch_output = torch_out.pooler_output # after merger

# JAX forward
jax_encoder_output, _ = jax_encoder(jax_hidden_states)
jax_output = jax_projector(jax_encoder_output)
jax_output = jax_output[0]

# Compare final projected outputs. Use 2e-2 (vs 1e-2 for omni) because out_hidden_size=4096
# (vs 2048 for omni) doubles the projector MLP output, accumulating slightly more float32 error.
assert_all_close_jax_torch(
jax_output,
torch_output,
rtol=1e-2,
atol=2e-2,
error_msg="Qwen3.5 JAX subclassed vision tower final output differs from PyTorch reference",
)


if __name__ == "__main__":
unittest.main()
Loading