diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 5d6fc41ac7..b16291ff61 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -151,6 +151,58 @@ def test_sanity(feature_dirs): assert stat in output, f"Stat {stat} not found in output" +LOG_FP8_MODEL_PARAMETERS_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogTensorStats: + enabled: + True + stats: [min] + tensors: [weight, activation, gradient] + freq: 1 + LogFp8TensorStats: + enabled: + True + tensors_struct: + - tensor: activation + stats: [scale_inv_min, scale_inv_max, underflows%] + - tensor: weight + stats: [scale_inv_min, scale_inv_max] + freq: 1 +""" + + +def test_sanity_log_fp8_model_parameters(feature_dirs): + """ + Tests logging stats when model parameters are in fp8. + It tests 3 things: + - LogTensorStats for weight tensor should work without change, + - LogTensorStats and LogFp8TensorStats for non-weight tensors should work without change, + - LogFp8TensorStats should support scale_inv_min, scale_inv_max for weight tensor. + + """ + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + with debug_session(LOG_FP8_MODEL_PARAMETERS_CONFIG_BASE, feature_dirs) as log_dir: + with te.fp8_model_init(recipe=recipe.DelayedScaling()): + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() + for _ in range(10): + with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + output = read_log(log_dir) + assert output, "Output is empty" + TEDebugState._reset() + + fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index d47bc553b0..3ef8c0983f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -551,8 +551,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): - if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: - pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size num_tokens = bs * config.max_seqlen_q @@ -599,8 +597,6 @@ def test_sanity_grouped_linear( num_gemms, empty_split, ): - if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: - pytest.skip("FP8 model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. @@ -1222,8 +1218,6 @@ def test_inference_mode( quantization: Optional[str], ) -> None: """Test heuristics for initializing quantized weights""" - if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None: - pytest.skip("Quantized model parameters are not supported in debug mode.") # Tensor dimensions sequence_length = 32 diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 9c30f87c3b..774fae3594 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -244,7 +244,7 @@ def inspect_tensor( config: Dict, layer_name: str, tensor_name: str, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor], columnwise_quantized_tensor: Optional[torch.Tensor], quantizer: Optional[Quantizer], @@ -262,8 +262,8 @@ def inspect_tensor( layer_name: str tensor_name: str one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], - tensor: torch.Tensor - tensor in high precision, + tensor: Optional[torch.Tensor] + tensor in high precision. It can be None only if fp8 model parameters are used and tensor name is `weight`. rowwise_quantized_tensor: Optional[torch.Tensor] rowwise quantized tensor, columnwise_quantized_tensor: Optional[torch.Tensor] diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index 108b33fd86..fd18d590ec 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -122,6 +122,10 @@ class LogFp8TensorStats(BaseLogTensorStats): - scale_inv_max - maximum of the inverse of the scaling factors, - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, + When collecting stats for the weight tensor with FP8 model parameters enabled, + only "scale_inv_min" and "scale_inv_max" are available. + All other statistics require access to the high precision tensor. + tensors/tensors_struct: List[str] list of tensors to log - activation, @@ -159,7 +163,9 @@ class LogFp8TensorStats(BaseLogTensorStats): end_step: 80 """ - def check_if_stat_is_supported(self, stat: str, current_recipe: str): + def check_if_stat_is_supported( + self, stat: str, current_recipe: str, high_precision_tensor_provided: bool + ): """Returns True if stat is supported, raises ValueError otherwise.""" columnwise = stat.endswith("_columnwise") if columnwise: @@ -167,6 +173,17 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe) stat_without_recipe = stat.replace(recipe_from_stat + "_", "") + need_high_precision_tensor_stats = ["underflows%", "overflows%", "mse"] + if ( + stat_without_recipe in need_high_precision_tensor_stats + and not high_precision_tensor_provided + ): + raise ValueError( + f"Stat {stat} requires a high precision tensor to be provided. " + "This feature is not supported for weight tensors when using fp8 model " + "parameters." + ) + if current_recipe == "" and recipe_from_stat == "": raise ValueError( f"Stat {stat} does not contain a recipe name and the current recipe is not set." @@ -290,7 +307,7 @@ def inspect_tensor( tensor_name: str, iteration: int, tp_group: torch.distributed.ProcessGroup, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, @@ -322,7 +339,9 @@ def inspect_tensor( recipe_name = _get_recipe_name(quantizer) for stat in config["stats"]: - self.check_if_stat_is_supported(stat, recipe_name) + self.check_if_stat_is_supported( + stat, recipe_name, high_precision_tensor_provided=tensor is not None + ) start_step = config.get("start_step", None) end_step = config.get("end_step", None) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 100fa64481..76e61fab24 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -180,13 +180,19 @@ def inspect_tensor( tensor_name: str, iteration: int, tp_group: torch.distributed.ProcessGroup, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, ): # pylint: disable=unused-argument """API call used to collect the data about the tensor before process_tensor()/quantization.""" + # Tensor is None only if fp8 model parameters are used and tensor name is `weight`. + # If one wants to collect stats for this tensor, we need to dequantize it. + if tensor is None: + assert isinstance(rowwise_quantized_tensor, QuantizedTensor) + tensor = rowwise_quantized_tensor.dequantize() + assert ( type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage] and tensor.dtype != torch.uint8 diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 9ce56dd76d..ca7f22e2de 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -90,12 +90,19 @@ def feed(self, tensor, iteration, aux_dict=None): if self.modified[0] and not self.reduce_within_microbatch: return - if ( - tensor.numel() == 0 - if hasattr(tensor, "numel") - else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors()) - ): - return + if tensor is not None: + # tensor can be None if we compute fp8 stats for weight and fp8 model parameters are used + # then high precision is not provided and quantized tensor from aux_dict is used. + + # This condition prevents computation of stats for empty tensor. + # This will not happen for weight - since it is the only situation then tensor can be None, + # we do not need to check similar condition for weight. + if ( + tensor.numel() == 0 + if hasattr(tensor, "numel") + else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors()) + ): + return # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 455079143b..5624970547 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -267,7 +267,7 @@ def _call_inspect_tensor_api( "rowwise_quantized_tensor": rowwise_gemm_tensor, "quantizer": self.parent_quantizer, } - if tensor is not None and self.inspect_tensor_enabled: + if self.inspect_tensor_enabled: debug_api.transformer_engine.inspect_tensor(**args) if self.output_tensor: @@ -559,6 +559,30 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): if not self.output_tensor: self._update_parent_quantizer_usage() + def wrap_quantized_tensor(self, tensor: QuantizedTensor): + """ + Wraps the quantized tensor with the debug quantizer. + It is used for weight tensors when fp8 model parameters are enabled. + """ + + assert ( + self.rowwise_tensor_plan == STANDARD_QUANTIZE + and self.columnwise_tensor_plan == STANDARD_QUANTIZE + ), ( + "[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot be" + " modified by any feature." + ) + + self._call_inspect_tensor_api(None, tensor, tensor) + + return DebugQuantizedTensor( + rowwise_gemm_tensor=tensor, + columnwise_gemm_tensor=tensor, + quantizer=self, + layer_name=self.layer_name, + tensor_name=self.tensor_name, + ) + @classmethod def multi_tensor_quantize( cls, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..4858383c26 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1390,6 +1390,10 @@ def get_weight_workspace( rowwise_usage=update_rowwise_usage, columnwise_usage=update_columnwise_usage, ) + + if isinstance(quantizer, DebugQuantizer): + tensor = quantizer.wrap_quantized_tensor(tensor) + return tensor # Try getting workspace from cache @@ -1585,8 +1589,6 @@ def no_debug_features_active(self, quantizers): if not run_current: return True - if self.primary_weights_in_fp8: - raise RuntimeError("FP8 weights are not supported in debug mode.") return False def _validate_name(self): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2f859e748b..b381073d78 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -123,8 +123,9 @@ def forward( and not in_fp8_activation_recompute_phase() ) # No need to set the quantizer states if weight is already quantized - if weight_quantizers[0] is not None and not isinstance( - weights[0], QuantizedTensorStorage + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if weight_quantizers[0] is not None and ( + not isinstance(weights[0], QuantizedTensorStorage) or debug ): for weight_quantizer in weight_quantizers: weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) @@ -874,9 +875,6 @@ def forward( debug = False quantizers = self._get_quantizers() - if isinstance(weight_tensors, QuantizedTensorStorage): - raise RuntimeError("FP8 weights are not supported in debug mode.") - ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..27632db15b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -293,7 +293,8 @@ def forward( # Configure quantizer # If weight is already quantized, no need to set quantizer states - if is_weight_param_quantized: + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if is_weight_param_quantized and not debug: weight_quantizer = weight._quantizer elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4532ea60e7..b8823e46ca 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -473,12 +473,13 @@ def _forward( # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch # No need to set the quantizer states if weights are already quantized - if isinstance(fc1_weight, QuantizedTensorStorage): + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if isinstance(fc1_weight, QuantizedTensorStorage) and not debug: fc1_weight_quantizer = fc1_weight._quantizer elif fc1_weight_quantizer is not None: fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - if isinstance(fc2_weight, QuantizedTensorStorage): + if isinstance(fc2_weight, QuantizedTensorStorage) and not debug: fc2_weight_quantizer = fc2_weight._quantizer elif fc2_weight_quantizer is not None: fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..a55429d33d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -253,7 +253,8 @@ def forward( if fp8 or debug: # Configure quantizer # No need to set the quantizer states if weight is already quantized - if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): + # for debug mode we create quantizer every iteration, thus we need to set the quantizer states + if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): columnwise_usage = is_grad_enabled and inp.requires_grad if not columnwise_usage: columnwise_usage = (