Summary
PR #12847 fixed multi_stream_moe + mlir_elementwise_fusion accuracy for the piecewise/eager path but not for the monolithic CUDA graph path (decode). With multi_stream_moe: enabled: true and mlir_elementwise_fusion: enabled: true, GSM8K accuracy drops from ~90% to ~80% on Gemma4 MoE (google/gemma-4-26B-A4B-it).
Disabling MLIR while keeping multi-stream enabled restores accuracy to ~90%, confirming the issue is specific to MLIR Triton kernels running on the aux stream during CUDA graph capture/replay.
Root Cause
MLIR-generated Triton kernels interact with PyTorch's CUDA caching allocator in a way that breaks cross-stream safety during CUDA graph capture. The caller_stream.synchronize() fix from #12847 only covers the non-capture path (gated by not torch.cuda.is_current_stream_capturing()). The monolithic CUDA graph path (used for decode-only batches) captures with multi-stream + MLIR kernels on both main and aux streams, leading to silent data corruption.
The piecewise path is correct because #12847 reclassifies stream-switch partitions as dynamic (eager execution with synchronize()).
Reproduction
# With multi_stream_moe enabled + MLIR enabled (gemma4_moe.yaml)
TRTLLM_ACCURACY_NO_REFERENCE=1 pytest tests/integration/defs/accuracy/test_llm_api_autodeploy.py::TestGemma4MoE::test_bf16 -s -v
# GSM8K: ~80% (expected: ~90%)
To confirm MLIR is the cause, set mlir_elementwise_fusion: enabled: false in gemma4_moe.yaml — accuracy returns to ~90%.
Proposed Fix
Prevent MLIR fusion of ops that will run on the aux stream (the shared-expert branch). Main-stream ops (MoE, merge, attention) keep MLIR fusion. This preserves both multi-stream overlap for decode latency and MLIR fusion for main-stream performance.
Approach: pre-MLIR annotation pass
- Add a lightweight transform (
mark_multi_stream_moe_shared_experts) that runs in post_load_fusion stage, before mlir_elementwise_fusion.
- This transform reuses the shared-expert identification logic from
multi_stream_moe (find fused MoE → walk forward to merge node → trace shared branch backwards) but only sets node.meta["skip_mlir_fusion"] = True — no graph mutation.
- In
FXToMLIRConverter._convert_call_function(), check the flag and route marked nodes to opaque (non-fusible) lowering. Exclude operator.getitem from the check (it must always use _convert_getitem for correct multi-result tuple access).
multi_stream_moe stays in compile stage and inserts stream switches as before.
Constraints discovered during investigation
- Cannot move
multi_stream_moe before MLIR: The MLIR FX→MLIR→FX roundtrip cannot handle the begin_aux_stream_passthrough / end_aux_stream_passthrough nodes (they lack val metadata), corrupting the graph.
- Cannot move MLIR to
compile stage: The MLIR roundtrip cannot handle ops inserted by cache_init transforms (cached attention, etc.), causing getitem index errors on fused GEMM tuple returns.
Key Files
tensorrt_llm/_torch/auto_deploy/utils/multi_stream_utils.py — synchronize() fix (non-capture path only)
tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py — piecewise dynamic reclassification
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py — shared-expert identification logic
tensorrt_llm/_torch/auto_deploy/mlir/fx_to_mlir.py — _convert_call_function (skip flag check point)
tensorrt_llm/_torch/auto_deploy/transform/library/mlir_elementwise_fusion.py — MLIR fusion transform
examples/auto_deploy/model_registry/configs/gemma4_moe.yaml — model config
Summary
PR #12847 fixed
multi_stream_moe+mlir_elementwise_fusionaccuracy for the piecewise/eager path but not for the monolithic CUDA graph path (decode). Withmulti_stream_moe: enabled: trueandmlir_elementwise_fusion: enabled: true, GSM8K accuracy drops from ~90% to ~80% on Gemma4 MoE (google/gemma-4-26B-A4B-it).Disabling MLIR while keeping multi-stream enabled restores accuracy to ~90%, confirming the issue is specific to MLIR Triton kernels running on the aux stream during CUDA graph capture/replay.
Root Cause
MLIR-generated Triton kernels interact with PyTorch's CUDA caching allocator in a way that breaks cross-stream safety during CUDA graph capture. The
caller_stream.synchronize()fix from #12847 only covers the non-capture path (gated bynot torch.cuda.is_current_stream_capturing()). The monolithic CUDA graph path (used for decode-only batches) captures with multi-stream + MLIR kernels on both main and aux streams, leading to silent data corruption.The piecewise path is correct because #12847 reclassifies stream-switch partitions as dynamic (eager execution with
synchronize()).Reproduction
To confirm MLIR is the cause, set
mlir_elementwise_fusion: enabled: falseingemma4_moe.yaml— accuracy returns to ~90%.Proposed Fix
Prevent MLIR fusion of ops that will run on the aux stream (the shared-expert branch). Main-stream ops (MoE, merge, attention) keep MLIR fusion. This preserves both multi-stream overlap for decode latency and MLIR fusion for main-stream performance.
Approach: pre-MLIR annotation pass
mark_multi_stream_moe_shared_experts) that runs inpost_load_fusionstage, beforemlir_elementwise_fusion.multi_stream_moe(find fused MoE → walk forward to merge node → trace shared branch backwards) but only setsnode.meta["skip_mlir_fusion"] = True— no graph mutation.FXToMLIRConverter._convert_call_function(), check the flag and route marked nodes to opaque (non-fusible) lowering. Excludeoperator.getitemfrom the check (it must always use_convert_getitemfor correct multi-result tuple access).multi_stream_moestays incompilestage and inserts stream switches as before.Constraints discovered during investigation
multi_stream_moebefore MLIR: The MLIR FX→MLIR→FX roundtrip cannot handle thebegin_aux_stream_passthrough/end_aux_stream_passthroughnodes (they lackvalmetadata), corrupting the graph.compilestage: The MLIR roundtrip cannot handle ops inserted bycache_inittransforms (cached attention, etc.), causinggetitemindex errors on fused GEMM tuple returns.Key Files
tensorrt_llm/_torch/auto_deploy/utils/multi_stream_utils.py—synchronize()fix (non-capture path only)tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py— piecewise dynamic reclassificationtensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py— shared-expert identification logictensorrt_llm/_torch/auto_deploy/mlir/fx_to_mlir.py—_convert_call_function(skip flag check point)tensorrt_llm/_torch/auto_deploy/transform/library/mlir_elementwise_fusion.py— MLIR fusion transformexamples/auto_deploy/model_registry/configs/gemma4_moe.yaml— model config