[None][feat] Avoid duplicated computation with ADP + Helix CP in GQA#11891
[None][feat] Avoid duplicated computation with ADP + Helix CP in GQA#11891brb-nv merged 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughThe pull request adds Helix Context Parallelism (CP) support with residual forwarding in Qwen3 models. Changes introduce CP utility functions for data partitioning and cross-rank communication, extend Attention and MLA forward signatures to accept and propagate residuals, and update Qwen3 decoder layers to pass residuals through attention operations with CP-aware data handling. Changes
Sequence Diagram(s)sequenceDiagram
participant DecoderLayer as Qwen3DecoderLayer
participant Attn as Attention/MLA
participant CPUtil as CP Utilities
participant CPGroup as CP Rank Group
DecoderLayer->>Attn: forward(hidden_states, residual, AllReduceParams)
activate Attn
Attn->>CPUtil: _helix_cp_allgather_input(hidden_states, attn_metadata, mapping)
activate CPUtil
CPUtil->>CPGroup: cp_allgather across ranks
CPGroup-->>CPUtil: gathered data
CPUtil-->>Attn: concatenated input
deactivate CPUtil
Attn->>Attn: attention computation
Attn->>CPUtil: _helix_cp_output_projection(o_proj, attn_output, residual, mapping_o)
activate CPUtil
CPUtil->>CPUtil: CP-aware projection & slice
CPUtil->>CPUtil: handle residual slicing if provided
CPUtil-->>Attn: (projected_output, sliced_residual)
deactivate CPUtil
Attn-->>DecoderLayer: (hidden_states, residual) or tensor
deactivate Attn
DecoderLayer->>DecoderLayer: propagate residual to next layer
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/attention.py (1)
279-280: The Ellipsis (...) sentinel pattern for optional residual handling.The
residual is not ...check accommodatesMLA.forwardwhereresidualdefaults to...(Ellipsis) as a sentinel to distinguish "no residual provided" fromresidual=None. This is a valid runtime pattern, though the type annotationresidual: Optional[torch.Tensor]doesn't capture it.Consider documenting this sentinel pattern in the docstring for clarity, since the interaction between
Attention.forward(whereresidual=Noneis the default) andMLA.forward(whereresidual=...is the default) relies on this distinction.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 279 - 280, The code uses the Ellipsis sentinel (residual == ...) to distinguish "no residual provided" (MLA.forward default) from an explicit None (Attention.forward default); document this pattern in the relevant docstrings so future readers understand why the check `residual is not ...` exists. Update the docstrings on MLA.forward and Attention.forward to state that the residual parameter can be Ellipsis to mean "not provided" (and explain the semantic difference from None), reference the residual parameter and the runtime check (`residual is not ...`) and, if desired, mention the helper `_helix_cp_slice` that is run when a residual is present.tensorrt_llm/_torch/models/modeling_qwen3.py (1)
285-300: Consider documenting the_frozenattribute access pattern.The temporary unfreezing of
model_config._frozento modify the mapping is necessary for the CP-to-TP repurposing logic. However, accessing this private attribute (_frozen) creates a coupling with the internal implementation ofModelConfig.If this pattern is expected to be used elsewhere or
ModelConfigchanges, this could silently break. Consider either:
- Adding a comment explaining this is intentional and tested, or
- Exposing a public method on
ModelConfiglikewith_mapping(new_mapping)that handles the freeze/unfreeze safely.The logic itself is correct: attention layers use the original CP mapping via
mapping_with_cp, while other components (MLP, etc.) see the repurposed TP mapping.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/models/modeling_qwen3.py` around lines 285 - 300, This code temporarily toggles the private flag model_config._frozen to swap mappings (using model_config.mapping.repurpose_helix_cp_to_tp) before constructing Qwen3Model and then restores it; document this pattern: add a concise inline comment above the block that mentions the intentional unfreeze/freeze of model_config._frozen, why mapping_with_cp is preserved, and that this behavior is tested/required for CP→TP repurposing (referencing model_config._frozen, mapping_with_cp, repurpose_helix_cp_to_tp, and Qwen3Model); alternatively, if preferred, add a public helper on ModelConfig (e.g., with_mapping(new_mapping)) that encapsulates the freeze/unfreeze and call that here instead—either add the explanatory comment near the shown block or refactor to use ModelConfig.with_mapping to avoid direct _frozen access.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_qwen3.py`:
- Around line 285-300: This code temporarily toggles the private flag
model_config._frozen to swap mappings (using
model_config.mapping.repurpose_helix_cp_to_tp) before constructing Qwen3Model
and then restores it; document this pattern: add a concise inline comment above
the block that mentions the intentional unfreeze/freeze of model_config._frozen,
why mapping_with_cp is preserved, and that this behavior is tested/required for
CP→TP repurposing (referencing model_config._frozen, mapping_with_cp,
repurpose_helix_cp_to_tp, and Qwen3Model); alternatively, if preferred, add a
public helper on ModelConfig (e.g., with_mapping(new_mapping)) that encapsulates
the freeze/unfreeze and call that here instead—either add the explanatory
comment near the shown block or refactor to use ModelConfig.with_mapping to
avoid direct _frozen access.
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 279-280: The code uses the Ellipsis sentinel (residual == ...) to
distinguish "no residual provided" (MLA.forward default) from an explicit None
(Attention.forward default); document this pattern in the relevant docstrings so
future readers understand why the check `residual is not ...` exists. Update the
docstrings on MLA.forward and Attention.forward to state that the residual
parameter can be Ellipsis to mean "not provided" (and explain the semantic
difference from None), reference the residual parameter and the runtime check
(`residual is not ...`) and, if desired, mention the helper `_helix_cp_slice`
that is run when a residual is present.
ℹ️ Review info
Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 559091fc-4df2-4371-818b-a83d68fd0ee6
📒 Files selected for processing (2)
tensorrt_llm/_torch/models/modeling_qwen3.pytensorrt_llm/_torch/modules/attention.py
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #37920 [ run ] triggered by Bot. Commit: |
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
e93b873 to
e1d12bb
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #37935 [ run ] triggered by Bot. Commit: |
|
PR_Github #37935 [ run ] completed with state |
…VIDIA#11891) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Description
This MR is the GQA counterpart of this MR: #11167
Commonizes functionality between
MLA()andAttention().Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit