-
Notifications
You must be signed in to change notification settings - Fork 702
[PyTorch] Recipe heuristics for initializing quantized weights #1827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
586a081
6efac63
6413aaf
cd325cd
9cb709d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1152,25 +1152,39 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: | |||||||||||||||||||
| with get_rng_state_tracker().fork(): | ||||||||||||||||||||
| init_fn(param) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # If primary weights are in fp8, wrap the parameter as FP8Tensor | ||||||||||||||||||||
| # Wrap parameters in QuantizedTensor if needed | ||||||||||||||||||||
| fp8_meta_index = self.param_init_meta[name].fp8_meta_index | ||||||||||||||||||||
| high_precision_init_val = None | ||||||||||||||||||||
| if self.primary_weights_in_fp8 and fp8_meta_index is not None: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Keep high-precision values on CPU if needed | ||||||||||||||||||||
| if self.preserve_high_precision_init_val: | ||||||||||||||||||||
| high_precision_init_val = param.detach().cpu() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Get quantizer | ||||||||||||||||||||
| quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] | ||||||||||||||||||||
| assert ( | ||||||||||||||||||||
| quantizer is not None | ||||||||||||||||||||
| ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. | ||||||||||||||||||||
| if quantizer is None: | ||||||||||||||||||||
| raise RuntimeError("Weight quantizer has not been initialized") | ||||||||||||||||||||
| quantizer.internal = False | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Recipe-specific quantizer configuration | ||||||||||||||||||||
| recipe = self.fp8_meta["recipe"] | ||||||||||||||||||||
| if recipe is not None: | ||||||||||||||||||||
| if recipe.heuristic == "inference": | ||||||||||||||||||||
| # Weight needs column-wise usage for dgrad | ||||||||||||||||||||
| # GEMM, so not needed for inference | ||||||||||||||||||||
| quantizer.set_usage(rowwise=True, columnwise=False) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Quantize parameter | ||||||||||||||||||||
| param = quantizer(param) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Redo parameter wrap in case we broke it above | ||||||||||||||||||||
| # NOTE: Currently this can only be broken when primary weights are in Fp8 but | ||||||||||||||||||||
| # re-applying the nn.Parameter() wrap is a no-op when the input is already | ||||||||||||||||||||
| # a parameter so we always re-apply it just for extra safety. | ||||||||||||||||||||
| param = torch.nn.Parameter(param) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Keep high-precision values on CPU if needed | ||||||||||||||||||||
| if high_precision_init_val is not None: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # - Master weights are initialized from model weights, if we use fp8 primary | ||||||||||||||||||||
|
|
@@ -1214,7 +1228,7 @@ def get_weight_workspace( | |||||||||||||||||||
| fsdp_group: Optional[dist_group_type] = None, | ||||||||||||||||||||
| workspace_dtype: Optional[torch.dtype] = None, | ||||||||||||||||||||
| ) -> QuantizedTensor: | ||||||||||||||||||||
| """Get FP8 workspace buffer and maybe update its values | ||||||||||||||||||||
| """Get workspace buffer for weights and maybe update its values | ||||||||||||||||||||
|
|
||||||||||||||||||||
| The workspace buffer may be cached for future function calls. | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -1238,15 +1252,19 @@ def get_weight_workspace( | |||||||||||||||||||
| workspace_dtype: torch.dtype, default = None | ||||||||||||||||||||
| If weight workspace contains high-precision tensor - for example | ||||||||||||||||||||
| for debug quantization, this is dtype of the tensor. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # FP8 primary weights | ||||||||||||||||||||
| # Handle case where weights are already quantized | ||||||||||||||||||||
| # Note: Make sure weights have required usages, but do not | ||||||||||||||||||||
| # destroy unnecessary usages since they may be used later. | ||||||||||||||||||||
| if isinstance(tensor, QuantizedTensor): | ||||||||||||||||||||
| if update_workspace and quantizer is not None: | ||||||||||||||||||||
| tensor.update_usage( | ||||||||||||||||||||
| rowwise_usage=quantizer.rowwise_usage, | ||||||||||||||||||||
| columnwise_usage=quantizer.columnwise_usage, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| update_rowwise_usage = True if quantizer.rowwise_usage else None | ||||||||||||||||||||
| update_columnwise_usage = True if quantizer.columnwise_usage else None | ||||||||||||||||||||
| tensor.update_usage( | ||||||||||||||||||||
| rowwise_usage=update_rowwise_usage, | ||||||||||||||||||||
| columnwise_usage=update_columnwise_usage, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+1262
to
+1267
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Destroying unnecessary usages was causing problems when alternating between training steps (column-wise data needed) and validation steps (column-wise data not needed). See #1832 (comment).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH this issue is just because of optimizer not doing the right job with quantizing. If we made it so it uses the quantizer then we would not need this part at all.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The layers will configure the quantizer to avoid unnecessary allocations: TransformerEngine/transformer_engine/pytorch/module/linear.py Lines 225 to 233 in 855fa65
This is what we want when allocating new buffers, but is overly aggressive when dealing with an existing QuantizedTensor. We could remove this logic from get_weight_workspace, but I don't like how it would ignore the configuration within the quantizer.
|
||||||||||||||||||||
| return tensor | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Try getting workspace from cache | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -384,6 +384,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: | |
|
|
||
| # Quantize to FP8 | ||
| assert self._quantizer is not None, "Can't quantize without a quantizer" | ||
| self._quantizer.internal = False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes me think that internal should maybe be an option to tex.quantize rather than the member of quantizer.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have mixed opinions.
Maybe |
||
| self.data = self._quantizer.quantize(tensor) | ||
| if self.requires_grad != tensor.requires_grad: | ||
| self.requires_grad_(requires_grad=tensor.requires_grad) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is not the best - wouldn't you want performance during inference?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily if you're memory constrained.
Perhaps a naming scheme like "training_performance", "inference_performance", "training_memory", "inference_memory" would be more precise?