[https://nvbugs/6094072][fix] swizzle GPT-OSS dummy MXFP4 weights#13708
Conversation
|
/bot run |
📝 WalkthroughWalkthroughThis PR introduces a reusable static helper method ChangesMXFP4 Weight Swizzling and Memory Management
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Review rate limit: 9/10 reviews remaining, refill in 6 minutes. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (1)
1124-1126: ⚡ Quick winAdd type annotations to the new helper and override.
These new methods are unannotated, which weakens the mypy coverage this file is expected to support. Adding explicit
torch.nn.Module/str/torch.Tensorparameter types and-> Nonewhere appropriate would keep the new code aligned with the repo typing rules. As per coding guidelines, "Always annotate functions; make the return typeNoneif the function does not return anything" and "code should support mypy type checking".Also applies to: 1346-1360
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py` around lines 1124 - 1126, Annotate the new helper _swizzle_and_replace and the corresponding override method mentioned in the review with explicit types: use module: torch.nn.Module, weight_name: str, scale_name: str, weight_data: torch.Tensor, scale_data: torch.Tensor (or Optional[torch.Tensor] if it can be None) and add a return type of -> None; also ensure the override method's parameters and return type are similarly annotated (use torch.nn.Module/str/torch.Tensor as appropriate) and add any necessary imports (torch and typing.Optional) so mypy sees the types.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py`:
- Around line 1204-1208: After performing the weight replacement/swizzle path
you must refresh module.quant_scales so it no longer references the old tensors
whose storage was resized or set to None; update the code after the
_swizzle_and_replace calls (and after any assignments that set
fc31_input_dequant / fc2_input_dequant to None in the float8 branch) to rebuild
module.quant_scales (the object populated by setup_quant_scales() which
originally captured objects from create_weights()) so entries like
module.quant_scales.fc31_dequant, fc2_dequant, fc31_input_dequant and
fc2_input_dequant point to the new tensors or None as appropriate.
- Around line 1127-1134: _swizzle_and_replace unconditionally calls
old_param.data.storage().resize_(0) which can corrupt new_weight/new_scale if
they alias the original storage; update _swizzle_and_replace to check aliasing
the same way swizzle_weight_and_scale does (compare old_param.data.data_ptr() or
storage pointer to new_weight.data_ptr() and new_scale.data_ptr()) and only
free/resize the old storage when there is no alias, or simply remove the manual
resize; also add explicit type annotations to the _swizzle_and_replace method
signature for its parameters and return type so callers and linters know
expected types (referencing function name _swizzle_and_replace, helper
swizzle_weight_and_scale, variables old_param,
old_param.data.storage().resize_(0), new_weight, new_scale, and data_ptr()).
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py`:
- Around line 1124-1126: Annotate the new helper _swizzle_and_replace and the
corresponding override method mentioned in the review with explicit types: use
module: torch.nn.Module, weight_name: str, scale_name: str, weight_data:
torch.Tensor, scale_data: torch.Tensor (or Optional[torch.Tensor] if it can be
None) and add a return type of -> None; also ensure the override method's
parameters and return type are similarly annotated (use
torch.nn.Module/str/torch.Tensor as appropriate) and add any necessary imports
(torch and typing.Optional) so mypy sees the types.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 4f467fba-aca2-4000-af70-24353769124e
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py
|
PR_Github #46571 [ run ] triggered by Bot. Commit: |
|
PR_Github #46571 [ run ] completed with state
|
d1caab5 to
1766894
Compare
|
/bot run |
|
PR_Github #46579 [ run ] triggered by Bot. Commit: |
|
PR_Github #46579 [ run ] completed with state
|
1766894 to
322ed86
Compare
|
/bot run |
|
PR_Github #46606 [ run ] triggered by Bot. Commit: |
|
PR_Github #46606 [ run ] completed with state
|
|
/bot run |
|
PR_Github #46612 [ run ] triggered by Bot. Commit: |
|
PR_Github #46612 [ run ] completed with state
|
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
…rmat Fixed by the swizzle GPT-OSS dummy MXFP4 weights commit on this branch. Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
322ed86 to
c51ee54
Compare
|
/bot run |
|
PR_Github #46653 [ run ] triggered by Bot. Commit: |
|
PR_Github #46653 [ run ] completed with state
|
|
/bot run |
|
PR_Github #46694 [ run ] triggered by Bot. Commit: |
|
PR_Github #46694 [ run ] completed with state |
Signed-off-by: dongfengy <99041270+dongfengy@users.noreply.github.com>
|
/bot run |
|
PR_Github #46829 [ run ] triggered by Bot. Commit: |
|
PR_Github #46829 [ run ] completed with state
|
|
/bot skip --comment "Passed 19 hours ago. No change since then except rebase. CI failing with unrelated tests." |
|
PR_Github #46884 [ skip ] triggered by Bot. Commit: |
|
PR_Github #46884 [ skip ] completed with state |
…IDIA#13708) Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com> Signed-off-by: dongfengy <99041270+dongfengy@users.noreply.github.com>
TestGPTOSS::test_dummy_load_format tests dummy weights loading, which means actual weights loading is not called. We need to do some necessary post process to ensure the weights format is correct.
@coderabbitai summary
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.