Skip to content

Commit

Permalink
[Quant] [PT2] Add ConvBNAdd(ReLU) Annotation into X86InductorQuantizer (
Browse files Browse the repository at this point in the history
pytorch#111281)

**Summary**
This PR adds ConvBNAdd(ReLU) QAT Annotation into `X86InductorQuantizer`.

**Test Plan**
```
python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_binary_with_quantizer_api
python -m pytest test_x86inductor_quantizer.py -k test_qat_conv2d_binary_unary_with_quantizer_api
python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d_add
python -m pytest test_mkldnn_pattern_matcher.py -k test_qat_qconv2d_add_relu
```

Pull Request resolved: pytorch#111281
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: pytorch#111280
  • Loading branch information
leslie-fang-intel authored and Skylion007 committed Nov 14, 2023
1 parent d88b6a8 commit aa05f19
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 28 deletions.
127 changes: 127 additions & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,133 @@ def forward(self, x):
is_qat=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_add(self):
r"""
This testcase will quantize a Conv2d->Add pattern as:
X
/ \
Conv1(X) Conv2(X)
\ /
Add
|
Y
"""

class M(torch.nn.Module):
def __init__(
self,
**kwargs,
):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.bn1 = torch.nn.BatchNorm2d(6)
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.bn2 = torch.nn.BatchNorm2d(6)

def forward(self, x):
x1 = self.bn1(self.conv1(x))
x2 = self.bn2(self.conv2(x))
return x1 + x2

mod = M().train()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
# Totally 8 pattern_matcher_count, 39 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# However when dq has multiple users we will have
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# \-> to_float -> sub -> mul]
# So for now we will discount one pattern here
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 4. Quantization fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3,
# mul_6, round_4, add_4, clamp_min_3, clamp_max_3, convert_element_type_6]
self._test_common(
mod,
(v,),
7,
37,
check_quantization=True,
is_qat=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_add_relu(self):
r"""
This testcase will quantize a Conv2d->Add->ReLU pattern as:
X
/ \
Conv1(X) Conv2(X)
\ /
Add
|
ReLU
|
Y
"""

class M(torch.nn.Module):
def __init__(
self,
**kwargs,
):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.bn1 = torch.nn.BatchNorm2d(6)
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.bn2 = torch.nn.BatchNorm2d(6)
self.relu = torch.nn.ReLU()

def forward(self, x):
x1 = self.bn1(self.conv1(x))
x2 = self.bn2(self.conv2(x))
return self.relu(x1 + x2)

mod = M().train()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
# Totally 8 pattern_matcher_count, 40 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
# pointless_convert. A pointless convert appears in [q -> dq] decomposed, in inductor
# decomp, as [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# However when dq has multiple users we will have
# [mul(fp32) -> add(fp32) -> to_int8 -> to_float -> sub -> mul]
# \-> to_float -> sub -> mul]
# So for now we will discount one pattern here
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 4. Quantization fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu,
# mul_6, round_4, add_4, clamp_min_3, clamp_max_3, convert_element_type_6]
self._test_common(
mod,
(v,),
7,
38,
check_quantization=True,
is_qat=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
Expand Down
117 changes: 108 additions & 9 deletions test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
inplace_add: bool = False,
conv2d_type: Conv2DType = Conv2DType.left,
use_bias: bool = False,
with_bn: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
Expand All @@ -73,15 +74,22 @@ def __init__(self,
self.relu = nn.ReLU()
self.inplace_add = inplace_add
self.conv2d_type = conv2d_type
self.bn = torch.nn.BatchNorm2d(3)
self.with_bn = with_bn

def forward(self, x):
if self.conv2d_type == Conv2DType.left:
if self.inplace_add:
tmp = self.conv(x)
if self.with_bn:
tmp = self.bn(tmp)
tmp += self.relu(x)
return tmp
else:
return self.conv(x) + self.relu(x)
tmp = self.conv(x)
if self.with_bn:
tmp = self.bn(tmp)
return tmp + self.relu(x)
elif self.conv2d_type == Conv2DType.right:
if self.inplace_add:
tmp = self.relu(x)
Expand All @@ -103,6 +111,7 @@ def __init__(self,
conv2d_type: Conv2DType = Conv2DType.left,
inplace_relu: bool = False,
use_bias: bool = False,
with_bn: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
Expand All @@ -115,15 +124,22 @@ def __init__(self,
self.inplace_add = inplace_add
self.conv2d_type = conv2d_type
self.relu2 = nn.ReLU(inplace=inplace_relu)
self.bn = torch.nn.BatchNorm2d(3)
self.with_bn = with_bn

def forward(self, x):
if self.conv2d_type == Conv2DType.left:
if self.inplace_add:
tmp = self.conv(x)
if self.with_bn:
tmp = self.bn(tmp)
tmp += self.relu(x)
return self.relu2(tmp)
else:
return self.relu2(self.conv(x) + self.relu(x))
tmp = self.conv(x)
if self.with_bn:
tmp = self.bn(tmp)
return self.relu2(tmp + self.relu(x))
elif self.conv2d_type == Conv2DType.right:
if self.inplace_add:
tmp = self.relu(x)
Expand Down Expand Up @@ -284,7 +300,7 @@ def _test_quantizer(
@skipIfNoDynamoSupport
class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
@skipIfNoX86
def test_conv2d_with_quantizer_api(self):
def test_conv2d(self):
"""
Test pattern of single conv2d with X86InductorQuantizer.
"""
Expand Down Expand Up @@ -318,7 +334,7 @@ def test_conv2d_with_quantizer_api(self):
)

@skipIfNoX86
def test_conv2d_unary_with_quantizer_api(self):
def test_conv2d_unary(self):
"""
Test pattern of conv2d with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
Expand Down Expand Up @@ -357,7 +373,7 @@ def test_conv2d_unary_with_quantizer_api(self):
)

@skipIfNoX86
def test_conv2d_binary_with_quantizer_api(self):
def test_conv2d_binary(self):
"""
Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
Expand Down Expand Up @@ -411,7 +427,7 @@ def test_conv2d_binary_with_quantizer_api(self):
)

@skipIfNoX86
def test_conv2d_binary_unary_with_quantizer_api(self):
def test_conv2d_binary_unary(self):
"""
Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
Expand Down Expand Up @@ -467,7 +483,7 @@ def test_conv2d_binary_unary_with_quantizer_api(self):
)

@skipIfNoX86
def test_conv2d_serials_binary_unary_with_quantizer_api(self):
def test_conv2d_serials_binary_unary(self):
"""
Test pattern of 2 following up conv2d add relu with X86InductorQuantizer.
"""
Expand Down Expand Up @@ -889,7 +905,7 @@ def test_linear_unary(self):
)

@skipIfNoX86
def test_qat_conv2d_with_quantizer_api(self):
def test_qat_conv2d(self):
"""
Test QAT pattern of conv2d_bn with X86InductorQuantizer.
"""
Expand Down Expand Up @@ -926,7 +942,7 @@ def test_qat_conv2d_with_quantizer_api(self):
)

@skipIfNoX86
def test_qat_conv2d_unary_with_quantizer_api(self):
def test_qat_conv2d_unary(self):
"""
Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
Expand Down Expand Up @@ -965,3 +981,86 @@ def test_qat_conv2d_unary_with_quantizer_api(self):
node_list,
is_qat=True,
)

@skipIfNoX86
def test_qat_conv2d_binary(self):
"""
Test qat pattern of conv2d_bn with binary post ops (such as add) with X86InductorQuantizer.
Currently, only add as binary post op is supported.
"""
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
)
with override_quantized_engine("x86"):
for inplace_add in [True, False]:
m = TestHelperModules.Conv2dAddModule(inplace_add=inplace_add, with_bn=True)
node_occurrence = {
# one for input and weight of the conv
# one for output for the add
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
# quantize_per_channel for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
is_qat=True,
)

@skipIfNoX86
def test_qat_conv2d_binary_unary(self):
"""
Test QAT pattern of conv2d_bn with binary + unary post ops (such as add + relu) with X86InductorQuantizer.
Currently, only add as binary post op and relu as unary post op are supported.
"""
example_inputs = (torch.randn(2, 3, 6, 6),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
)
with override_quantized_engine("x86"):
m = TestHelperModules.Conv2dAddReLUModule(with_bn=True)
node_occurrence = {
# one for input for conv
# one for output for the relu
# one for extra input node of add
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
# note: quantize op for weights are const propagated
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.add.Tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
self._test_quantizer(
m,
example_inputs,
quantizer,
node_occurrence,
node_list,
is_qat=True,
)
Loading

0 comments on commit aa05f19

Please sign in to comment.