Skip to content

[None][fix] multi_stream_moe + MLIR accuracy regression in monolithic CUDA graph decode path #12954

@suyoggupta

Description

@suyoggupta

Summary

PR #12847 fixed multi_stream_moe + mlir_elementwise_fusion accuracy for the piecewise/eager path but not for the monolithic CUDA graph path (decode). With multi_stream_moe: enabled: true and mlir_elementwise_fusion: enabled: true, GSM8K accuracy drops from ~90% to ~80% on Gemma4 MoE (google/gemma-4-26B-A4B-it).

Disabling MLIR while keeping multi-stream enabled restores accuracy to ~90%, confirming the issue is specific to MLIR Triton kernels running on the aux stream during CUDA graph capture/replay.

Root Cause

MLIR-generated Triton kernels interact with PyTorch's CUDA caching allocator in a way that breaks cross-stream safety during CUDA graph capture. The caller_stream.synchronize() fix from #12847 only covers the non-capture path (gated by not torch.cuda.is_current_stream_capturing()). The monolithic CUDA graph path (used for decode-only batches) captures with multi-stream + MLIR kernels on both main and aux streams, leading to silent data corruption.

The piecewise path is correct because #12847 reclassifies stream-switch partitions as dynamic (eager execution with synchronize()).

Reproduction

# With multi_stream_moe enabled + MLIR enabled (gemma4_moe.yaml)
TRTLLM_ACCURACY_NO_REFERENCE=1 pytest tests/integration/defs/accuracy/test_llm_api_autodeploy.py::TestGemma4MoE::test_bf16 -s -v
# GSM8K: ~80% (expected: ~90%)

To confirm MLIR is the cause, set mlir_elementwise_fusion: enabled: false in gemma4_moe.yaml — accuracy returns to ~90%.

Proposed Fix

Prevent MLIR fusion of ops that will run on the aux stream (the shared-expert branch). Main-stream ops (MoE, merge, attention) keep MLIR fusion. This preserves both multi-stream overlap for decode latency and MLIR fusion for main-stream performance.

Approach: pre-MLIR annotation pass

  1. Add a lightweight transform (mark_multi_stream_moe_shared_experts) that runs in post_load_fusion stage, before mlir_elementwise_fusion.
  2. This transform reuses the shared-expert identification logic from multi_stream_moe (find fused MoE → walk forward to merge node → trace shared branch backwards) but only sets node.meta["skip_mlir_fusion"] = True — no graph mutation.
  3. In FXToMLIRConverter._convert_call_function(), check the flag and route marked nodes to opaque (non-fusible) lowering. Exclude operator.getitem from the check (it must always use _convert_getitem for correct multi-result tuple access).
  4. multi_stream_moe stays in compile stage and inserts stream switches as before.

Constraints discovered during investigation

  • Cannot move multi_stream_moe before MLIR: The MLIR FX→MLIR→FX roundtrip cannot handle the begin_aux_stream_passthrough / end_aux_stream_passthrough nodes (they lack val metadata), corrupting the graph.
  • Cannot move MLIR to compile stage: The MLIR roundtrip cannot handle ops inserted by cache_init transforms (cached attention, etc.), causing getitem index errors on fused GEMM tuple returns.

Key Files

  • tensorrt_llm/_torch/auto_deploy/utils/multi_stream_utils.pysynchronize() fix (non-capture path only)
  • tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py — piecewise dynamic reclassification
  • tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py — shared-expert identification logic
  • tensorrt_llm/_torch/auto_deploy/mlir/fx_to_mlir.py_convert_call_function (skip flag check point)
  • tensorrt_llm/_torch/auto_deploy/transform/library/mlir_elementwise_fusion.py — MLIR fusion transform
  • examples/auto_deploy/model_registry/configs/gemma4_moe.yaml — model config

Metadata

Metadata

Assignees

Labels

AutoDeploy<NV> AutoDeploy BackendCUDA GraphCustomized kernels<NV>Specialized/modified CUDA kernels in TRTLLM for LLM ops, beyond standard TRT. Dev & perf.

Type

No type

Projects

Status

Backlog

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions