Skip to content

Conversation

@kinjalpatel27
Copy link
Contributor

@kinjalpatel27 kinjalpatel27 commented Oct 7, 2025

What does this PR do?

new feature
- Added support for quantizing TEGroupedMLP for Megatron-LM
- Added support for synchronize amax across experts for SequentialMLP
- Added support to synchronize amax across expert_model_parallel

Usage

pytest gpu/torch/quantization/plugins/test_megatron.py -k test_expert_parallel_sync
pytest gpu/torch/quantization/plugins/test_megatron.py -k test_expert_parallel_sync_with_tp
pytest gpu/torch/quantization/plugins/test_megatron.py -k test_te_grouped_vs_sequential_quantize
pytest gpu/torch/quantization/plugins/test_megatron.py -k test_moe_sharded_state_dict

Testing

  • Added tests for EP, ETP for amax sync
  • Added tests to compare outputs between SequentialMLP model TEGroupedMLP model, before and after quantization
  • Added tests for sharded state dict store and restore

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: -
  • Did you update Changelog?: Not yet

Additional Information

Summary by CodeRabbit

  • New Features

    • MoE quantization with cross-group amax synchronization across data-, tensor-, and expert-parallel groups
    • Transformer Engine (TE) GroupedLinear quantization and TE-aware quantization paths
    • Recognition of alternative weight layouts so more linear layers are detected as quantized
  • Tests

    • MoE-focused quantization tests, amax synchronization checks, utilities for grouped-vs-sequential comparisons, and a new 4-GPU test fixture

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 7, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 7, 2025

Walkthrough

The changes add expert-model-parallel (EP) and MOE support to quantization: extend amax synchronization across DP+EP+TP and MOE local experts, integrate Transformer Engine GroupedLinear quantization, add Megatron MOE/TE-aware quantizers and registrations, extend distributed ParallelState with expert group, and expand tests/fixtures for MOE multi-GPU validation.

Changes

Cohort / File(s) Summary
Distributed infrastructure updates
modelopt/torch/utils/distributed.py
Add expert_model_parallel_group parameter to ParallelState, initialize self.expert_model_parallel_group via DistributedProcessGroup, and include it in __repr__.
Quantization calibration & utils
modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/utils.py
Rename/extend DP sync to sync_quantizer_amax_across_dp_ep to sync amax across data-parallel and expert-model-parallel groups; add post-TP pass calling module.sync_moe_local_experts_amax() when present; update is_quantized_linear() to accept weight or weight0.
Megatron plugin: MOE and TE integration
modelopt/torch/quantization/plugins/megatron.py
Make _setup lazy for parallel_state; add MOE-aware _MegatronSequentialMLP; add TE-aware quantized classes and registrations (_QuantMegatronTEGroupedLinear, _MegatronTEGroupedColumnParallelLinear, _MegatronTEGroupedRowParallelLinear, _MegatronTEGroupedMLP); propagate parallel_state to local experts and filter TE grouped-linear extra state on load.
Transformer Engine plugin
modelopt/torch/quantization/plugins/transformer_engine.py
Register TE GroupedLinear handler; add _QuantTEGroupedLinear that aliases weight0 to weight during setup/post-restore and provides a TE-aware quantized forward/apply function handling TE argument/layout differences.
Test utilities (Megatron)
tests/_test_utils/torch_dist/plugins/megatron_common.py
Extend helpers with EP/ETP parameters, num_moe_experts, moe_grouped_gemm, and use_te; add copy_weights_from_grouped_to_non_grouped and compare_amax_sync_across_expert_parallel; propagate expert-parallel sizes into model initialization.
Test fixtures
tests/gpu/torch/conftest.py
Add need_4_gpus pytest fixture (skips when CUDA device count < 4).
Megatron quantization tests
tests/gpu/torch/quantization/plugins/test_megatron.py
Expand _gpt_model_provider and _test_sharded_state_dict signatures for MOE/TE; add MOE-focused tests (MOE sharded state-dict, TE-grouped vs sequential quantize, expert-parallel amax sync); import new utilities and MLP types.

Sequence Diagram(s)

sequenceDiagram
    participant Calib as Calibration Flow
    participant DP as DataParallel
    participant EP as ExpertParallel
    participant TP as TensorParallel
    participant MOE as MOE Local Experts

    Calib->>DP: sync_quantizer_amax_across_dp_ep()
    DP->>DP: collect & reduce amax across DP ranks
    DP->>EP: synchronize amax across expert_model_parallel_group
    Calib->>TP: sync_quantizer_amax_across_tp() (tensor parallel)
    Calib->>MOE: for each module with sync_moe_local_experts_amax -> call it
    MOE->>MOE: sync local-expert amax across EP/ETP ranks
    Note over Calib: Amax synchronized across DP, EP, TP, and MOE local experts
Loading
sequenceDiagram
    participant Model as Quantized MOE Model
    participant TE as TransformerEngine GroupedLinear
    participant Seq as Sequential MLP
    participant Qt as Quantizers

    Model->>TE: forward(GroupedLinear)
    TE->>Qt: quantize input
    TE->>Qt: quantize weight (uses `weight0` alias during setup)
    Qt-->>TE: return quant params
    TE->>Model: execute grouped linear op
    Model->>Seq: forward(Sequential MLP)
    Seq->>Qt: input & weight quantize (standard path)
    Seq->>Model: sync amax across expert-parallel groups when calibrating
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰
I hopped through syncs of DP and EP,
counted amax under moonlit spree,
grouped weights whispered via TE,
experts aligned in tidy rows,
quant dreams bloom where the rabbit goes. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "Added support for quantizing TEGroupedMLP for megatron-lm" accurately identifies and focuses on the primary deliverable of this changeset. While the PR introduces several supporting infrastructure changes—including expert parallel group synchronization in distributed.py, amax synchronization logic in model_calib.py, and various MOE-related classes in megatron.py—the core objective stated in the PR summary is to enable quantization of TEGroupedMLP models for Megatron-LM. The title clearly conveys this main outcome and is specific enough that a developer reviewing the commit history would understand the key enablement being added. The supporting changes (expert parallel synchronization, grouped linear support) serve as enabling mechanisms rather than the primary deliverable.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kinjal/grouped_linear

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 7, 2025

Codecov Report

❌ Patch coverage is 87.50000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 73.38%. Comparing base (99c44d3) to head (ca55348).
⚠️ Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #403   +/-   ##
=======================================
  Coverage   73.37%   73.38%           
=======================================
  Files         180      180           
  Lines       17925    17942   +17     
=======================================
+ Hits        13152    13166   +14     
- Misses       4773     4776    +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/grouped_linear branch 2 times, most recently from 22bfe0e to 1c821d8 Compare October 8, 2025 23:54
@realAsma realAsma changed the base branch from main to jennifchen/cp_amax_sync October 9, 2025 15:53
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/grouped_linear branch from e2858f9 to 4d7dbce Compare October 9, 2025 16:49
)


def _test_expert_model_parallel_amax_sync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need register_custom_post_calibration_plugins. Lets not introduce new infrastructure un-necessarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point of post_calibration plugins now. Let's keep them as we discussed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change looks good!

Base automatically changed from jennifchen/cp_amax_sync to main October 10, 2025 23:16
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

560-608: Fix tensor handling in expert amax sync checker.

The checker still calls .item() on module.amax and feeds tensor objects to Python’s max/min. Per-channel quantizers surface multi-element tensors, so .item() raises, and even if it didn’t, max(values) on tensors triggers runtime errors. This is exactly what earlier feedback flagged, and the current code will break in the expert-parallel configurations this PR adds. Please keep amax tensors intact (clone/detach, move to CPU), compare them with torch ops, and derive a scalar diff for the return path.

-                amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax
-                expert_amax_values[name] = amax_val
+                expert_amax_values[name] = module.amax.detach().clone().cpu()
@@
-            if (
-                quantizer_type in expert_quantizers
-                and rank_idx in expert_quantizers[quantizer_type]
-            ):
-                # compare expert value across expert for sequential MoE
-                assert expert_quantizers[quantizer_type][rank_idx] == amax_val, (
-                    f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: "
-                    f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}"
-                )
-            expert_quantizers[quantizer_type][rank_idx] = amax_val
+            existing = expert_quantizers[quantizer_type].get(rank_idx)
+            if existing is not None and not torch.allclose(existing, amax_val, rtol=1e-6, atol=1e-6):
+                return False, quantizer_type, {rank_idx: {"expected": existing, "actual": amax_val}}
+            expert_quantizers[quantizer_type][rank_idx] = amax_val
@@
-            values = list(rank_values.values())
-            max_diff = max(values) - min(values)
-            if max_diff > 1e-6:  # Allow for small floating point differences
+            values = torch.stack([v.flatten() for v in rank_values.values()])
+            diff = (values.max(dim=0).values - values.min(dim=0).values).max().item()
+            if diff > 1e-6:  # Allow for small floating point differences
                 return False, quantizer_type, rank_values
tests/gpu/torch/quantization/plugins/test_megatron.py (2)

541-564: Cap spawned world size to the TP×EP topology.

This test always spawns torch.cuda.device_count() ranks even though the geometry is hard-coded to tp_size=2 and ep_size=2. On a 6‑ or 8‑GPU host, Megatron init asserts because data_parallel_size = world_size / tp_size is no longer divisible by ep_size. Please skip when the host has fewer than four GPUs and otherwise launch exactly tp_size * ep_size ranks so the process groups line up.

-    size = torch.cuda.device_count()
+    required_world_size = moe_config["tp_size"] * moe_config["ep_size"]
+    available = torch.cuda.device_count()
+    if available < required_world_size:
+        pytest.skip(f"Need {required_world_size} GPUs, found {available}")
@@
-    spawn_multiprocess_job(
-        size=size,
+    spawn_multiprocess_job(
+        size=required_world_size,
         job=partial(
             _test_sharded_state_dict,

710-729: Align expert-sync spawn size with TP×EP requirements.

test_expert_parallel_sync launches torch.cuda.device_count() ranks while the worker fixes tp_size=2. Whenever the host has more GPUs than the minimal topology (e.g., 6 GPUs with ep_size=2), initialize_model_parallel trips on data_parallel_size % ep_size != 0. Compute the largest usable multiple of tp_size * ep_size (per parameter set), skip if none fits, and pass that to spawn_multiprocess_job.

-    size = torch.cuda.device_count()
-    if size < ep_size * etp_size:
-        pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
-
-    spawn_multiprocess_job(
-        size=size,
+    available = torch.cuda.device_count()
+    required = 2 * ep_size  # tp_size is fixed to 2 in the worker
+    if available < required:
+        pytest.skip(f"Requires at least {required} GPUs for expert model parallel test")
+    usable = (available // required) * required
+    if usable == 0:
+        pytest.skip(f"No usable world size for tp={2}, ep={ep_size} on {available} GPUs")
+
+    spawn_multiprocess_job(
+        size=usable,
         job=partial(
             _test_expert_model_parallel_amax_sync,
             2,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2df77b1 and 5bc99e0.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/megatron.py (4 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (10 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
  • _ParallelLinear (83-181)
  • _setup (114-122)
  • modelopt_post_restore (124-181)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-257)
modelopt/torch/opt/plugins/megatron.py (1)
  • _MegatronMLP (120-142)
tests/gpu/torch/quantization/plugins/test_megatron.py (5)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
  • compare_amax_sync_across_expert_parallel (543-609)
  • copy_weights_from_grouped_to_non_grouped (511-540)
  • get_mcore_gpt_model (147-242)
  • initialize_for_megatron (425-444)
tests/gpu/torch/conftest.py (1)
  • need_4_gpus (44-46)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (62-177)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

565-566: .item() and scalar aggregation will fail for per-channel quantizers—keep tensors throughout.

Line 565 calls .item() on potentially multi-element amax tensors (e.g., per-channel quantizers), which raises at runtime. Even if it passed, lines 605–606 would fail because Python's max()/min() can't operate on a list of tensors. The past review claimed this was resolved in 2df77b1, but the code still has the same issue. Store amax as tensors (detached, on CPU), then compute differences via tensor ops.

Apply this diff:

             # Check for both TEGrouped and sequential MoE patterns
             if "local_experts" in name or ("experts" in name and "linear_fc" in name):
-                amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax
-                expert_amax_values[name] = amax_val
+                expert_amax_values[name] = module.amax.detach().clone()

Then fix the aggregation at lines 602–608:

     # Check synchronization - fail fast on first inconsistency
     for quantizer_type, rank_values in expert_quantizers.items():
         if len(rank_values) > 1:  # Only check if we have multiple ranks
-            values = list(rank_values.values())
-            max_diff = max(values) - min(values)
-            if max_diff > 1e-6:  # Allow for small floating point differences
+            values = [v.detach().cpu().flatten() for v in rank_values.values()]
+            stacked = torch.stack(values)
+            max_diff = (stacked.max(dim=0).values - stacked.min(dim=0).values).max().item()
+            if max_diff > 1e-6:
                 return False, quantizer_type, rank_values
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)

263-276: Pass expert_model_parallel_group to ParallelState for consistency with MOE classes.

Lines 272–275 create a ParallelState without passing expert_model_parallel_group, but the MOE-specific classes (_MegatronTEGroupedMLP, _MegatronSequentialMLP) do pass it (lines 614–616, 628–631). For consistency and to ensure expert-parallel quantizer synchronization works in all paths, also retrieve and pass mcore_parallel.get_expert_model_parallel_group() here.

Apply this diff:

         self.parallel_state = ParallelState(
             data_parallel_group,
             mcore_parallel.get_tensor_model_parallel_group(),
+            mcore_parallel.get_expert_model_parallel_group(),
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5bc99e0 and 4919b08.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/megatron.py (4 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
modelopt/torch/utils/distributed.py (2)
  • size (61-65)
  • world_size (204-206)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
  • _ParallelLinear (83-181)
  • _setup (114-122)
  • modelopt_post_restore (124-181)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-257)
modelopt/torch/opt/plugins/megatron.py (1)
  • _MegatronMLP (120-142)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/megatron.py (3)

51-81: LGTM—amax sync correctly uses tensor ops across experts.

The implementation properly uses torch.maximum() (line 74) to aggregate amax tensors and assigns the shared value back to all experts (line 81), avoiding the scalar/tensor issues. Well done!


611-621: LGTM—parallel state correctly initialized and propagated for TE grouped MLP.

The implementation properly sets up expert_model_parallel_group (line 616) and propagates parallel_state to linear_fc1 and linear_fc2 submodules (lines 619–620), addressing the past review feedback.


625-637: LGTM—parallel state correctly initialized and propagated for sequential MLP.

The implementation mirrors the TE grouped approach, correctly setting up expert parallel groups (line 631) and propagating parallel_state to all local_experts submodules (lines 635–637).

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)

540-564: Cap spawn size to the required topology dimensions.

Spawning torch.cuda.device_count() ranks (line 541) violates the MOE topology constraints when the device count exceeds tp_size * ep_size = 4. For instance, a 6-GPU host will attempt world size 6 with TP=2, yielding DP=3, which fails Megatron's requirement that DP is divisible by EP=2. The need_4_gpus fixture only ensures a minimum, not an exact match.

Apply this diff to restrict the spawn size:

-def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
-    size = torch.cuda.device_count()
+def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm):
+    device_count = torch.cuda.device_count()
+    required_size = 4  # tp_size * ep_size = 2 * 2
+    if device_count < required_size:
+        pytest.skip(f"Requires exactly {required_size} GPUs, found {device_count}")
+    size = required_size
     # TODO: Add support for compress=True for TEGroupedMLP

699-718: Enforce valid world size for MOE topology.

The skip condition size < ep_size * etp_size is insufficient. With TP=2 (hardcoded at line 711), EP=2, and device_count=6, the test spawns 6 ranks, yielding DP=3, which violates Megatron's data_parallel_size % ep_size == 0 constraint. You must ensure the world size is an exact multiple of tp_size * ep_size.

Apply this diff to fix the topology check:

 def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
     """Test expert model parallel synchronization."""
-    size = torch.cuda.device_count()
-    if size < ep_size * etp_size:
-        pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
+    device_count = torch.cuda.device_count()
+    tp_size = 2  # hardcoded in the partial call below
+    required_size = tp_size * ep_size
+    if device_count < required_size:
+        pytest.skip(f"Requires at least {required_size} GPUs (TP={tp_size}, EP={ep_size})")
+    # Use the largest valid multiple that doesn't exceed device_count
+    size = (device_count // required_size) * required_size
 
     spawn_multiprocess_job(
         size=size,
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

563-609: Critical: Multi-element amax tensors still break this checker.

Despite the past review comment indicating this was resolved, the code at line 566 still calls module.amax.item(), which raises RuntimeError for per-channel quantizers that produce multi-element tensors. Additionally, lines 607–609 use Python's max() and min() on a list that could contain tensors (when .item() is skipped via the else branch), causing a TypeError.

This breaks synchronization checks for exactly the expert-parallel cases this utility is meant to cover.

Apply this diff to handle multi-element tensors correctly:

             if "local_experts" in name or ("experts" in name and "linear_fc" in name):
-                amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax
-                expert_amax_values[name] = amax_val
+                expert_amax_values[name] = module.amax.detach().clone().cpu()
...
     for quantizer_type, rank_values in expert_quantizers.items():
         if len(rank_values) > 1:  # Only check if we have multiple ranks
-            values = list(rank_values.values())
-            max_diff = max(values) - min(values)
-            if max_diff > 1e-6:  # Allow for small floating point differences
+            values = [v.flatten() for v in rank_values.values()]
+            stacked = torch.stack(values)
+            max_diff = (stacked.max(dim=0).values - stacked.min(dim=0).values).max().item()
+            if max_diff > 1e-6:  # Allow for small floating point differences
                 return False, quantizer_type, rank_values
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4919b08 and 23daf38.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/megatron.py (4 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (10 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
modelopt/torch/utils/distributed.py (2)
  • size (61-65)
  • world_size (204-206)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • amax (236-241)
  • amax (244-255)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
  • compare_amax_sync_across_expert_parallel (544-611)
  • copy_weights_from_grouped_to_non_grouped (511-541)
  • get_mcore_gpt_model (147-242)
  • initialize_for_megatron (425-444)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (62-177)
modelopt/torch/quantization/plugins/megatron.py (1)
  • sync_amax_across_sequential_mlp (51-81)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/plugins/custom.py (3)
  • _ParallelLinear (83-181)
  • _setup (114-122)
  • modelopt_post_restore (124-181)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
modelopt/torch/trace/symbols.py (3)
  • named_modules (444-447)
  • register (289-324)
  • items (434-437)
modelopt/torch/utils/distributed.py (1)
  • ParallelState (232-257)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (1)

511-542: Weight and bias mapping looks correct.

The template correctly omits the hardcoded .weight suffix, and param_name is properly extracted and appended to the sequential key. This addresses the past review concern about bias parameters mapping incorrectly.

modelopt/torch/quantization/plugins/megatron.py (1)

51-84: LGTM! Amax synchronization correctly uses tensor operations.

The function properly:

  • Uses torch.maximum() for element-wise maximum across experts (lines 72–75)
  • Assigns the synchronized amax back to each expert with proper device placement (lines 78–81)
  • Handles both collection and distribution phases cleanly

This addresses the past review concern about using Python's max() instead of tensor operations.

@cjluo-nv cjluo-nv requested a review from sugunav14 October 14, 2025 16:45
__all__ = []


def sync_amax_across_sequential_mlp(model: torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this only for per-tensor amax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per-channel weight amax: element wise maximum for RowParallel (fc2 - RowParallel Cout dim is shared across experts)
per-channel weight amax for ColumnParallel - No Op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do in the followup MR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: per-tensor amax also works now. I have modified the test case to correctly check that.

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)

642-650: Cap spawn size to 4 ranks for EP×ETP topology.

Line 645 spawns all available GPUs, but the test topology (EP=2, ETP=2) requires exactly 4 ranks. On hosts with more than 4 GPUs, the excess ranks break Megatron's parallelism constraints.

Apply this diff:

 def test_te_grouped_vs_sequential_quantize(need_4_gpus):
     """Test that TEGrouped and sequential MoE models produce similar quantized models."""
     pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
-    size = torch.cuda.device_count()
+    required_size = 4  # ep_size * etp_size = 2 * 2
+    size = required_size
     spawn_multiprocess_job(

714-737: Cap spawn size to match expert-parallel topology.

Line 719 spawns all available GPUs, but the test parametrization requires at most 4 ranks (max of EP×ETP across test cases is 2×2=4). On hosts with 6+ GPUs, the extra ranks violate Megatron's topology constraints.

Apply this diff:

 def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
     """Test expert model parallel synchronization."""
-    size = torch.cuda.device_count()
+    device_count = torch.cuda.device_count()
+    required_size = ep_size * etp_size
-    if size < ep_size * etp_size:
-        pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
+    if device_count < required_size:
+        pytest.skip(f"Requires at least {required_size} GPUs for expert model parallel test")
+    size = required_size
 
     if moe_grouped_gemm:

542-569: Cap spawn size to match EP×TP topology.

Line 545 spawns torch.cuda.device_count() ranks, but the test requires exactly 4 ranks (TP=2, EP=2). On a host with 6+ GPUs, the extra ranks violate Megatron's topology constraints (e.g., data-parallel size must be divisible by EP).

Apply this diff:

 def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
     if moe_grouped_gemm:
         pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
-    size = torch.cuda.device_count()
+    required_size = 4  # tp_size * ep_size = 2 * 2
+    size = required_size
     # TODO: Add support for compress=True for TEGroupedMLP
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8bff6b0 and 5481d10.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/megatron.py (4 hunks)
  • tests/_test_utils/torch_dist/plugins/megatron_common.py (10 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/megatron.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
modelopt/torch/utils/distributed.py (2)
  • size (61-65)
  • world_size (204-206)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (65-1111)
  • amax (236-241)
  • amax (244-255)
tests/gpu/torch/quantization/plugins/test_megatron.py (4)
tests/_test_utils/torch_dist/plugins/megatron_common.py (4)
  • compare_amax_sync_across_expert_parallel (552-663)
  • copy_weights_from_grouped_to_non_grouped (519-549)
  • get_mcore_gpt_model (149-250)
  • initialize_for_megatron (433-452)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (62-181)
modelopt/torch/quantization/plugins/megatron.py (1)
  • sync_moe_local_experts_amax (508-533)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (10)
tests/_test_utils/torch_dist/plugins/megatron_common.py (5)

16-17: LGTM!

The new imports (re and defaultdict) are used in the helper functions added later in the file.


43-46: LGTM!

The expert-parallel imports are correctly added and used in compare_amax_sync_across_expert_parallel.


152-246: LGTM!

The MOE and expert-parallel parameters are correctly propagated through get_mcore_gpt_model and into the model configuration and initialization flow.


438-451: LGTM!

The expert-parallel parameters are correctly added to initialize_for_megatron and propagated to initialize_model_parallel.


519-550: LGTM!

The weight-copying logic correctly handles both weight and bias parameters, and the past review concern about the hardcoded .weight in the template has been properly addressed.

tests/gpu/torch/quantization/plugins/test_megatron.py (5)

24-25: LGTM!

The new helper imports from megatron_common are correctly added and used in the MOE tests below.


47-48: LGTM!

The Megatron MOE imports are correctly added for use in the new MOE-focused tests.


234-282: LGTM!

The MOE-related parameters are correctly added to _gpt_model_provider and properly propagated to the model construction calls.


285-357: LGTM!

The moe_config parameter is correctly integrated into _test_sharded_state_dict, and MOE parameters are properly extracted and propagated to model initialization and provider calls.


572-640: LGTM!

The test helper correctly creates both TEGrouped and Sequential MOE models, copies weights between them, and validates output equivalence before and after quantization.

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 91837c3 and ca55348.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/model_calib.py (3 hunks)
  • modelopt/torch/quantization/plugins/transformer_engine.py (2 hunks)
  • modelopt/torch/quantization/utils.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.

Applied to files:

  • modelopt/torch/quantization/utils.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_calib.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • SequentialQuantizer (1114-1222)
  • sync_amax_across_distributed_group (1071-1083)
  • TensorQuantizer (65-1111)
modelopt/torch/quantization/plugins/megatron.py (1)
  • sync_moe_local_experts_amax (508-533)
modelopt/torch/quantization/plugins/transformer_engine.py (1)
modelopt/torch/quantization/plugins/custom.py (1)
  • _ParallelLinear (76-174)

Comment on lines +180 to +182
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard MOE expert sync behind an initialized process group

max_calibrate is invoked in single-process flows. The new call into module.sync_moe_local_experts_amax() executes torch.distributed.barrier() unconditionally, so on a non-initialized default group this now throws RuntimeError: Default process group has not been initialized. Please gate this loop on dist.is_available() / dist.is_initialized() (or make the callee accept a group handle) so single-process calibration keeps working.

-    for name, module in model.named_modules():
-        if hasattr(module, "sync_moe_local_experts_amax"):
-            module.sync_moe_local_experts_amax()
+    if dist.is_available() and dist.is_initialized():
+        for name, module in model.named_modules():
+            if hasattr(module, "sync_moe_local_experts_amax"):
+                module.sync_moe_local_experts_amax()
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 180-182, the call to
module.sync_moe_local_experts_amax() triggers a torch.distributed.barrier()
without checking if the default process group is initialized, causing errors in
single-process runs. Fix this by wrapping the call with a guard that checks if
torch.distributed.is_available() and torch.distributed.is_initialized() return
True before invoking the method, ensuring it only runs when the distributed
backend is properly set up.

Comment on lines +72 to +115
def _setup(self):
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
self.weight = self.weight0
# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
super()._setup()
# Remove self.weight after setup.
delattr(self, "weight")

def modelopt_post_restore(self, prefix: str = ""):
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
# Remove self.weight after post_restore.
delattr(self, "weight")

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
idx = 1 if func_name == "_forward" else 0
inp = args[idx]
num_gemms = len(args[idx + 1])
weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
quantized_inputs = self.input_quantizer(inp)
quantized_weights = [self.weight_quantizer(weight) for weight in weights]

output = getattr(package, func_name)(
*(
args[0],
quantized_inputs,
)
if func_name == "_forward"
else (quantized_inputs,),
*args[idx + 1 : -2 * num_gemms],
*quantized_weights,
*biases,
)
return self.output_quantizer(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Expose a stable .weight view for grouped TE layers

With is_quantized_linear() now recognizing modules that only provide weight0, helpers such as smoothquant, disable_pre_quant_scale_and_resmooth, etc. immediately access module.weight. Because _QuantTEGroupedLinear deletes that alias after setup, those helpers now hit AttributeError and break the quantization flows for grouped TE models. Please keep a .weight view backed by weight0 (without registering a duplicate parameter) so the existing utilities continue to function.

     def _setup(self):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Use weight0 to drive quantizer setup.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         # Memorize the original weight.dtype for modelopt_post_restore given that
         # the dtype can change later.
         super()._setup()
-        # Remove self.weight after setup.
-        delattr(self, "weight")
+        # Setter below is a no-op so we do not register a duplicate Parameter named "weight".
@@
     def modelopt_post_restore(self, prefix: str = ""):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Reuse weight0 to drive post_restore.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         super().modelopt_post_restore(prefix=prefix)
-        # Remove self.weight after post_restore.
-        delattr(self, "weight")
+        # Setter below keeps weight0 as the canonical tensor.
+
+    @property
+    def weight(self):
+        return self.weight0
+
+    @weight.setter
+    def weight(self, value):
+        if value is not self.weight0:
+            raise ValueError("TEGroupedLinear expects weight0 to back the canonical weight parameter.")
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/transformer_engine.py around lines 72 to
115, the current implementation temporarily assigns self.weight to self.weight0
during setup and post_restore, then deletes self.weight afterward. This deletion
causes AttributeError in utilities that expect a stable .weight attribute. To
fix this, keep a persistent .weight property backed by self.weight0 without
deleting it so that .weight remains accessible, ensuring compatibility with
helpers relying on this attribute.

@kinjalpatel27 kinjalpatel27 merged commit 6ef9954 into main Oct 17, 2025
27 checks passed
@kinjalpatel27 kinjalpatel27 deleted the kinjal/grouped_linear branch October 17, 2025 21:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants