[None][fix] Relax W8A16 MoE test tolerance for DTP mode#12335
[None][fix] Relax W8A16 MoE test tolerance for DTP mode#12335xxi-nv merged 1 commit intoNVIDIA:mainfrom
Conversation
…error Replace strict torch.testing.assert_close with percent-based check_accuracy in W8A16RefGatedMLPFusedMoE. In DTP/TTP mode (moe_tp_size > 1), TP AllReduce accumulates bf16 rounding errors on top of INT8 quantization error, causing 0.024% element mismatch that exceeds the element-wise atol. Use percent=0.96 for TP mode (consistent with UnquantizedRefMLPFusedMoE) and percent=0.99 for single-GPU/EP mode. Signed-off-by: xxi <xxi@nvidia.com>
|
/bot run --disable-fail-fast |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughUpdated the Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip Flake8 can be used to improve the quality of Python code reviews.Flake8 is a Python linter that wraps PyFlakes, pycodestyle and Ned Batchelder's McCabe script. To configure Flake8, add a '.flake8' or 'setup.cfg' file to your project root. See Flake8 Documentation for more details. |
|
/bot run --disable-fail-fast |
|
PR_Github #39582 [ run ] triggered by Bot. Commit: |
|
PR_Github #39582 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #39653 [ run ] triggered by Bot. Commit: |
|
PR_Github #39653 [ run ] completed with state |
Summary
torch.testing.assert_closewith percent-basedcheck_accuracyinW8A16RefGatedMLPFusedMoEmoe_tp_size > 1), TP AllReduce accumulates bf16 rounding errors on top of INT8 quantization error, causing 0.024% element mismatch that exceeds the element-wise atolpercent=0.96for TP mode (consistent withUnquantizedRefMLPFusedMoE) andpercent=0.99for single-GPU/EP modeRoot cause
When running CUTLASS W8A16 with DTP parallel mode and DeepSeekV3 routing (
top_k == num_experts), each expert's weight matrix is split across 4 ranks. Each rank computes a partial GEMM, then AllReduce sums the partials. The AllReduce introduces bf16 rounding error that compounds with INT8 quantization error. Only 1/4096 elements (0.024%) exceed the strict tolerance.Affected test cases
All 3 failures share:
parallel=DTP, backend=CUTLASS, quant=W8A16, routing=DeepSeekV3, e4_k4_h512_i512, seq=8, bfloat16— only comm method differs (NVLINK_ONE_SIDED / NVLINK_TWO_SIDED / DEEPEP).Test plan
🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes