Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9cac53c
sync amax in context parallel and awq act scale
jenchen13 Sep 24, 2025
be5e838
lint
jenchen13 Sep 25, 2025
4a2a8d7
test weight quantizer too
jenchen13 Sep 25, 2025
cacee61
fix test
jenchen13 Sep 26, 2025
41cc9bd
awq test
jenchen13 Sep 29, 2025
4a706ef
move awq test inside megatron tests
jenchen13 Sep 29, 2025
7b2c969
fix amax tests
jenchen13 Sep 30, 2025
e6dc5e5
fix awq lite param
jenchen13 Sep 30, 2025
f17320e
fix test
jenchen13 Sep 30, 2025
a1fdf18
add print
jenchen13 Oct 1, 2025
cd31159
docstring
jenchen13 Oct 1, 2025
5a67acf
fix tests
jenchen13 Oct 2, 2025
9d7dff1
fix multiprocess size
jenchen13 Oct 2, 2025
3bf16e6
Added quantization support for TEGroupedMoE for megatron-lm
kinjalpatel27 Oct 7, 2025
70776c3
code cleanup
kinjalpatel27 Oct 7, 2025
bab9ca2
code and test cleanup
kinjalpatel27 Oct 8, 2025
f9ba6e8
Updated moe names in tests
kinjalpatel27 Oct 9, 2025
a917c2b
updated parallel state for experts
kinjalpatel27 Oct 9, 2025
1ea4ed1
fixed bug for is_quantized_linear check
kinjalpatel27 Oct 9, 2025
169677c
code cleanup and bug fixes
kinjalpatel27 Oct 11, 2025
153e376
rebase bug fixes
kinjalpatel27 Oct 11, 2025
5bc99e0
fixing test and comments
kinjalpatel27 Oct 11, 2025
23daf38
Code cleanup
kinjalpatel27 Oct 13, 2025
15ffb87
Code cleanup and test update
kinjalpatel27 Oct 14, 2025
28c8bbf
remove post calib hook
kinjalpatel27 Oct 14, 2025
5481d10
fixed tests for per-channel support
kinjalpatel27 Oct 16, 2025
91837c3
minor fix
kinjalpatel27 Oct 16, 2025
ca55348
Addressed MR comments
kinjalpatel27 Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions modelopt/torch/quantization/model_calib.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change looks good!

Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
if not distributed_sync:
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel group."""
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# TODO: create sync_bias_across_distributed_group

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand All @@ -117,6 +118,7 @@ def sync_quantizer_amax_across_tp(
# Syncing amax across TP for sequential quantizer
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
# Syncing amax across TP for sequential quantizer
sync_quantizer_amax_across_tp(
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
)
Expand Down Expand Up @@ -174,6 +176,10 @@ def sync_quantizer_amax_across_tp(
parallel_state=module.parallel_state,
)

for name, module in model.named_modules():
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

Comment on lines +180 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard MOE expert sync behind an initialized process group

max_calibrate is invoked in single-process flows. The new call into module.sync_moe_local_experts_amax() executes torch.distributed.barrier() unconditionally, so on a non-initialized default group this now throws RuntimeError: Default process group has not been initialized. Please gate this loop on dist.is_available() / dist.is_initialized() (or make the callee accept a group handle) so single-process calibration keeps working.

-    for name, module in model.named_modules():
-        if hasattr(module, "sync_moe_local_experts_amax"):
-            module.sync_moe_local_experts_amax()
+    if dist.is_available() and dist.is_initialized():
+        for name, module in model.named_modules():
+            if hasattr(module, "sync_moe_local_experts_amax"):
+                module.sync_moe_local_experts_amax()
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 180-182, the call to
module.sync_moe_local_experts_amax() triggers a torch.distributed.barrier()
without checking if the default process group is initialized, causing errors in
single-process runs. Fix this by wrapping the call with a guard that checks if
torch.distributed.is_available() and torch.distributed.is_initialized() return
True before invoking the method, ensuring it only runs when the distributed
backend is properly set up.


def enable_stats_collection(model: nn.Module):
"""Enable stats collection for all quantizers in the model."""
Expand Down
128 changes: 118 additions & 10 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import megatron.core.parallel_state as mcore_parallel
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import megatron.core.transformer.moe.experts as megatron_moe
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
Expand All @@ -40,6 +41,18 @@
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TERowParallelGroupedLinear,
)

from .transformer_engine import _QuantTEGroupedLinear

HAS_TE = True
except ImportError:
HAS_TE = False

logger = logging.getLogger(__name__)

__all__ = []
Expand Down Expand Up @@ -221,16 +234,19 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning("Context parallel group is not initialized, using data parallel group")
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)
if not hasattr(self, "parallel_state") or self.parallel_state is None:
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning(
"Context parallel group is not initialized, using data parallel group"
)
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)
super()._setup()

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
Expand Down Expand Up @@ -472,3 +488,95 @@ class _RealQuantMegatronRowParallelLinear(

def forward(self, input, *args, **kwargs):
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)


@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
class _MegatronSequentialMLP(_MegatronMLP):
def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
self.parallel_state = ParallelState(
mcore_parallel.get_expert_data_parallel_group(),
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
)

# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
for expert in self.local_experts:
expert.linear_fc1.parallel_state = self.parallel_state
expert.linear_fc2.parallel_state = self.parallel_state

def sync_moe_local_experts_amax(self):
"""Sync amax across local experts in a SequentialMLP.

amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate().
This function is called to synchronize the amax values across local experts s.t. all localexperts will
share the same amax.
"""
torch.distributed.barrier()
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
amax_tensor
if stored_amax is None
else torch.maximum(stored_amax, amax_tensor)
)

# Apply synchronized amax values back to all local experts
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
module.amax = amax_dict[name].detach().clone().to(module.amax.device)


if HAS_TE:
# Quantized subclasses to support TEGroupedMLP quantization
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
# for modelopt checkpoint restore
filtered_state_dict = {
k: v
for k, v in state_dict.items()
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
}
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
quantizer_state_dict[k] = v.view(-1)

@QuantModuleRegistry.register(
{TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
)
class _MegatronTEGroupedColumnParallelLinear(
_QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear
):
pass

@QuantModuleRegistry.register(
{TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
)
class _MegatronTEGroupedRowParallelLinear(
_QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear
):
pass

@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
class _MegatronTEGroupedMLP(_MegatronMLP):
def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
self.parallel_state = ParallelState(
mcore_parallel.get_expert_data_parallel_group(),
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
)
# initialize parallel state for submodules linear_fc1 and linear_fc2
self.linear_fc1.parallel_state = self.parallel_state
self.linear_fc2.parallel_state = self.parallel_state
58 changes: 58 additions & 0 deletions modelopt/torch/quantization/plugins/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
import transformer_engine as te
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
import transformer_engine.pytorch.module.linear as te_linear

from ..nn import QuantModuleRegistry
Expand Down Expand Up @@ -58,3 +59,60 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):

# Override the quantized linear function
_quantized_linear_fn = te_quantized_linear_fn


# Register the public te.pytorch.GroupedLinear class
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
class _QuantTEGroupedLinear(_ParallelLinear):
_functionals_to_replace = [
(te_grouped_linear._GroupedLinear, "forward"),
(te_grouped_linear._GroupedLinear, "apply"),
]

def _setup(self):
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
self.weight = self.weight0
# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
super()._setup()
# Remove self.weight after setup.
delattr(self, "weight")

def modelopt_post_restore(self, prefix: str = ""):
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
# self.weight0 to self.weight to run the quantizer states initialization.
assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
# Remove self.weight after post_restore.
delattr(self, "weight")

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
idx = 1 if func_name == "_forward" else 0
inp = args[idx]
num_gemms = len(args[idx + 1])
weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
quantized_inputs = self.input_quantizer(inp)
quantized_weights = [self.weight_quantizer(weight) for weight in weights]

output = getattr(package, func_name)(
*(
args[0],
quantized_inputs,
)
if func_name == "_forward"
else (quantized_inputs,),
*args[idx + 1 : -2 * num_gemms],
*quantized_weights,
*biases,
)
return self.output_quantizer(output)
Comment on lines +72 to +115
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Expose a stable .weight view for grouped TE layers

With is_quantized_linear() now recognizing modules that only provide weight0, helpers such as smoothquant, disable_pre_quant_scale_and_resmooth, etc. immediately access module.weight. Because _QuantTEGroupedLinear deletes that alias after setup, those helpers now hit AttributeError and break the quantization flows for grouped TE models. Please keep a .weight view backed by weight0 (without registering a duplicate parameter) so the existing utilities continue to function.

     def _setup(self):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Use weight0 to drive quantizer setup.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         # Memorize the original weight.dtype for modelopt_post_restore given that
         # the dtype can change later.
         super()._setup()
-        # Remove self.weight after setup.
-        delattr(self, "weight")
+        # Setter below is a no-op so we do not register a duplicate Parameter named "weight".
@@
     def modelopt_post_restore(self, prefix: str = ""):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Reuse weight0 to drive post_restore.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         super().modelopt_post_restore(prefix=prefix)
-        # Remove self.weight after post_restore.
-        delattr(self, "weight")
+        # Setter below keeps weight0 as the canonical tensor.
+
+    @property
+    def weight(self):
+        return self.weight0
+
+    @weight.setter
+    def weight(self, value):
+        if value is not self.weight0:
+            raise ValueError("TEGroupedLinear expects weight0 to back the canonical weight parameter.")
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/transformer_engine.py around lines 72 to
115, the current implementation temporarily assigns self.weight to self.weight0
during setup and post_restore, then deletes self.weight afterward. This deletion
causes AttributeError in utilities that expect a stable .weight attribute. To
fix this, keep a persistent .weight property backed by self.weight0 without
deleting it so that .weight remains accessible, ensuring compatibility with
helpers relying on this attribute.


# Override the quantized linear function
_quantized_linear_fn = te_grouped_quantized_linear_fn
7 changes: 5 additions & 2 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ def is_quantized_linear(module):
isinstance(module, QuantModule)
and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
and hasattr(module, "weight_quantizer")
and getattr(module, "weight", None) is not None
and module.weight.dim() == 2
and (
(getattr(module, "weight", None) is not None and module.weight.dim() == 2)
# module.weight0 check is required to support TEGroupedLinear
or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2)
)
)


Expand Down
6 changes: 5 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,20 @@ def __init__(
self,
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
):
"""Initialize the parallel state."""
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group)

def __repr__(self) -> str:
return (
parallel_groups = (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
f"expert_model_parallel_group: {self.expert_model_parallel_group}"
)
return parallel_groups


def get_group(ranks: list[int]):
Expand Down
Loading