Skip to content

[NVBug 6108145] Fix PTQ calibration and export for fused-experts MoE (Qwen3.5-MoE VLM)#1340

Merged
meenchen merged 6 commits intomainfrom
weimingc/fix_qwen36_moe_ptq
Apr 29, 2026
Merged

[NVBug 6108145] Fix PTQ calibration and export for fused-experts MoE (Qwen3.5-MoE VLM)#1340
meenchen merged 6 commits intomainfrom
weimingc/fix_qwen36_moe_ptq

Conversation

@meenchen
Copy link
Copy Markdown
Contributor

@meenchen meenchen commented Apr 24, 2026

What does this PR do?

Type of change: Bug fix

Fixes a 4-bug cascade that caused silent PTQ failure on Qwen3.5-MoE VLMs (Qwen3.6-35B-A3B): calibration
appeared to succeed but produced token-salad at inference. Root cause: HF's @use_experts_implementation
dispatches expert forward to torch._grouped_mm / torch.bmm, bypassing the F.linear hook that captures
activations — so gate_up_proj_input_quantizer / down_proj_input_quantizer never calibrated and no input_scale
tensors were emitted.

Changes:

  • examples/llm_ptq/hf_ptq.py — force config._experts_implementation = "eager" (recursing into text_config /
    vision_config / …) so per-expert F.linear calls are visible to the calibration hook.
  • modelopt/torch/quantization/conversion.py — normalize plural ModuleList quantizer names (weight_quantizers.N
    → weight_quantizer) before fnmatch, so wildcards like mlp.expertsweight_quantizer match fused-expert
    quantizers.
  • modelopt/torch/export/unified_export_hf.py — hoist the _QuantFusedExperts export branch above the
    get_quantization_format() gate so _export_fused_experts() runs even when the top-level format query returns
    QUANTIZATION_NONE (happens for experts-only recipes).
  • modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml — layerwise: false (VLM nested layer structure
    breaks the layerwise walker).

Usage

  python examples/llm_ptq/hf_ptq.py \
      --pyt_ckpt_path Qwen/Qwen3.6-35B-A3B \
      --qformat nvfp4 \
      --kv_cache_qformat fp8 \
      --calib_size 512 \
      --export_path Qwen3.6-35B-A3B-NVFP4

Testing

Testing

End-to-end PTQ → vLLM deploy → NEL eval on Qwen3.6-35B-A3B (256 experts × 40 layers, 35B params):

Hook-call diagnostic: 0 → 6720 per-expert F.linear calls during calibration after the fix; 0 → 30720
input_scale tensors emitted in the exported checkpoint.

FP8 fused-MoE path still produces gibberish — separate follow-up (vLLM per-expert weight_scale handling).

  • vLLM full-FP8: the FlashInfer TRTLLM Fp8MoE loader doesn't stack the 256 per-expert scalar weight_scale tensors
    into a [num_experts] per-expert vector — it ends up applying one expert's scale across all 256, so every
    routed expert dequants with the wrong amplitude → coherent token stream collapses into multilingual gibberish.

  • SGLang full-FP8: qwen3_5.py::_make_packed_weight_loader rejects with AssertionError: Unexpected scalar for
    tuple shard load: loaded_shard_id=(0,1,2), split_sizes=[1,1,1] — its packed-loader has no path for "N
    independent per-tensor source scalars combining into one fused-shard parameter," so the fused QKV (or
    in_proj_qkvz) load is structurally refused and the model never finishes loading.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Better fused-expert export flow, a plugin to force eager expert execution during calibration/export, and a representative quantizer discovery utility.
  • Bug Fixes

    • Reliable matching/discovery of per-expert indexed quantizers enabling correct calibration and mixed-precision export; fixes for calibration in nested decoder layouts.
  • Documentation

    • Clarified PTQ config guidance on layerwise calibration.
  • Tests

    • Added fused-experts calibration, export, and name-normalization tests.

@meenchen meenchen requested review from a team as code owners April 24, 2026 05:17
@meenchen meenchen requested review from cjluo-nv and realAsma April 24, 2026 05:17
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@meenchen meenchen self-assigned this Apr 24, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds end-to-end support for HuggingFace fused-experts in PTQ: normalizes per-expert quantizer names for wildcard matching, forces eager experts implementation at runtime, discovers representative weight quantizers, moves fused-expert export earlier, updates a recipe, and adds tests covering calibration and export.

Changes

Cohort / File(s) Summary
Exporter changes
modelopt/torch/export/unified_export_hf.py
Moves fused-expert export handling earlier in the quantized-module loop; exports fused experts immediately under an FSDP2-aware weight update (reshard=False) to ensure fused-expert splitting runs even when quantization-format checks would skip it.
Quantizer name matching
modelopt/torch/quantization/conversion.py
Adds _normalize_fused_experts_quantizer_name and updates _match_quantizer() to test both original and normalized names so wildcard rules match per-expert indexed quantizers (e.g., ...weight_quantizers.<N>).
Representative quantizer discovery
modelopt/torch/quantization/utils/core_utils.py, modelopt/torch/quantization/utils/__init__.py
Adds representative_weight_quantizer(module, weight_name) and exports it via __all__; refactors weight_attr_names to use the representative quantizer so ModuleList-style fused-expert quantizers are recognized.
Export quant utilities
modelopt/torch/export/quant_utils.py
Replaces direct attribute lookups with representative_weight_quantizer(...) for deriving weight block sizes and per-layer quantizer extraction, enabling fused-expert ModuleList quantizers to be discovered during export.
HuggingFace plugin
modelopt/torch/quantization/plugins/huggingface.py
Implements _QuantFusedExperts.fold_weight(keep_attrs) to fold per-expert quantizers into fused weight slices; adds force_eager_experts_impl_on_the_fly(model) that sets _experts_implementation = "eager" recursively on model/configs and registers it in CUSTOM_MODEL_PLUGINS.
Recipe tweak
modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml
Disables layerwise calibration (quantize.algorithm.layerwise: false) and documents a note about VLM decoder layer discovery limitations.
Tests
tests/unit/torch/quantization/plugins/test_fused_experts.py
Adds tests for force_eager_experts_impl_on_the_fly, _normalize_fused_experts_quantizer_name, end-to-end PTQ calibration for fused-experts (ensuring per-expert amax values), and mixed-precision export coverage for ModuleList-style expert quantizers.

Sequence Diagram(s)

sequenceDiagram
  participant Model as Model<br/><span style="background-color:rgba(52,152,219,0.5)">Model</span>
  participant Plugin as Plugin<br/><span style="background-color:rgba(46,204,113,0.5)">force_eager_experts_impl_on_the_fly</span>
  participant Calibrator as Calibrator<br/><span style="background-color:rgba(155,89,182,0.5)">PTQ Calibrator</span>
  participant Exporter as Exporter<br/><span style="background-color:rgba(241,196,15,0.5)">Unified Exporter</span>

  Model->>Plugin: detect fused-experts
  Plugin-->>Model: set _experts_implementation = "eager" (recursively)
  Model->>Calibrator: run calibration
  Calibrator->>Model: invoke _QuantFusedExperts.forward (eager path)
  Calibrator->>Calibrator: collect per-expert quantizer amax via hooks
  Model->>Exporter: request export
  Exporter->>Exporter: resolve representative_weight_quantizer for weight names
  Exporter->>Exporter: detect fused-experts early
  Exporter->>Exporter: call _export_fused_experts under fsdp2_aware_weight_update
  Exporter-->>Caller: produce export metadata (includes fused-experts quant spec)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.81% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and accurately describes the main change: fixing PTQ calibration and export for fused-experts MoE models, with a specific bug reference and model example.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Pull request contains no critical security anti-patterns as defined in SECURITY.md. All modified Python files follow secure coding practices.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch weimingc/fix_qwen36_moe_ptq

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 24, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-04-29 07:01 UTC

@meenchen meenchen requested a review from sychen52 April 24, 2026 05:21
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 649-657: The elif branch that checks for the same attribute
gate_up_proj_weight_quantizers is dead code because the preceding if block
handles that case and continues; remove the unreachable elif block (the second
check for gate_up_proj_weight_quantizers and its body) so only the initial
handling using fsdp2_aware_weight_update and _export_fused_experts(sub_module,
dtype) remains, leaving no duplicate checks for gate_up_proj_weight_quantizers
in the loop.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 6c028304-1560-45eb-bcba-de04e7c03a20

📥 Commits

Reviewing files that changed from the base of the PR and between 5887410 and 7c6c132.

📒 Files selected for processing (4)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/conversion.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml

Comment thread modelopt/torch/export/unified_export_hf.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 24, 2026

Codecov Report

❌ Patch coverage is 87.23404% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.96%. Comparing base (8eec6d4) to head (fda1e20).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/plugins/vllm_fakequant_hf.py 67.85% 9 Missing ⚠️
modelopt/torch/quantization/plugins/huggingface.py 90.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1340      +/-   ##
==========================================
+ Coverage   76.93%   76.96%   +0.03%     
==========================================
  Files         471      471              
  Lines       50404    50482      +78     
==========================================
+ Hits        38776    38855      +79     
+ Misses      11628    11627       -1     
Flag Coverage Δ
examples 41.58% <39.36%> (+0.90%) ⬆️
gpu 59.71% <86.17%> (-0.45%) ⬇️
regression 14.90% <10.63%> (+0.20%) ⬆️
unit 52.74% <51.06%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@meenchen meenchen added bug Something isn't working cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc labels Apr 24, 2026
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a well-structured bug fix with good test coverage. The 4-part fix (eager impl forcing, quantizer name normalization, export path hoisting, YAML layerwise change) is logically coherent and well-tested. However, there's a duplicate code path in unified_export_hf.py that should be cleaned up.

Design: Despite the complexity gate firing, this is a bug fix within existing systems, not an architectural change. No new abstractions are introduced.

Tests: Comprehensive — covers force_eager_experts_impl_on_the_fly edge cases, and an end-to-end calibration test that guards the full pipeline (name normalization → wildcard matching → amax collection). Good.

Issue: The new early-exit block in _process_quantized_modules makes the existing elif hasattr(sub_module, "gate_up_proj_weight_quantizers") block (deeper in the same function, inside the get_quantization_format() != QUANTIZATION_NONE guard) dead code. One of these should be removed.

Comment thread modelopt/torch/export/unified_export_hf.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/quantization/plugins/huggingface.py (1)

1441-1449: Harden recursive config traversal with a cycle guard.

The recursive _force() walk can loop forever on cyclic config graphs. Add a visited-set guard to make this robust.

Proposed patch
     nested_cfg_attrs = ("text_config", "vision_config", "audio_config", "speech_config")
+    visited_cfg_ids = set()
 
     def _force(cfg):
         if cfg is None:
             return
+        cfg_id = id(cfg)
+        if cfg_id in visited_cfg_ids:
+            return
+        visited_cfg_ids.add(cfg_id)
         if hasattr(cfg, "_experts_implementation"):
             cfg._experts_implementation = "eager"
         for sub in nested_cfg_attrs:
             if hasattr(cfg, sub):
                 _force(getattr(cfg, sub))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 1441 - 1449,
The recursive helper _force can loop on cyclic config graphs; modify _force to
accept and maintain a visited set (e.g., of object ids) and skip recursing into
cfg instances already seen, so each cfg is processed only once. Specifically,
update the _force signature to take an optional visited set, add the current cfg
to visited (use id(cfg) or cfg itself), return early if already visited, and
keep the existing behavior of setting cfg._experts_implementation = "eager" and
iterating nested_cfg_attrs; ensure recursive calls pass the same visited set.
This change hardens _force and prevents infinite recursion on cycles while still
touching the same symbols (_force, nested_cfg_attrs,
cfg._experts_implementation).
tests/unit/torch/quantization/plugins/test_fused_experts.py (1)

385-441: Make registry cleanup exception-safe in the calibration test.

If an assertion fails before the last line, the temporary registry entry may leak into subsequent tests.

Proposed patch
     def test_calibration_populates_all_expert_quantizers(self):
         """After PTQ, every input/weight quantizer on the fused-experts module has amax set."""
         import modelopt.torch.quantization as mtq
 
         model = _TinyMoEModel()
         expert_type = type(model.moe.experts)
         self._cleanup_registry(expert_type)
-
-        quant_cfg = {
+        try:
+            quant_cfg = {
             "quant_cfg": [
                 {"quantizer_name": "*", "enable": False},
                 {
                     "quantizer_name": "*gate_up_proj_input_quantizer",
                     "cfg": {"num_bits": 8, "axis": None},
@@
         for idx in range(NUM_EXPERTS):
             assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, (
                 f"gate_up_proj_weight_quantizers[{idx}].amax is None — "
                 "plural ModuleList name normalization in _match_quantizer likely broken."
             )
             assert experts.down_proj_weight_quantizers[idx].amax is not None, (
                 f"down_proj_weight_quantizers[{idx}].amax is None."
             )
-
-        self._cleanup_registry(expert_type)
+        finally:
+            self._cleanup_registry(expert_type)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 385
- 441, The test creates a temporary registry entry via expert_type and calls
self._cleanup_registry(expert_type) at the end but can leak if an assertion
fails; wrap the main test actions (quant_cfg setup, forward_loop, mtq.quantize,
and all asserts) in a try/finally and move the final
self._cleanup_registry(expert_type) into the finally block so cleanup always
runs; keep expert_type assigned before the try and leave the initial cleanup
call (before quantization) as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 1441-1449: The recursive helper _force can loop on cyclic config
graphs; modify _force to accept and maintain a visited set (e.g., of object ids)
and skip recursing into cfg instances already seen, so each cfg is processed
only once. Specifically, update the _force signature to take an optional visited
set, add the current cfg to visited (use id(cfg) or cfg itself), return early if
already visited, and keep the existing behavior of setting
cfg._experts_implementation = "eager" and iterating nested_cfg_attrs; ensure
recursive calls pass the same visited set. This change hardens _force and
prevents infinite recursion on cycles while still touching the same symbols
(_force, nested_cfg_attrs, cfg._experts_implementation).

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 385-441: The test creates a temporary registry entry via
expert_type and calls self._cleanup_registry(expert_type) at the end but can
leak if an assertion fails; wrap the main test actions (quant_cfg setup,
forward_loop, mtq.quantize, and all asserts) in a try/finally and move the final
self._cleanup_registry(expert_type) into the finally block so cleanup always
runs; keep expert_type assigned before the try and leave the initial cleanup
call (before quantization) as-is.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 310a1d2b-745c-4df5-8317-038bf82c199f

📥 Commits

Reviewing files that changed from the base of the PR and between 7c6c132 and 9414089.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

Comment thread modelopt/torch/export/unified_export_hf.py Outdated
Comment thread modelopt/torch/export/unified_export_hf.py Outdated
Comment thread modelopt/torch/quantization/conversion.py Outdated
@meenchen meenchen requested a review from cjluo-nv April 24, 2026 22:27
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All critical previous review comments have been addressed: dead code removed, import moved to top, control flow restructured, and singular/plural quantizer naming covered. The PR is a well-structured 4-part bug fix (eager impl forcing, quantizer name normalization, export path hoisting, YAML layerwise change) with comprehensive tests. No new abstractions — purely surgical fixes within existing systems. Code is correct and clean.

Complex PR: spans 6 directories (≥ 5). Looping in a human for approval.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/utils/core_utils.py`:
- Around line 232-257: weight_attr_names now yields plural fused-experts names
but consumers still assume a singular "<name>_weight_quantizer" attr; update
those consumers (e.g. the getattr call in quant_module.py and the logic in
model_calib.py) to use representative_weight_quantizer(module, name) to fetch
the correct quantizer (handles singular, plural ModuleList, or None) or
explicitly check for "<name>_weight_quantizers" and handle ModuleList entries,
ensuring you avoid bare getattr(...) without a default and preserve existing
behavior for singular quantizers.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 376b95ff-4015-46e6-81e9-e9ae47a24a05

📥 Commits

Reviewing files that changed from the base of the PR and between 5a11aeb and ecf7dd2.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/core_utils.py

Comment thread modelopt/torch/quantization/utils/core_utils.py
@meenchen meenchen force-pushed the weimingc/fix_qwen36_moe_ptq branch from c1242ce to 3eb0f21 Compare April 27, 2026 22:58
@meenchen
Copy link
Copy Markdown
Contributor Author

/ok to test 3eb0f21

@meenchen meenchen force-pushed the weimingc/fix_qwen36_moe_ptq branch from e65e48e to ce3e0e0 Compare April 28, 2026 20:53
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
…standard export branches

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
The base QuantModule.fold_weight only walks singular *_weight_quantizer
attributes, so per-expert quantizers in _QuantFusedExperts'
gate_up_proj_weight_quantizers / down_proj_weight_quantizers ModuleLists
are never folded, leaving _amax behind. The vLLM fake-quant export test
test_hf_vllm_export_tiny_qwen3_moe[FP8] surfaces this. Override
fold_weight on _QuantFusedExperts to walk the per-expert ModuleList,
apply each quantizer to its 3-D slice, disable, and drop _amax /
_pre_quant_scale.

Also add the blank line between docstring intro and bullet list in
representative_weight_quantizer / weight_attr_names so sphinx stops
emitting "Block quote ends without a blank line" warnings during
build-docs.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen force-pushed the weimingc/fix_qwen36_moe_ptq branch 2 times, most recently from 58b897a to 29fdcec Compare April 28, 2026 21:19
The vllm-fakequant export's _fakequant_module_weights iterated
module.named_children() looking for attr_name.endswith("weight_quantizer")
plus isinstance(quantizer, TensorQuantizer). Both filters skip the plural
*_weight_quantizers ModuleList that _QuantFusedExperts uses for per-expert
quantizers. As a result, fused MoE experts had their fused 3-D
gate_up_proj / down_proj saved un-fake-quantized while the rest of the
pipeline (fold_weight, modelopt state) treated them as fake-quanted ->
weight mismatch on round-trip and the gpu test
test_hf_vllm_export_tiny_qwen3_moe[FP8_DEFAULT_CFG] failed by 0.00244
(the FP8 quant noise that should have been zero).

Fix:
- Add _fakequant_fused_experts_weights helper that walks the per-expert
  ModuleList and applies each TensorQuantizer to its slice of the 3-D
  fused weight, writing into state_dict (non-mutating export) or
  in-place on the parameter, mirroring _QuantFusedExperts.fold_weight.
- Extend _WEIGHT_QUANTIZER_STATE_KEY regex to match the plural form
  weight_quantizers? so per-expert quantizer state (_amax etc.) is
  recognized as weight-quantizer state and stripped before save -
  otherwise the saved modelopt_state_weights would carry per-expert
  amax tensors and trip the test's "weight quantizer ... should have
  empty state after fold" invariant.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen force-pushed the weimingc/fix_qwen36_moe_ptq branch from 29fdcec to fda1e20 Compare April 28, 2026 21:45
@meenchen
Copy link
Copy Markdown
Contributor Author

/ok to test fda1e20

@meenchen meenchen merged commit 077e29a into main Apr 29, 2026
55 of 59 checks passed
@meenchen meenchen deleted the weimingc/fix_qwen36_moe_ptq branch April 29, 2026 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants