[TRTLLM-11058][feat] Support Helix CP with GQA#11570
Conversation
📝 WalkthroughWalkthroughThis PR adds Helix parallelism support to TensorRT-LLM's attention computation stack by introducing Changes
Sequence Diagram(s)sequenceDiagram
participant App as Python App
participant AttentionMod as Attention Module<br/>(Python)
participant TorchBackend as Torch Backend
participant Enqueue as Enqueue Interface<br/>(C++)
participant Dispatcher as XQA Dispatcher<br/>(C++)
participant KernelParams as Kernel Parameter<br/>Structs
participant AttentionKernel as Attention Kernels<br/>(CUDA)
participant KVCache as KV Cache
App->>AttentionMod: forward(input, attn_metadata)
AttentionMod->>AttentionMod: detect Helix CP mode<br/>(mapping.has_cp_helix())
AttentionMod->>AttentionMod: compute helix_position_offsets<br/>helix_is_inactive_rank
AttentionMod->>TorchBackend: call attention backend<br/>with helix params
TorchBackend->>Enqueue: create EnqueueGenerationParams<br/>with helix_position_offsets,<br/>helix_is_inactive_rank
Enqueue->>Dispatcher: dispatch with helix params
Dispatcher->>KernelParams: populate XQAParams<br/>helix_position_offsets,<br/>helix_is_inactive_rank,<br/>softmax_stats
Dispatcher->>AttentionKernel: launch kernel<br/>with configured params
AttentionKernel->>AttentionKernel: select rope_position from<br/>helix_position_offsets or tlength
AttentionKernel->>AttentionKernel: check helix_is_inactive_rank<br/>for KV store gating
AttentionKernel->>KVCache: conditionally update KV<br/>based on inactive flag
AttentionKernel->>AttentionKernel: compute attention with<br/>adjusted rope positions
AttentionKernel-->>TorchBackend: partial outputs + softmax_stats
TorchBackend->>AttentionMod: return outputs
AttentionMod->>AttentionMod: if Helix CP:<br/>_helix_post_process<br/>(partial_o, softmax_stats)
AttentionMod->>AttentionMod: alltoall redistribution<br/>across CP ranks
AttentionMod-->>App: final attention output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (10)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h (1)
1-15:⚠️ Potential issue | 🟡 MinorUpdate the copyright year.
This file was modified for Helix support, but the header still lists 2020–2023. Please update it to include 2026. As per coding guidelines: “All source files must contain an NVIDIA copyright header with the year of latest meaningful modification. Include NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h` around lines 1 - 15, Update the NVIDIA copyright header at the top of the file by changing the year range "2020-2023" to include the latest modification year (e.g., "2020-2026") so the header reflects the most recent meaningful change; locate the header comment block at the file top and replace the year range accordingly.cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h (1)
1-15:⚠️ Potential issue | 🟡 MinorUpdate the copyright year.
Please update the header to include 2026 to reflect this modification. As per coding guidelines: “All source files must contain an NVIDIA copyright header with the year of latest meaningful modification. Include NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h` around lines 1 - 15, Update the top-of-file NVIDIA copyright header so the year range includes 2026 (e.g., change "2020-2025" to "2020-2026") in the header block at the top of xqaParams.h; edit the existing comment block rather than adding a new header and ensure the License text remains unchanged.cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (1)
1-16:⚠️ Potential issue | 🟡 MinorUpdate the copyright year.
The header still lists 2019–2024, but the file was modified in 2026. Please update the year range accordingly. As per coding guidelines: “All source files must contain an NVIDIA copyright header with the year of latest meaningful modification. Include NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h` around lines 1 - 16, Update the copyright header's year range in the file by replacing the existing "Copyright (c) 2019-2024, NVIDIA CORPORATION." entry in the top-of-file comment block so it reflects the latest modification year (e.g., change "2019-2024" to "2019-2026"); ensure the rest of the header text (including the NAVER/CLOVA line and Apache License block) remains unchanged.tensorrt_llm/_torch/attention_backend/trtllm.py (2)
1-5:⚠️ Potential issue | 🟡 MinorAdd the NVIDIA copyright header (2026).
This modified Python file currently has no NVIDIA Apache 2.0 header. Please add the standard header with the latest modification year. As per coding guidelines: “All source files must contain an NVIDIA copyright header with the year of latest meaningful modification. Include NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 1 - 5, Add the standard NVIDIA Apache-2.0 copyright header (with year 2026) to the top of the file tensorrt_llm/_torch/attention_backend/trtllm.py before any imports; ensure the header matches the project's canonical NVIDIA header text and license block and update the file's modification year to 2026, leaving the rest of the module (imports like math, os, weakref and dataclass/type hints) unchanged.
513-629:⚠️ Potential issue | 🟠 MajorAvoid hard‑failing on non‑KV‑cache unsupported reasons.
The new assert crashes whenever TRTLLM‑GEN is enabled but unsupported for reasons other than "KV cache update"—including missing flashinfer, unsupported head configs (MLA, cross-attention, spec-decoding), ALiBi, padded input, or custom mask types. This is a behavior regression; previous behavior fell back to
thop.attentionfor all unsupported reasons. Prefer a warning + fallback to preserve backward compatibility and prevent crashes for users with the env var enabled.🐛 Proposed fix (warn + fallback)
- # KV cache update is expected to fall back to thop since - # trtllm-gen only reads from KV cache. Assert on other reasons. - assert not _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION or ( - "KV cache update" in _trtllm_gen_reason - ), ( - f"TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION is set but trtllm-gen " - f"is not supported: {_trtllm_gen_reason}" - ) + # KV cache update is expected to fall back to thop since + # trtllm-gen only reads from KV cache. Warn (don't hard-fail) for other reasons. + if _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION and _trtllm_gen_reason and ( + "KV cache update" not in _trtllm_gen_reason): + logger.warning( + "TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION is set but trtllm-gen " + f"is not supported: {_trtllm_gen_reason}. Falling back to thop." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 513 - 629, The assert that checks _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION and _trtllm_gen_reason should be replaced with a non‑fatal warning and fall back to the existing thop attention path when trtllm_gen.is_supported returns False for reasons other than KV cache update; locate the assert block after the trtllm_gen.is_supported call and the trtllm_gen_attention invocation and change it to log a warning including _trtllm_gen_reason (use your logger or warnings.warn) and let execution continue to the else path (thop.attention fallback) instead of raising, keeping the special case for KV cache update handling if needed.cpp/tensorrt_llm/kernels/xqaDispatcher.cpp (1)
1-15:⚠️ Potential issue | 🟡 MinorUpdate the copyright year.
The header still lists 2020–2024; please include 2026 for this modification. As per coding guidelines: “All source files must contain an NVIDIA copyright header with the year of latest meaningful modification. Include NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp` around lines 1 - 15, Update the copyright header block at the top of cpp/tensorrt_llm/kernels/xqaDispatcher.cpp by changing the year range in the comment that currently reads "2020-2024" to include 2026 (e.g., "2020-2026") so the NVIDIA copyright header reflects the latest modification.cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h (1)
1-15:⚠️ Potential issue | 🟡 MinorUpdate the copyright year.
The header still lists 2020–2023; please update to include 2026 after these changes. As per coding guidelines: “All source files must contain an NVIDIA copyright header with the year of latest meaningful modification. Include NVIDIA copyright header on ALL new files and update year on modified files.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h` around lines 1 - 15, Update the top-of-file NVIDIA copyright header block (the multi-line comment starting with "Copyright (c) 2020-2023, NVIDIA CORPORATION.") to reflect the latest modification year by replacing "2020-2023" with "2020-2026" so the header shows the current year of change.cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h (1)
1562-1564:⚠️ Potential issue | 🟠 Major
update_rotary_base_n_scaleshould userope_positioninstead oftlengthwhen Helix is active with dynamic scaling.For
RotaryScalingType::kDYNAMIC, the dynamic base extension depends directly on the sequence length parameter—specifically in the formulabase * powf((scale*seq_len/max_positions)-(scale-1), d/(d-2)). When Helix is enabled andrope_position != tlength, passingtlength(KV cache length) to this function causes the base frequency extension to be computed for the wrong context window, whileapply_rotary_embeddinglater applies rotation atrope_position. This mismatch yields inconsistent rotation frequencies.Line 1726 already uses
rope_positionfor the m_scale decision; line 1564 should be updated similarly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h` around lines 1562 - 1564, The call to mmha::update_rotary_base_n_scale is using tlength (KV cache length) which causes incorrect base extension for RotaryScalingType::kDYNAMIC when Helix dynamic scaling is active; change the third argument for sequence length from tlength to rope_position so the dynamic base is computed using the actual rotation position used later by apply_rotary_embedding; update the invocation that passes rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions, tlength to instead pass rope_position (keeping the same other symbols) so the computed frequencies match apply_rotary_embedding’s behavior.cpp/tensorrt_llm/common/attentionOp.cpp (1)
1-2:⚠️ Potential issue | 🟡 MinorUpdate copyright year to 2026.
The copyright header still shows
1993-2025, but this file has meaningful modifications in 2026. As per coding guidelines, the year should reflect the latest meaningful modification.-* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & +* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/common/attentionOp.cpp` around lines 1 - 2, Update the SPDX copyright header at the top of the file by changing the year range from "1993-2025" to "1993-2026" so it reflects the current meaningful modification; locate the SPDX header line beginning with "SPDX-FileCopyrightText" (the comment block in attentionOp.cpp) and replace the trailing year range accordingly, preserving the rest of the header text and formatting.tensorrt_llm/_torch/modules/attention.py (1)
1633-1637:⚠️ Potential issue | 🟡 Minor
forward_context_defaultsetshelix_position_offsetswithout theenable_helix_testguard, unlikeAttention.forward.In
Attention.forward(line 728), helix position offsets are only set whenself.enable_helix_test and self.mapping.has_cp_helix(). However, inMLA.forward_context_default(line 1633), the guard is onlyself.enable_helix_test— it doesn't checkself.mapping.has_cp_helix(). Ifenable_helix_test=Trueis ever set without Helix CP, this will unconditionally write toattn_metadata.helix_position_offsets.Proposed fix
- if self.enable_helix_test: + if self.enable_helix_test and self.mapping.has_cp_helix():🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 1633 - 1637, MLA.forward_context_default currently sets attn_metadata.helix_position_offsets whenever self.enable_helix_test is true; change it to match Attention.forward by guarding the write with both self.enable_helix_test and self.mapping.has_cp_helix() so helix_position_offsets is only set when Helix CP is present. Locate MLA.forward_context_default and replace the single-condition block that assigns attn_metadata.helix_position_offsets = position_ids with a compound condition checking self.enable_helix_test and self.mapping.has_cp_helix() before assigning.
🧹 Nitpick comments (9)
cpp/tensorrt_llm/common/attentionOp.h (1)
147-206:enqueueContextParamsToString()is missing the two new Helix fields.
helix_position_offsetsandhelix_is_inactive_rankare not emitted. When debugging Helix-related attention issues, this omission makes the dump incomplete.♻️ Proposed addition
ss << "v_ptr: " << this->v_ptr << std::endl; + ss << "helix_position_offsets: " << this->helix_position_offsets << std::endl; + ss << "helix_is_inactive_rank: " << this->helix_is_inactive_rank << std::endl; return ss.str();🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/common/attentionOp.h` around lines 147 - 206, enqueueContextParamsToString() currently omits the two new Helix members; update this function to append the helix fields to the string dump by adding lines that output this->helix_position_offsets and this->helix_is_inactive_rank (mirror the style used for other pointer/int members such as "block_offsets" and "cross_kv"); ensure you format them consistently (e.g., "helix_position_offsets: " << this->helix_position_offsets << std::endl and "helix_is_inactive_rank: " << this->helix_is_inactive_rank << std::endl) so Helix-related attention debugging includes these values.tests/unittest/_torch/modules/test_mha_helix.py (5)
616-628: Broad exception catch and re-raise without chaining loses context.Catching bare
Exceptionis overly broad, and re-raising withoutfromloses the exception chain. Useraise ... from errto preserve the original traceback in the exception chain. As per coding guidelines: "When using try-except blocks, limit the except to the smallest set of errors possible. Avoid bareexcept:clauses."Suggested fix
def _run_single_rank(func, *args, **kwargs): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) print(f"rank {rank} starting") try: ret = func(rank, *args, **kwargs) - print(f"rank {rank} done") - return ret - except Exception: + except Exception as err: traceback.print_exc() tb = traceback.format_exc() - raise Exception(f"\n\nError occurred. Original traceback is\n{tb}\n") + raise RuntimeError( + f"\n\nError occurred on rank {rank}. Original traceback is\n{tb}\n" + ) from err + print(f"rank {rank} done") + return ret🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/test_mha_helix.py` around lines 616 - 628, The helper _run_single_rank currently catches Exception broadly and re-raises a new Exception without chaining, which loses the original traceback; change the except block to "except Exception as err" and re-raise the new Exception using "raise Exception(... ) from err" (or simply re-raise the original error) so the original exception context from the call to func(rank, ...) is preserved; update references in this function around tensorrt_llm.mpi_rank() and torch.cuda.set_device(rank) accordingly.
311-317: Dead assignment:startis assigned twice before the loop.Line 311 assigns
start = time.time()but line 317 immediately overwrites it before any use. Remove the first assignment.outputs = [] - start = time.time() # CUDA graph setup for timing use_cuda_graph = gen_steps > scenario.ref_steps graph = None graph_output = None start = time.time()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/test_mha_helix.py` around lines 311 - 317, Remove the redundant initial timestamp assignment: the variable start is set to time.time() at the top and immediately overwritten later before use; delete the first start = time.time() so only the later assignment remains. Edit the test (around the CUDA graph setup where use_cuda_graph, graph, and graph_output are declared) to keep a single start = time.time() just before timing begins.
197-205: Unused loop variablename.- for name, param in attn.named_parameters(): + for _name, param in attn.named_parameters():🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/test_mha_helix.py` around lines 197 - 205, The loop in _generate_random_weights uses an unused variable name from attn.named_parameters(); change the loop to avoid the unused binding by either iterating over attn.parameters() or replacing name with an underscore (for _, param in attn.named_parameters()), then keep the existing dtype/initialization logic for param.data so there are no unused variables flagged.
596-602: Usingcp_allgatherto broadcast is functional but wasteful.Only rank 0 has valid
ref_output; other ranks allocate an empty tensor just to participate in the allgather. Atorch.distributed.broadcastfrom rank 0 would be more efficient. This is a test, so performance isn't critical, but worth noting.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/test_mha_helix.py` around lines 596 - 602, The test currently uses cp_allgather(ref_output, mapping=mapping, dim=0) to broadcast the reference from rank 0, which forces all other ranks to allocate empty tensors; replace that allgather with a broadcast from rank 0 (e.g., torch.distributed.broadcast or your test-suite broadcast helper) so only rank 0 provides the real data and other ranks create an appropriately shaped/typed tensor and receive it; update the code around the cp_allgather call (referencing cp_allgather, ref_output, and mapping) to allocate ref_output on non-root ranks with the same shape/dtype and then call broadcast(ref_output, src=0), removing the mapping/allgather usage.
22-22: Use built-in generic types instead oftyping.Listandtyping.Optional.Since TensorRT-LLM requires Python ≥ 3.10,
listandOptionalfromtypingare unnecessary. Uselist[float]andint | Nonedirectly.-from typing import List, OptionalThen update usages at lines 639-641:
- gen_steps: Optional[int] = None, + gen_steps: int | None = None, max_mismatch_ratio: float = 0.02, - mismatch_ratios: Optional[List[float]] = None, + mismatch_ratios: list[float] | None = None,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/test_mha_helix.py` at line 22, Remove the typing import and migrate annotations that use List and Optional to Python 3.10+ built-in generics: delete the line importing "List" and "Optional" and replace any occurrences of "List[float]" with "list[float]" and "Optional[int]" (or similar Optional[...] uses) with the union form "int | None" (or the appropriate type | None) throughout the module; specifically update the spots that reference the symbols "List" and "Optional" so all type annotations use built-in generics.tensorrt_llm/_torch/modules/attention.py (3)
737-741: Simplify redundanthasattr+getattrguard.
getattr(obj, attr, None) is not Nonealready handles the case where the attribute doesn't exist; the precedinghasattrcheck is redundant.Proposed simplification
- if hasattr(attn_metadata, - 'helix_is_inactive_rank') and getattr( - attn_metadata, 'helix_is_inactive_rank', - None) is not None: - attn_metadata.helix_is_inactive_rank.fill_(False) + inactive = getattr(attn_metadata, 'helix_is_inactive_rank', None) + if inactive is not None: + inactive.fill_(False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 737 - 741, The guard before clearing helix_is_inactive_rank is redundant: replace the combined hasattr + getattr check with a single getattr(attn_metadata, 'helix_is_inactive_rank', None) is not None check and only then call attn_metadata.helix_is_inactive_rank.fill_(False); locate this in the attention module where attn_metadata and its helix_is_inactive_rank attribute are referenced and remove the initial hasattr(...) condition so the code uses the single getattr-based null check.
991-995: Silent fallback ofrms_norm_epsin Helix test mode may hide configuration bugs.When
enable_helix_testisTrue,rms_norm_epssilently falls back to1e-6if the attribute is missing frompretrained_config. If the pretrained model actually uses a different epsilon (e.g.,1e-5), this silent default could produce subtly wrong test results. Consider logging when the fallback is used.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 991 - 995, The silent fallback of rms_norm_eps when enable_helix_test is True can hide config mismatches; update the block that sets rms_norm_eps (referencing enable_helix_test and config.pretrained_config) to detect whether pretrained_config actually has rms_norm_eps and, if not, emit a warning or info log stating that the default 1e-6 is being used for helix tests (include the model identifier or config reference if available) so callers know a fallback occurred; keep the existing fallback value but ensure the log is clear and only triggered when the attribute is missing.
459-507: Helix post-processing logic is duplicated betweenAttention._helix_post_processandMLA._attn_forward_gen.The NCCL path (lines 465-475) is character-for-character identical to
MLA._attn_forward_genlines 1298-1315, and the FIFO paths differ only in the value dimension (head_dimvskv_lora_rank) and the use ofmaybe_execute_in_parallelin MLA. Consider extracting a shared helper (e.g.,_helix_alltoall_and_combine(partial_o, softmax_stats, mapping, num_heads_tp_cp, value_dim, ...)) to avoid maintaining two copies of the same multi-branch logic.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 459 - 507, The duplicated Helix all-to-all + combine logic in Attention._helix_post_process and MLA._attn_forward_gen should be extracted into a shared helper (suggested name _helix_alltoall_and_combine) that accepts (partial_o, softmax_stats, mapping, num_heads_tp_cp, value_dim, fifo_version_override=None, use_maybe_parallel=False) and encapsulates the NCCL branch (torch.transpose/contiguous, torch.split, alltoall_helix, transpose back, torch.ops.trtllm.helix_post_process) and the FIFO branches (HelixAllToAllNative.get(mapping), view/transpose patterns for fifo_version==1 and else, helix.alltoall_native, appropriate reshapes, and calls to torch.ops.trtllm.helix_post_process_native with the correct final flag); then replace logic in _helix_post_process to call this helper with value_dim=head_dim and in MLA._attn_forward_gen to call it with value_dim=kv_lora_rank and use_maybe_parallel set as before, preserving fifo_version from mapping.cp_config and cp_size/num_tokens behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h`:
- Around line 1494-1500: The helix inactivity flag list is being appended
per-beam causing a [b*beam_width] shape; in _torch/pyexecutor/model_engine.py
fix the accumulation of helix_is_inactive_rank so it appends once per request
instead of once per beam: move the helix_is_inactive_rank.append(...) out of the
beam loop (or wrap it with a conditional like only append when beam_idx==0) so
that helix_is_inactive_rank has length b (one entry per request) and indexing
used by decoderMaskedMultiheadAttentionTemplate.h (batch_idx_for_helix,
helix_is_inactive_rank) works correctly with beam_width>1.
In
`@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h`:
- Around line 431-445: The code computes rotary_position and helix_inactive
using helix_position_offsets[global_token_idx] (and other arrays) before
checking valid_token, which can cause out-of-bounds access when remove_padding
is enabled; move the valid_token guard to before the rotary_position and
helix_inactive calculations (i.e., evaluate valid_token first), and only read
params.helix_position_offsets[global_token_idx],
params.spec_decoding_position_offsets[...], and
params.mrope_position_deltas[...] when valid_token is true (otherwise use safe
defaults like 0 for rotary_position and false for helix_inactive); update
references in the block that sets rotary_position and helix_inactive (symbols:
rotary_position, helix_inactive, helix_position_offsets, global_token_idx,
spec_decoding_position_offsets, local_token_idx, batch_idx, past_seq_len,
mrope_position_deltas, helix_is_inactive_rank) accordingly.
In `@cpp/tensorrt_llm/thop/attentionOp.cpp`:
- Around line 474-483: Add brace-delimited blocks for the single-statement if
bodies around helix extraction and remove the duplicated extraction by creating
a small lambda that captures mla_tensor_params and assigns
helix_position_offsets and helix_is_inactive_rank into a given enqueue_params
instance; call this lambda for both EnqueueContextParams<T> and
EnqueueGenerationParams<T> (both inherit helix_* from EnqueueParams<T>) so the
logic for extracting mla_tensor_params and setting
enqueue_params.helix_position_offsets / enqueue_params.helix_is_inactive_rank is
centralized and all if(...) statements use braces.
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 543-561: The Helix CP branch in attention.py (the block guarded by
self.mapping.has_cp_helix() and attn_metadata.num_contexts == 0) silently
bypasses all quantization parameters (out_scale, out_scale_sf, kv_scales_sf,
kv_scales_sf_inv); add an explicit guard or comment: either assert that
quantization is incompatible with Helix CP (e.g., raise/assert if any of those
quant parameters are set) or add a clear comment above the block referencing
mapping.has_cp_helix(), attn_metadata and explaining that Helix CP currently
disables quantization and why, and if applicable emit a one-line warning/log
when quant params are present to prevent silent skipping; ensure references to
self.attn.forward(), softmax_stats and self._helix_post_process() remain
unchanged.
---
Outside diff comments:
In `@cpp/tensorrt_llm/common/attentionOp.cpp`:
- Around line 1-2: Update the SPDX copyright header at the top of the file by
changing the year range from "1993-2025" to "1993-2026" so it reflects the
current meaningful modification; locate the SPDX header line beginning with
"SPDX-FileCopyrightText" (the comment block in attentionOp.cpp) and replace the
trailing year range accordingly, preserving the rest of the header text and
formatting.
In `@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h`:
- Around line 1-15: Update the top-of-file NVIDIA copyright header block (the
multi-line comment starting with "Copyright (c) 2020-2023, NVIDIA CORPORATION.")
to reflect the latest modification year by replacing "2020-2023" with
"2020-2026" so the header shows the current year of change.
In
`@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h`:
- Around line 1562-1564: The call to mmha::update_rotary_base_n_scale is using
tlength (KV cache length) which causes incorrect base extension for
RotaryScalingType::kDYNAMIC when Helix dynamic scaling is active; change the
third argument for sequence length from tlength to rope_position so the dynamic
base is computed using the actual rotation position used later by
apply_rotary_embedding; update the invocation that passes rotary_embedding_base,
rotary_embedding_scale, params.rotary_embedding_scale_type,
params.rotary_embedding_dim, params.rotary_embedding_max_positions, tlength to
instead pass rope_position (keeping the same other symbols) so the computed
frequencies match apply_rotary_embedding’s behavior.
In `@cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h`:
- Around line 1-15: Update the top-of-file NVIDIA copyright header so the year
range includes 2026 (e.g., change "2020-2025" to "2020-2026") in the header
block at the top of xqaParams.h; edit the existing comment block rather than
adding a new header and ensure the License text remains unchanged.
In `@cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h`:
- Around line 1-15: Update the NVIDIA copyright header at the top of the file by
changing the year range "2020-2023" to include the latest modification year
(e.g., "2020-2026") so the header reflects the most recent meaningful change;
locate the header comment block at the file top and replace the year range
accordingly.
In
`@cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h`:
- Around line 1-16: Update the copyright header's year range in the file by
replacing the existing "Copyright (c) 2019-2024, NVIDIA CORPORATION." entry in
the top-of-file comment block so it reflects the latest modification year (e.g.,
change "2019-2024" to "2019-2026"); ensure the rest of the header text
(including the NAVER/CLOVA line and Apache License block) remains unchanged.
In `@cpp/tensorrt_llm/kernels/xqaDispatcher.cpp`:
- Around line 1-15: Update the copyright header block at the top of
cpp/tensorrt_llm/kernels/xqaDispatcher.cpp by changing the year range in the
comment that currently reads "2020-2024" to include 2026 (e.g., "2020-2026") so
the NVIDIA copyright header reflects the latest modification.
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1-5: Add the standard NVIDIA Apache-2.0 copyright header (with
year 2026) to the top of the file
tensorrt_llm/_torch/attention_backend/trtllm.py before any imports; ensure the
header matches the project's canonical NVIDIA header text and license block and
update the file's modification year to 2026, leaving the rest of the module
(imports like math, os, weakref and dataclass/type hints) unchanged.
- Around line 513-629: The assert that checks
_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION and _trtllm_gen_reason should be replaced
with a non‑fatal warning and fall back to the existing thop attention path when
trtllm_gen.is_supported returns False for reasons other than KV cache update;
locate the assert block after the trtllm_gen.is_supported call and the
trtllm_gen_attention invocation and change it to log a warning including
_trtllm_gen_reason (use your logger or warnings.warn) and let execution continue
to the else path (thop.attention fallback) instead of raising, keeping the
special case for KV cache update handling if needed.
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 1633-1637: MLA.forward_context_default currently sets
attn_metadata.helix_position_offsets whenever self.enable_helix_test is true;
change it to match Attention.forward by guarding the write with both
self.enable_helix_test and self.mapping.has_cp_helix() so helix_position_offsets
is only set when Helix CP is present. Locate MLA.forward_context_default and
replace the single-condition block that assigns
attn_metadata.helix_position_offsets = position_ids with a compound condition
checking self.enable_helix_test and self.mapping.has_cp_helix() before
assigning.
---
Nitpick comments:
In `@cpp/tensorrt_llm/common/attentionOp.h`:
- Around line 147-206: enqueueContextParamsToString() currently omits the two
new Helix members; update this function to append the helix fields to the string
dump by adding lines that output this->helix_position_offsets and
this->helix_is_inactive_rank (mirror the style used for other pointer/int
members such as "block_offsets" and "cross_kv"); ensure you format them
consistently (e.g., "helix_position_offsets: " << this->helix_position_offsets
<< std::endl and "helix_is_inactive_rank: " << this->helix_is_inactive_rank <<
std::endl) so Helix-related attention debugging includes these values.
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 737-741: The guard before clearing helix_is_inactive_rank is
redundant: replace the combined hasattr + getattr check with a single
getattr(attn_metadata, 'helix_is_inactive_rank', None) is not None check and
only then call attn_metadata.helix_is_inactive_rank.fill_(False); locate this in
the attention module where attn_metadata and its helix_is_inactive_rank
attribute are referenced and remove the initial hasattr(...) condition so the
code uses the single getattr-based null check.
- Around line 991-995: The silent fallback of rms_norm_eps when
enable_helix_test is True can hide config mismatches; update the block that sets
rms_norm_eps (referencing enable_helix_test and config.pretrained_config) to
detect whether pretrained_config actually has rms_norm_eps and, if not, emit a
warning or info log stating that the default 1e-6 is being used for helix tests
(include the model identifier or config reference if available) so callers know
a fallback occurred; keep the existing fallback value but ensure the log is
clear and only triggered when the attribute is missing.
- Around line 459-507: The duplicated Helix all-to-all + combine logic in
Attention._helix_post_process and MLA._attn_forward_gen should be extracted into
a shared helper (suggested name _helix_alltoall_and_combine) that accepts
(partial_o, softmax_stats, mapping, num_heads_tp_cp, value_dim,
fifo_version_override=None, use_maybe_parallel=False) and encapsulates the NCCL
branch (torch.transpose/contiguous, torch.split, alltoall_helix, transpose back,
torch.ops.trtllm.helix_post_process) and the FIFO branches
(HelixAllToAllNative.get(mapping), view/transpose patterns for fifo_version==1
and else, helix.alltoall_native, appropriate reshapes, and calls to
torch.ops.trtllm.helix_post_process_native with the correct final flag); then
replace logic in _helix_post_process to call this helper with value_dim=head_dim
and in MLA._attn_forward_gen to call it with value_dim=kv_lora_rank and
use_maybe_parallel set as before, preserving fifo_version from mapping.cp_config
and cp_size/num_tokens behavior.
In `@tests/unittest/_torch/modules/test_mha_helix.py`:
- Around line 616-628: The helper _run_single_rank currently catches Exception
broadly and re-raises a new Exception without chaining, which loses the original
traceback; change the except block to "except Exception as err" and re-raise the
new Exception using "raise Exception(... ) from err" (or simply re-raise the
original error) so the original exception context from the call to func(rank,
...) is preserved; update references in this function around
tensorrt_llm.mpi_rank() and torch.cuda.set_device(rank) accordingly.
- Around line 311-317: Remove the redundant initial timestamp assignment: the
variable start is set to time.time() at the top and immediately overwritten
later before use; delete the first start = time.time() so only the later
assignment remains. Edit the test (around the CUDA graph setup where
use_cuda_graph, graph, and graph_output are declared) to keep a single start =
time.time() just before timing begins.
- Around line 197-205: The loop in _generate_random_weights uses an unused
variable name from attn.named_parameters(); change the loop to avoid the unused
binding by either iterating over attn.parameters() or replacing name with an
underscore (for _, param in attn.named_parameters()), then keep the existing
dtype/initialization logic for param.data so there are no unused variables
flagged.
- Around line 596-602: The test currently uses cp_allgather(ref_output,
mapping=mapping, dim=0) to broadcast the reference from rank 0, which forces all
other ranks to allocate empty tensors; replace that allgather with a broadcast
from rank 0 (e.g., torch.distributed.broadcast or your test-suite broadcast
helper) so only rank 0 provides the real data and other ranks create an
appropriately shaped/typed tensor and receive it; update the code around the
cp_allgather call (referencing cp_allgather, ref_output, and mapping) to
allocate ref_output on non-root ranks with the same shape/dtype and then call
broadcast(ref_output, src=0), removing the mapping/allgather usage.
- Line 22: Remove the typing import and migrate annotations that use List and
Optional to Python 3.10+ built-in generics: delete the line importing "List" and
"Optional" and replace any occurrences of "List[float]" with "list[float]" and
"Optional[int]" (or similar Optional[...] uses) with the union form "int | None"
(or the appropriate type | None) throughout the module; specifically update the
spots that reference the symbols "List" and "Optional" so all type annotations
use built-in generics.
...nsorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h
Outdated
Show resolved
Hide resolved
e409754 to
d046dac
Compare
6c9550f to
681c940
Compare
edb7f6c to
d90ab3a
Compare
d90ab3a to
7165932
Compare
5416db9 to
6afc9e0
Compare
6eb0db6 to
4043768
Compare
6b2eafa to
4ae518f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36195 [ run ] triggered by Bot. Commit: |
b2a951f to
579333c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36878 [ run ] triggered by Bot. Commit: |
|
PR_Github #36878 [ run ] completed with state
|
mikeiovine
left a comment
There was a problem hiding this comment.
Signing off on torch module changes
579333c to
df698fd
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36965 [ run ] triggered by Bot. Commit: |
|
PR_Github #36965 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37023 [ run ] triggered by Bot. Commit: |
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
df698fd to
dafaee1
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #37024 [ run ] triggered by Bot. Commit: |
|
PR_Github #37024 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37090 [ run ] triggered by Bot. Commit: |
|
PR_Github #37090 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #37116 [ run ] triggered by Bot. Commit: |
|
PR_Github #37116 [ run ] completed with state |
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Description
This MR generalizes Helix CP from MLA-only to standard GQA/MHA.
Attentionclass.enable_helix_testfrom MLA module as it's only for testing.splitKVCacheKernelandcacheFormatterfor CP-aware block distribution.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.