Skip to content

[PyTorch Debug] Support precision debug tools for fp8 model parameters.#2141

Merged
ptrendx merged 26 commits intoNVIDIA:mainfrom
pggPL:nvinspect_fp8_model_weights
Feb 25, 2026
Merged

[PyTorch Debug] Support precision debug tools for fp8 model parameters.#2141
ptrendx merged 26 commits intoNVIDIA:mainfrom
pggPL:nvinspect_fp8_model_weights

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Sep 1, 2025

Description

Currently precision debug tools are not supported for FP8 model parameters. It is because all logic of debug tools is inside quantize() function in DebugQuantizers, which are not called if weight is in FP8. Also, for some stats like number of underflows we need both high precision tensor and quantized tensor.

I added function DebugQunatizer.wrap_quantized_tensor(QuantizedTensor) -> DebugQuantizedTensor which will be called for debug iterations for weight. The debug for all the other tensors work without any change.

I made argument tensor for inspect_tensor call optional - it is None for weight tensor in case of fp8 model parameters.
If one wants to use LogTensorStats, the quantized tensor is dequantized. For LogFp8TensorStats the stats which needs high precision tensor are disabled in considered case.

Fixes #2140

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 9 commits September 1, 2025 10:17
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review September 1, 2025 13:06
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Sep 1, 2025

/te-ci pytorch

pre-commit-ci Bot and others added 3 commits September 1, 2025 13:06
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Sep 15, 2025

/te-ci pytorch

1 similar comment
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Sep 15, 2025

/te-ci pytorch

Comment thread transformer_engine/debug/features/log_fp8_tensor_stats.py Outdated
pggPL and others added 3 commits October 24, 2025 13:51
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The PR extends TransformerEngine's nvidia-dlfw-inspect debugging tools to support FP8 model parameters (when fp8_primary_weight is enabled). Previously, the debug infrastructure lived entirely inside the quantize() method which wasn't called for FP8 weights stored in quantized format. The solution introduces DebugQuantizer.wrap_quantized_tensor() as a new entry point that wraps pre-quantized weights in debug functionality and makes the tensor parameter optional in the inspect_tensor API (passing None for FP8 weights). When high-precision tensors are unavailable, LogTensorStats dequantizes on-demand while LogFp8TensorStats disables statistics requiring high-precision (underflows%, overflows%, mse). The implementation preserves existing behavior for non-weight tensors (activations, gradients) while adding three test skips removal to validate the feature works end-to-end.

Important Files Changed

Filename Score Overview
transformer_engine/debug/pytorch/debug_quantization.py 4/5 Added wrap_quantized_tensor() method to inject debug logic for pre-quantized FP8 weights with assertions preventing weight modification
transformer_engine/debug/features/api.py 4/5 Made tensor parameter optional in inspect_tensor API and added backward compatibility logic to filter kwargs for legacy feature implementations
transformer_engine/debug/features/log_fp8_tensor_stats.py 3/5 Added validation to reject stats requiring high-precision tensors when only quantized weights available, with potential issue in cross-recipe stat path
transformer_engine/debug/features/log_tensor_stats.py 4/5 Added dequantization path when tensor=None to compute stats from quantized FP8 weights, introduces performance overhead
transformer_engine/debug/features/utils/stats_buffer.py 3/5 Allows None tensors in stats buffer for FP8 weights, but stat computation functions may not handle None safely
transformer_engine/pytorch/module/base.py 4/5 Removed FP8 debug blocker and added wrap_quantized_tensor() call, but wrapping only occurs on early-return path causing potential inconsistency
tests/pytorch/test_sanity.py 5/5 Removed three test skips that blocked FP8 parameters in debug mode, enabling validation of the new feature
tests/pytorch/debug/test_log.py 4/5 Added integration test for FP8 model parameter debug logging with explicit state reset to prevent test pollution

Confidence score: 3/5

  • This PR requires careful review due to the complexity of threading optional tensor handling through multiple layers of the debug infrastructure
  • Score lowered because: (1) the wrapping logic in base.py only applies to one code path (early-return) but not to newly constructed workspaces around lines 1424-1443, potentially causing inconsistent debug behavior, (2) stat computation functions in stats_buffer.py may dereference None tensors without null checks causing runtime errors, (3) the cross-recipe stat path in log_fp8_tensor_stats.py might attempt to requantize a None tensor, and (4) the dequantization in log_tensor_stats.py adds undocumented performance overhead
  • Pay close attention to transformer_engine/debug/features/utils/stats_buffer.py, transformer_engine/debug/features/log_fp8_tensor_stats.py, and transformer_engine/pytorch/module/base.py

Additional Comments (2)

  1. transformer_engine/debug/features/utils/stats_buffer.py, line 108-110 (link)

    logic: When tensor is None, every stat function fn(tensor, aux_dict) receives None as the first argument. Verify that all stat functions in STATS handle None tensors correctly, or add explicit guards here to prevent runtime errors.

  2. transformer_engine/debug/features/log_fp8_tensor_stats.py, line 347 (link)

    logic: When tensor is None (FP8 weight case), passing None as original_tensor to update_aux_dict may cause issues in line 249 where quantizer(original_tensor) is called if different recipes are requested. What happens when a user requests stats for a different recipe (e.g., mxfp8_underflows%) on FP8 weight tensors where tensor=None? Line 249 will fail with NoneType error.

8 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/debug/features/log_tensor_stats.py
Comment thread transformer_engine/debug/features/utils/stats_buffer.py
Comment thread transformer_engine/pytorch/module/base.py
Comment thread transformer_engine/debug/pytorch/debug_quantization.py
Comment thread transformer_engine/debug/features/log_fp8_tensor_stats.py
pggPL added 2 commits October 31, 2025 10:09
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Nov 6, 2025

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Nov 18, 2025

Greptile Summary

This PR enables debug tools for FP8 model parameters by adding wrap_quantized_tensor() to wrap pre-quantized weights and making the tensor parameter optional in inspect_tensor(). The implementation handles two cases: LogTensorStats dequantizes FP8 weights to compute stats, while LogFp8TensorStats restricts available stats to scale_inv_min and scale_inv_max when the high-precision tensor is unavailable.

Critical issue found:

  • transformer_engine/pytorch/module/base.py:1454-1461 - Newly created quantized tensors (line 1454) and cached tensors (line 1473) are never wrapped with wrap_quantized_tensor(). Only the early-return path (line 1395) wraps cached FP8 model parameters. This breaks the feature for non-cached weight tensors.

Minor issues:

  • transformer_engine/debug/features/log_tensor_stats.py:193 - Missing assertion message
  • Previous review comments about STANDARD_FP8_QUANTIZE being undefined and operator precedence bugs are incorrect - the code uses the correct STANDARD_QUANTIZE constant and has proper parentheses

Confidence Score: 2/5

  • Critical logic bug in base.py breaks FP8 debug feature for non-cached tensors
  • The wrapping logic at line 1395 only applies to already-quantized tensors in the early return path, but newly created quantized tensors (line 1454) and cached tensors (line 1473) bypass wrapping entirely. This means the core feature will not work correctly when weights are quantized on-the-fly rather than being pre-quantized. While tests pass, they may only exercise the early-return path with pre-quantized weights.
  • transformer_engine/pytorch/module/base.py requires immediate attention - wrapping logic must be fixed to cover all return paths

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Adds wrap_quantized_tensor call for FP8 model params, but only wraps cached tensors (line 1395), not newly created ones (line 1454) - critical logic bug
transformer_engine/debug/pytorch/debug_quantization.py Adds wrap_quantized_tensor method to wrap FP8 model parameters for debug support; uses correct constant STANDARD_QUANTIZE
transformer_engine/debug/features/log_tensor_stats.py Handles None tensor by dequantizing FP8 weight; assertion at line 193 missing descriptive message
transformer_engine/debug/features/utils/stats_buffer.py Adds None check for tensor before empty-tensor validation; assumes aux_dict contains valid tensors when tensor is None

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[FP8 Model Parameters] --> B{Debug Mode Enabled?}
    B -->|No| C[Standard FP8 Flow]
    B -->|Yes| D[Create DebugQuantizer]
    D --> E{Weight Already Quantized?}
    E -->|Yes| F[wrap_quantized_tensor]
    E -->|No| G[Quantize Weight]
    G --> H{Is DebugQuantizer?}
    H -->|Yes| I[Should wrap but doesn't]
    H -->|No| J[Return quantized tensor]
    F --> K[DebugQuantizedTensor]
    I --> L[Bug: Missing wrap]
    K --> M[inspect_tensor with tensor=None]
    M --> N{Feature Type?}
    N -->|LogTensorStats| O[Dequantize for stats]
    N -->|LogFp8TensorStats| P[Use quantized tensor directly]
    O --> Q[Compute stats]
    P --> R[Compute scale_inv stats only]
Loading

Last reviewed commit: e59091f

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Nov 18, 2025

/te-ci pytorch

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Comment thread transformer_engine/pytorch/module/linear.py Outdated
Comment thread transformer_engine/pytorch/module/base.py
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/module/linear.py, line 257 (link)

    logic: operator precedence bug - and binds tighter than or. This evaluates as (weight_quantizer is not None and not isinstance(weight, QuantizedTensor)) or debug. When weight_quantizer is None and debug is True, the condition is True, causing line 264 to call weight_quantizer.set_usage(...) on None, resulting in AttributeError.

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Add 'or debug' / 'and not debug' conditions to weight quantizer configuration
in linear.py, grouped_linear.py, layernorm_linear.py, and layernorm_mlp.py.

In debug mode, quantizers are recreated every iteration, so we need to
set quantizer states even when weights are already quantized.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the nvinspect_fp8_model_weights branch from 3a8c6a8 to 863de12 Compare December 11, 2025 12:53
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Dec 12, 2025

/te-ci pytorch

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/debug/pytorch/debug_quantization.py Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 16, 2026

Additional Comments (1)

transformer_engine/debug/features/log_fp8_tensor_stats.py
None tensor passed to update_aux_dict can cause crash with cross-recipe stats

When tensor is None (FP8 model parameters weight case), it is passed as original_tensor to update_aux_dict. Inside that method (line 271), if any stat uses a different recipe than the current one (e.g., user requests mxfp8_scale_inv_min during fp8_delayed_scaling training), quantizer(original_tensor) will be called with None, causing a runtime error.

The check_if_stat_is_supported validation above only blocks underflows%, overflows%, and mse — it does not prevent cross-recipe scale_inv_min/scale_inv_max stats. Consider either:

  1. Blocking all cross-recipe stats when the high-precision tensor is not available, or
  2. Guarding against original_tensor is None inside update_aux_dict.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/debug/pytorch/debug_quantization.py Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 16, 2026

Additional Comments (1)

transformer_engine/debug/features/log_fp8_tensor_stats.py
Cross-recipe stats crash when original_tensor is None

When FP8 model parameters are enabled and tensor is None, the original_tensor passed to update_aux_dict is also None. If a user configures a cross-recipe stat like mxfp8_scale_inv_min while training with fp8_delayed_scaling, line 269 evaluates to True (different recipe), and line 271 calls quantizer(original_tensor) which is quantizer(None) — this will crash.

The check_if_stat_is_supported guard only blocks underflows%, overflows%, and mse when the high-precision tensor is missing, but it allows scale_inv_min/scale_inv_max with a cross-recipe prefix through. Consider also blocking cross-recipe stats when high_precision_tensor_provided is False, since re-quantizing with a different recipe requires the original tensor:

        if not high_precision_tensor_provided and recipe_from_stat != "" and recipe_from_stat != current_recipe:
            raise ValueError(
                f"Stat {stat} requires a different recipe than the current one, which needs "
                "a high precision tensor to re-quantize. "
                "This is not supported for weight tensors when using fp8 model parameters."
            )

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 16, 2026

Additional Comments (1)

transformer_engine/debug/features/log_fp8_tensor_stats.py
Cross-recipe stats crash when original_tensor is None

When fp8 model parameters are used for weights, original_tensor is None. If a user requests a stat prefixed with a different recipe than the current one (e.g., mxfp8_scale_inv_min when the model uses DelayedScaling), this branch is taken and quantizer(original_tensor) becomes quantizer(None), which will crash.

The check_if_stat_is_supported guard only rejects underflows%, overflows%, and mse when the high-precision tensor is not provided. Stats like scale_inv_min and scale_inv_max pass through even with high_precision_tensor_provided=False, so this crash path is reachable for cross-recipe scale stats.

Consider adding a guard either here (to skip cross-recipe quantization when original_tensor is None) or in check_if_stat_is_supported (to disallow all cross-recipe stats when no high-precision tensor is available).

@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Feb 23, 2026

/te-ci pytorch

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/debug/features/log_tensor_stats.py
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 23, 2026

Additional Comments (1)

transformer_engine/pytorch/module/base.py
newly created quantized tensors bypass debug wrapping. The wrap_quantized_tensor call at line 1395 only wraps cached tensors (early return path). Tensors created here and returned at line 1461, plus tensors returned at line 1473, are never wrapped.

Move the wrapping logic to wrap ALL returns:

# After line 1454, before any return
if isinstance(quantizer, DebugQuantizer):
    out = quantizer.wrap_quantized_tensor(out)

Or consolidate wrapping at the end before line 1473.

Comment thread transformer_engine/debug/pytorch/debug_quantization.py
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Feb 24, 2026

/te-ci pytorch

@ptrendx ptrendx merged commit 7222d87 into NVIDIA:main Feb 25, 2026
10 of 14 checks passed
Oleg-Goncharov pushed a commit to Oleg-Goncharov/TransformerEngine that referenced this pull request Feb 27, 2026
…s. (NVIDIA#2141)

* initial code drop

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* Fix weight quantizer logic in debug mode

Add 'or debug' / 'and not debug' conditions to weight quantizer configuration
in linear.py, grouped_linear.py, layernorm_linear.py, and layernorm_mlp.py.

In debug mode, quantizers are recreated every iteration, so we need to
set quantizer states even when weights are already quantized.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/debug/pytorch/debug_quantization.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[REQUEST ]Support using nvidia-dlfw-inspect when the fp8_primary_weight option is enabled.

2 participants