[None][feat] Optimize causal_conv1d prefill and decode kernels#13103
[None][feat] Optimize causal_conv1d prefill and decode kernels#13103Wanli-Jiang wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
b9d7663 to
545f026
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #43653 [ run ] triggered by Bot. Commit: |
|
PR_Github #43653 [ run ] completed with state
|
545f026 to
e35ae73
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #43712 [ run ] triggered by Bot. Commit: |
|
PR_Github #43712 [ run ] completed with state
|
e35ae73 to
f10252c
Compare
Decode: - Add fast-path kernel for seqlen=1 with zero loops and fast math - Increase thread block from 64 to 128 to halve block count - Compile-time specialize conv_state_indices and silu branches Prefill: - Use 128 threads for varlen with long sequences - Enable VecLoad for varlen BS=1 (seq_start=0 is always aligned) - Move conv_state save before main loop, removing 80 lines of complex cross-chunk state extraction from smem_exchange - Compile-time specialize conv_state_indices and silu branches Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
f10252c to
3a24d6e
Compare
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThe causal convolution CUDA kernel is refactored to use compile-time template specialization for Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu (2)
379-380: Add braces around the early return statement.The coding guidelines require that
ifstatements always be followed by brace-delimited statements.Proposed fix
- if (channel_id >= params.dim) - return; + if (channel_id >= params.dim) + { + return; + }As per coding guidelines: "
ifandelsein C++ should always be followed by brace-delimited statements, even if empty or a single statement"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu` around lines 379 - 380, The if statement checking "if (channel_id >= params.dim) return;" violates the brace rule; modify the conditional in the causalConv1d kernel so it uses brace-delimited body (e.g., if (channel_id >= params.dim) { return; }) ensuring you update the statement that references channel_id and params.dim to include the braces.
499-500: Add braces around single-statement bodies.Multiple
ifstatements andforloops are missing brace-delimited bodies, violating the coding guidelines.Proposed fix for if statements and for loops
int const channel_id = blockIdx.y * kNThreads + tidx; - if (channel_id >= params.dim) - return; + if (channel_id >= params.dim) + { + return; + } int conv_state_batch_coord; if constexpr (kHasConvStateIndices) { conv_state_batch_coord = params.conv_state_indices_ptr[batch_id]; if (conv_state_batch_coord == params.pad_slot_id) + { return; + } } else { conv_state_batch_coord = batch_id; }float w[kWidth]; `#pragma` unroll for (int i = 0; i < kWidth; ++i) + { w[i] = float(__ldg(&weight[i * params.weight_width_stride])); + } float s[kWidth]; `#pragma` unroll for (int i = 0; i < kWidth - 1; ++i) + { s[i] = float(conv_state[i * params.conv_state_l_stride]); + } s[kWidth - 1] = float(x[0]); float out_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t*>(params.bias_ptr)[channel_id]); `#pragma` unroll for (int i = 0; i < kWidth; ++i) + { out_val = __fmaf_rn(w[i], s[i], out_val); + } out_val = out_val * __frcp_rn(1.0f + __expf(-out_val)); - x[0] = input_t(out_val); + out[0] = input_t(out_val); // Shift conv_state left by one and append the new token. `#pragma` unroll for (int i = 0; i < kWidth - 1; ++i) + { conv_state[i * params.conv_state_l_stride] = input_t(s[i + 1]); + }As per coding guidelines: "
ifandelsein C++ should always be followed by brace-delimited statements" and "The body of aswitch,while,do..while, orforstatement in C++ must be a compound statement (use braces)"Also applies to: 506-507, 522-523, 527-528, 533-534, 540-541
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu` around lines 499 - 500, The code contains single-statement bodies for control structures (e.g., the check "if (channel_id >= params.dim) return;") which violates the project's C++ style; update each such if and every for/while mentioned (including the other occurrences around the same block) to use brace-delimited compound statements instead of bare single statements (wrap the existing single-line body in { ... }), keeping the original logic and indentation intact; specifically modify the "if (channel_id >= params.dim)" check and the adjacent if/for statements referenced in this region so they all have braces.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu`:
- Around line 513-536: The sl1 kernel is writing the result into the input
buffer x (x[0]) instead of the designated output buffer; fix by mirroring the
other kernel’s output setup: create an output pointer using params.out_ptr and
the same batch/channel strides (e.g. input_t* out =
reinterpret_cast<input_t*>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride) and replace x[0] = input_t(out_val) with
out[0] = input_t(out_val); ensure you use the same stride symbols from the diff
(params.out_ptr, params.out_batch_stride/out_c_stride) so the kernel writes to
params.out_ptr when x_ptr != out_ptr.
---
Nitpick comments:
In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu`:
- Around line 379-380: The if statement checking "if (channel_id >= params.dim)
return;" violates the brace rule; modify the conditional in the causalConv1d
kernel so it uses brace-delimited body (e.g., if (channel_id >= params.dim) {
return; }) ensuring you update the statement that references channel_id and
params.dim to include the braces.
- Around line 499-500: The code contains single-statement bodies for control
structures (e.g., the check "if (channel_id >= params.dim) return;") which
violates the project's C++ style; update each such if and every for/while
mentioned (including the other occurrences around the same block) to use
brace-delimited compound statements instead of bare single statements (wrap the
existing single-line body in { ... }), keeping the original logic and
indentation intact; specifically modify the "if (channel_id >= params.dim)"
check and the adjacent if/for statements referenced in this region so they all
have braces.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 980dc93f-cb60-4bce-8e6e-579f5aa95eb3
📒 Files selected for processing (1)
cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
|
|
||
| input_t* conv_state = reinterpret_cast<input_t*>(params.conv_state_ptr) | ||
| + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; | ||
| weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) + channel_id * params.weight_c_stride; | ||
| input_t* x | ||
| = reinterpret_cast<input_t*>(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; | ||
|
|
||
| float w[kWidth]; | ||
| #pragma unroll | ||
| for (int i = 0; i < kWidth; ++i) | ||
| w[i] = float(__ldg(&weight[i * params.weight_width_stride])); | ||
|
|
||
| float s[kWidth]; | ||
| #pragma unroll | ||
| for (int i = 0; i < kWidth - 1; ++i) | ||
| s[i] = float(conv_state[i * params.conv_state_l_stride]); | ||
| s[kWidth - 1] = float(x[0]); | ||
|
|
||
| float out_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t*>(params.bias_ptr)[channel_id]); | ||
| #pragma unroll | ||
| for (int i = 0; i < kWidth; ++i) | ||
| out_val = __fmaf_rn(w[i], s[i], out_val); | ||
| out_val = out_val * __frcp_rn(1.0f + __expf(-out_val)); | ||
| x[0] = input_t(out_val); |
There was a problem hiding this comment.
Output written to input buffer instead of output buffer.
The sl1 kernel writes the result to x[0] (line 536), but the general update kernel writes to out[i * params.out_l_stride] (line 476). This kernel doesn't set up the out pointer at all.
If params.x_ptr != params.out_ptr, the output will be written to the wrong buffer, causing incorrect results or data corruption.
Proposed fix
input_t* conv_state = reinterpret_cast<input_t*>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride;
weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t* x
= reinterpret_cast<input_t*>(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride;
+ input_t* out = reinterpret_cast<input_t*>(params.out_ptr) + batch_id * params.out_batch_stride
+ + channel_id * params.out_c_stride;
float w[kWidth];
`#pragma` unroll
@@ -533,7 +535,7 @@
for (int i = 0; i < kWidth; ++i)
out_val = __fmaf_rn(w[i], s[i], out_val);
out_val = out_val * __frcp_rn(1.0f + __expf(-out_val));
- x[0] = input_t(out_val);
+ out[0] = input_t(out_val);#!/bin/bash
# Check if there are any callers that assume in-place operation (x_ptr == out_ptr)
# for the update kernel, which might justify writing to x instead of out.
echo "=== Searching for causal_conv1d_update calls to check buffer usage ==="
rg -n -C5 'causal_conv1d_update' --type cpp -g '!*.cu'
echo ""
echo "=== Checking ConvParamsBase setup to see if x_ptr and out_ptr are ever the same ==="
rg -n -B3 -A3 'out_ptr\s*=' --type cpp | head -60🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu` around lines 513 -
536, The sl1 kernel is writing the result into the input buffer x (x[0]) instead
of the designated output buffer; fix by mirroring the other kernel’s output
setup: create an output pointer using params.out_ptr and the same batch/channel
strides (e.g. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
batch_id * params.out_batch_stride + channel_id * params.out_c_stride) and
replace x[0] = input_t(out_val) with out[0] = input_t(out_val); ensure you use
the same stride symbols from the diff (params.out_ptr,
params.out_batch_stride/out_c_stride) so the kernel writes to params.out_ptr
when x_ptr != out_ptr.
|
PR_Github #43996 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #44090 [ run ] triggered by Bot. Commit: |
Features
Decode:
Prefill:
Benchmark Results
Hardware: NVIDIA B300 SXM6 AC, bf16
Model config: conv_dim=10240, width=4, 40 Mamba layers
Methodology: CUDA event timing (median of 1000 iterations, 300 warmup) for per-seqlen prefill; nsys pure GPU kernel time for decode
Decode (conv1d_update, SL=1, nsys kernel time)
Prefill (conv1d_fwd, BS=1, CUDA event per-seqlen)
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.