[None][perf] Fuse add + norm + fp8 quant pattern#12674
Conversation
a47bca0 to
9c5dd30
Compare
9c5dd30 to
81e4fbe
Compare
📝 WalkthroughWalkthroughThis pull request introduces a fused add+RMSNorm+quantization operation integrated into the TensorRT-LLM compilation pipeline. The changes add a new custom operator, pattern matcher registration, backend integration, utility mappings, and tests to support flashinfer's fused quantized normalization kernel for single-process execution. Changes
Sequence DiagramsequenceDiagram
participant torch as torch.compile
participant backend as Backend
participant matcher as Pattern Matcher
participant op as Custom Operator
participant kernel as Flashinfer Kernel
torch->>backend: get_custom_pass(world_size=1)
backend->>matcher: register_add_norm_quant()
matcher->>matcher: Match: add → rmsnorm → quantize
backend->>matcher: register_add_norm()
torch->>matcher: Apply pattern matching
matcher->>op: Detected fusion pattern
op->>kernel: Dispatch to fused_add_rmsnorm_quant
kernel-->>op: fp8_out, updated_residual
op-->>torch: Return fused result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/unittest/_torch/compilation/test_add_norm_quant.py (2)
53-53: Assertion message could be more descriptive.Consider including the actual match count in the assertion message for easier debugging when the test fails.
💡 Suggested improvement
- assert backend.match_count[0] == 1, "Pattern Matching Failed" + assert backend.match_count[0] == 1, f"Pattern Matching Failed: expected 1 match, got {backend.match_count[0]}"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/compilation/test_add_norm_quant.py` at line 53, The assertion at backend.match_count[0] in test_add_norm_quant.py uses a generic message; update it to include the actual match count and expected value for easier debugging (e.g., reference backend.match_count[0] and the expected 1 in the message). Modify the assertion in the test function that contains "assert backend.match_count[0] == 1" so the failure message interpolates backend.match_count[0] (and optionally the expected 1) into the string to show both actual and expected counts.
50-67: Consider tightening tolerances forinter_outcomparison.The
inter_outresult (the residual after the add operation) should be numerically identical between the fused and unfused paths since both computex + residual. The current tolerances (rtol=0.05, atol=0.15) are quite loose for a simple element-wise addition.If the fused kernel produces bit-identical results for the residual update, consider using stricter tolerances (e.g.,
rtol=1e-5, atol=1e-5for float16/bfloat16). If there are known numerical differences due to the kernel implementation, documenting this in a comment would be helpful.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/compilation/test_add_norm_quant.py` around lines 50 - 67, The inter_out comparison is too loose—tighten the tolerances for actual_inter vs ref_inter in this test: replace torch.testing.assert_close(actual_inter, ref_inter, rtol=0.05, atol=0.15) with much stricter tolerances (e.g., rtol=1e-5, atol=1e-5) for the elementwise add check, or if dtype-specific differences exist, set dtype-aware tolerances (check dtype and use 1e-5 for float16/bfloat16, else appropriate tighter values) and add a brief comment above the assertion explaining why a relaxed tolerance would be used only for known kernel-induced differences; keep references to ref_func/func/actual_inter/ref_inter and ensure backend.match_count[0] assertion remains.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unittest/_torch/compilation/test_add_norm_quant.py`:
- Line 53: The assertion at backend.match_count[0] in test_add_norm_quant.py
uses a generic message; update it to include the actual match count and expected
value for easier debugging (e.g., reference backend.match_count[0] and the
expected 1 in the message). Modify the assertion in the test function that
contains "assert backend.match_count[0] == 1" so the failure message
interpolates backend.match_count[0] (and optionally the expected 1) into the
string to show both actual and expected counts.
- Around line 50-67: The inter_out comparison is too loose—tighten the
tolerances for actual_inter vs ref_inter in this test: replace
torch.testing.assert_close(actual_inter, ref_inter, rtol=0.05, atol=0.15) with
much stricter tolerances (e.g., rtol=1e-5, atol=1e-5) for the elementwise add
check, or if dtype-specific differences exist, set dtype-aware tolerances (check
dtype and use 1e-5 for float16/bfloat16, else appropriate tighter values) and
add a brief comment above the assertion explaining why a relaxed tolerance would
be used only for known kernel-induced differences; keep references to
ref_func/func/actual_inter/ref_inter and ensure backend.match_count[0] assertion
remains.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cbc06d65-16ae-4ef2-8e9b-1ac4cab2c992
📒 Files selected for processing (6)
requirements.txttensorrt_llm/_torch/compilation/backend.pytensorrt_llm/_torch/compilation/patterns/residual_add_norm.pytensorrt_llm/_torch/compilation/utils.pytensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.pytests/unittest/_torch/compilation/test_add_norm_quant.py
81e4fbe to
ddebe9d
Compare
ddebe9d to
37befd4
Compare
|
/bot run |
|
PR_Github #41577 [ run ] triggered by Bot. Commit: |
|
PR_Github #41577 [ run ] completed with state
|
|
/bot run |
|
PR_Github #41717 [ run ] triggered by Bot. Commit: |
|
PR_Github #41717 [ run ] completed with state
|
|
/bot run |
|
PR_Github #41738 [ run ] triggered by Bot. Commit: |
|
PR_Github #41738 [ run ] completed with state
|
c3910b9 to
4554a23
Compare
|
/bot run |
|
PR_Github #41767 [ run ] triggered by Bot. Commit: |
|
PR_Github #41767 [ run ] completed with state
|
|
/bot run |
|
PR_Github #41959 [ run ] triggered by Bot. Commit: |
|
PR_Github #41959 [ run ] completed with state
|
4554a23 to
db73338
Compare
db73338 to
a4194eb
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #46430 [ run ] triggered by Bot. Commit: |
|
PR_Github #46430 [ run ] completed with state
|
Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
|
/bot run |
|
PR_Github #46516 [ run ] triggered by Bot. Commit: |
|
PR_Github #46516 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46528 [ run ] triggered by Bot. Commit: |
|
PR_Github #46528 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46669 [ run ] triggered by Bot. Commit: |
|
PR_Github #46669 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
Description
Add an additional
PatternMatcherPass()in compilation backend to fuse (residual_add, rms_norm, fp8 static quantization) pattern. This pattern is replaced by a fused kernel from FlashInfer.~8% speedup for Qwen3-4B-FP8 checkpoint
ISL=OSL=1000, concurrency=1, requests=10
ISL=OSL=1000, concurrency=8 requests=80
Test Coverage
New unittest in
unittest/_torch/compilation/test_add_norm_quant.pyExisting integration tests will cover e2e FP8 accuracy.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.