Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
10fdc46
initial code drop
pggPL Sep 1, 2025
2e5debb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2025
11b4c26
fixes
pggPL Sep 1, 2025
8a8eed3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2025
ecdc727
fix
pggPL Sep 1, 2025
b0162af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2025
20515c7
fixes
pggPL Sep 1, 2025
26926cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2025
39d500a
fix
pggPL Sep 1, 2025
206cbc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2025
82e3f2c
fix
pggPL Sep 15, 2025
ef7ef41
Merge branch 'main' into nvinspect_fp8_model_weights
pggPL Sep 15, 2025
c986708
fix
pggPL Oct 24, 2025
510822e
Merge branch 'main' into nvinspect_fp8_model_weights
pggPL Oct 24, 2025
3740ff6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
1aa421c
fix
pggPL Oct 31, 2025
e0ba7e7
fix
pggPL Nov 4, 2025
de379b9
Merge branch 'main' into nvinspect_fp8_model_weights
pggPL Nov 18, 2025
a3ec90d
fix
pggPL Nov 18, 2025
91e0332
fix
pggPL Nov 18, 2025
c5e047f
Merge remote-tracking branch 'upstream/main' into nvinspect_fp8_model…
pggPL Dec 11, 2025
863de12
Fix weight quantizer logic in debug mode
pggPL Dec 11, 2025
304b77a
Merge branch 'main' into nvinspect_fp8_model_weights
pggPL Feb 16, 2026
ae85aef
fix
pggPL Feb 16, 2026
fca78e9
Update transformer_engine/debug/pytorch/debug_quantization.py
pggPL Feb 16, 2026
e59091f
Merge branch 'main' into nvinspect_fp8_model_weights
pggPL Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,58 @@ def test_sanity(feature_dirs):
assert stat in output, f"Stat {stat} not found in output"


LOG_FP8_MODEL_PARAMETERS_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogTensorStats:
enabled:
True
stats: [min]
tensors: [weight, activation, gradient]
freq: 1
LogFp8TensorStats:
enabled:
True
tensors_struct:
- tensor: activation
stats: [scale_inv_min, scale_inv_max, underflows%]
- tensor: weight
stats: [scale_inv_min, scale_inv_max]
freq: 1
"""


def test_sanity_log_fp8_model_parameters(feature_dirs):
"""
Tests logging stats when model parameters are in fp8.
It tests 3 things:
- LogTensorStats for weight tensor should work without change,
- LogTensorStats and LogFp8TensorStats for non-weight tensors should work without change,
- LogFp8TensorStats should support scale_inv_min, scale_inv_max for weight tensor.

"""
if not fp8_available:
pytest.skip(reason_for_no_fp8)

with debug_session(LOG_FP8_MODEL_PARAMETERS_CONFIG_BASE, feature_dirs) as log_dir:
with te.fp8_model_init(recipe=recipe.DelayedScaling()):
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert output, "Output is empty"
TEDebugState._reset()


fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
Expand Down
6 changes: 0 additions & 6 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.max_seqlen_q
Expand Down Expand Up @@ -599,8 +597,6 @@ def test_sanity_grouped_linear(
num_gemms,
empty_split,
):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
Expand Down Expand Up @@ -1222,8 +1218,6 @@ def test_inference_mode(
quantization: Optional[str],
) -> None:
"""Test heuristics for initializing quantized weights"""
if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
pytest.skip("Quantized model parameters are not supported in debug mode.")

# Tensor dimensions
sequence_length = 32
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/debug/features/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def inspect_tensor(
config: Dict,
layer_name: str,
tensor_name: str,
tensor: torch.Tensor,
tensor: Optional[torch.Tensor],
rowwise_quantized_tensor: Optional[torch.Tensor],
columnwise_quantized_tensor: Optional[torch.Tensor],
quantizer: Optional[Quantizer],
Expand All @@ -262,8 +262,8 @@ def inspect_tensor(
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in high precision,
tensor: Optional[torch.Tensor]
tensor in high precision. It can be None only if fp8 model parameters are used and tensor name is `weight`.
rowwise_quantized_tensor: Optional[torch.Tensor]
rowwise quantized tensor,
columnwise_quantized_tensor: Optional[torch.Tensor]
Expand Down
25 changes: 22 additions & 3 deletions transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class LogFp8TensorStats(BaseLogTensorStats):
- scale_inv_max - maximum of the inverse of the scaling factors,
- mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements,

When collecting stats for the weight tensor with FP8 model parameters enabled,
only "scale_inv_min" and "scale_inv_max" are available.
All other statistics require access to the high precision tensor.

tensors/tensors_struct: List[str]
list of tensors to log
- activation,
Expand Down Expand Up @@ -159,14 +163,27 @@ class LogFp8TensorStats(BaseLogTensorStats):
end_step: 80
"""

def check_if_stat_is_supported(self, stat: str, current_recipe: str):
def check_if_stat_is_supported(
self, stat: str, current_recipe: str, high_precision_tensor_provided: bool
):
"""Returns True if stat is supported, raises ValueError otherwise."""
columnwise = stat.endswith("_columnwise")
if columnwise:
stat = stat[: -len("_columnwise")]
recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe)
stat_without_recipe = stat.replace(recipe_from_stat + "_", "")

need_high_precision_tensor_stats = ["underflows%", "overflows%", "mse"]
if (
stat_without_recipe in need_high_precision_tensor_stats
and not high_precision_tensor_provided
):
raise ValueError(
f"Stat {stat} requires a high precision tensor to be provided. "
"This feature is not supported for weight tensors when using fp8 model "
"parameters."
)
Comment thread
pggPL marked this conversation as resolved.

if current_recipe == "" and recipe_from_stat == "":
raise ValueError(
f"Stat {stat} does not contain a recipe name and the current recipe is not set."
Expand Down Expand Up @@ -290,7 +307,7 @@ def inspect_tensor(
tensor_name: str,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
tensor: torch.Tensor,
tensor: Optional[torch.Tensor],
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
Expand Down Expand Up @@ -322,7 +339,9 @@ def inspect_tensor(
recipe_name = _get_recipe_name(quantizer)

for stat in config["stats"]:
self.check_if_stat_is_supported(stat, recipe_name)
self.check_if_stat_is_supported(
stat, recipe_name, high_precision_tensor_provided=tensor is not None
)

start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/debug/features/log_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,19 @@ def inspect_tensor(
tensor_name: str,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
tensor: torch.Tensor,
tensor: Optional[torch.Tensor],
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
): # pylint: disable=unused-argument
"""API call used to collect the data about the tensor before process_tensor()/quantization."""

# Tensor is None only if fp8 model parameters are used and tensor name is `weight`.
# If one wants to collect stats for this tensor, we need to dequantize it.
if tensor is None:
assert isinstance(rowwise_quantized_tensor, QuantizedTensor)
Comment thread
pggPL marked this conversation as resolved.
tensor = rowwise_quantized_tensor.dequantize()
Comment thread
ptrendx marked this conversation as resolved.

assert (
type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage]
and tensor.dtype != torch.uint8
Expand Down
19 changes: 13 additions & 6 deletions transformer_engine/debug/features/utils/stats_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,19 @@ def feed(self, tensor, iteration, aux_dict=None):
if self.modified[0] and not self.reduce_within_microbatch:
return

if (
tensor.numel() == 0
if hasattr(tensor, "numel")
else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors())
):
return
if tensor is not None:
# tensor can be None if we compute fp8 stats for weight and fp8 model parameters are used
# then high precision is not provided and quantized tensor from aux_dict is used.

# This condition prevents computation of stats for empty tensor.
# This will not happen for weight - since it is the only situation then tensor can be None,
# we do not need to check similar condition for weight.
if (
tensor.numel() == 0
if hasattr(tensor, "numel")
else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors())
):
return
Comment thread
pggPL marked this conversation as resolved.

# save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute:
Expand Down
26 changes: 25 additions & 1 deletion transformer_engine/debug/pytorch/debug_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _call_inspect_tensor_api(
"rowwise_quantized_tensor": rowwise_gemm_tensor,
"quantizer": self.parent_quantizer,
}
if tensor is not None and self.inspect_tensor_enabled:
if self.inspect_tensor_enabled:
Comment thread
ptrendx marked this conversation as resolved.
debug_api.transformer_engine.inspect_tensor(**args)

if self.output_tensor:
Expand Down Expand Up @@ -559,6 +559,30 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None):
if not self.output_tensor:
self._update_parent_quantizer_usage()

def wrap_quantized_tensor(self, tensor: QuantizedTensor):
"""
Wraps the quantized tensor with the debug quantizer.
It is used for weight tensors when fp8 model parameters are enabled.
"""

assert (
self.rowwise_tensor_plan == STANDARD_QUANTIZE
and self.columnwise_tensor_plan == STANDARD_QUANTIZE
), (
"[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot be"
" modified by any feature."
)

self._call_inspect_tensor_api(None, tensor, tensor)
Comment thread
pggPL marked this conversation as resolved.

return DebugQuantizedTensor(
rowwise_gemm_tensor=tensor,
columnwise_gemm_tensor=tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)

@classmethod
def multi_tensor_quantize(
cls,
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,10 @@ def get_weight_workspace(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)

if isinstance(quantizer, DebugQuantizer):
tensor = quantizer.wrap_quantized_tensor(tensor)
Comment thread
pggPL marked this conversation as resolved.

Comment thread
pggPL marked this conversation as resolved.
return tensor

# Try getting workspace from cache
Expand Down Expand Up @@ -1585,8 +1589,6 @@ def no_debug_features_active(self, quantizers):
if not run_current:
return True

if self.primary_weights_in_fp8:
raise RuntimeError("FP8 weights are not supported in debug mode.")
return False

def _validate_name(self):
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def forward(
and not in_fp8_activation_recompute_phase()
)
# No need to set the quantizer states if weight is already quantized
if weight_quantizers[0] is not None and not isinstance(
weights[0], QuantizedTensorStorage
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if weight_quantizers[0] is not None and (
not isinstance(weights[0], QuantizedTensorStorage) or debug
):
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
Expand Down Expand Up @@ -874,9 +875,6 @@ def forward(
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def forward(

# Configure quantizer
# If weight is already quantized, no need to set quantizer states
if is_weight_param_quantized:
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if is_weight_param_quantized and not debug:
weight_quantizer = weight._quantizer
elif weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,13 @@ def _forward(
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
# No need to set the quantizer states if weights are already quantized
if isinstance(fc1_weight, QuantizedTensorStorage):
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if isinstance(fc1_weight, QuantizedTensorStorage) and not debug:
fc1_weight_quantizer = fc1_weight._quantizer
elif fc1_weight_quantizer is not None:
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)

if isinstance(fc2_weight, QuantizedTensorStorage):
if isinstance(fc2_weight, QuantizedTensorStorage) and not debug:
fc2_weight_quantizer = fc2_weight._quantizer
elif fc2_weight_quantizer is not None:
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ def forward(
if fp8 or debug:
# Configure quantizer
# No need to set the quantizer states if weight is already quantized
if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
# for debug mode we create quantizer every iteration, thus we need to set the quantizer states
if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug):
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
Expand Down
Loading