[None][feat] Support kv cache in Trtllm-Gen attention backend#11667
[None][feat] Support kv cache in Trtllm-Gen attention backend#11667yihwang-nv merged 19 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughIntroduces a new CUDA-accelerated PyTorch extension for QKV processing and decoder info construction, adds request ID support for beam search through the attention backend, and refactors the attention backend with a workspace-centric architecture for context and generation phases. Changes
Sequence Diagram(s)sequenceDiagram
actor PyTorch as PyTorch Extension<br/>(trtllm.py)
participant QKV as QKV Processor<br/>(Op)
participant CUDA as CUDA Kernels<br/>(trtllm_gen.cu)
participant Cache as KV Cache<br/>Manager
PyTorch->>QKV: qkv_preprocessing(tensors, params)
QKV->>CUDA: Dispatch to kernel (float/half/BF16/FP8)
CUDA->>CUDA: Preprocess QKV data
CUDA->>Cache: Build/update KV buffers
Cache-->>CUDA: Return buffer pointers
CUDA->>CUDA: Apply rotary embeddings
CUDA-->>QKV: Kernel complete
QKV-->>PyTorch: Return (void)
PyTorch->>QKV: kv_cache_postprocessing(...)
QKV->>CUDA: Dispatch postprocessing
CUDA->>Cache: Finalize cache state
Cache-->>CUDA: Confirm sync
CUDA-->>QKV: Complete
QKV-->>PyTorch: Return (void)
sequenceDiagram
participant Attn as TrtllmAttention<br/>(Frontend)
participant WM as WorkspaceManager<br/>(Allocator)
participant FI as FlashInfer<br/>FMHA
participant QKV as QKV<br/>Preprocess
Attn->>WM: compute_workspace_size(phase, params)
WM-->>Attn: size (context/generation)
Attn->>WM: allocate(size)
WM-->>Attn: workspace buffer
alt Context Phase
Attn->>WM: split_context_workspace()
WM-->>Attn: ContextWorkspaceBuffers
Attn->>QKV: qkv_preprocessing(...)
QKV-->>Attn: (void)
Attn->>FI: fmha_forward(ctx_buffers)
FI-->>Attn: attention output
else Generation Phase
Attn->>WM: split_generation_workspace()
WM-->>Attn: GenerationWorkspaceBuffers
Attn->>QKV: qkv_preprocessing(..., request_ids)
QKV-->>Attn: (void)
Attn->>FI: fmha_decode(gen_buffers)
FI-->>Attn: attention output
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/attention_backend/trtllm_gen.py (1)
1558-1913:⚠️ Potential issue | 🟡 MinorAllocate workspace on demand if
None.
workspaceis declared asOptional[torch.Tensor]in the function signature, but the code callsresize_()unconditionally at line 1783 when size is insufficient. IfworkspaceisNone, this raisesAttributeErrorat runtime. Allocate atorch.uint8tensor on demand to match the byte-buffer semantics used throughout the code.Suggested fix
- current_workspace_size = ( - workspace.numel() * workspace.element_size() if workspace is not None else 0 - ) - - if current_workspace_size < required_workspace_size: - workspace.resize_(required_workspace_size) + if workspace is None: + workspace = torch.empty(required_workspace_size, device=q.device, dtype=torch.uint8) + elif workspace.numel() * workspace.element_size() < required_workspace_size: + workspace.resize_(required_workspace_size)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 1558 - 1913, The workspace can be None but the code calls workspace.resize_(), causing AttributeError; modify the allocation logic around WorkspaceManager.get_workspace_size/current_workspace_size to allocate a byte buffer when workspace is None: if workspace is None create a torch.empty(required_workspace_size, dtype=torch.uint8, device=q.device) and assign it to workspace, otherwise check current size and call workspace.resize_(required_workspace_size) only when needed; update usages that pass workspace (e.g., common_params passed to EnqueueContextParams/EnqueueGenerationParams and FlashInferTrtllmGenAttention) to use the newly allocated workspace tensor.
🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp (1)
208-237: Align local constants with k‑prefixed naming.
kv_factorandvector_sizeare constants; please rename them to follow the k‑prefixed uppercase snakecase rule.♻️ Suggested refactor
- int32_t const kv_factor = 2; + int32_t const kKV_FACTOR = 2; auto const block_size = tokens_per_block * kv_head_num * size_per_head; auto const bytes_per_block = block_size * kvElemBits / 8 /*bits*/; - auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block; + auto const intra_pool_offset = layer_idx_in_cache_pool * kKV_FACTOR * bytes_per_block; @@ - auto constexpr vector_size = 16; - auto const bytes_per_block_sf = block_size / vector_size * 1 /*bytes per E4M3 sf*/; - auto const intra_pool_offset_sf = layer_idx_in_cache_pool * kv_factor * bytes_per_block_sf; + auto constexpr kVECTOR_SIZE = 16; + auto const bytes_per_block_sf = block_size / kVECTOR_SIZE * 1 /*bytes per E4M3 sf*/; + auto const intra_pool_offset_sf = layer_idx_in_cache_pool * kKV_FACTOR * bytes_per_block_sf;As per coding guidelines: "Constants naming should be uppercase snakecase with prefix 'k' (e.g., kDIGIT_NUM = 10). Function-scope constants that are not magic numbers are named like non-constant variables".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp` around lines 208 - 237, Rename the function-scope constants kv_factor and vector_size to follow the project convention (prefix 'k' + uppercase snakecase) and update all uses: change kv_factor -> k_KV_FACTOR and vector_size -> K_VECTOR_SIZE (or follow repo-specific 'k' + UPPER_SNAKE style) in trtllmGenQKVProcessOp.cpp so intra_pool_offset, intra_pool_offset_sf and any calculations that reference kv_factor or vector_size use the new constant names (e.g., replace kv_factor in the calculation of block offsets and vector_size in bytes_per_block_sf), and ensure the new constants retain the same types/constexpr qualifiers as the originals.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp`:
- Around line 1-16: Update the copyright header year range in the file
trtllmGenQKVProcessOp.cpp from "1993-2025" to "1993-2026" (i.e., change the end
year to 2026) so the Apache-2.0 license block reflects the latest modification;
locate the top-of-file NVIDIA copyright header comment and adjust the year range
accordingly.
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1104-1112: This branch uses self.request_ids unguarded and can
crash or copy wrong blocks; before calling
self.kv_cache_manager.get_block_ids_per_seq(...) validate that self.request_ids
is not None and has at least self.num_seqs elements, then use a sliced active
batch (e.g. active_ids = self.request_ids[:self.num_seqs]) when calling
self.kv_cache_manager.get_block_ids_per_seq(active_ids).pin_memory(), compute
num_blocks from the sliced result, and copy only that many blocks into
self.block_ids_per_seq[:self.num_seqs, :num_blocks] to ensure lengths match and
avoid runtime errors; keep the checks around _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION
and self.kv_cache_manager as before.
---
Outside diff comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 1558-1913: The workspace can be None but the code calls
workspace.resize_(), causing AttributeError; modify the allocation logic around
WorkspaceManager.get_workspace_size/current_workspace_size to allocate a byte
buffer when workspace is None: if workspace is None create a
torch.empty(required_workspace_size, dtype=torch.uint8, device=q.device) and
assign it to workspace, otherwise check current size and call
workspace.resize_(required_workspace_size) only when needed; update usages that
pass workspace (e.g., common_params passed to
EnqueueContextParams/EnqueueGenerationParams and FlashInferTrtllmGenAttention)
to use the newly allocated workspace tensor.
---
Nitpick comments:
In `@cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp`:
- Around line 208-237: Rename the function-scope constants kv_factor and
vector_size to follow the project convention (prefix 'k' + uppercase snakecase)
and update all uses: change kv_factor -> k_KV_FACTOR and vector_size ->
K_VECTOR_SIZE (or follow repo-specific 'k' + UPPER_SNAKE style) in
trtllmGenQKVProcessOp.cpp so intra_pool_offset, intra_pool_offset_sf and any
calculations that reference kv_factor or vector_size use the new constant names
(e.g., replace kv_factor in the calculation of block offsets and vector_size in
bytes_per_block_sf), and ensure the new constants retain the same
types/constexpr qualifiers as the originals.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
cpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpptensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/trtllm_gen.py
dd3cce1 to
b939fbf
Compare
|
/bot run --disable-fail-fast |
b939fbf to
74ccd35
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36600 [ run ] triggered by Bot. Commit: |
|
PR_Github #36601 [ run ] triggered by Bot. Commit: |
|
PR_Github #36601 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #36638 [ run ] triggered by Bot. Commit: |
|
PR_Github #36638 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #36774 [ run ] triggered by Bot. Commit: |
|
PR_Github #36774 [ run ] completed with state
|
|
PR_Github #38858 [ run ] completed with state |
|
trtllm_gen_attention and thop.attention both pass. |
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
…ttention CI Signed-off-by: Yihan Wang <yihwang@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #39032 [ run ] triggered by Bot. Commit: |
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #39038 [ run ] triggered by Bot. Commit: |
|
PR_Github #39038 [ run ] completed with state |
…hop.attention path Signed-off-by: Yihan Wang <yihwang@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #39094 [ run ] triggered by Bot. Commit: |
|
PR_Github #39094 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #39143 [ run ] triggered by Bot. Commit: |
|
PR_Github #39143 [ run ] completed with state |
…#11667) Signed-off-by: Yihan Wang <yihwang@nvidia.com>
Summary by CodeRabbit
Release Notes
New Features
Refactor
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.