[None][feat] Integrate FP4 indexer for DSA on Blackwell#13340
[None][feat] Integrate FP4 indexer for DSA on Blackwell#13340longlee0622 merged 1 commit intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request adds FP4 quantization support to the indexer K-cache system. It introduces new CUDA kernels for fused concatenation and FP4 quantization, updates configuration parameters throughout the stack to enable FP4 mode, generalizes kernel implementations to support variable head dimensions (128 for FP8 or 64 for FP4 packed), and adds comprehensive test coverage for the FP4 indexer path. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsGit: Failed to clone repository. Please run the Comment |
There was a problem hiding this comment.
Actionable comments posted: 12
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp (1)
2-2:⚠️ Potential issue | 🟡 MinorUpdate the copyright year on this modified source file.
The file was changed in this PR but still shows 2025.
🛠️ Proposed fix
- * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: “Add NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp` at line 2, Update the copyright header at the top of cacheTransceiver.cpp to reflect the current modification year (replace "2025" with the correct year) so the NVIDIA copyright header is current for this modified file; ensure the header text and SPDX line remain otherwise unchanged and match the project's standard header format used in other files.tensorrt_llm/llmapi/llm_args.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAdd the required NVIDIA source header to this modified Python file.
This file was modified but is missing the mandatory NVIDIA copyright/license header.
🛠️ Proposed fix
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # SPDX-License-Identifier: Apache-2.0 + import astAs per coding guidelines: “All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` at line 1, Add the mandatory NVIDIA copyright/license source header to the top of the modified module llm_args.py (before any imports); ensure the header matches the project's standard NVIDIA header used in other TensorRT-LLM files and includes the correct year of latest meaningful modification and full license text, then save the file so the header precedes the existing "import ast" statement.tensorrt_llm/_torch/attention_backend/sparse/dsa.py (3)
1-1:⚠️ Potential issue | 🟡 MinorAdd the required NVIDIA copyright header to this modified source file.
This file is part of the changed surface, but it still starts directly with the module docstring.
As per coding guidelines, "All TensorRT-LLM source files must contain an NVIDIA copyright header with the year of latest meaningful modification".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` at line 1, Add the required NVIDIA copyright header at the top of the modified source file (before the module docstring) so the file begins with the standard multi-line copyright comment including "NVIDIA" and the year of latest meaningful modification; update the header year to the current modification year and ensure it precedes the existing module docstring in tensorrt_llm._torch.attention_backend.sparse.dsa (the top of the file).
76-78:⚠️ Potential issue | 🔴 CriticalUse the FP4 packed byte width in slot-mapping math.
block_strideandscale_base_offsetstill assumehead_dimdata bytes per token. In FP4 mode the cache view ishead_dim // 2 + 4bytes per token (get_indexer_k_cache_buffers()at Lines 2194-2198), so these offsets are too large and the scatter/gather path will address the wrong locations.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 76 - 78, The current math uses head_dim bytes per token when computing scale_size, block_stride, and scale_base_offset, but FP4 packs two values per byte so the cache layout uses head_dim//2 + 4 bytes per token; update the calculations to use a data_bytes_per_token variable that is head_dim for normal modes and head_dim // 2 for FP4 mode, then compute scale_size (based on quant_block_size) and set block_stride = tokens_per_block * (data_bytes_per_token + scale_size) and scale_base_offset = tokens_per_block * data_bytes_per_token; refer to get_indexer_k_cache_buffers() to detect/align with the FP4 layout and use the same packed-byte logic for FP4.
2207-2258:⚠️ Potential issue | 🟠 MajorUpdate the cache-size estimators for FP4.
Both sizing helpers still charge the indexer cache as
index_head_dim + scale_bytesper token. The runtime layout is onlyindex_head_dim // 2 + 4bytes in FP4 mode, so capacity planning overestimates KV usage and under-admits requests.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 2207 - 2258, Both cache-size helpers overcount indexer K-cache for FP4; update get_cache_size_per_token and get_cache_bytes_per_token to use the FP4 runtime layout (index_head_dim//2 + 4) instead of the current index_head_dim + (index_head_dim // quant_block_size * 4) when FP4 is active. In get_cache_size_per_token detect FP4 via model_config.quant_config / model_config.quant_config.quant_mode (same check used elsewhere) and compute head_dim_factor = (index_head_dim//2 + 4) / head_dim for FP4, otherwise keep the existing formula; in get_cache_bytes_per_token branch on self.dtype == DataType.NVFP4 and replace the head_dim_factor calculation with (self.index_head_dim//2 + 4) / self.head_dim, leaving the rest (kv_factor, size conversion, and calculate_scaling_factor_size_bytes) unchanged.cpp/tensorrt_llm/executor/serialization.cpp (1)
2-2:⚠️ Potential issue | 🟡 MinorUpdate the copyright year in this modified source file.
This file was meaningfully modified in this PR, but the header still lists 2025 only.
🛠️ Suggested header update
- * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "Add NVIDIA copyright header on ALL new files and update year on modified files."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/executor/serialization.cpp` at line 2, The file header in serialization.cpp still lists only "2025" but this file was modified; update the copyright header to include the current year (e.g., change "2025" to "2025-2026" or "2026") so the NVIDIA copyright header reflects the modification; locate the top-of-file comment block in serialization.cpp and adjust the year range accordingly.
🧹 Nitpick comments (2)
cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp (1)
150-154: Add parameter-name comments to this long positional constructor call.This call has multiple non-obvious positional values (especially bools); inline param comments will reduce maintenance risk.
♻️ Proposed refactor
- mCacheState - = std::make_unique<executor::kv_cache::CacheState>(cacheStateModelCfg, worldConfig, attentionLayerNumPerPP, - dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse(), cacheManager->isEnablePartialReuse(), - cacheManager->isEnableIndexerKCache(), cacheManager->getIndexerKCacheIndexHeadDim(), - cacheManager->getIndexerKCacheQuantBlockSize(), cacheManager->getIndexerKCacheUseFp4()); + mCacheState = std::make_unique<executor::kv_cache::CacheState>( + /*modelConfig=*/cacheStateModelCfg, + /*worldConfig=*/worldConfig, + /*attentionLayerNumPerPP=*/attentionLayerNumPerPP, + /*dataType=*/dataType, + /*attentionType=*/attentionType, + /*kvFactor=*/kvFactor, + /*enableBlockReuse=*/cacheManager->isEnableBlockReuse(), + /*enablePartialReuse=*/cacheManager->isEnablePartialReuse(), + /*hasIndexerKCache=*/cacheManager->isEnableIndexerKCache(), + /*indexerDimPerHead=*/cacheManager->getIndexerKCacheIndexHeadDim(), + /*indexerKCacheQuantBlockSize=*/cacheManager->getIndexerKCacheQuantBlockSize(), + /*indexerKCacheUseFp4=*/cacheManager->getIndexerKCacheUseFp4());As per coding guidelines: “In C++ function calls with non-obvious parameters, use inline C comments with the format
/*paramName=*/to document parameters.”🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp` around lines 150 - 154, The long positional call creating mCacheState using executor::kv_cache::CacheState should document non-obvious parameters with inline C comments; update the call to annotate each argument (especially the bool-returning cacheManager methods) using the /*paramName=*/ style so readers know what each value means — e.g., label cacheStateModelCfg, worldConfig, attentionLayerNumPerPP, dataType, attentionType, kvFactor, and the cacheManager calls as /*enableBlockReuse=*/ cacheManager->isEnableBlockReuse(), /*enablePartialReuse=*/ cacheManager->isEnablePartialReuse(), /*enableIndexerKCache=*/ cacheManager->isEnableIndexerKCache(), /*indexerKCacheIndexHeadDim=*/ cacheManager->getIndexerKCacheIndexHeadDim(), /*indexerKCacheQuantBlockSize=*/ cacheManager->getIndexerKCacheQuantBlockSize(), /*indexerKCacheUseFp4=*/ cacheManager->getIndexerKCacheUseFp4() so the CacheState(...) constructor arguments are explicit and maintainable.cpp/tensorrt_llm/kernels/indexerKCacheGather.cu (1)
143-147: Prefer named constants for head-dim and scale literals.Please replace inline literals (
128,64,4) with named constants in this check block.♻️ Suggested cleanup
constexpr int32_t VEC_SIZE = 4; + constexpr int32_t kFp8HeadDim = 128; + constexpr int32_t kFp4PackedHeadDim = 64; + constexpr int32_t kScaleBytes = 4; - TLLM_CHECK_WITH_INFO(head_dim == 128 || head_dim == 64, + TLLM_CHECK_WITH_INFO(head_dim == kFp8HeadDim || head_dim == kFp4PackedHeadDim, "head_dim must be 128 (FP8) or 64 (FP4 packed) for the indexer cache (got %d)", head_dim); TLLM_CHECK_WITH_INFO(head_dim % VEC_SIZE == 0, "head_dim (%d) must be a multiple of %d", head_dim, VEC_SIZE); - TLLM_CHECK_WITH_INFO(scale_size == 4, + TLLM_CHECK_WITH_INFO(scale_size == kScaleBytes, "scale_size must equal 4 bytes (packed UE8M0 x4 for FP4, 1 float32 for FP8, got %d)", scale_size);As per coding guidelines: "Except for
0,nullptr,true, andfalse, all other literal values in C++ should only be used for variable initialization; use named constants instead."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/indexerKCacheGather.cu` around lines 143 - 147, Replace the magic literals in the TLLM_CHECK_WITH_INFO calls with named constants: define constants like kHeadDimFp8 = 128, kHeadDimFp4Packed = 64 and kScaleSizeBytes = 4 (in this translation unit or the appropriate header) and use them in the checks and formatted messages instead of raw numbers; update the conditions (head_dim == kHeadDimFp8 || head_dim == kHeadDimFp4Packed, head_dim % VEC_SIZE == 0, scale_size == kScaleSizeBytes) and the error strings to reference the constant names so the checks in indexerKCacheGather.cu (the TLLM_CHECK_WITH_INFO calls referencing head_dim, VEC_SIZE, scale_size) no longer contain hard-coded literals.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h`:
- Around line 2108-2109: The new boolean flag indexerKCacheUseFp4 is declared on
the KVCacheManager constructor signature but not threaded through the remaining
overloaded constructors and call sites; update every KVCacheManager constructor
overload (the signatures that include SizeType32 indexerKCacheQuantBlockSize and
indexerKCacheIndexHeadDim) to accept and forward the indexerKCacheUseFp4
parameter, and update the production instantiation in
trtGptModelInflightBatching.cpp (the KVCacheManager construction around the
block previously at lines ~686-694) to pass the appropriate indexerKCacheUseFp4
argument (and the related indexerKCacheQuantBlockSize/indexerKCacheIndexHeadDim
where present) so the FP4 path is not silently disabled.
In `@cpp/include/tensorrt_llm/executor/dataTransceiverState.h`:
- Around line 56-57: Update the copyright header block in the top-of-file
comment for tensorrt_llm/executor/dataTransceiverState.h to include 2026
(replace the current 2024 year with 2026); ensure the header matches the NVIDIA
copyright format used across the project so the file-level comment reflects the
latest meaningful modification.
- Around line 114-119: operator== in dataTransceiverState (kv_cache::CacheState)
currently omits several layout-defining fields so states with different cache
layouts compare equal; update the operator== implementation (the method named
operator== in dataTransceiverState.h) to also compare mEnableBlockReuse,
mEnablePartialReuse, mHasIndexerKCache, mIndexerDimPerHead, and
mIndexerKCacheQuantBlockSize (i.e., include all layout-defining members used to
determine cache layout) so any differing cache-layout-related fields cause
inequality.
In `@cpp/tensorrt_llm/kernels/fusedCatFp4.cu`:
- Around line 188-202: Summary: The kernel uses *reinterpret_cast<int2
const*>(src + col) which requires 8-byte alignment of the base pointers; current
checks only validate strides not base-address alignment. Fix: In the
launcher/wrapper that prepares inputs for fusedCatFp4 (where pe.data_ptr() and
nope.data_ptr() are passed), add runtime checks that
reinterpret_cast<uintptr_t>(pe.data_ptr()) % 8 == 0 and
reinterpret_cast<uintptr_t>(nope.data_ptr()) % 8 == 0 before launching
fusedCatFp4; if either base pointer is not 8-byte aligned, either return an
error or take the scalar-safe fallback path (e.g., a separate kernel or code
path that reads elements individually instead of using int2 loads). Ensure the
new checks live alongside the existing dimension/stride checks and reference the
same symbols (pe.data_ptr(), nope.data_ptr(), fusedCatFp4 kernel launch).
In `@cpp/tensorrt_llm/thop/fusedCatFp4Op.cpp`:
- Around line 27-71: In fused_cat_fp4, add a device check and a device guard:
TORCH_CHECK that nope.device() == pe.device() (and that pe.is_cuda()) to reject
mixed-device inputs, then create an at::cuda::CUDAGuard (or at::DeviceGuard)
scoped to pe.device() before calling at::cuda::getCurrentCUDAStream(...) and
invoking tensorrt_llm::kernels::invokeFusedCatFp4 so the kernel launch and raw
data_ptr() access occur on the correct CUDA device/stream.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 1957-1970: The FP4 branch reshapes q_fp8 but never reshapes
q_scale, so scales returned by torch.ops.trtllm.fused_cat_fp4 (shape [num_tokens
* n_heads, 1]) are left in token*head-major layout and later indexing/views
break; fix by reshaping q_scale in the same branch (e.g., after q_fp8 =
q_fp8.view(-1, self.n_heads, self.head_dim // 2) add q_scale = q_scale.view(-1,
self.n_heads, 1)) so callers (and any later _weight_scale usage) see q_scale in
[num_tokens, n_heads, 1] token-major format while preserving the existing
weights logic (weights *= self.weight_scale_factor).
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 1752-1757: The fake custom-op implementation _mla_dsa_proj_fake
currently models the old 8-output contract while
forward_dsa_proj/forward_dsa_attn now return nine outputs (with q_scale as the
9th and the FP4 path expecting 5 tensors passed into mla_dsa_attn_inplace via
indexer_intermediates); update _mla_dsa_proj_fake to return the new 9-tensor
tuple (and ensure the list/tuple used for indexer_intermediates reflects five
runtime tensors for the FP4/FP8 path), and update the surrounding
docstrings/comments to state the new 9/5 contract so torch.compile shape
propagation matches runtime (adjust any unpacking in callers of
_mla_dsa_proj_fake, e.g., where indexer_intermediates is built/consumed, to
match the new ordering).
In `@tests/unittest/_torch/attention/sparse/test_cpp_custom_ops.py`:
- Around line 168-170: Parametrize the gather tests to run with head_dim values
128 and 64 so the FP4-packed branch in indexer_k_cache_gather_op is exercised:
replace the hardcoded HEAD_DIM usage in the test(s) that call
torch.ops.trtllm.indexer_k_cache_gather_op with a parameter (e.g., head_dim) and
compute per_token_size and slot_scale from that head_dim (derive per_token_size
= head_dim / some_unit used in the op and set slot_scale accordingly) before
constructing k_cache/slot_fp8/slot_scale; update the three affected call sites
around the test lines (168-170, 235-237, 261-263) to use the parameterized
head_dim so both 128 and 64 cases run.
In `@tests/unittest/_torch/attention/sparse/test_dsa_fp4_indexer.py`:
- Around line 198-224: The test
test_fp4_indexer_k_cache_per_token_size_drops_to_68_bytes currently only
recomputes constants and never exercises the implementation; replace the
synthetic assertions with an integration check that constructs the real FP4
cache manager and inspects the allocated pools: instantiate or obtain
WindowBlockManager and/or DSACacheManager, call createIndexerKCachePools (or
get_indexer_k_cache_buffers) with index_head_dim=128 and quant_block_size=128,
retrieve the actual buffer/stride/bytes-per-token from the returned pool objects
or buffers, and assert those runtime values equal the expected 132 (FP8) and 68
(FP4) and the shrink ratio; if constructing the real managers is heavy, use a
minimal fixture or monkeypatch to run the real allocation code instead of
recomputing literals so the test will fail on regressions.
- Around line 230-308: The test
test_fp4_paged_mqa_logits_jit_first_compile_latency currently only logs JIT
timings and must be turned into a perf-sanity check: add an assertion that fails
when jit_overhead (first_ms - warm_ms) exceeds a small, documented threshold
(parameterize by next_n) so regressions block CI, and make the threshold
configurable via environment or pytest marker; update the test to record the
measured values in the authoritative B200 perf DB entry and add the matching QA
perf list entry so this case is tracked by scheduled/perf jobs (ensure you
update the test metadata that references the B200 path and QA list accordingly).
- Around line 29-34: The current broad except hides real import/runtime errors;
change the import block for tensorrt_llm.deep_gemm so only import failures are
caught and attribute checks are done separately: wrap "from tensorrt_llm import
deep_gemm" in a try/except ImportError (or ModuleNotFoundError) and set
HAS_DEEP_GEMM = False on import failure; if import succeeds (else branch) set
HAS_DEEP_GEMM = hasattr(deep_gemm, "fp8_fp4_mqa_logits") so unexpected
exceptions during import or attribute access are allowed to surface instead of
being swallowed by a bare except.
In `@tests/unittest/_torch/attention/sparse/test_dsa_indexer.py`:
- Around line 460-470: The mock currently fills only kv_lens_cuda_2d and leaves
scheduler_metadata_buffer zeroed; update the mock to mirror production by
copying the same 2D slice into scheduler_metadata_buffer using the same shape
and dtype as done by get_paged_mqa_logits_metadata (i.e., take
gen_kv_lens.unsqueeze(-1).expand(-1, next_n_cap) or the equivalent 2D view and
write it into scheduler_metadata_buffer[:num_generations, :next_n_cap]),
ensuring the expanded path does the same so
test_indexer_decode_with_paged_kv_cache sees the populated scheduler metadata.
---
Outside diff comments:
In `@cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp`:
- Line 2: Update the copyright header at the top of cacheTransceiver.cpp to
reflect the current modification year (replace "2025" with the correct year) so
the NVIDIA copyright header is current for this modified file; ensure the header
text and SPDX line remain otherwise unchanged and match the project's standard
header format used in other files.
In `@cpp/tensorrt_llm/executor/serialization.cpp`:
- Line 2: The file header in serialization.cpp still lists only "2025" but this
file was modified; update the copyright header to include the current year
(e.g., change "2025" to "2025-2026" or "2026") so the NVIDIA copyright header
reflects the modification; locate the top-of-file comment block in
serialization.cpp and adjust the year range accordingly.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 1: Add the required NVIDIA copyright header at the top of the modified
source file (before the module docstring) so the file begins with the standard
multi-line copyright comment including "NVIDIA" and the year of latest
meaningful modification; update the header year to the current modification year
and ensure it precedes the existing module docstring in
tensorrt_llm._torch.attention_backend.sparse.dsa (the top of the file).
- Around line 76-78: The current math uses head_dim bytes per token when
computing scale_size, block_stride, and scale_base_offset, but FP4 packs two
values per byte so the cache layout uses head_dim//2 + 4 bytes per token; update
the calculations to use a data_bytes_per_token variable that is head_dim for
normal modes and head_dim // 2 for FP4 mode, then compute scale_size (based on
quant_block_size) and set block_stride = tokens_per_block *
(data_bytes_per_token + scale_size) and scale_base_offset = tokens_per_block *
data_bytes_per_token; refer to get_indexer_k_cache_buffers() to detect/align
with the FP4 layout and use the same packed-byte logic for FP4.
- Around line 2207-2258: Both cache-size helpers overcount indexer K-cache for
FP4; update get_cache_size_per_token and get_cache_bytes_per_token to use the
FP4 runtime layout (index_head_dim//2 + 4) instead of the current index_head_dim
+ (index_head_dim // quant_block_size * 4) when FP4 is active. In
get_cache_size_per_token detect FP4 via model_config.quant_config /
model_config.quant_config.quant_mode (same check used elsewhere) and compute
head_dim_factor = (index_head_dim//2 + 4) / head_dim for FP4, otherwise keep the
existing formula; in get_cache_bytes_per_token branch on self.dtype ==
DataType.NVFP4 and replace the head_dim_factor calculation with
(self.index_head_dim//2 + 4) / self.head_dim, leaving the rest (kv_factor, size
conversion, and calculate_scaling_factor_size_bytes) unchanged.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Line 1: Add the mandatory NVIDIA copyright/license source header to the top of
the modified module llm_args.py (before any imports); ensure the header matches
the project's standard NVIDIA header used in other TensorRT-LLM files and
includes the correct year of latest meaningful modification and full license
text, then save the file so the header precedes the existing "import ast"
statement.
---
Nitpick comments:
In `@cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp`:
- Around line 150-154: The long positional call creating mCacheState using
executor::kv_cache::CacheState should document non-obvious parameters with
inline C comments; update the call to annotate each argument (especially the
bool-returning cacheManager methods) using the /*paramName=*/ style so readers
know what each value means — e.g., label cacheStateModelCfg, worldConfig,
attentionLayerNumPerPP, dataType, attentionType, kvFactor, and the cacheManager
calls as /*enableBlockReuse=*/ cacheManager->isEnableBlockReuse(),
/*enablePartialReuse=*/ cacheManager->isEnablePartialReuse(),
/*enableIndexerKCache=*/ cacheManager->isEnableIndexerKCache(),
/*indexerKCacheIndexHeadDim=*/ cacheManager->getIndexerKCacheIndexHeadDim(),
/*indexerKCacheQuantBlockSize=*/ cacheManager->getIndexerKCacheQuantBlockSize(),
/*indexerKCacheUseFp4=*/ cacheManager->getIndexerKCacheUseFp4() so the
CacheState(...) constructor arguments are explicit and maintainable.
In `@cpp/tensorrt_llm/kernels/indexerKCacheGather.cu`:
- Around line 143-147: Replace the magic literals in the TLLM_CHECK_WITH_INFO
calls with named constants: define constants like kHeadDimFp8 = 128,
kHeadDimFp4Packed = 64 and kScaleSizeBytes = 4 (in this translation unit or the
appropriate header) and use them in the checks and formatted messages instead of
raw numbers; update the conditions (head_dim == kHeadDimFp8 || head_dim ==
kHeadDimFp4Packed, head_dim % VEC_SIZE == 0, scale_size == kScaleSizeBytes) and
the error strings to reference the constant names so the checks in
indexerKCacheGather.cu (the TLLM_CHECK_WITH_INFO calls referencing head_dim,
VEC_SIZE, scale_size) no longer contain hard-coded literals.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: a1585573-a2ff-41c4-9421-166d5d822372
📒 Files selected for processing (24)
3rdparty/fetch_content.jsoncpp/include/tensorrt_llm/batch_manager/kvCacheManager.hcpp/include/tensorrt_llm/executor/dataTransceiverState.hcpp/tensorrt_llm/batch_manager/cacheTransceiver.cppcpp/tensorrt_llm/batch_manager/kvCacheManager.cppcpp/tensorrt_llm/deep_gemm/CMakeLists.txtcpp/tensorrt_llm/executor/serialization.cppcpp/tensorrt_llm/kernels/fusedCatFp4.cucpp/tensorrt_llm/kernels/fusedCatFp4.hcpp/tensorrt_llm/kernels/indexerKCacheGather.cucpp/tensorrt_llm/kernels/indexerKCacheScatter.cucpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cppcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/IndexerKCacheGatherOp.cppcpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cppcpp/tensorrt_llm/thop/fusedCatFp4Op.cpptensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/attention/sparse/test_cpp_custom_ops.pytests/unittest/_torch/attention/sparse/test_dsa_fp4_indexer.pytests/unittest/_torch/attention/sparse/test_dsa_indexer.py
|
/bot run --disable-fail-fast |
|
PR_Github #44980 [ run ] triggered by Bot. Commit: |
|
The 2 yaml files need to be updated in https://github.com/NVIDIA/TensorRT-LLM/tree/main/scripts/attribution/data. |
Re-ran scripts/attribute.py against the Ninja build produced from this
branch. The scanner picked up the dependency versions this PR introduces
(or that landed on main since the last attribution refresh) and added:
- deepgemm/c491439ed5966833d56883ca302b6f72e74f8105 (the upgrade this
PR pulls in via 3rdparty/fetch_content.json)
- cutlass/v4.4.2
- cuda/13.2
- nccl/2.29.2-1+cuda13.1
plus the matching file-hash entries in files_to_dependency.yml and three
new content-addressable license/copyright blobs under data/cas/.
No manual edits to the YAML/CAS files — everything is a straight
regeneration from `python scripts/attribute.py --build-dir cpp/build`.
Addresses PR feedback from Barry-Delaney:
NVIDIA#13340 (comment)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #45181 [ run ] triggered by Bot. Commit: |
|
PR_Github #45181 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45296 [ run ] triggered by Bot. Commit: |
|
PR_Github #45296 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45494 [ run ] triggered by Bot. Commit: |
edf9dd7 to
102e6be
Compare
|
/bot run |
|
PR_Github #47010 [ run ] triggered by Bot. Commit: |
|
PR_Github #47010 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47024 [ run ] triggered by Bot. Commit: |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
PR_Github #47024 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47061 [ run ] triggered by Bot. Commit: |
|
PR_Github #47061 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47104 [ run ] triggered by Bot. Commit: |
|
PR_Github #47104 [ run ] completed with state
|
…elated fails Add waives for tests blocking PR NVIDIA#13340 CI that bisect to causes unrelated to this PR: 1. piecewise CUDA graph capture rework (https://nvbugs/6153575) Two symptoms with the same bisection point: - perf/test_perf_sanity.py::test_e2e[aggr_upload-super_ad_blackwell-super_ad_ws1_1k1k] Reproducible -14% throughput / +103% mean_ttft on Nemotron-3-Super-120B-A12B-NVFP4 served via openai_server + AutoDeploy. TPOT only +3%; the regression lives in the warmup / first-token window, matching CUDA graph capture overhead. - unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1] Triton "out of memory" during chunk_fwd_kernel_o autotune on Qwen3-Next-80B-A3B-Instruct (--moe-backend=TRTLLM, layers 6,7). Increased CUDA graph footprint from the rework starves the Triton autotune scratch, which OOMs on the larger BLOCK_K/V configs. Both bisect to main commit 9c1869b "[https://nvbugs/5615248][fix] Broader capture of piecewise cudagraph (NVIDIA#13574)" which reworked piecewise CUDA graph capture filtering in tensorrt_llm/_torch/pyexecutor/model_engine.py to force-include max_batch_size*(max_seq_len-1-N) ceiling. Verified by comparing PR builds rebased before vs after that commit: L0/36967 (base 3e4a775) : both PASS L0/37033 (base 7943536) : both FAIL L0/37073 (base 7943536, retry) : both FAIL PR NVIDIA#13340 itself does not touch model_engine.py, AutoDeploy, fla, mamba, or fused_moe in the relevant paths. 2. openai-server smoke tests on A10 (https://nvbugs/6153638) - test_e2e.py::test_openai_lora (A10-PyTorch-1) - test_e2e.py::test_openai_tool_call (A10-PyTorch-2) - test_e2e.py::test_trtllm_serve_lora_example (A10-PyTorch-2) All three crash uniformly with "Server exited unexpectedly" / "Connection refused on health endpoint", with no underlying server-side traceback captured. Inconsistent across CI runs of the same PR head 102e6be (build 47061 / L0 37033 did not hit any of these), so the symptom is environmental/host-level flake on A10 stages, not a code regression. PR NVIDIA#13340 only touches openai_server.py to add getattr fallbacks for hf_tokenizer_path and vocab_size; both are no-op on the happy path. Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #47108 [ run ] triggered by Bot. Commit: |
|
/bot kill |
33676f9 to
102e6be
Compare
|
PR_Github #47130 [ kill ] triggered by Bot. Commit: |
|
PR_Github #47130 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #47153 [ run ] triggered by Bot. Commit: |
|
PR_Github #47153 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47157 [ run ] triggered by Bot. Commit: |
|
PR_Github #47157 [ run ] completed with state |
…xer_gemm Resolve conflicts in DSA paged-MQA-logits dispatch and tests after DeepGEMM submodule bump (4ff3f54d -> c491439e via PR NVIDIA#13340 / DG NVIDIA#304): - dsa.py: take upstream's scheduler_metadata_buffer / _full_next_n selection (mtp3 buffer removed); add DSL early-branch using the existing scheduler_metadata_buffer (built with (num_gen, 1) shape, num_atoms=1, matching DSL's 1-atom-per-q design) and the 1D kv_lens_cuda_runtime slice for context_lens. - dsa.py: introduce module-level _DG_SCHEDULE_BLOCK_KV = 64, used by all 6 get_paged_mqa_logits_metadata calls (3 in on_update_kv_lens(), 3 in Indexer.prepare()) instead of cache tokens_per_block. Decouples schedule SPLIT_KV from cache page size and side-steps a SM100 + block_kv=32 latent regression in DG commit 7f2a703 (NVIDIA#304). - test_dsa_indexer.py: take upstream's scheduler buffer selection; DSL test branch reads scheduler_metadata_buffer + 1D kv_lens. - test_cute_dsl_fp8_paged_mqa_logits.py: 4 metadata calls now pass 2D context_lens via .unsqueeze(-1) and DG_METADATA_BLOCK_KV=64; DG bench drops cluster(2,1,1) for next_n=4 (SM100 always uses num_kv_multicast=1) and passes 2D context_lens to fp8_paged_mqa_logits. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
@coderabbitai summary
Description
Add the FP4 indexer path for DSA on Blackwell (SM100+) and land it with a native CUDA fused-quantize op. Three commits on top of
main:[None][feat] Integrate FP4 indexer for DSA on Blackwell(e9542afeb)Barry-Delaney/DeepGEMM@a97b74d7(user/jinshik/nv_dev_rebase), which adds the FP4 MQA logits kernels and switches the paged MQA logits APIs to require 2Dcontext_lens. Allget_paged_mqa_logits_metadata/fp8_paged_mqa_logitscall sites indsa.pyare migrated to 2D via a pre-allocatedkv_lens_cuda_2dbuffer (no capture-time allocations).DeepSeekSparseAttentionConfig.indexer_k_dtype: Literal["fp8","fp4"].indexerKCacheUseFp4flag through the C++ KVCacheManager chain (WindowBlockManager,BlockManager, all 4KVCacheManageroverloads, the nanobind binding) and through the disaggCacheState(serialization +operator==, so prefill/decode refuse to pair when the flag disagrees).createIndexerKCachePoolspicksdata_size = index_head_dim/2under FP4, shrinking per-token size from 132B to 68B.IndexerKCacheScatterOp/IndexerKCacheGatherOp(+ their kernels) to accepthead_dim∈{128 (FP8), 64 (FP4-packed)}, pickthreads_per_block = head_dim / VEC_SIZEat runtime, and forwardhead_dim=64under FP4 from the callers.Indexer._prep_q_or_kdispatches to the FP4 quantizer whenuse_fp4;pre_indexer_projreturns a 5thq_scaleoutput; weights under FP4 carry onlysoftmax_scale * n_heads^-0.5because the kernel appliesq_scaleinternally. New_call_mqa_logits/_call_paged_mqa_logitshelpers reinterpret the FP8 gather output as int8 / int32 under FP4 and route tofp8_fp4_mqa_logits/fp8_fp4_paged_mqa_logits.DSACacheManagercomputes per-token size based onuse_fp4.q_scalethrough the two-op CUDA-graph-split DSA path inattention.pyand throughKVCacheManager.__init__inresource_manager.py.fp4_quantize_1x32_sf_transpose— a Triton kernel that fuses amax, UE8M0 ceil, FP4 E2M1 quantize, nibble packing, and four-per-int32 scale packing, bit-identical to DeepGEMM'stesting.per_token_cast_to_fp4(..., gran_k=32, use_packed_ue8m0=True)reference.[None][perf] Replace Triton FP4 indexer quantizer with fused_cat_fp4 CUDA op(c90c56e85)fused_cat_fp8op with an FP4 E2M1 variant:torch.ops.trtllm.fused_cat_fp4(pe, nope) -> (packed_int8, scale_int32). Fuses concat + per-block-32 quantize + UE8M0 scale packing into one CUDA kernel, removing the remaining Triton DSL dependency on the DSA CUDA-graph hot path and giving FP4/FP8 indexer branches a uniform native-CUDA call-site shape.(ceil(M / 8),), block256. Each warp handles one 128-element row; threadtcovers elements[4t, 4t+3]. Per-block-32 amax via 3-round__shfl_xor_sync(offsets 1, 2, 4 — stays inside each group of 8 lanes). UE8M0 scale via IEEE 754 bit trick;MIN_AMAX = 1e-12fkeepsrationormal so noexpclamp needed. FP4 E2M1 quantize via 7-way bucketize on{0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0}. IEEEdiv.rn.f32for the scaled value. Nibble pack (even → low, odd → high); lane 0 gathers the four UE8M0 exponent bytes from{0, 8, 16, 24}via__shfl_syncand writes one int32 per row (little-endian).Indexer._prep_q_or_kswitched to the new op;fp4_quantize_1x32_sf_transposeimport removed.tensorrt_llm/quantization/utils/fp4_utils.pyreverted exactly to its pre-PR state (git diff HEAD^on this file is empty).tests/unittest/_torch/quantization/{__init__.py,test_fp4_quantize.py}removed; coverage moves totest_cpp_custom_ops.py. Theunittest/_torch/quantizationentry is dropped froml0_b200.yml.[None][fix] Address review feedback on FP4 indexer PR(397eafe06) +[None][fix] Address round-2 review feedback on FP4 indexer PR(eac155a4c)pre_indexer_projnow runsq_scale = q_scale.view(-1, self.n_heads, 1)—fused_cat_fp4flattens toM = N * n_heads, so downstream token-axis slicing insparse_attn_indexer(chunk / ctx / decode) needs the same reshape the FP8 branch already had. Without this, the FP4 path was silently slicing wrong per-token scales.DSACacheManager.get_cache_size_per_token/get_cache_bytes_per_tokennow pickindex_head_dim // 2under FP4, so the KV cache block budget sees the 68 B/token footprint (was hardcoded to 132 B). Without this, the block budget over-counted and the FP4 memory win never reached the allocator.@model_validator(mode="after")onDeepSeekSparseAttentionConfigrejectsfp4on SM<100 orindex_head_dim != 128, degrades gracefully when CUDA is unavailable, and now includes recovery hints in the error message.ModelConfig.from_pretrained(DSV3.2 / GlmMoeDsa rebuild branch) now forwardsindexer_k_dtype— previously it was hand-copying every other sparse-attention field, so the user knob was silently reset to"fp8"before any downstream consumer saw it. Added a regression-guard test that exercises the rebuild and statically asserts the keyword is present in the source.scheduler_metadata_buffer_mtp3path. DeepGEMM's upgraded paged MQA logits kernel picksnum_kv_multicast=1for everynext_non SM100 (verified by the_schedule_meta_sizeassertion indeepgemm-src/csrc/apis/attention.hpp:339firing when the legacy layout is passed). Removed the buffer allocation, both population sites, and the production dispatch branch.CacheState::operator==now also comparesmHasIndexerKCache,mIndexerDimPerHead,mIndexerKCacheQuantBlockSizeso prefill/decode refuse to pair on any incompatible indexer layout (not just FP8/FP4);quant_block_sizeassertion tightened to== 128; DeepGEMM-source reference comment on the FP4 magic constants; 8-byte base-pointerTORCH_CHECK(with pointer in the message) onfused_cat_fp4Op; refreshed_call_mqa_logitscomment post q_scale reshape fix; copyright year bumps ondataTransceiverState.h,cacheTransceiver.cpp,serialization.cpp.Test Coverage
Bit-exactness of the new CUDA op:
tests/unittest/_torch/attention/sparse/test_cpp_custom_ops.py::test_fused_cat_fp4_matches_deepgemm— parametrized over shapes(4, 128),(1, 32, 128),(3, 7, 128),(2, 5, 4, 128)and seeds{0, 42, 2026}. Usestorch.equal(notallclose) on both packed bytes and scale int32 vstensorrt_llm.deep_gemm.utils.math.per_token_cast_to_fp4(..., use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)— all 12 parametrizations pass.test_fused_cat_fp4_shape_dispatch[64-64 / 32-96 / 16-112 / 96-32]— asymmetricpe/nopesplits, shape and dtype sanity.test_fused_cat_fp4_noncontiguous_split— op accepts non-contiguous views fromtorch.split()and matches the contiguous baseline.Config plumbing:
test_dsa_fp4_indexer.py::test_indexer_k_dtype_survives_model_config_rebuild— new. Exercises theModelConfig.from_pretrainedDSV3.2 rebuild branch with a stub pretrained config, assertsindexer_k_dtype="fp4"round-trips, and statically grepsModelConfig.from_pretrained's source for the forwarding keyword so a future edit that drops it fails fast.FP4 indexer end-to-end on B200:
test_fp4_mqa_logits_shape_and_topk_intersection[32|64]— FP4 vs FP8 top-k overlap.test_fp4_quantize_roundtrip_matches_bf16_kv— packing / scale recovery round-trip.test_fp4_indexer_k_cache_per_token_size_drops_to_68_bytes— per-token KV-cache shrinkage from 132B to 68B.FP8 indexer regression:
test_dsa_indexer.py— FP8 path unregressed; all parametrizations includingnext_n ∈ {1,2,3,4}now pass with the base scheduler metadata buffer.test_cpp_custom_ops.pyexisting gather / scatter / convert_req_index tests still pass.Full attention sweep:
tests/unittest/_torch/attentionon B200 — no regressions from any of the four commits.FP8 DSA smoke on a real model:
tests/integration/defs/accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline]launched end-to-end on 8 × B200 withDeepSeek-V3.2-Exp-hf; model loaded and forward passes proceeded without error, confirming the FP8fused_cat_fp8path is unaffected (stopped after ramp-up to save time).PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.