Skip to content

[None][perf] FC2 DenseGEMM autotune: split-K, swap_ab, fine-grained tuning buckets#13833

Merged
zongfeijing merged 5 commits into
NVIDIA:mainfrom
JacobHu-NV:pr/fc2-kernel-autotune
May 18, 2026
Merged

[None][perf] FC2 DenseGEMM autotune: split-K, swap_ab, fine-grained tuning buckets#13833
zongfeijing merged 5 commits into
NVIDIA:mainfrom
JacobHu-NV:pr/fc2-kernel-autotune

Conversation

@JacobHu-NV
Copy link
Copy Markdown
Collaborator

@JacobHu-NV JacobHu-NV commented May 7, 2026

[None][perf] FC2 DenseGEMM autotune: split-K, swap_ab, fine-grained tuning buckets

Summary

Improves the MoE FC2 DenseGEMM kernel and its autotune setup:

  • Add split-K (sk ∈ {1, 2, 4}) and additional (mma_tiler_mn, cluster_shape_mn) candidates; epilogue uses atomic-add reduction when split_k > 1.
  • Add swap_ab mode in the Blackwell MoE FC2 kernel (token-major alpha-scale buffer, dedicated cp.async warp).
  • Cast m/n/k/l from Int64Int32 at the kernel entry; required by downstream cutlass.range / cute.size for small-m autotune profiling.
  • Switch FC2 autotuner to deep_gemm_gen_tuning_buckets (8-stride below 128, 128-stride above) and raise tune_max_num_tokens 256 → 512 to match FC1.

Performance

NVFP4, n=7168, k=65536, expert_count=256, B200, cold L2.
Brute-force sweep over (m × tactic × split_k × swap_ab) on the new kernel vs baseline (HEAD~4, no split-K / no swap_ab).

m new min (us) new tactic baseline min (us) baseline tactic
32 52.09 (128, 128),(1, 2),sk=2,swap=0 57.63 (128, 64),(1, 1),sk=1,swap=0
64 52.68 (128, 128),(1, 2),sk=2,swap=0 58.04 (128, 64),(1, 1),sk=1,swap=0
96 53.15 (128, 128),(1, 2),sk=2,swap=0 58.88 (128, 64),(1, 1),sk=1,swap=0
128 53.89 (128, 128),(1, 2),sk=2,swap=0 66.20 (128, 64),(1, 1),sk=1,swap=0
160 68.00 (256, 128),(2, 1),sk=1,swap=0 70.61 (256, 128),(2, 1),sk=1,swap=0
192 68.30 (256, 128),(2, 1),sk=1,swap=0 70.86 (256, 128),(2, 1),sk=1,swap=0
224 68.56 (256, 128),(2, 1),sk=1,swap=0 70.92 (256, 128),(2, 1),sk=1,swap=0
256 68.78 (256, 128),(2, 1),sk=1,swap=0 71.24 (256, 128),(2, 1),sk=1,swap=0
288 97.32 (128, 128),(1, 2),sk=4,swap=0 135.45 (256, 128),(2, 1),sk=1,swap=0
320 99.00 (128, 128),(1, 2),sk=4,swap=0 135.56 (256, 128),(2, 1),sk=1,swap=0
352 99.81 (128, 128),(1, 2),sk=4,swap=0 135.34 (256, 128),(2, 1),sk=1,swap=0
384 100.75 (128, 128),(1, 2),sk=4,swap=0 135.84 (256, 128),(2, 1),sk=1,swap=0

Highlights:

  • Small m (32–128): split_k=2 lifts per-wave occupancy → +10–23%.
  • Large m (288–384): the m=288 step in the baseline (71→135us) is reduced to ~100us by split_k=4+35–39%.
  • Mid m (160–256): same tactic chosen as baseline, +3–4% kernel-level tweak.

Test Plan

  • Kernel-level sweep across (m × tactic × split_k × swap_ab) on B200; numerical correctness verified against the reference implementation.
  • Op-level autotune + bench through torch.ops.trtllm.cute_dsl_nvfp4_dense_gemm_fc2_blackwell; autotune cache populated for all buckets with no failures.
  • CI: /bot run on B200 stages.

contributors: @mingyangHao @zongfeijing @JacobHu-NV

Summary by CodeRabbit

  • New Features

    • Added split-K execution support for MoE FC2 dense GEMM with configurable split factors.
    • Added SwapAB mode to change output layout and alpha-scaling behavior.
    • Added optimized vectorized FP16 atomic-reduction operations for split-K paths.
  • Tests

    • Extended test runner and CLI to validate split-K and SwapAB configurations.

@JacobHu-NV JacobHu-NV changed the title Enhance MoE DenseGEMM FC2 with split-K, swap_ab, and tuning improvements [None][perf] FC2 DenseGEMM autotune: split-K, swap_ab, fine-grained tuning buckets May 7, 2026
@JacobHu-NV JacobHu-NV marked this pull request as ready for review May 7, 2026 06:01
@JacobHu-NV JacobHu-NV requested review from a team as code owners May 7, 2026 06:01
@JacobHu-NV JacobHu-NV requested review from hyukn and leslie-fang25 May 7, 2026 06:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds split-K execution and SwapAB transposition support to the MoE FC2 dense GEMM kernel. It expands tactic search with split_k, introduces vectorized/Float16 atomics, threads split_k/swap_ab through kernel/grid/wrapper, rewrites epilogue storage for atomic reductions, and updates the test runner and CLI.

Changes

Split-K and SwapAB MoE FC2 GEMM

Layer / File(s) Summary
Low-level Atomic Operations
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
New vectorized_atomic_add_fp16x8 for vectorized FP16 atomics and added cutlass.Float16 branch in atomic_add_func emitting inline-asm atomic adds.
Kernel Interface and Data Shapes
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py, tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Sm100BlockScaledPersistentDenseGemmKernel constructor gains split_k and swap_ab; computes epi_compute_layout/epi_layout_atomic conditionally; wrapper casts shapes to Int32 and constructs C/alpha layouts based on swap_ab.
Tactic Space and Search
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
FC2 tactic enumeration extended with split_k candidates [1,2,4]; adds _MMA_TILE_K, validates K-tile divisibility and expert-boundary alignment; supports legacy 2-tuple or new 3-tuple tactics and includes split_k in the cache key; tuning config adjusted.
Grid Computation and Scheduling
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
_compute_grid accepts split_k; persistent grid L/third dimension inflates when split_k > 1; stage counts recomputed using epi_compute_layout.
TMA / TMEM Adjustments
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
Special-case TMA tensor shape/stride and TMEM pointer shifting for specific CTA tile shapes before loading SFB.
MMA Main Loop and Buffer Management
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
MMA mainloop refactored to operate over split-local K tiles (k_tiles_per_split / k_tile_cnt_local) with expert grouping and adjusted producer/consumer and accumulate/commit semantics.
Alpha Scale Loading and Indexing
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
Alpha staging/loading reworked for split-K expert indexing and SwapAB semantics: L-coordinate decomposition, per-split tile mapping, alpha SMEM sizing and pipeline barrier updates.
Epilogue Alpha Application
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
Two paths: non-swap applies per-subtile alpha scalars; SwapAB broadcasts per-CTA-N alpha across M and uses a dedicated alpha partition/consumer pipeline.
Epilogue Storage and Atomics
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
split_k==1 retains TMA store; split_k>1 uses atomic-add reductions with dtype-specific vectorized atomics or scalar fallback; SwapAB stages via non-swizzled SMEM for vectorized atomics.
Epilogue Copy Operations
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
T2R/R2S and SMEM partition/copy operations updated to use epi_compute_layout; TMA completion semantics adjusted for atomic-add path.
Infrastructure and Memory Layout
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py, tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Int64→Int32 casting for shapes; C tensor layout conditional on swap_ab; alpha_scale token-major layout choice by swap_ab; alpha SMEM sizing generalized with max(M,N); tuning config and compiled-kernel cache key include split_k.
Test Harness and Validation
tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py
CLI adds --split_k and --swap_ab; run() accepts them; kernel instantiation and reference computation updated for SwapAB; zero-initialize c_torch and generated workspace when split_k>1; alpha_scale layout and einsum indexing adjusted.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main changes: adding split-K support, swap_ab mode, and fine-grained tuning buckets to the FC2 DenseGEMM kernel autotune.
Description check ✅ Passed The PR description is comprehensive and well-structured. It includes a clear summary of changes, detailed performance metrics with specific benchmarking results, and a complete test plan with verification steps.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py (1)

2574-2580: 💤 Low value

Int64 to Int32 cast lacks bounds validation.

The cast assumes "practical tensor dimensions always fit in Int32" but doesn't verify this. While dimensions >2B elements are rare, a silent overflow would cause subtle correctness issues. Consider adding an assertion for safety, especially since this is a public wrapper API.

🛡️ Optional bounds check
+        INT32_MAX = 2147483647
+        if m > INT32_MAX or n > INT32_MAX or k > INT32_MAX or l > INT32_MAX:
+            raise ValueError("Tensor dimensions exceed Int32 range")
         m = cutlass.Int32(m)
         n = cutlass.Int32(n)
         k = cutlass.Int32(k)
         l = cutlass.Int32(l)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py`
around lines 2574 - 2580, Add explicit bounds validation before casting m, n, k,
l to cutlass.Int32: check each dimension is within the signed 32-bit range
(e.g., between -(2**31) and 2**31-1) and raise/throw a clear error or assert if
not. Update the wrapper where m, n, k, l are converted (the cutlass.Int32(...)
lines) to perform these checks and include the offending variable name and value
in the error message so callers of this public API get immediate, actionable
feedback instead of silent overflow. Ensure the validation runs before any
cutlass.range() or other 32-bit-only calls.
tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py (1)

319-351: 💤 Low value

Misleading comment and unused parameter in create_alpha_scale_tensor.

  1. Line 320: The comment ### if not swapAB is misleading since this function is called for both swap_ab modes — the caller passes different values for m (either the actual M or N dimension).

  2. Line 319: The n parameter is never used inside create_alpha_scale_tensor. Consider removing it or documenting why it's retained.

These are documentation nits; the logic is correct.

📝 Suggested documentation fix
-    def create_alpha_scale_tensor(l, m, n, expert_count, dtype):  # noqa: E741
-        ### if not swapAB
-        # True means alpha_scale is token (m) major for coalesced global memory access.
+    def create_alpha_scale_tensor(l, token_dim, expert_count, dtype):  # noqa: E741
+        # token_dim = M (standard) or N (swap_ab)
+        # alpha_scale is token-major for coalesced global memory access.
         alpha_scale_ref = cutlass_torch.matrix(
             l,
-            m,
+            token_dim,
             expert_count,
             True,  # token_dim is major
             cutlass.Float32,

Then update the call sites accordingly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py`
around lines 319 - 351, The function create_alpha_scale_tensor currently has an
unused parameter n and a misleading comment "### if not swapAB"; remove the
unused n parameter from create_alpha_scale_tensor's signature and from its
internal references, update the call site where create_alpha_scale_tensor(...)
is invoked (the call that passes alpha_token_dim, n, expert_count, ...) to pass
only the needed parameters (l, alpha_token_dim, expert_count, dtype), and
replace the "### if not swapAB" comment with a short clarifying comment that the
caller controls token-major ordering via alpha_token_dim (derived from swap_ab).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py`:
- Around line 1873-1878: The scalar atomic fallback loop incorrectly calls
atomic_add_func(rVec_flat[j], scatter_out) instead of using the per-j computed
address; change the call to atomic_add_func(rVec_flat[j], scatter_j) so each
iteration uses scatter_j (computed via cute.domain_offset) rather than the
shared scatter_out, preserving correct per-element destinations; this affects
the branch handling in the vectorized_atomic_add_bf16x8 fallback where rVec,
rVec_flat, scatter_out, scatter_j, and atomic_add_func are used.

---

Nitpick comments:
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py`:
- Around line 2574-2580: Add explicit bounds validation before casting m, n, k,
l to cutlass.Int32: check each dimension is within the signed 32-bit range
(e.g., between -(2**31) and 2**31-1) and raise/throw a clear error or assert if
not. Update the wrapper where m, n, k, l are converted (the cutlass.Int32(...)
lines) to perform these checks and include the offending variable name and value
in the error message so callers of this public API get immediate, actionable
feedback instead of silent overflow. Ensure the validation runs before any
cutlass.range() or other 32-bit-only calls.

In
`@tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py`:
- Around line 319-351: The function create_alpha_scale_tensor currently has an
unused parameter n and a misleading comment "### if not swapAB"; remove the
unused n parameter from create_alpha_scale_tensor's signature and from its
internal references, update the call site where create_alpha_scale_tensor(...)
is invoked (the call that passes alpha_token_dim, n, expert_count, ...) to pass
only the needed parameters (l, alpha_token_dim, expert_count, dtype), and
replace the "### if not swapAB" comment with a short clarifying comment that the
caller controls token-major ordering via alpha_token_dim (derived from swap_ab).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9d51f9b2-9a7e-418d-893b-e4cbb7797a11

📥 Commits

Reviewing files that changed from the base of the PR and between c20b192 and ccf1718.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
  • tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py

@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

/bot run

@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

@coderabbitai full review

@zongfeijing
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

✅ Actions performed

Full review triggered.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47359 [ run ] triggered by Bot. Commit: f7b6b2e Link to invocation

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)

4255-4274: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Expose swap_ab in the FC2 tactic path too.

This runner now autotunes only (mma_tiler_mn, cluster_shape_mn, split_k). Because swap_ab never enters the tactic tuple, forward parsing, or the cache/kernel construction, cute_dsl_nvfp4_dense_gemm_fc2_blackwell still cannot autotune or launch the new swap-ab mode described for this PR, so the op-level sweep/cache population will miss that dimension entirely.

Also applies to: 4338-4437

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 4255 -
4274, The FC2 tactic candidate generator currently only returns (mma_tiler_mn,
cluster_shape_mn, split_k) and omits the swap_ab dimension; update the generator
in the method that builds FC2 candidates (the function surrounding the shown
diff that feeds cute_dsl_nvfp4_dense_gemm_fc2_blackwell) to include swap_ab in
the candidate tuple and enumeration (e.g., add a swap_ab_candidates list like
[False, True] and produce candidates as (mma_tiler_mn, cluster_shape_mn,
split_k, swap_ab)); then propagate this new 4-tuple shape through the FC2 tactic
path by (1) updating the forward parsing logic that reads tactics to expect
swap_ab for cute_dsl_nvfp4_dense_gemm_fc2_blackwell, and (2) including swap_ab
when constructing cache keys and kernel launch configuration so the autotuner
and cache population account for and can select the swap-ab mode.
🧹 Nitpick comments (3)
tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py (2)

603-608: 💤 Low value

Consider adding CLI validation for split_k values.

Per the PR description, split-K only supports values in {1, 2, 4}. Adding a choices constraint would provide clearer error messages than failing later in kernel construction.

💡 Suggested fix
     parser.add_argument(
         "--split_k",
         type=int,
         default=1,
+        choices=[1, 2, 4],
         help="Split-K factor (default: 1)",
     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py`
around lines 603 - 608, The CLI currently accepts any integer for the
"--split_k" argument in parser.add_argument; restrict accepted values to the
supported set {1,2,4} by adding an argparse validation (e.g., use the choices
parameter) on the "--split_k" argument so invalid inputs produce a helpful
argparse error before kernel construction; update the parser.add_argument call
for "--split_k" to include choices=[1,2,4] and adjust the help text to reflect
allowed values.

105-106: 💤 Low value

Consider documenting new parameters in docstring.

The new split_k and swap_ab parameters are not documented in the function docstring below. For consistency with other parameters, consider adding brief descriptions.

📝 Suggested docstring addition
     :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False
     :type use_cold_l2: bool, optional
+    :param split_k: Split-K factor for parallel reduction (valid values: 1, 2, 4), defaults to 1
+    :type split_k: int, optional
+    :param swap_ab: Whether to swap A/B roles (A=weight, B=activation), defaults to False
+    :type swap_ab: bool, optional
     :raises RuntimeError: If CUDA GPU is not available
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py`
around lines 105 - 106, The function run_moe_as_dense_gemm_fc2 has two new
parameters split_k and swap_ab that are not described in the function docstring;
update the docstring for run_moe_as_dense_gemm_fc2 to add short descriptions for
"split_k: int" (what it controls, default 1) and "swap_ab: bool" (what swapping
A/B does, default False), matching the style/format of the existing parameter
descriptions so callers and generated docs are consistent.
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py (1)

2627-2634: 💤 Low value

Clarify semantic difference in swap_ab alpha tensor.

The comments state "alpha per M (token)" and "alpha per N (token)" but this could be confusing since the term "token" has a specific meaning in MoE contexts. Consider clarifying whether swap_ab actually swaps which dimension represents tokens, or if the alpha semantics change for a different reason.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py`
around lines 2627 - 2634, The comment around alpha_token_dim/alpha_scale is
ambiguous about what swap_ab changes; update the comment to clearly state that
swap_ab toggles which matrix axis corresponds to the token dimension (i.e., when
swap_ab is true the token index is N instead of M) and that
alpha_scale_ptr/make_tensor/make_ordered_layout are built accordingly. Edit the
block around alpha_token_dim, alpha_scale_ptr, and the call to
cute.make_ordered_layout((alpha_token_dim, expert_count, l), order=(0,1,2)) so
the comment explicitly says "when swap_ab is False, tokens map to M (first dim);
when True, tokens map to N (first dim)" and reference swap_ab, alpha_token_dim,
alpha_scale, and alpha_scale_ptr so readers know the layout change is
intentional.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 4306-4314: The wrapped conditional causing flake8 E129 should be
reindented so continuations align under the opening parenthesis; update the
predicate in the loop over split_k_candidates (the if that checks k_tiles %
split_k and (k_tiles // split_k) % tiles_per_expert) to use a single
parenthesized expression with proper alignment, keeping the same logic, and then
append the tactic tuple (mma_tiler_mn, cluster_shape_mn, split_k) to tactics as
before; ensure variables referenced (_MMA_TILE_K, self.weight_per_expert,
k_tiles, tiles_per_expert, split_k) are unchanged.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py`:
- Around line 2265-2273: The calculation of alpha_dim uses
tiled_mma.thr_id.shape directly which can be a layout/tuple; change the division
to use cute.size(tiled_mma.thr_id.shape) so alpha_dim = max(mma_tiler_mnk[0] //
cute.size(tiled_mma.thr_id.shape), mma_tiler_mnk[1]) and recompute alpha_bytes
accordingly (update the expression that constructs alpha_bytes if necessary) to
match usages elsewhere in this file (see references to tiled_mma.thr_id.shape
and cute.size).
- Around line 916-919: Add an explicit validation that k_tile_total is divisible
by self.split_k to avoid silently dropping K-tiles: check the divisibility
either in the class constructor or in can_implement() (where other kernel
constraints are validated) and raise/assert with a clear message if k_tile_total
% self.split_k != 0; refer to k_tile_total, self.split_k, and k_tiles_per_split
when adding the check so it runs before computing k_tiles_per_split.
- Line 1582: The swap_ab split-K epilogue is using m_total =
malpha_scale_mnl.shape[0] and then checking m_global + 7 < m_total which
compares an M-coordinate against the N-sized alpha tensor; instead use the
actual M bound from the output tensor and the existing helper used in the
non-swap path. Replace the m_total/m_global comparison in the swap_ab split-K
branch with a bounds check against mC_raw.shape[0] (or call thread_in_bounds
with m_global and mC_raw) so the epilogue validates M against the real M
dimension of mC_raw, mirroring the non-swap split-K approach.

---

Outside diff comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 4255-4274: The FC2 tactic candidate generator currently only
returns (mma_tiler_mn, cluster_shape_mn, split_k) and omits the swap_ab
dimension; update the generator in the method that builds FC2 candidates (the
function surrounding the shown diff that feeds
cute_dsl_nvfp4_dense_gemm_fc2_blackwell) to include swap_ab in the candidate
tuple and enumeration (e.g., add a swap_ab_candidates list like [False, True]
and produce candidates as (mma_tiler_mn, cluster_shape_mn, split_k, swap_ab));
then propagate this new 4-tuple shape through the FC2 tactic path by (1)
updating the forward parsing logic that reads tactics to expect swap_ab for
cute_dsl_nvfp4_dense_gemm_fc2_blackwell, and (2) including swap_ab when
constructing cache keys and kernel launch configuration so the autotuner and
cache population account for and can select the swap-ab mode.

---

Nitpick comments:
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py`:
- Around line 2627-2634: The comment around alpha_token_dim/alpha_scale is
ambiguous about what swap_ab changes; update the comment to clearly state that
swap_ab toggles which matrix axis corresponds to the token dimension (i.e., when
swap_ab is true the token index is N instead of M) and that
alpha_scale_ptr/make_tensor/make_ordered_layout are built accordingly. Edit the
block around alpha_token_dim, alpha_scale_ptr, and the call to
cute.make_ordered_layout((alpha_token_dim, expert_count, l), order=(0,1,2)) so
the comment explicitly says "when swap_ab is False, tokens map to M (first dim);
when True, tokens map to N (first dim)" and reference swap_ab, alpha_token_dim,
alpha_scale, and alpha_scale_ptr so readers know the layout change is
intentional.

In
`@tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py`:
- Around line 603-608: The CLI currently accepts any integer for the "--split_k"
argument in parser.add_argument; restrict accepted values to the supported set
{1,2,4} by adding an argparse validation (e.g., use the choices parameter) on
the "--split_k" argument so invalid inputs produce a helpful argparse error
before kernel construction; update the parser.add_argument call for "--split_k"
to include choices=[1,2,4] and adjust the help text to reflect allowed values.
- Around line 105-106: The function run_moe_as_dense_gemm_fc2 has two new
parameters split_k and swap_ab that are not described in the function docstring;
update the docstring for run_moe_as_dense_gemm_fc2 to add short descriptions for
"split_k: int" (what it controls, default 1) and "swap_ab: bool" (what swapping
A/B does, default False), matching the style/format of the existing parameter
descriptions so callers and generated docs are consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 12182d32-835b-4e75-8042-1e1c43c79320

📥 Commits

Reviewing files that changed from the base of the PR and between 01618fa and f7b6b2e.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/moe_as_dense_gemm/fc2.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
  • tests/scripts/cute_dsl_kernels/moe_as_dense_gemm/run_moe_as_dense_gemm_fc2.py

Comment thread tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47359 [ run ] completed with state SUCCESS. Commit: f7b6b2e
/LLM/main/L0_MergeRequest_PR pipeline #37294 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

1 similar comment
@zongfeijing
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47645 [ run ] triggered by Bot. Commit: b12b493 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47645 [ run ] completed with state SUCCESS. Commit: b12b493
/LLM/main/L0_MergeRequest_PR pipeline #37550 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47883 [ run ] triggered by Bot. Commit: b12b493 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47883 [ run ] completed with state SUCCESS. Commit: b12b493
/LLM/main/L0_MergeRequest_PR pipeline #37736 completed with status: 'SUCCESS'

CI Report

Link to invocation

@zongfeijing zongfeijing self-requested a review May 12, 2026 08:30
@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48264 [ run ] triggered by Bot. Commit: e5fbbb2 Link to invocation

…GEMM FC2

Signed-off-by: JacobHu-NV <266902545+JacobHu-NV@users.noreply.github.com>
Introduces swap_ab mode in the Blackwell MoE FC2 kernel plus the
accompanying run script. Dedicated warp (warp_id=6) cp.async-loads
alpha scale into a pipelined smem buffer; alpha buffer is now
token-major (stride=1 along token dim) and token_dim flips to N
when swap_ab is enabled.

Signed-off-by: JacobHu-NV <266902545+JacobHu-NV@users.noreply.github.com>
The Sm100BlockScaledPersistentDenseGemmKernel.__call__ entry receives
m/n/k/l as Int64 for API compatibility, but downstream cutlass.range and
cute.size derivations require Int32. Without an explicit cast, autotune
profiling at small m (e.g. m=8 from deep_gemm_gen_tuning_buckets) fails
with "DSLRuntimeError: expected Int32 for stop, got Int64".

Cast at the kernel entry; practical tensor dimensions always fit in 32-bit.

Signed-off-by: JacobHu-NV <266902545+JacobHu-NV@users.noreply.github.com>
…ckets

Match the FC1 DenseGEMM autotuner setup so FC2 uses the dense per-token
buckets generated by deep_gemm_gen_tuning_buckets (8-stride below 128,
128-stride above) instead of the coarse power-of-2 bucket grid, and raise
tune_max_num_tokens from 256 to 512.

This gives autotune room to pick split-K tactics that benefit small and
mid-range m without inflating the cache or losing the M=288 step that
the previous bucket grid masked.

Signed-off-by: JacobHu-NV <266902545+JacobHu-NV@users.noreply.github.com>
…er alpha layout / SF zero-volume

Three discrete fixes that together unblock the full
test_nvfp4_dense_gemm_fc2_blackwell parametrize sweep (180 cases) and
the production-path test_moe_backend -k DENSEGEMM (8 cases):

1. acc / alpha pipeline cadence (kernel hang). MMA accumulates
   tiles_per_expert k-tiles into a single tmem stage and commits the
   acc_pipeline once per expert, but the epilogue and the alpha producer
   were both iterating k_tile_cnt_local times. The 2nd consumer_wait per
   expert blocked forever for weight_per_expert > mma_tiler_k. Producer
   and consumer of both acc_pipeline and alpha_scale_pipeline now iterate
   experts_per_split = k_tile_cnt_local / tiles_per_expert (compile-time
   constant per tactic), matching MMA's commit cadence. Applied to both
   swap_ab and non-swap epilogue paths.

2. Wrapper alpha layout (numerical mismatch). Kernel wrapper builds
   alpha_scale token-major (token has stride 1, expert has stride m) so
   that warp 6 can coalesce-load 32 contiguous M alphas per expert.
   PyTorch's default contiguous (M, expert_count) is expert-major, so
   the runner now does alpha_scale = alpha_scale.t().contiguous() before
   taking data_ptr(). This is invisible at m=1 or expert_count=1 but
   produced 16-34% match for any case with m>1 AND expert>1.

3. Wrapper SF zero-volume at m<128. The A/B SF cute layouts used floor
   div m // 128 / n // 128 for the block dim, which collapses to 0 for
   m or n < 128 and made the MMA read undefined SF state. Mirror FC1:
   use ceil-div m_blocks = (m + 127) // 128 (and the same for n).

Verification on B200 (sm_100):
- test_moe_densegemm.py (kernel-direct, 564 cases): PASS
- test_nvfp4_dense_gemm_fc2_blackwell (FC2 only, 180 cases): PASS
- test_moe_backend.py -k DENSEGEMM (production path, 8 cases): PASS

Signed-off-by: JacobHu-NV <266902545+JacobHu-NV@users.noreply.github.com>
@JacobHu-NV JacobHu-NV force-pushed the pr/fc2-kernel-autotune branch from e5fbbb2 to 7ddc37e Compare May 14, 2026 02:48
@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48282 [ run ] triggered by Bot. Commit: 7ddc37e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48282 [ run ] completed with state SUCCESS. Commit: 7ddc37e
/LLM/main/L0_MergeRequest_PR pipeline #38094 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@JacobHu-NV
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48317 [ run ] triggered by Bot. Commit: 7ddc37e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48317 [ run ] completed with state SUCCESS. Commit: 7ddc37e
/LLM/main/L0_MergeRequest_PR pipeline #38126 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@zongfeijing zongfeijing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@hyukn hyukn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zongfeijing zongfeijing merged commit 5e25977 into NVIDIA:main May 18, 2026
7 checks passed
KleinBlueC pushed a commit to KleinBlueC/TensorRT-LLM that referenced this pull request May 19, 2026
…uning buckets (NVIDIA#13833)

Signed-off-by: JacobHu-NV <266902545+JacobHu-NV@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants