Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion defuser/modeling/unfused_moe/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Route tokens exactly like HF Qwen2 MoE, then run explicit expert modules."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states)
_, routing_weights, selected_experts = self.gate(hidden_states)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = run_routed_experts(
Expand All @@ -44,7 +45,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.num_experts,
)

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

final_hidden_states = final_hidden_states + shared_expert_output
Expand Down
2 changes: 1 addition & 1 deletion defuser/modeling/unfused_moe/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Route tokens exactly like HF Qwen3-Next MoE, then run explicit experts."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states)
_, routing_weights, selected_experts = self.gate(hidden_states)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = run_routed_experts(
Expand All @@ -43,7 +44,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.num_experts,
)

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

final_hidden_states = final_hidden_states + shared_expert_output
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "Defuser"
version = "0.0.20"
version = "0.0.21"
description = "Model defuser helper for HF Transformers."
readme = "README.md"
requires-python = ">=3.9"
Expand Down
92 changes: 92 additions & 0 deletions tests/test_convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,76 @@ def _assert_sparse_moe_defused_matches_fused_math(
torch.testing.assert_close(actual, expected, **assert_close_kwargs)


def _force_route_all_experts(block: nn.Module) -> None:
"""Set MoE routers to select all experts so execution-order hooks always fire."""

router = getattr(block, "gate", None)
num_experts = getattr(block, "num_experts", None)
if router is None or num_experts is None:
return

for name in ("top_k", "num_experts_per_tok"):
if hasattr(router, name):
setattr(router, name, num_experts)
return


def _semantic_sparse_moe_execution_order(block: nn.Module, hidden_states: torch.Tensor) -> list[str]:
"""Record semantic MoE execution order for shared expert, router, routed experts, and shared gate."""

_force_route_all_experts(block)
raw_events: list[str] = []
handles = []

def _record(event_name: str):
def _hook(_module, _inputs):
raw_events.append(event_name)
return _hook

if hasattr(block, "shared_expert"):
handles.append(block.shared_expert.register_forward_pre_hook(_record("shared_expert")))
if hasattr(block, "gate"):
handles.append(block.gate.register_forward_pre_hook(_record("gate")))
if hasattr(block, "shared_expert_gate"):
handles.append(block.shared_expert_gate.register_forward_pre_hook(_record("shared_expert_gate")))

experts = getattr(block, "experts", None)
if isinstance(experts, nn.ModuleList):
for idx, expert in enumerate(experts):
handles.append(expert.register_forward_pre_hook(_record(f"expert_{idx}")))
elif isinstance(experts, nn.Module):
handles.append(experts.register_forward_pre_hook(_record("experts")))

try:
with torch.inference_mode():
block.eval()(hidden_states)
finally:
for handle in handles:
handle.remove()

semantic_events: list[str] = []
for event in raw_events:
normalized = "routed_experts" if event.startswith("expert_") or event == "experts" else event
if not semantic_events or semantic_events[-1] != normalized:
semantic_events.append(normalized)
return semantic_events


def _assert_sparse_moe_defused_matches_fused_execution_order(
original_block: nn.Module,
defused_block: nn.Module,
hidden_states: torch.Tensor,
) -> None:
"""Defused blocks must preserve the same semantic execution order as fused HF blocks."""

_seed_floating_tensors(original_block)
_copy_sparse_moe_weights(original_block, defused_block)

expected = _semantic_sparse_moe_execution_order(original_block, hidden_states)
actual = _semantic_sparse_moe_execution_order(defused_block, hidden_states)
assert actual == expected


def test_qwen2_moe():
model_type = "qwen2_moe"
replace_fused_blocks(model_type)
Expand Down Expand Up @@ -858,6 +928,17 @@ def test_qwen2_moe_defused_forward_matches_fused_math():
)


def test_qwen2_moe_defused_forward_matches_fused_execution_order():
config = _tiny_moe_config(Qwen2MoeConfig)
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)

_assert_sparse_moe_defused_matches_fused_execution_order(
Qwen2MoeSparseMoeBlock(config),
LinearQwen2MoeSparseMoeBlock(config),
hidden_states,
)


def test_qwen3_moe_defused_forward_matches_fused_math():
config = _tiny_moe_config(Qwen3MoeConfig)
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
Expand All @@ -880,6 +961,17 @@ def test_qwen3_next_defused_forward_matches_fused_math():
)


def test_qwen3_next_defused_forward_matches_fused_execution_order():
config = _tiny_moe_config(Qwen3NextConfig)
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)

_assert_sparse_moe_defused_matches_fused_execution_order(
Qwen3NextSparseMoeBlock(config),
LinearQwen3NextSparseMoeBlock(config),
hidden_states,
)


def test_qwen3_omni_defused_forward_matches_fused_math():
config = _tiny_qwen3_omni_config().thinker_config.text_config
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
Expand Down