- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
Added support for quantizing TEGroupedMLP for megatron-lm #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9cac53c
              be5e838
              4a2a8d7
              cacee61
              41cc9bd
              4a706ef
              7b2c969
              e6dc5e5
              f17320e
              a1fdf18
              cd31159
              5a67acf
              9d7dff1
              3bf16e6
              70776c3
              bab9ca2
              f9ba6e8
              a917c2b
              1ea4ed1
              169677c
              153e376
              5bc99e0
              23daf38
              15ffb87
              28c8bbf
              5481d10
              91837c3
              ca55348
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis | |
| if not distributed_sync: | ||
| return | ||
| 
     | 
||
| def sync_quantizer_amax_across_dp(quantizer, parallel_state): | ||
| """Synchronize the amax across all ranks in the data parallel group.""" | ||
| def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): | ||
| """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" | ||
| if isinstance(quantizer, SequentialQuantizer): | ||
| for _q in quantizer: | ||
| sync_quantizer_amax_across_dp(_q, parallel_state) | ||
| sync_quantizer_amax_across_dp_ep(_q, parallel_state) | ||
| return | ||
| if getattr(quantizer, "_amax", None) is not None: | ||
| quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) | ||
| quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) | ||
| # TODO: create sync_bias_across_distributed_group | ||
| 
     | 
||
| for name, module in model.named_modules(): | ||
| if isinstance(module, QuantModule): | ||
| for child in module.children(): | ||
| if isinstance(child, (TensorQuantizer, SequentialQuantizer)): | ||
| sync_quantizer_amax_across_dp(child, module.parallel_state) | ||
| sync_quantizer_amax_across_dp_ep(child, module.parallel_state) | ||
| # TP sync: | ||
| # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same | ||
| 
     | 
||
| 
        
          
        
         | 
    @@ -117,6 +118,7 @@ def sync_quantizer_amax_across_tp( | |
| # Syncing amax across TP for sequential quantizer | ||
| if isinstance(quantizer, SequentialQuantizer): | ||
| for _q in quantizer: | ||
| # Syncing amax across TP for sequential quantizer | ||
| sync_quantizer_amax_across_tp( | ||
| _q, linear_name, quantizer_type, axes_for_sync, parallel_state | ||
| ) | ||
| 
          
            
          
           | 
    @@ -174,6 +176,10 @@ def sync_quantizer_amax_across_tp( | |
| parallel_state=module.parallel_state, | ||
| ) | ||
| 
     | 
||
| for name, module in model.named_modules(): | ||
| if hasattr(module, "sync_moe_local_experts_amax"): | ||
| module.sync_moe_local_experts_amax() | ||
| 
     | 
||
| 
         
      Comment on lines
    
      +180
     to 
      +182
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard MOE expert sync behind an initialized process group 
 -    for name, module in model.named_modules():
-        if hasattr(module, "sync_moe_local_experts_amax"):
-            module.sync_moe_local_experts_amax()
+    if dist.is_available() and dist.is_initialized():
+        for name, module in model.named_modules():
+            if hasattr(module, "sync_moe_local_experts_amax"):
+                module.sync_moe_local_experts_amax()🤖 Prompt for AI Agents | 
||
| 
     | 
||
| def enable_stats_collection(model: nn.Module): | ||
| """Enable stats collection for all quantizers in the model.""" | ||
| 
          
            
          
           | 
    ||
                              
      
                  kinjalpatel27 marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
            
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -17,6 +17,7 @@ | |
| 
     | 
||
| import torch | ||
| import transformer_engine as te | ||
| import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear | ||
| import transformer_engine.pytorch.module.linear as te_linear | ||
| 
     | 
||
| from ..nn import QuantModuleRegistry | ||
| 
          
            
          
           | 
    @@ -58,3 +59,60 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): | |
| 
     | 
||
| # Override the quantized linear function | ||
| _quantized_linear_fn = te_quantized_linear_fn | ||
| 
     | 
||
| 
     | 
||
| # Register the public te.pytorch.GroupedLinear class | ||
| @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) | ||
| class _QuantTEGroupedLinear(_ParallelLinear): | ||
| _functionals_to_replace = [ | ||
| (te_grouped_linear._GroupedLinear, "forward"), | ||
| (te_grouped_linear._GroupedLinear, "apply"), | ||
| ] | ||
| 
     | 
||
| def _setup(self): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" | ||
| self.weight = self.weight0 | ||
                
      
                  kinjalpatel27 marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| # Memorize the original weight.dtype for modelopt_post_restore given that | ||
| # the dtype can change later. | ||
| super()._setup() | ||
| # Remove self.weight after setup. | ||
| delattr(self, "weight") | ||
| 
     | 
||
| def modelopt_post_restore(self, prefix: str = ""): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" | ||
| self.weight = self.weight0 | ||
| super().modelopt_post_restore(prefix=prefix) | ||
| # Remove self.weight after post_restore. | ||
| delattr(self, "weight") | ||
| 
     | 
||
| @staticmethod | ||
| def te_grouped_quantized_linear_fn(package, func_name, self, *args): | ||
| idx = 1 if func_name == "_forward" else 0 | ||
| inp = args[idx] | ||
| num_gemms = len(args[idx + 1]) | ||
| weights_and_biases = args[-2 * num_gemms :] | ||
| weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] | ||
| quantized_inputs = self.input_quantizer(inp) | ||
| quantized_weights = [self.weight_quantizer(weight) for weight in weights] | ||
| 
     | 
||
| output = getattr(package, func_name)( | ||
| *( | ||
| args[0], | ||
| quantized_inputs, | ||
| ) | ||
| if func_name == "_forward" | ||
| else (quantized_inputs,), | ||
| *args[idx + 1 : -2 * num_gemms], | ||
| *quantized_weights, | ||
| *biases, | ||
| ) | ||
| return self.output_quantizer(output) | ||
| 
         
      Comment on lines
    
      +72
     to 
      +115
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Expose a stable  With       def _setup(self):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Use weight0 to drive quantizer setup.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         # Memorize the original weight.dtype for modelopt_post_restore given that
         # the dtype can change later.
         super()._setup()
-        # Remove self.weight after setup.
-        delattr(self, "weight")
+        # Setter below is a no-op so we do not register a duplicate Parameter named "weight".
@@
     def modelopt_post_restore(self, prefix: str = ""):
-        # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
-        # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
-        # self.weight0 to self.weight to run the quantizer states initialization.
-        assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
-        self.weight = self.weight0
+        # GroupedMLP stores the weights as weight0, weight1, etc. Reuse weight0 to drive post_restore.
+        assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+        self.weight = self.weight0
         super().modelopt_post_restore(prefix=prefix)
-        # Remove self.weight after post_restore.
-        delattr(self, "weight")
+        # Setter below keeps weight0 as the canonical tensor.
+
+    @property
+    def weight(self):
+        return self.weight0
+
+    @weight.setter
+    def weight(self, value):
+        if value is not self.weight0:
+            raise ValueError("TEGroupedLinear expects weight0 to back the canonical weight parameter.")🤖 Prompt for AI Agents | 
||
| 
     | 
||
| # Override the quantized linear function | ||
| _quantized_linear_fn = te_grouped_quantized_linear_fn | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change looks good!