-
Notifications
You must be signed in to change notification settings - Fork 190
Added support for quantizing TEGroupedMLP for megatron-lm #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughThe changes add expert-model-parallel (EP) and MOE support to quantization: extend amax synchronization across DP+EP+TP and MOE local experts, integrate Transformer Engine GroupedLinear quantization, add Megatron MOE/TE-aware quantizers and registrations, extend distributed ParallelState with expert group, and expand tests/fixtures for MOE multi-GPU validation. Changes
Sequence Diagram(s)sequenceDiagram
participant Calib as Calibration Flow
participant DP as DataParallel
participant EP as ExpertParallel
participant TP as TensorParallel
participant MOE as MOE Local Experts
Calib->>DP: sync_quantizer_amax_across_dp_ep()
DP->>DP: collect & reduce amax across DP ranks
DP->>EP: synchronize amax across expert_model_parallel_group
Calib->>TP: sync_quantizer_amax_across_tp() (tensor parallel)
Calib->>MOE: for each module with sync_moe_local_experts_amax -> call it
MOE->>MOE: sync local-expert amax across EP/ETP ranks
Note over Calib: Amax synchronized across DP, EP, TP, and MOE local experts
sequenceDiagram
participant Model as Quantized MOE Model
participant TE as TransformerEngine GroupedLinear
participant Seq as Sequential MLP
participant Qt as Quantizers
Model->>TE: forward(GroupedLinear)
TE->>Qt: quantize input
TE->>Qt: quantize weight (uses `weight0` alias during setup)
Qt-->>TE: return quant params
TE->>Model: execute grouped linear op
Model->>Seq: forward(Sequential MLP)
Seq->>Qt: input & weight quantize (standard path)
Seq->>Model: sync amax across expert-parallel groups when calibrating
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #403 +/- ##
=======================================
Coverage 73.37% 73.38%
=======================================
Files 180 180
Lines 17925 17942 +17
=======================================
+ Hits 13152 13166 +14
- Misses 4773 4776 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
22bfe0e to
1c821d8
Compare
e2858f9 to
4d7dbce
Compare
| ) | ||
|
|
||
|
|
||
| def _test_expert_model_parallel_amax_sync( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not need register_custom_post_calibration_plugins. Lets not introduce new infrastructure un-necessarily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the point of post_calibration plugins now. Let's keep them as we discussed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change looks good!
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
560-608: Fix tensor handling in expert amax sync checker.The checker still calls
.item()onmodule.amaxand feeds tensor objects to Python’smax/min. Per-channel quantizers surface multi-element tensors, so.item()raises, and even if it didn’t,max(values)on tensors triggers runtime errors. This is exactly what earlier feedback flagged, and the current code will break in the expert-parallel configurations this PR adds. Please keep amax tensors intact (clone/detach, move to CPU), compare them with torch ops, and derive a scalar diff for the return path.- amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax - expert_amax_values[name] = amax_val + expert_amax_values[name] = module.amax.detach().clone().cpu() @@ - if ( - quantizer_type in expert_quantizers - and rank_idx in expert_quantizers[quantizer_type] - ): - # compare expert value across expert for sequential MoE - assert expert_quantizers[quantizer_type][rank_idx] == amax_val, ( - f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: " - f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}" - ) - expert_quantizers[quantizer_type][rank_idx] = amax_val + existing = expert_quantizers[quantizer_type].get(rank_idx) + if existing is not None and not torch.allclose(existing, amax_val, rtol=1e-6, atol=1e-6): + return False, quantizer_type, {rank_idx: {"expected": existing, "actual": amax_val}} + expert_quantizers[quantizer_type][rank_idx] = amax_val @@ - values = list(rank_values.values()) - max_diff = max(values) - min(values) - if max_diff > 1e-6: # Allow for small floating point differences + values = torch.stack([v.flatten() for v in rank_values.values()]) + diff = (values.max(dim=0).values - values.min(dim=0).values).max().item() + if diff > 1e-6: # Allow for small floating point differences return False, quantizer_type, rank_valuestests/gpu/torch/quantization/plugins/test_megatron.py (2)
541-564: Cap spawned world size to the TP×EP topology.This test always spawns
torch.cuda.device_count()ranks even though the geometry is hard-coded totp_size=2andep_size=2. On a 6‑ or 8‑GPU host, Megatron init asserts becausedata_parallel_size = world_size / tp_sizeis no longer divisible byep_size. Please skip when the host has fewer than four GPUs and otherwise launch exactlytp_size * ep_sizeranks so the process groups line up.- size = torch.cuda.device_count() + required_world_size = moe_config["tp_size"] * moe_config["ep_size"] + available = torch.cuda.device_count() + if available < required_world_size: + pytest.skip(f"Need {required_world_size} GPUs, found {available}") @@ - spawn_multiprocess_job( - size=size, + spawn_multiprocess_job( + size=required_world_size, job=partial( _test_sharded_state_dict,
710-729: Align expert-sync spawn size with TP×EP requirements.
test_expert_parallel_synclaunchestorch.cuda.device_count()ranks while the worker fixestp_size=2. Whenever the host has more GPUs than the minimal topology (e.g., 6 GPUs withep_size=2),initialize_model_paralleltrips ondata_parallel_size % ep_size != 0. Compute the largest usable multiple oftp_size * ep_size(per parameter set), skip if none fits, and pass that tospawn_multiprocess_job.- size = torch.cuda.device_count() - if size < ep_size * etp_size: - pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") - - spawn_multiprocess_job( - size=size, + available = torch.cuda.device_count() + required = 2 * ep_size # tp_size is fixed to 2 in the worker + if available < required: + pytest.skip(f"Requires at least {required} GPUs for expert model parallel test") + usable = (available // required) * required + if usable == 0: + pytest.skip(f"No usable world size for tp={2}, ep={ep_size} on {available} GPUs") + + spawn_multiprocess_job( + size=usable, job=partial( _test_expert_model_parallel_amax_sync, 2,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/megatron.py(4 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(10 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
_ParallelLinear(83-181)_setup(114-122)modelopt_post_restore(124-181)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)modelopt/torch/utils/distributed.py (1)
ParallelState(232-257)modelopt/torch/opt/plugins/megatron.py (1)
_MegatronMLP(120-142)
tests/gpu/torch/quantization/plugins/test_megatron.py (5)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
compare_amax_sync_across_expert_parallel(543-609)copy_weights_from_grouped_to_non_grouped(511-540)get_mcore_gpt_model(147-242)initialize_for_megatron(425-444)tests/gpu/torch/conftest.py (1)
need_4_gpus(44-46)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)modelopt/torch/quantization/model_calib.py (1)
max_calibrate(62-177)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
565-566:.item()and scalar aggregation will fail for per-channel quantizers—keep tensors throughout.Line 565 calls
.item()on potentially multi-element amax tensors (e.g., per-channel quantizers), which raises at runtime. Even if it passed, lines 605–606 would fail because Python'smax()/min()can't operate on a list of tensors. The past review claimed this was resolved in 2df77b1, but the code still has the same issue. Store amax as tensors (detached, on CPU), then compute differences via tensor ops.Apply this diff:
# Check for both TEGrouped and sequential MoE patterns if "local_experts" in name or ("experts" in name and "linear_fc" in name): - amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax - expert_amax_values[name] = amax_val + expert_amax_values[name] = module.amax.detach().clone()Then fix the aggregation at lines 602–608:
# Check synchronization - fail fast on first inconsistency for quantizer_type, rank_values in expert_quantizers.items(): if len(rank_values) > 1: # Only check if we have multiple ranks - values = list(rank_values.values()) - max_diff = max(values) - min(values) - if max_diff > 1e-6: # Allow for small floating point differences + values = [v.detach().cpu().flatten() for v in rank_values.values()] + stacked = torch.stack(values) + max_diff = (stacked.max(dim=0).values - stacked.min(dim=0).values).max().item() + if max_diff > 1e-6: return False, quantizer_type, rank_values
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)
263-276: Passexpert_model_parallel_grouptoParallelStatefor consistency with MOE classes.Lines 272–275 create a
ParallelStatewithout passingexpert_model_parallel_group, but the MOE-specific classes (_MegatronTEGroupedMLP,_MegatronSequentialMLP) do pass it (lines 614–616, 628–631). For consistency and to ensure expert-parallel quantizer synchronization works in all paths, also retrieve and passmcore_parallel.get_expert_model_parallel_group()here.Apply this diff:
self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/plugins/megatron.py(4 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
modelopt/torch/utils/distributed.py (2)
size(61-65)world_size(204-206)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
_ParallelLinear(83-181)_setup(114-122)modelopt_post_restore(124-181)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)modelopt/torch/utils/distributed.py (1)
ParallelState(232-257)modelopt/torch/opt/plugins/megatron.py (1)
_MegatronMLP(120-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/megatron.py (3)
51-81: LGTM—amax sync correctly uses tensor ops across experts.The implementation properly uses
torch.maximum()(line 74) to aggregate amax tensors and assigns the shared value back to all experts (line 81), avoiding the scalar/tensor issues. Well done!
611-621: LGTM—parallel state correctly initialized and propagated for TE grouped MLP.The implementation properly sets up
expert_model_parallel_group(line 616) and propagatesparallel_statetolinear_fc1andlinear_fc2submodules (lines 619–620), addressing the past review feedback.
625-637: LGTM—parallel state correctly initialized and propagated for sequential MLP.The implementation mirrors the TE grouped approach, correctly setting up expert parallel groups (line 631) and propagating
parallel_stateto alllocal_expertssubmodules (lines 635–637).
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
4919b08 to
23daf38
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)
540-564: Cap spawn size to the required topology dimensions.Spawning
torch.cuda.device_count()ranks (line 541) violates the MOE topology constraints when the device count exceedstp_size * ep_size = 4. For instance, a 6-GPU host will attempt world size 6 with TP=2, yielding DP=3, which fails Megatron's requirement that DP is divisible by EP=2. Theneed_4_gpusfixture only ensures a minimum, not an exact match.Apply this diff to restrict the spawn size:
-def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm): - size = torch.cuda.device_count() +def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm): + device_count = torch.cuda.device_count() + required_size = 4 # tp_size * ep_size = 2 * 2 + if device_count < required_size: + pytest.skip(f"Requires exactly {required_size} GPUs, found {device_count}") + size = required_size # TODO: Add support for compress=True for TEGroupedMLP
699-718: Enforce valid world size for MOE topology.The skip condition
size < ep_size * etp_sizeis insufficient. With TP=2 (hardcoded at line 711), EP=2, and device_count=6, the test spawns 6 ranks, yielding DP=3, which violates Megatron'sdata_parallel_size % ep_size == 0constraint. You must ensure the world size is an exact multiple oftp_size * ep_size.Apply this diff to fix the topology check:
def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): """Test expert model parallel synchronization.""" - size = torch.cuda.device_count() - if size < ep_size * etp_size: - pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") + device_count = torch.cuda.device_count() + tp_size = 2 # hardcoded in the partial call below + required_size = tp_size * ep_size + if device_count < required_size: + pytest.skip(f"Requires at least {required_size} GPUs (TP={tp_size}, EP={ep_size})") + # Use the largest valid multiple that doesn't exceed device_count + size = (device_count // required_size) * required_size spawn_multiprocess_job( size=size,tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
563-609: Critical: Multi-element amax tensors still break this checker.Despite the past review comment indicating this was resolved, the code at line 566 still calls
module.amax.item(), which raisesRuntimeErrorfor per-channel quantizers that produce multi-element tensors. Additionally, lines 607–609 use Python'smax()andmin()on a list that could contain tensors (when.item()is skipped via theelsebranch), causing aTypeError.This breaks synchronization checks for exactly the expert-parallel cases this utility is meant to cover.
Apply this diff to handle multi-element tensors correctly:
if "local_experts" in name or ("experts" in name and "linear_fc" in name): - amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax - expert_amax_values[name] = amax_val + expert_amax_values[name] = module.amax.detach().clone().cpu() ... for quantizer_type, rank_values in expert_quantizers.items(): if len(rank_values) > 1: # Only check if we have multiple ranks - values = list(rank_values.values()) - max_diff = max(values) - min(values) - if max_diff > 1e-6: # Allow for small floating point differences + values = [v.flatten() for v in rank_values.values()] + stacked = torch.stack(values) + max_diff = (stacked.max(dim=0).values - stacked.min(dim=0).values).max().item() + if max_diff > 1e-6: # Allow for small floating point differences return False, quantizer_type, rank_values
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/megatron.py(4 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(10 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
modelopt/torch/utils/distributed.py (2)
size(61-65)world_size(204-206)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
amax(236-241)amax(244-255)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
compare_amax_sync_across_expert_parallel(544-611)copy_weights_from_grouped_to_non_grouped(511-541)get_mcore_gpt_model(147-242)initialize_for_megatron(425-444)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/quantization/model_calib.py (1)
max_calibrate(62-177)modelopt/torch/quantization/plugins/megatron.py (1)
sync_amax_across_sequential_mlp(51-81)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
_ParallelLinear(83-181)_setup(114-122)modelopt_post_restore(124-181)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)modelopt/torch/trace/symbols.py (3)
named_modules(444-447)register(289-324)items(434-437)modelopt/torch/utils/distributed.py (1)
ParallelState(232-257)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
511-542: Weight and bias mapping looks correct.The template correctly omits the hardcoded
.weightsuffix, andparam_nameis properly extracted and appended to the sequential key. This addresses the past review concern about bias parameters mapping incorrectly.modelopt/torch/quantization/plugins/megatron.py (1)
51-84: LGTM! Amax synchronization correctly uses tensor operations.The function properly:
- Uses
torch.maximum()for element-wise maximum across experts (lines 72–75)- Assigns the synchronized amax back to each expert with proper device placement (lines 78–81)
- Handles both collection and distribution phases cleanly
This addresses the past review concern about using Python's
max()instead of tensor operations.
| __all__ = [] | ||
|
|
||
|
|
||
| def sync_amax_across_sequential_mlp(model: torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do this only for per-tensor amax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per-channel weight amax: element wise maximum for RowParallel (fc2 - RowParallel Cout dim is shared across experts)
per-channel weight amax for ColumnParallel - No Op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do in the followup MR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: per-tensor amax also works now. I have modified the test case to correctly check that.
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
8bff6b0 to
5481d10
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
642-650: Cap spawn size to 4 ranks for EP×ETP topology.Line 645 spawns all available GPUs, but the test topology (EP=2, ETP=2) requires exactly 4 ranks. On hosts with more than 4 GPUs, the excess ranks break Megatron's parallelism constraints.
Apply this diff:
def test_te_grouped_vs_sequential_quantize(need_4_gpus): """Test that TEGrouped and sequential MoE models produce similar quantized models.""" pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") - size = torch.cuda.device_count() + required_size = 4 # ep_size * etp_size = 2 * 2 + size = required_size spawn_multiprocess_job(
714-737: Cap spawn size to match expert-parallel topology.Line 719 spawns all available GPUs, but the test parametrization requires at most 4 ranks (max of EP×ETP across test cases is 2×2=4). On hosts with 6+ GPUs, the extra ranks violate Megatron's topology constraints.
Apply this diff:
def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm): """Test expert model parallel synchronization.""" - size = torch.cuda.device_count() + device_count = torch.cuda.device_count() + required_size = ep_size * etp_size - if size < ep_size * etp_size: - pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") + if device_count < required_size: + pytest.skip(f"Requires at least {required_size} GPUs for expert model parallel test") + size = required_size if moe_grouped_gemm:
542-569: Cap spawn size to match EP×TP topology.Line 545 spawns
torch.cuda.device_count()ranks, but the test requires exactly 4 ranks (TP=2, EP=2). On a host with 6+ GPUs, the extra ranks violate Megatron's topology constraints (e.g., data-parallel size must be divisible by EP).Apply this diff:
def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm): if moe_grouped_gemm: pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") - size = torch.cuda.device_count() + required_size = 4 # tp_size * ep_size = 2 * 2 + size = required_size # TODO: Add support for compress=True for TEGroupedMLP
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/megatron.py(4 hunks)tests/_test_utils/torch_dist/plugins/megatron_common.py(10 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
modelopt/torch/utils/distributed.py (2)
size(61-65)world_size(204-206)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
TensorQuantizer(65-1111)amax(236-241)amax(244-255)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
compare_amax_sync_across_expert_parallel(552-663)copy_weights_from_grouped_to_non_grouped(519-549)get_mcore_gpt_model(149-250)initialize_for_megatron(433-452)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/quantization/model_calib.py (1)
max_calibrate(62-181)modelopt/torch/quantization/plugins/megatron.py (1)
sync_moe_local_experts_amax(508-533)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (10)
tests/_test_utils/torch_dist/plugins/megatron_common.py (5)
16-17: LGTM!The new imports (
reanddefaultdict) are used in the helper functions added later in the file.
43-46: LGTM!The expert-parallel imports are correctly added and used in
compare_amax_sync_across_expert_parallel.
152-246: LGTM!The MOE and expert-parallel parameters are correctly propagated through
get_mcore_gpt_modeland into the model configuration and initialization flow.
438-451: LGTM!The expert-parallel parameters are correctly added to
initialize_for_megatronand propagated toinitialize_model_parallel.
519-550: LGTM!The weight-copying logic correctly handles both
weightandbiasparameters, and the past review concern about the hardcoded.weightin the template has been properly addressed.tests/gpu/torch/quantization/plugins/test_megatron.py (5)
24-25: LGTM!The new helper imports from
megatron_commonare correctly added and used in the MOE tests below.
47-48: LGTM!The Megatron MOE imports are correctly added for use in the new MOE-focused tests.
234-282: LGTM!The MOE-related parameters are correctly added to
_gpt_model_providerand properly propagated to the model construction calls.
285-357: LGTM!The
moe_configparameter is correctly integrated into_test_sharded_state_dict, and MOE parameters are properly extracted and propagated to model initialization and provider calls.
572-640: LGTM!The test helper correctly creates both TEGrouped and Sequential MOE models, copies weights between them, and validates output equivalence before and after quantization.
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/model_calib.py(3 hunks)modelopt/torch/quantization/plugins/transformer_engine.py(2 hunks)modelopt/torch/quantization/utils.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.
Applied to files:
modelopt/torch/quantization/utils.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_calib.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
SequentialQuantizer(1114-1222)sync_amax_across_distributed_group(1071-1083)TensorQuantizer(65-1111)modelopt/torch/quantization/plugins/megatron.py (1)
sync_moe_local_experts_amax(508-533)
modelopt/torch/quantization/plugins/transformer_engine.py (1)
modelopt/torch/quantization/plugins/custom.py (1)
_ParallelLinear(76-174)
| if hasattr(module, "sync_moe_local_experts_amax"): | ||
| module.sync_moe_local_experts_amax() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard MOE expert sync behind an initialized process group
max_calibrate is invoked in single-process flows. The new call into module.sync_moe_local_experts_amax() executes torch.distributed.barrier() unconditionally, so on a non-initialized default group this now throws RuntimeError: Default process group has not been initialized. Please gate this loop on dist.is_available() / dist.is_initialized() (or make the callee accept a group handle) so single-process calibration keeps working.
- for name, module in model.named_modules():
- if hasattr(module, "sync_moe_local_experts_amax"):
- module.sync_moe_local_experts_amax()
+ if dist.is_available() and dist.is_initialized():
+ for name, module in model.named_modules():
+ if hasattr(module, "sync_moe_local_experts_amax"):
+ module.sync_moe_local_experts_amax()🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 180-182, the call to
module.sync_moe_local_experts_amax() triggers a torch.distributed.barrier()
without checking if the default process group is initialized, causing errors in
single-process runs. Fix this by wrapping the call with a guard that checks if
torch.distributed.is_available() and torch.distributed.is_initialized() return
True before invoking the method, ensuring it only runs when the distributed
backend is properly set up.
| def _setup(self): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" | ||
| self.weight = self.weight0 | ||
| # Memorize the original weight.dtype for modelopt_post_restore given that | ||
| # the dtype can change later. | ||
| super()._setup() | ||
| # Remove self.weight after setup. | ||
| delattr(self, "weight") | ||
|
|
||
| def modelopt_post_restore(self, prefix: str = ""): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" | ||
| self.weight = self.weight0 | ||
| super().modelopt_post_restore(prefix=prefix) | ||
| # Remove self.weight after post_restore. | ||
| delattr(self, "weight") | ||
|
|
||
| @staticmethod | ||
| def te_grouped_quantized_linear_fn(package, func_name, self, *args): | ||
| idx = 1 if func_name == "_forward" else 0 | ||
| inp = args[idx] | ||
| num_gemms = len(args[idx + 1]) | ||
| weights_and_biases = args[-2 * num_gemms :] | ||
| weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] | ||
| quantized_inputs = self.input_quantizer(inp) | ||
| quantized_weights = [self.weight_quantizer(weight) for weight in weights] | ||
|
|
||
| output = getattr(package, func_name)( | ||
| *( | ||
| args[0], | ||
| quantized_inputs, | ||
| ) | ||
| if func_name == "_forward" | ||
| else (quantized_inputs,), | ||
| *args[idx + 1 : -2 * num_gemms], | ||
| *quantized_weights, | ||
| *biases, | ||
| ) | ||
| return self.output_quantizer(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expose a stable .weight view for grouped TE layers
With is_quantized_linear() now recognizing modules that only provide weight0, helpers such as smoothquant, disable_pre_quant_scale_and_resmooth, etc. immediately access module.weight. Because _QuantTEGroupedLinear deletes that alias after setup, those helpers now hit AttributeError and break the quantization flows for grouped TE models. Please keep a .weight view backed by weight0 (without registering a duplicate parameter) so the existing utilities continue to function.
def _setup(self):
- # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
- # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
- # self.weight0 to self.weight to run the quantizer states initialization.
- assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
- self.weight = self.weight0
+ # GroupedMLP stores the weights as weight0, weight1, etc. Use weight0 to drive quantizer setup.
+ assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+ self.weight = self.weight0
# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
super()._setup()
- # Remove self.weight after setup.
- delattr(self, "weight")
+ # Setter below is a no-op so we do not register a duplicate Parameter named "weight".
@@
def modelopt_post_restore(self, prefix: str = ""):
- # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
- # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
- # self.weight0 to self.weight to run the quantizer states initialization.
- assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
- self.weight = self.weight0
+ # GroupedMLP stores the weights as weight0, weight1, etc. Reuse weight0 to drive post_restore.
+ assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+ self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
- # Remove self.weight after post_restore.
- delattr(self, "weight")
+ # Setter below keeps weight0 as the canonical tensor.
+
+ @property
+ def weight(self):
+ return self.weight0
+
+ @weight.setter
+ def weight(self, value):
+ if value is not self.weight0:
+ raise ValueError("TEGroupedLinear expects weight0 to back the canonical weight parameter.")🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/transformer_engine.py around lines 72 to
115, the current implementation temporarily assigns self.weight to self.weight0
during setup and post_restore, then deletes self.weight afterward. This deletion
causes AttributeError in utilities that expect a stable .weight attribute. To
fix this, keep a persistent .weight property backed by self.weight0 without
deleting it so that .weight remains accessible, ensuring compatibility with
helpers relying on this attribute.
What does this PR do?
new feature
- Added support for quantizing TEGroupedMLP for Megatron-LM
- Added support for synchronize amax across experts for SequentialMLP
- Added support to synchronize amax across expert_model_parallel
Usage
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Tests