From 270c356096a726d4cd416c8d6ba77674d3f9c511 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 3 Nov 2023 16:35:19 +0800 Subject: [PATCH] [Quant] [PT2] Enable Decomposed quant per tensor/channel to accept bfloat16 input (#112225) **Summary** - PR 4 for enabling Int8-Mixed-BF16 PT2E PTQ Quantization with Inductor https://github.com/pytorch/pytorch/issues/111640. - Enable `decomposed quant_per_tensor` and `quant_per_channel` accepts bfloat16 input. **TestPlan** ``` python -m pytest test_quantized_tensor.py -k test_decomposed_quantize_per_tensor_bfloat16_input python -m pytest test_quantized_tensor.py -k test_decomposed_quantize_per_channel_bfloat16_input ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112225 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 --- .../core/test_quantized_tensor.py | 30 +++++++++++++++++++ torch/_inductor/decomposition.py | 4 +++ torch/ao/quantization/fx/_decomposed.py | 10 +++++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index 96d5cea156af3..4e83b044a7239 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -1478,6 +1478,18 @@ def test_decomposed_quantize_per_tensor(self): self.assertEqual(quantized_decomposed_X.dtype, dtype) self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_quantize_per_tensor_bfloat16_input(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randint(1, 10, (5, 5)).to(torch.float32) + scale, zero_point = _calculate_dynamic_qparams(X, torch.quint8) + quantized_X = torch.quantize_per_tensor(X, scale, zero_point, torch.quint8) + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_tensor( + X.to(torch.bfloat16), scale, zero_point, 0, 255, torch.uint8) + self.assertEqual(quantized_decomposed_X.dtype, torch.uint8) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_dequantize_per_tensor(self): import torch.ao.quantization.fx._decomposed X = torch.randn(5, 10) @@ -1541,6 +1553,24 @@ def test_decomposed_quantize_per_channel(self): self.assertEqual(quantized_decomposed_X.dtype, dtype) self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_quantize_per_channel_bfloat16_input(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randint(1, 10, (5, 5)).to(torch.float32) + qdtype = torch.quint8 + dtype = torch.uint8 + scales = torch.randn(5,) + zero_points = torch.randint(0, 100, (5,)) + quant_min, quant_max = 0, 255 + axis = 0 + + quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype) + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_channel( + X.to(torch.bfloat16), scales, zero_points, axis, quant_min, quant_max, dtype) + self.assertEqual(quantized_decomposed_X.dtype, dtype) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + def test_decomposed_dequantize_per_channel(self): # register the ops import torch.ao.quantization.fx._decomposed diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 883ad4e7e95d7..a8bf2aa4a0f0b 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -437,6 +437,8 @@ def quantize_per_tensor_default_decomp_impl( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) inv_scale = 1.0 / scale return torch.clamp( torch.round(input * inv_scale) + zero_point, quant_min, quant_max @@ -466,6 +468,8 @@ def quantize_per_tensor_tensor_decomp_impl( quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) inv_scale = 1.0 / scale return torch.clamp( torch.round(input * inv_scale) + zero_point, quant_min, quant_max diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 08f0608fa788a..c64defa89f700 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -46,7 +46,7 @@ def quantize_per_tensor( from floating point to quantized values Args: - input (torch.Tensor): original float32 Tensor + input (torch.Tensor): original float32 or bfloat16 Tensor scale (float): quantization parameter for affine quantization zero_point (int): quantization parameter for affine quantization quant_min (int): minimum quantized value for output Tensor @@ -57,6 +57,9 @@ def quantize_per_tensor( Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters are not stored in the Tensor, we are storing them in function arguments instead """ + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) @@ -355,7 +358,7 @@ def quantize_per_channel( parameters for each channel/axis to map from floating point to quantized values Args: - input (torch.Tensor): original float32 Tensor + input (torch.Tensor): original float32 or bfloat16 Tensor scales (torch.Tensor): a list of scale quantization parameter for affine quantization, one per channel zero_point (torch.Tensor): a list of zero_point quantization parameter for @@ -368,6 +371,9 @@ def quantize_per_channel( Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters are not stored in the Tensor, we are storing them in function arguments instead """ + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" _quant_min_max_bounds_check(quant_min, quant_max, dtype)