[None][feat] Minimax RMS norm optimization#12163
Conversation
096ec75 to
201717c
Compare
049daba to
d406d02
Compare
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot run --disable-fail-fast |
|
PR_Github #42503 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces a new MiniMax collective all-reduce operation for RMS normalization using Lamport-style cross-rank synchronization. It adds CUDA kernels, PyTorch bindings, a distributed module wrapper, integration into the MiniMaxM2 attention layer, benchmarks, and unit tests to support both single-tensor and dual Q+K tensor paths. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu`:
- Around line 1-7: This new CUDA source (MiniMaxReduceRMSKernel.cu) is missing
the required NVIDIA copyright/SPDX file header; add the standard NVIDIA header
block at the very top of MiniMaxReduceRMSKernel.cu (before any `#include`),
including the current year, "NVIDIA CORPORATION" copyright line and the
SPDX-License-Identifier (e.g., Apache-2.0) per the repo guideline so the file
matches other TensorRT-LLM sources.
In `@cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h`:
- Around line 1-68: This header (containing MiniMaxReduceRMSParams and
minimax_reduce_rms_op in namespace kernels::minimax_ar) is new and must include
the required NVIDIA OSS file header; add the NVIDIA copyright/SPDX header block
at the top of the file (with the correct latest modification year and SPDX
identifier) before the `#pragma` once so the file complies with the project coding
guidelines.
In `@cpp/tensorrt_llm/thop/allreduceOp.cpp`:
- Around line 1837-1848: The kernel currently assumes rms_gamma is BF16 but
accepts any norm_weight dtype; add a runtime dtype guard that rejects non-BF16
norm_weight before constructing MiniMaxReduceRMSParams to avoid silent mis-typed
gamma (check norm_weight.scalar_type() and return/throw a clear error if not
torch::kBFloat16), and apply the same guard for the analogous assignment blocks
around the other allreduce params region (the block using
allreduce_params.rms_gamma/_k at ~1867-1898) so both entrypoints refuse non-BF16
gamma until the kernel supports other types.
In `@tensorrt_llm/_torch/distributed/__init__.py`:
- Around line 5-8: The export block in
tensorrt_llm/_torch/distributed/__init__.py is not sorted and fails pre-commit;
run isort (or manually sort alphabetically) on the from .ops import (...) line
so the imported names (AllReduce, AllReduceParams, AllReduceStrategy,
HelixAllToAllNative, MiniMaxAllReduceRMS, MoEAllReduce, MoEAllReduceParams,
all_to_all_4d, all_to_all_5d, allgather, alltoall_helix, cp_allgather,
reducescatter, userbuffers_allreduce_finalize) are in the linter-expected order
and update the single import line accordingly.
In `@tests/unittest/_torch/multi_gpu/test_allreduce.py`:
- Around line 900-903: The test test_minimax_allreduce_rms_qk currently forces
mpi_pool_executor=4 but lacks a guard; add a pytest skip condition to the test
so it only runs when at least 4 GPUs are visible (e.g., use
pytest.mark.skipif(torch.cuda.device_count() < 4, reason="requires 4 GPUs")) and
ensure torch is imported at top of the test file; apply this to the parametrized
decorator that sets mpi_pool_executor so fixture setup won't fail on smaller
runners.
- Around line 715-725: The current reference path computes rms_norm only over
the local hidden slice (after reshape to [total_tokens, tp_size, local_hidden])
so it misses the cross-rank reduction; change the reference computation to
perform normalization over the full hidden dimension (tp_size * local_hidden)
before slicing back to the per-rank view: reshape input to [total_tokens, -1]
(or compute squared-sum/mean across the combined hidden dimension using
tensor_parallel_size * local_hidden), run rms_norm (or equivalent manual rms
calculation using rms_weights and eps) on that full-hidden tensor to produce a
global ref_output, then reshape to [total_tokens, tensor_parallel_size, -1],
cast to origin_dtype and finally take the slice ref_output[:,
tensor_parallel_rank, :] so the reference includes the cross-rank reduction;
update uses of rms_norm, ref_output, input and rms_weights accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a69d8bb5-91e6-4f32-815b-f00f4d262d37
📒 Files selected for processing (8)
cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cucpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.hcpp/tensorrt_llm/thop/allreduceOp.cpptensorrt_llm/_torch/distributed/__init__.pytensorrt_llm/_torch/distributed/ops.pytensorrt_llm/_torch/models/modeling_minimaxm2.pytests/microbenchmarks/minimax_all_reduce.pytests/unittest/_torch/multi_gpu/test_allreduce.py
|
PR_Github #42503 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42663 [ run ] triggered by Bot. Commit: |
|
PR_Github #42663 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42930 [ run ] triggered by Bot. Commit: |
|
PR_Github #42930 [ run ] completed with state |
328d56c to
8cf5272
Compare
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
d1a546f to
d912d1c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #43710 [ run ] triggered by Bot. Commit: |
|
/bot kill |
|
/bot run --disable-fail-fast |
|
PR_Github #43784 [ kill ] triggered by Bot. Commit: |
|
PR_Github #43785 [ run ] triggered by Bot. Commit: |
|
PR_Github #43784 [ kill ] completed with state |
|
PR_Github #43710 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #43793 [ run ] triggered by Bot. Commit: |
|
PR_Github #43793 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43941 [ run ] triggered by Bot. Commit: |
|
PR_Github #43941 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #44086 [ run ] triggered by Bot. Commit: |
|
PR_Github #44086 [ run ] completed with state |
This PR optimizes MiniMax M2 Q/K RMSNorm in tensor-parallel attention.
Previously, after
qkv_proj, each rank only owned a local shard[N, D / tp]. To perform RMSNorm over the full Q/K hidden dimension, the implementation first all-gathered local shardsinto a full
[N, D]tensor, applied RMSNorm, and then sliced the result back to each rank. This introduced unnecessary communication and temporary full-tensor materialization.This PR adds a dedicated MiniMax allreduce RMS kernel that keeps computation on local shards. Each rank computes the local variance sum for its
[N, D / tp]shard, reduces the varianceacross TP ranks, and then applies RMSNorm locally using the rank-local gamma shard. This reduces synchronization volume from full Q/K activations to per-token variance sums and removes the
allgather -> full RMSNorm -> reshard path.
Main changes:
Here's benchmark result B200 * 4, isl/osl 2k/256, concurrency 10
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.