Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions tests/integration/model_bridge/test_qwen3_moe_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Integration tests for the Qwen3MoE TransformerBridge.

Uses a tiny programmatic config on the meta device — no network access or
weight downloads. Tensor ops can't execute on meta, so forward-pass tests are
skipped and run manually during verification. Fixture pattern mirrors
tests/unit/model_bridge/test_gpt_oss_moe.py.
"""

import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.model_bridge.bridge import TransformerBridge
from transformer_lens.model_bridge.generalized_components import MoEBridge
from transformer_lens.model_bridge.sources.transformers import (
map_default_transformer_lens_config,
)
from transformer_lens.model_bridge.supported_architectures.qwen3_moe import (
Qwen3MoeArchitectureAdapter,
)


class _MockTokenizer:
"""Stand-in to satisfy TransformerBridge(tokenizer=...)."""

pass


@pytest.fixture(scope="module")
def tiny_qwen3moe_config():
"""Small Qwen3MoeConfig: 2 layers, 4 heads, 4 experts."""
return AutoConfig.for_model(
"qwen3_moe",
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=16,
moe_intermediate_size=32,
num_experts=4,
num_experts_per_tok=2,
vocab_size=256,
max_position_embeddings=128,
decoder_sparse_step=1,
mlp_only_layers=[],
)


@pytest.fixture(scope="module")
def tiny_qwen3moe_model_meta(tiny_qwen3moe_config):
"""Qwen3MoE model on meta device (no weights loaded)."""
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(tiny_qwen3moe_config)
return model


@pytest.fixture(scope="module")
def tiny_qwen3moe_bridge(tiny_qwen3moe_config, tiny_qwen3moe_model_meta):
"""TransformerBridge wrapping the tiny meta-device Qwen3MoE model."""
tl_config = map_default_transformer_lens_config(tiny_qwen3moe_config)

bridge_config = TransformerBridgeConfig(
d_model=tl_config.d_model,
d_head=tl_config.d_head,
n_layers=tl_config.n_layers,
n_ctx=tl_config.n_ctx,
n_heads=tl_config.n_heads,
n_key_value_heads=tl_config.n_key_value_heads,
d_vocab=tl_config.d_vocab,
architecture="Qwen3MoeForCausalLM",
)

adapter = Qwen3MoeArchitectureAdapter(bridge_config)

return TransformerBridge(
model=tiny_qwen3moe_model_meta,
adapter=adapter,
tokenizer=_MockTokenizer(),
)


class TestQwen3MoeModelStructure:
def test_model_has_layers(self, tiny_qwen3moe_model_meta) -> None:
assert hasattr(tiny_qwen3moe_model_meta, "model")
assert hasattr(tiny_qwen3moe_model_meta.model, "layers")
assert len(tiny_qwen3moe_model_meta.model.layers) == 2

def test_layer_has_sparse_moe_block(self, tiny_qwen3moe_model_meta) -> None:
# Qwen3MoeSparseMoeBlock stores experts as batched 3D tensors, not a ModuleList
layer0_mlp = tiny_qwen3moe_model_meta.model.layers[0].mlp
assert hasattr(layer0_mlp, "experts")
experts = layer0_mlp.experts
assert hasattr(experts, "gate_up_proj")
assert hasattr(experts, "down_proj")
assert not hasattr(experts, "__iter__")

def test_layer_has_gate_router(self, tiny_qwen3moe_model_meta) -> None:
layer0_mlp = tiny_qwen3moe_model_meta.model.layers[0].mlp
assert hasattr(layer0_mlp, "gate")

def test_attention_has_q_norm_k_norm(self, tiny_qwen3moe_model_meta) -> None:
attn = tiny_qwen3moe_model_meta.model.layers[0].self_attn
assert hasattr(attn, "q_norm")
assert hasattr(attn, "k_norm")


class TestQwen3MoeBridgeStructure:
def test_block_count(self, tiny_qwen3moe_bridge) -> None:
assert len(tiny_qwen3moe_bridge.blocks) == 2

def test_has_core_components(self, tiny_qwen3moe_bridge) -> None:
assert hasattr(tiny_qwen3moe_bridge, "embed")
assert hasattr(tiny_qwen3moe_bridge, "unembed")
assert hasattr(tiny_qwen3moe_bridge, "ln_final")

def test_cfg_final_rms_is_true(self, tiny_qwen3moe_bridge) -> None:
"""Qwen3MoE uses final_rms=True; OLMoE uses False."""
assert tiny_qwen3moe_bridge.cfg.final_rms is True

def test_cfg_n_kv_heads(self, tiny_qwen3moe_bridge) -> None:
assert tiny_qwen3moe_bridge.cfg.n_key_value_heads == 2

def test_cfg_positional_embedding_type(self, tiny_qwen3moe_bridge) -> None:
assert tiny_qwen3moe_bridge.cfg.positional_embedding_type == "rotary"

def test_cfg_normalization_type(self, tiny_qwen3moe_bridge) -> None:
assert tiny_qwen3moe_bridge.cfg.normalization_type == "RMS"

def test_mlp_blocks_are_moe_bridge(self, tiny_qwen3moe_bridge) -> None:
for i, block in enumerate(tiny_qwen3moe_bridge.blocks):
assert isinstance(
block.mlp, MoEBridge
), f"Block {i} mlp is {type(block.mlp).__name__}, expected MoEBridge"

def test_moe_bridge_has_router_scores_hook(self, tiny_qwen3moe_bridge) -> None:
mlp = tiny_qwen3moe_bridge.blocks[0].mlp
assert hasattr(mlp, "hook_router_scores")

def test_block_has_ln1_and_ln2(self, tiny_qwen3moe_bridge) -> None:
block = tiny_qwen3moe_bridge.blocks[0]
assert hasattr(block, "ln1")
assert hasattr(block, "ln2")

def test_block_attn_has_q_norm_k_norm(self, tiny_qwen3moe_bridge) -> None:
attn = tiny_qwen3moe_bridge.blocks[0].attn
assert hasattr(attn, "q_norm")
assert hasattr(attn, "k_norm")


# Forward-pass tests require real weights — meta-device tensor ops raise
# NotImplementedError. Run these manually during Step 3 verification.


@pytest.mark.skip(reason="Requires real weights — run manually during verification")
def test_forward_pass_matches_hf(tiny_qwen3moe_bridge) -> None:
"""Bridge logits match the HF model."""
tokens = torch.tensor([[1, 2, 3, 4]])
with torch.no_grad():
bridge_out = tiny_qwen3moe_bridge(tokens)
hf_out = tiny_qwen3moe_bridge.original_model(tokens).logits
max_diff = (bridge_out - hf_out).abs().max().item()
assert max_diff < 1e-4, f"Bridge vs HF max diff = {max_diff}"


@pytest.mark.skip(reason="Requires real weights — run manually during verification")
def test_run_with_cache_captures_moe_router_scores(tiny_qwen3moe_bridge) -> None:
"""MoEBridge captures router scores in the activation cache."""
tiny_qwen3moe_bridge.enable_compatibility_mode(no_processing=True)
tokens = torch.tensor([[1, 2, 3, 4]])
_, cache = tiny_qwen3moe_bridge.run_with_cache(tokens)
for i in range(len(tiny_qwen3moe_bridge.blocks)):
assert f"blocks.{i}.mlp.hook_router_scores" in cache, f"Missing router scores for block {i}"
194 changes: 194 additions & 0 deletions tests/unit/model_bridge/test_qwen3_moe_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Unit tests for the Qwen3MoeArchitectureAdapter.

All tests use programmatic TransformerBridgeConfig instances — no network access
or model downloads.
"""

import pytest

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps.rearrange_tensor_conversion import (
RearrangeTensorConversion,
)
from transformer_lens.conversion_utils.param_processing_conversion import (
ParamProcessingConversion,
)
from transformer_lens.factories.architecture_adapter_factory import (
SUPPORTED_ARCHITECTURES,
)
from transformer_lens.model_bridge.generalized_components import (
MoEBridge,
RMSNormalizationBridge,
)
from transformer_lens.model_bridge.supported_architectures.qwen3_moe import (
Qwen3MoeArchitectureAdapter,
)


@pytest.fixture
def cfg() -> TransformerBridgeConfig:
return TransformerBridgeConfig(
d_model=64,
d_head=16,
n_layers=2,
n_ctx=128,
n_heads=4,
n_key_value_heads=2,
d_vocab=256,
architecture="Qwen3MoeForCausalLM",
)


@pytest.fixture
def adapter(cfg: TransformerBridgeConfig) -> Qwen3MoeArchitectureAdapter:
return Qwen3MoeArchitectureAdapter(cfg)


class TestQwen3MoeAdapterConfig:
def test_normalization_type_is_rms(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
assert adapter.cfg.normalization_type == "RMS"

def test_positional_embedding_type_is_rotary(
self, adapter: Qwen3MoeArchitectureAdapter
) -> None:
assert adapter.cfg.positional_embedding_type == "rotary"

def test_final_rms_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
"""Qwen3MoE uses final_rms=True; OLMoE uses False."""
assert adapter.cfg.final_rms is True

def test_gated_mlp_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
assert adapter.cfg.gated_mlp is True

def test_uses_rms_norm_is_true(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
assert adapter.cfg.uses_rms_norm is True

def test_attn_implementation_is_eager(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
assert adapter.cfg.attn_implementation == "eager"

def test_default_prepend_bos_is_false(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
assert adapter.cfg.default_prepend_bos is False

def test_n_kv_heads_propagated(self) -> None:
"""n_key_value_heads from the loaded config is preserved."""
cfg = TransformerBridgeConfig(
d_model=64,
d_head=16,
n_layers=2,
n_ctx=128,
n_heads=4,
n_key_value_heads=2,
d_vocab=256,
architecture="Qwen3MoeForCausalLM",
)
adapter = Qwen3MoeArchitectureAdapter(cfg)
assert adapter.cfg.n_key_value_heads == 2


class TestQwen3MoeWeightConversions:
def test_has_qkvo_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
convs = adapter.weight_processing_conversions
assert convs is not None
assert "blocks.{i}.attn.q.weight" in convs
assert "blocks.{i}.attn.k.weight" in convs
assert "blocks.{i}.attn.v.weight" in convs
assert "blocks.{i}.attn.o.weight" in convs

def test_q_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
"""Q rearrange uses n_heads (4)."""
convs = adapter.weight_processing_conversions
assert convs is not None
q_conv = convs["blocks.{i}.attn.q.weight"]
assert isinstance(q_conv, ParamProcessingConversion)
assert isinstance(q_conv.tensor_conversion, RearrangeTensorConversion)
axes = q_conv.tensor_conversion.axes_lengths
assert axes.get("n") == 4

def test_kv_rearrange_uses_n_kv_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
"""K/V rearrange uses n_key_value_heads (2) for GQA."""
convs = adapter.weight_processing_conversions
assert convs is not None
k_conv = convs["blocks.{i}.attn.k.weight"]
v_conv = convs["blocks.{i}.attn.v.weight"]
assert isinstance(k_conv, ParamProcessingConversion)
assert isinstance(v_conv, ParamProcessingConversion)
assert isinstance(k_conv.tensor_conversion, RearrangeTensorConversion)
assert isinstance(v_conv.tensor_conversion, RearrangeTensorConversion)
assert k_conv.tensor_conversion.axes_lengths.get("n") == 2
assert v_conv.tensor_conversion.axes_lengths.get("n") == 2

def test_o_rearrange_uses_n_heads(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
"""O rearrange uses n_heads (4)."""
convs = adapter.weight_processing_conversions
assert convs is not None
o_conv = convs["blocks.{i}.attn.o.weight"]
assert isinstance(o_conv, ParamProcessingConversion)
assert isinstance(o_conv.tensor_conversion, RearrangeTensorConversion)
assert o_conv.tensor_conversion.axes_lengths.get("n") == 4


class TestQwen3MoeComponentMapping:
def test_has_required_top_level_keys(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
for key in ("embed", "rotary_emb", "blocks", "ln_final", "unembed"):
assert key in mapping, f"Missing top-level key: {key!r}"

def test_blocks_has_required_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
blocks = mapping["blocks"]
for key in ("ln1", "ln2", "attn", "mlp"):
assert key in blocks.submodules, f"Missing blocks submodule: {key!r}"

def test_attn_has_all_submodules(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
attn = mapping["blocks"].submodules["attn"]
for key in ("q", "k", "v", "o", "q_norm", "k_norm"):
assert key in attn.submodules, f"Missing attn submodule: {key!r}"

def test_ln1_ln2_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
subs = mapping["blocks"].submodules
assert isinstance(subs["ln1"], RMSNormalizationBridge)
assert isinstance(subs["ln2"], RMSNormalizationBridge)

def test_mlp_is_moe_bridge(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
mlp = mapping["blocks"].submodules["mlp"]
assert isinstance(mlp, MoEBridge)

def test_mlp_has_gate_submodule(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
mlp = mapping["blocks"].submodules["mlp"]
assert "gate" in mlp.submodules

def test_q_norm_k_norm_are_rms_norm_bridges(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
mapping = adapter.component_mapping
assert mapping is not None
attn_subs = mapping["blocks"].submodules["attn"].submodules
assert isinstance(attn_subs["q_norm"], RMSNormalizationBridge)
assert isinstance(attn_subs["k_norm"], RMSNormalizationBridge)

def test_hf_module_paths(self, adapter: Qwen3MoeArchitectureAdapter) -> None:
"""HF module path names are mapped correctly."""
mapping = adapter.component_mapping
assert mapping is not None
assert mapping["embed"].name == "model.embed_tokens"
assert mapping["ln_final"].name == "model.norm"
assert mapping["unembed"].name == "lm_head"
assert mapping["blocks"].name == "model.layers"
subs = mapping["blocks"].submodules
assert subs["ln1"].name == "input_layernorm"
assert subs["ln2"].name == "post_attention_layernorm"
assert subs["attn"].name == "self_attn"
assert subs["mlp"].name == "mlp"


class TestQwen3MoeFactoryRegistration:
def test_factory_lookup_returns_adapter_class(self) -> None:
assert SUPPORTED_ARCHITECTURES["Qwen3MoeForCausalLM"] is Qwen3MoeArchitectureAdapter
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
PhiArchitectureAdapter,
Qwen2ArchitectureAdapter,
Qwen3ArchitectureAdapter,
Qwen3MoeArchitectureAdapter,
Qwen3NextArchitectureAdapter,
QwenArchitectureAdapter,
StableLmArchitectureAdapter,
Expand Down Expand Up @@ -91,6 +92,7 @@
"QwenForCausalLM": QwenArchitectureAdapter,
"Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
"Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter,
"Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
"StableLmForCausalLM": StableLmArchitectureAdapter,
"T5ForConditionalGeneration": T5ArchitectureAdapter,
Expand Down
Loading
Loading