[None][feat] Optimize GDN of Qwen3-Next/3.5; adds BF16 TRTLLM MoE#12557
[None][feat] Optimize GDN of Qwen3-Next/3.5; adds BF16 TRTLLM MoE#12557rosenrodt wants to merge 9 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Three optimizations to eliminate GPU idle bubbles during prefill in
Mamba2Metadata.prepare() for hybrid GDN models (e.g. Qwen3.5):
1. Remove tl.constexpr from num_seqs and N in _cu_seqlens_triton_kernel.
Triton JIT recompiles for each unique constexpr value (~120ms each).
In serving, num_seqs varies every prefill step, causing repeated
recompilation. With dynamic parameters, only one compilation occurs.
2. Accept total_seqlens from caller to skip first GPU->CPU sync.
cu_seqlens[-1].item() blocked on all pending GPU work. The caller
(Mamba2Metadata.prepare) already has num_ctx_tokens on CPU.
3. Compute extra_chunks with pure Python arithmetic on CPU seq_lens
to eliminate the second GPU->CPU sync (cumsum + p[-1].item()).
Before: _prepare_inputs ~120-460ms per prefill step (Triton recompile +
GPU sync bubbles)
After: _prepare_inputs ~1-2ms steady state
Verified: 9200+ random equivalence tests + e2e serving assertion with
1000 requests (0 mismatches). GSM8K accuracy unchanged (90.07% on full
1319 samples).
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
- update chunked Gated Delta Rule prefill to use indexed in-kernel state updates - remove explicit Qwen3Next prefill state gather/scatter in forward_extend - retune causalConv1d forward launch selection for varlen and short sequences Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
e2c4962 to
85ec854
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #40455 Bot args parsing error: usage: /bot [-h] |
|
PR_Github #40456 [ run ] triggered by Bot. Commit: |
|
PR_Github #40456 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40481 [ run ] triggered by Bot. Commit: |
|
PR_Github #40481 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40500 [ run ] triggered by Bot. Commit: |
- keep decode qkv views and make the fused recurrent kernel stride-aware - restore the decode tile choice that wins on the representative bs256 pure-decode benchmark Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
162777e to
6b67c8e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #40502 [ run ] triggered by Bot. Commit: |
|
PR_Github #40502 [ run ] completed with state
|
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #40530 [ run ] triggered by Bot. Commit: |
|
PR_Github #40530 [ run ] completed with state |
|
cc @VALLIS-NERIA @nv-guomingz as this PR modifies some of the GDN, mamba state kernels |
@coderabbitai summary
Description
Perf
Qwen3.5-35B-A3B BF16 TP1
ISL/OSL=4k/1k synthetic (ignore_eos=True)
Tested on B200
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.