[PyTorch Debug] Support precision debug tools for fp8 model parameters.#2141
[PyTorch Debug] Support precision debug tools for fp8 model parameters.#2141ptrendx merged 26 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
1 similar comment
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The PR extends TransformerEngine's nvidia-dlfw-inspect debugging tools to support FP8 model parameters (when fp8_primary_weight is enabled). Previously, the debug infrastructure lived entirely inside the quantize() method which wasn't called for FP8 weights stored in quantized format. The solution introduces DebugQuantizer.wrap_quantized_tensor() as a new entry point that wraps pre-quantized weights in debug functionality and makes the tensor parameter optional in the inspect_tensor API (passing None for FP8 weights). When high-precision tensors are unavailable, LogTensorStats dequantizes on-demand while LogFp8TensorStats disables statistics requiring high-precision (underflows%, overflows%, mse). The implementation preserves existing behavior for non-weight tensors (activations, gradients) while adding three test skips removal to validate the feature works end-to-end.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/debug/pytorch/debug_quantization.py | 4/5 | Added wrap_quantized_tensor() method to inject debug logic for pre-quantized FP8 weights with assertions preventing weight modification |
| transformer_engine/debug/features/api.py | 4/5 | Made tensor parameter optional in inspect_tensor API and added backward compatibility logic to filter kwargs for legacy feature implementations |
| transformer_engine/debug/features/log_fp8_tensor_stats.py | 3/5 | Added validation to reject stats requiring high-precision tensors when only quantized weights available, with potential issue in cross-recipe stat path |
| transformer_engine/debug/features/log_tensor_stats.py | 4/5 | Added dequantization path when tensor=None to compute stats from quantized FP8 weights, introduces performance overhead |
| transformer_engine/debug/features/utils/stats_buffer.py | 3/5 | Allows None tensors in stats buffer for FP8 weights, but stat computation functions may not handle None safely |
| transformer_engine/pytorch/module/base.py | 4/5 | Removed FP8 debug blocker and added wrap_quantized_tensor() call, but wrapping only occurs on early-return path causing potential inconsistency |
| tests/pytorch/test_sanity.py | 5/5 | Removed three test skips that blocked FP8 parameters in debug mode, enabling validation of the new feature |
| tests/pytorch/debug/test_log.py | 4/5 | Added integration test for FP8 model parameter debug logging with explicit state reset to prevent test pollution |
Confidence score: 3/5
- This PR requires careful review due to the complexity of threading optional tensor handling through multiple layers of the debug infrastructure
- Score lowered because: (1) the wrapping logic in
base.pyonly applies to one code path (early-return) but not to newly constructed workspaces around lines 1424-1443, potentially causing inconsistent debug behavior, (2) stat computation functions instats_buffer.pymay dereferenceNonetensors without null checks causing runtime errors, (3) the cross-recipe stat path inlog_fp8_tensor_stats.pymight attempt to requantize aNonetensor, and (4) the dequantization inlog_tensor_stats.pyadds undocumented performance overhead - Pay close attention to transformer_engine/debug/features/utils/stats_buffer.py, transformer_engine/debug/features/log_fp8_tensor_stats.py, and transformer_engine/pytorch/module/base.py
Additional Comments (2)
-
transformer_engine/debug/features/utils/stats_buffer.py, line 108-110 (link)logic: When
tensorisNone, every stat functionfn(tensor, aux_dict)receivesNoneas the first argument. Verify that all stat functions inSTATShandleNonetensors correctly, or add explicit guards here to prevent runtime errors. -
transformer_engine/debug/features/log_fp8_tensor_stats.py, line 347 (link)logic: When
tensoris None (FP8 weight case), passing None asoriginal_tensortoupdate_aux_dictmay cause issues in line 249 wherequantizer(original_tensor)is called if different recipes are requested. What happens when a user requests stats for a different recipe (e.g., mxfp8_underflows%) on FP8 weight tensors where tensor=None? Line 249 will fail with NoneType error.
8 files reviewed, 7 comments
|
/te-ci pytorch |
Greptile SummaryThis PR enables debug tools for FP8 model parameters by adding Critical issue found:
Minor issues:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[FP8 Model Parameters] --> B{Debug Mode Enabled?}
B -->|No| C[Standard FP8 Flow]
B -->|Yes| D[Create DebugQuantizer]
D --> E{Weight Already Quantized?}
E -->|Yes| F[wrap_quantized_tensor]
E -->|No| G[Quantize Weight]
G --> H{Is DebugQuantizer?}
H -->|Yes| I[Should wrap but doesn't]
H -->|No| J[Return quantized tensor]
F --> K[DebugQuantizedTensor]
I --> L[Bug: Missing wrap]
K --> M[inspect_tensor with tensor=None]
M --> N{Feature Type?}
N -->|LogTensorStats| O[Dequantize for stats]
N -->|LogFp8TensorStats| P[Use quantized tensor directly]
O --> Q[Compute stats]
P --> R[Compute scale_inv stats only]
Last reviewed commit: e59091f |
There was a problem hiding this comment.
8 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
/te-ci pytorch |
There was a problem hiding this comment.
9 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
There was a problem hiding this comment.
9 files reviewed, 2 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/pytorch/module/linear.py, line 257 (link)logic: operator precedence bug -
andbinds tighter thanor. This evaluates as(weight_quantizer is not None and not isinstance(weight, QuantizedTensor)) or debug. Whenweight_quantizerisNoneanddebugisTrue, the condition isTrue, causing line 264 to callweight_quantizer.set_usage(...)onNone, resulting inAttributeError.
12 files reviewed, 1 comment
Add 'or debug' / 'and not debug' conditions to weight quantizer configuration in linear.py, grouped_linear.py, layernorm_linear.py, and layernorm_mlp.py. In debug mode, quantizers are recreated every iteration, so we need to set quantizer states even when weights are already quantized. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
3a8c6a8 to
863de12
Compare
|
/te-ci pytorch |
Additional Comments (1)
When The
|
Additional Comments (1)
When FP8 model parameters are enabled and The if not high_precision_tensor_provided and recipe_from_stat != "" and recipe_from_stat != current_recipe:
raise ValueError(
f"Stat {stat} requires a different recipe than the current one, which needs "
"a high precision tensor to re-quantize. "
"This is not supported for weight tensors when using fp8 model parameters."
) |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>
Additional Comments (1)
When fp8 model parameters are used for weights, The Consider adding a guard either here (to skip cross-recipe quantization when |
|
/te-ci pytorch |
Additional Comments (1)
Move the wrapping logic to wrap ALL returns: # After line 1454, before any return
if isinstance(quantizer, DebugQuantizer):
out = quantizer.wrap_quantized_tensor(out)Or consolidate wrapping at the end before line 1473. |
|
/te-ci pytorch |
…s. (NVIDIA#2141) * initial code drop Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Fix weight quantizer logic in debug mode Add 'or debug' / 'and not debug' conditions to weight quantizer configuration in linear.py, grouped_linear.py, layernorm_linear.py, and layernorm_mlp.py. In debug mode, quantizers are recreated every iteration, so we need to set quantizer states even when weights are already quantized. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Update transformer_engine/debug/pytorch/debug_quantization.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Description
Currently precision debug tools are not supported for FP8 model parameters. It is because all logic of debug tools is inside quantize() function in DebugQuantizers, which are not called if weight is in FP8. Also, for some stats like number of underflows we need both high precision tensor and quantized tensor.
I added function
DebugQunatizer.wrap_quantized_tensor(QuantizedTensor) -> DebugQuantizedTensorwhich will be called for debug iterations for weight. The debug for all the other tensors work without any change.I made argument
tensorforinspect_tensorcall optional - it is None for weight tensor in case of fp8 model parameters.If one wants to use LogTensorStats, the quantized tensor is dequantized. For LogFp8TensorStats the stats which needs high precision tensor are disabled in considered case.
Fixes #2140
Type of change
Checklist: