Skip to content

[None][feat] Optimize causal_conv1d prefill and decode kernels#13103

Open
Wanli-Jiang wants to merge 1 commit intoNVIDIA:mainfrom
Wanli-Jiang:user/williamj/update-causal-conv1d-update
Open

[None][feat] Optimize causal_conv1d prefill and decode kernels#13103
Wanli-Jiang wants to merge 1 commit intoNVIDIA:mainfrom
Wanli-Jiang:user/williamj/update-causal-conv1d-update

Conversation

@Wanli-Jiang
Copy link
Copy Markdown
Collaborator

@Wanli-Jiang Wanli-Jiang commented Apr 16, 2026

Features

Decode:

  • Add fast-path kernel for seqlen=1 with zero loops and fast math
  • Increase thread block from 64 to 128 to halve block count
  • Compile-time specialize conv_state_indices and silu branches

Prefill:

  • Use 128 threads for varlen with long sequences
  • Enable VecLoad for varlen BS=1 (seq_start=0 is always aligned)
  • Move conv_state save before main loop, removing 80 lines of complex cross-chunk state extraction from smem_exchange
  • Compile-time specialize conv_state_indices and silu branches

Benchmark Results

Hardware: NVIDIA B300 SXM6 AC, bf16
Model config: conv_dim=10240, width=4, 40 Mamba layers
Methodology: CUDA event timing (median of 1000 iterations, 300 warmup) for per-seqlen prefill; nsys pure GPU kernel time for decode

Decode (conv1d_update, SL=1, nsys kernel time)

BS Before (us) After (us) Speedup
1 3.8 2.0 1.90x
16 6.0 2.7 2.22x
64 14.6 5.2 2.81x
256 47.8 14.4 3.32x

Prefill (conv1d_fwd, BS=1, CUDA event per-seqlen)

ISL Before (us) After (us) Speedup
1,000 35.0 24.9 1.41x
8,000 228.0 106.7 2.14x
10,000 280.0 129.4 2.16x
50,000 1,280.0 678.0 1.89x

Summary by CodeRabbit

  • Refactor
    • Optimized causal convolution kernels through compile-time specialization for improved performance.
    • Introduced specialized fast-path kernel for single-token decode operations.
    • Enhanced kernel launch configuration for better hardware utilization.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-causal-conv1d-update branch from b9d7663 to 545f026 Compare April 16, 2026 05:05
@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43653 [ run ] triggered by Bot. Commit: 545f026 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43653 [ run ] completed with state FAILURE. Commit: 545f026
/LLM/main/L0_MergeRequest_PR pipeline #34140 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-causal-conv1d-update branch from 545f026 to e35ae73 Compare April 16, 2026 07:38
@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43712 [ run ] triggered by Bot. Commit: e35ae73 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43712 [ run ] completed with state FAILURE. Commit: e35ae73
/LLM/main/L0_MergeRequest_PR pipeline #34195 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-causal-conv1d-update branch from e35ae73 to f10252c Compare April 17, 2026 07:21
Decode:
- Add fast-path kernel for seqlen=1 with zero loops and fast math
- Increase thread block from 64 to 128 to halve block count
- Compile-time specialize conv_state_indices and silu branches

Prefill:
- Use 128 threads for varlen with long sequences
- Enable VecLoad for varlen BS=1 (seq_start=0 is always aligned)
- Move conv_state save before main loop, removing 80 lines of
  complex cross-chunk state extraction from smem_exchange
- Compile-time specialize conv_state_indices and silu branches

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-causal-conv1d-update branch from f10252c to 3a24d6e Compare April 17, 2026 07:23
@Wanli-Jiang Wanli-Jiang marked this pull request as ready for review April 17, 2026 07:23
@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 17, 2026

📝 Walkthrough

Walkthrough

The causal convolution CUDA kernel is refactored to use compile-time template specialization for kHasConvStateIndices and kSiluActivation parameters, replacing runtime branching. The forward kernel's state update logic is simplified by removing smem_exchange reconstruction paths. A new specialized decode kernel causal_conv1d_update_kernel_sl1 is added for single-token sequences, and the update kernel launch thread count increases from 64 to 128.

Changes

Cohort / File(s) Summary
Causal Convolution Kernel Specialization
cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
Forward kernel refactored to use compile-time template booleans (kHasConvStateIndices, kSiluActivation) for specialization instead of runtime branching. State update logic simplified by removing smem_exchange paths and replacing with direct conv_states writes. Update kernel gains compile-time specialization and new dedicated causal_conv1d_update_kernel_sl1 for decode case. Launch configurations updated with vectorized load expansion and heuristic changes for preferNarrowKernel. Update kernel thread count increased from 64 to 128.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description includes Features section, Benchmark Results table, and a checked PR Checklist, but the Description and Test Coverage sections required by the template are missing/empty. Complete the empty Description section explaining the issue/solution in short, and list relevant tests in the Test Coverage section to ensure sufficient coverage for the changes.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main optimization work on causal_conv1d kernels for both prefill and decode phases.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu (2)

379-380: Add braces around the early return statement.

The coding guidelines require that if statements always be followed by brace-delimited statements.

Proposed fix
-    if (channel_id >= params.dim)
-        return;
+    if (channel_id >= params.dim)
+    {
+        return;
+    }

As per coding guidelines: "if and else in C++ should always be followed by brace-delimited statements, even if empty or a single statement"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu` around lines 379 -
380, The if statement checking "if (channel_id >= params.dim) return;" violates
the brace rule; modify the conditional in the causalConv1d kernel so it uses
brace-delimited body (e.g., if (channel_id >= params.dim) { return; }) ensuring
you update the statement that references channel_id and params.dim to include
the braces.

499-500: Add braces around single-statement bodies.

Multiple if statements and for loops are missing brace-delimited bodies, violating the coding guidelines.

Proposed fix for if statements and for loops
     int const channel_id = blockIdx.y * kNThreads + tidx;
-    if (channel_id >= params.dim)
-        return;
+    if (channel_id >= params.dim)
+    {
+        return;
+    }
 
     int conv_state_batch_coord;
     if constexpr (kHasConvStateIndices)
     {
         conv_state_batch_coord = params.conv_state_indices_ptr[batch_id];
         if (conv_state_batch_coord == params.pad_slot_id)
+        {
             return;
+        }
     }
     else
     {
         conv_state_batch_coord = batch_id;
     }
     float w[kWidth];
 `#pragma` unroll
     for (int i = 0; i < kWidth; ++i)
+    {
         w[i] = float(__ldg(&weight[i * params.weight_width_stride]));
+    }
 
     float s[kWidth];
 `#pragma` unroll
     for (int i = 0; i < kWidth - 1; ++i)
+    {
         s[i] = float(conv_state[i * params.conv_state_l_stride]);
+    }
     s[kWidth - 1] = float(x[0]);
 
     float out_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t*>(params.bias_ptr)[channel_id]);
 `#pragma` unroll
     for (int i = 0; i < kWidth; ++i)
+    {
         out_val = __fmaf_rn(w[i], s[i], out_val);
+    }
     out_val = out_val * __frcp_rn(1.0f + __expf(-out_val));
-    x[0] = input_t(out_val);
+    out[0] = input_t(out_val);
 
     // Shift conv_state left by one and append the new token.
 `#pragma` unroll
     for (int i = 0; i < kWidth - 1; ++i)
+    {
         conv_state[i * params.conv_state_l_stride] = input_t(s[i + 1]);
+    }

As per coding guidelines: "if and else in C++ should always be followed by brace-delimited statements" and "The body of a switch, while, do..while, or for statement in C++ must be a compound statement (use braces)"

Also applies to: 506-507, 522-523, 527-528, 533-534, 540-541

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu` around lines 499 -
500, The code contains single-statement bodies for control structures (e.g., the
check "if (channel_id >= params.dim) return;") which violates the project's C++
style; update each such if and every for/while mentioned (including the other
occurrences around the same block) to use brace-delimited compound statements
instead of bare single statements (wrap the existing single-line body in { ...
}), keeping the original logic and indentation intact; specifically modify the
"if (channel_id >= params.dim)" check and the adjacent if/for statements
referenced in this region so they all have braces.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu`:
- Around line 513-536: The sl1 kernel is writing the result into the input
buffer x (x[0]) instead of the designated output buffer; fix by mirroring the
other kernel’s output setup: create an output pointer using params.out_ptr and
the same batch/channel strides (e.g. input_t* out =
reinterpret_cast<input_t*>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride) and replace x[0] = input_t(out_val) with
out[0] = input_t(out_val); ensure you use the same stride symbols from the diff
(params.out_ptr, params.out_batch_stride/out_c_stride) so the kernel writes to
params.out_ptr when x_ptr != out_ptr.

---

Nitpick comments:
In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu`:
- Around line 379-380: The if statement checking "if (channel_id >= params.dim)
return;" violates the brace rule; modify the conditional in the causalConv1d
kernel so it uses brace-delimited body (e.g., if (channel_id >= params.dim) {
return; }) ensuring you update the statement that references channel_id and
params.dim to include the braces.
- Around line 499-500: The code contains single-statement bodies for control
structures (e.g., the check "if (channel_id >= params.dim) return;") which
violates the project's C++ style; update each such if and every for/while
mentioned (including the other occurrences around the same block) to use
brace-delimited compound statements instead of bare single statements (wrap the
existing single-line body in { ... }), keeping the original logic and
indentation intact; specifically modify the "if (channel_id >= params.dim)"
check and the adjacent if/for statements referenced in this region so they all
have braces.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 980dc93f-cb60-4bce-8e6e-579f5aa95eb3

📥 Commits

Reviewing files that changed from the base of the PR and between 461f3b9 and 3a24d6e.

📒 Files selected for processing (1)
  • cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu

Comment on lines +513 to +536

input_t* conv_state = reinterpret_cast<input_t*>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride;
weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t* x
= reinterpret_cast<input_t*>(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride;

float w[kWidth];
#pragma unroll
for (int i = 0; i < kWidth; ++i)
w[i] = float(__ldg(&weight[i * params.weight_width_stride]));

float s[kWidth];
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i)
s[i] = float(conv_state[i * params.conv_state_l_stride]);
s[kWidth - 1] = float(x[0]);

float out_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t*>(params.bias_ptr)[channel_id]);
#pragma unroll
for (int i = 0; i < kWidth; ++i)
out_val = __fmaf_rn(w[i], s[i], out_val);
out_val = out_val * __frcp_rn(1.0f + __expf(-out_val));
x[0] = input_t(out_val);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Output written to input buffer instead of output buffer.

The sl1 kernel writes the result to x[0] (line 536), but the general update kernel writes to out[i * params.out_l_stride] (line 476). This kernel doesn't set up the out pointer at all.

If params.x_ptr != params.out_ptr, the output will be written to the wrong buffer, causing incorrect results or data corruption.

Proposed fix
     input_t* conv_state = reinterpret_cast<input_t*>(params.conv_state_ptr)
         + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride;
     weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) + channel_id * params.weight_c_stride;
     input_t* x
         = reinterpret_cast<input_t*>(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride;
+    input_t* out = reinterpret_cast<input_t*>(params.out_ptr) + batch_id * params.out_batch_stride
+        + channel_id * params.out_c_stride;
 
     float w[kWidth];
 `#pragma` unroll
@@ -533,7 +535,7 @@
     for (int i = 0; i < kWidth; ++i)
         out_val = __fmaf_rn(w[i], s[i], out_val);
     out_val = out_val * __frcp_rn(1.0f + __expf(-out_val));
-    x[0] = input_t(out_val);
+    out[0] = input_t(out_val);
#!/bin/bash
# Check if there are any callers that assume in-place operation (x_ptr == out_ptr)
# for the update kernel, which might justify writing to x instead of out.

echo "=== Searching for causal_conv1d_update calls to check buffer usage ==="
rg -n -C5 'causal_conv1d_update' --type cpp -g '!*.cu'

echo ""
echo "=== Checking ConvParamsBase setup to see if x_ptr and out_ptr are ever the same ==="
rg -n -B3 -A3 'out_ptr\s*=' --type cpp | head -60
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu` around lines 513 -
536, The sl1 kernel is writing the result into the input buffer x (x[0]) instead
of the designated output buffer; fix by mirroring the other kernel’s output
setup: create an output pointer using params.out_ptr and the same batch/channel
strides (e.g. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
batch_id * params.out_batch_stride + channel_id * params.out_c_stride) and
replace x[0] = input_t(out_val) with out[0] = input_t(out_val); ensure you use
the same stride symbols from the diff (params.out_ptr,
params.out_batch_stride/out_c_stride) so the kernel writes to params.out_ptr
when x_ptr != out_ptr.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43996 [ run ] triggered by Bot. Commit: 3a24d6e Link to invocation

@Wanli-Jiang Wanli-Jiang requested review from a team, leslie-fang25 and xxi-nv and removed request for a team April 17, 2026 09:24
@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44090 [ run ] triggered by Bot. Commit: 3a24d6e Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants