From 48e0634824b69399b90897264a8970f28f1bc0b9 Mon Sep 17 00:00:00 2001 From: Tracin <953719031@qq.com> Date: Tue, 14 Feb 2023 14:12:30 +0800 Subject: [PATCH] [Misc] Update about STPU. (#232) * [Misc] Update about STPU. --------- Co-authored-by: zhangqi3 --- mqbench/custom_quantizer/model_quantizer.py | 6 +-- .../custom_quantizer/total_int_quantizer.py | 11 +++-- mqbench/deploy/common.py | 7 ++++ mqbench/deploy/deploy_linear.py | 5 +-- mqbench/deploy/deploy_stpu.py | 37 +++++++++++++++-- mqbench/fuser_method_mappings.py | 41 +++++++------------ mqbench/prepare_by_platform.py | 4 +- 7 files changed, 67 insertions(+), 44 deletions(-) diff --git a/mqbench/custom_quantizer/model_quantizer.py b/mqbench/custom_quantizer/model_quantizer.py index 8a4a64bb..b05d8425 100644 --- a/mqbench/custom_quantizer/model_quantizer.py +++ b/mqbench/custom_quantizer/model_quantizer.py @@ -233,9 +233,7 @@ def _find_act_quants(self, model: GraphModule) -> List: ((node.op == 'call_function' or node.op == 'call_method') and node.target in self.function_type_to_quant_input) or node.name in self.additional_node_name: input_node_list = self._flatten_args(node.args) - # Means this is not Tensor + Tensor. - if not all([isinstance(_node, torch.fx.node.Node) for _node in input_node_list]): - continue + input_node_list = [_node for _node in input_node_list if isinstance(_node, torch.fx.node.Node)] for _node in input_node_list: if self._is_implicit_merge(modules, (node, _node)): logger.info("Implicit merge: {} + {}".format(_node.name, node.name)) @@ -272,4 +270,4 @@ def _convert(self, module, mapping=None, inplace=False, scope=''): for key, value in reassign.items(): module._modules[key] = value - return module \ No newline at end of file + return module diff --git a/mqbench/custom_quantizer/total_int_quantizer.py b/mqbench/custom_quantizer/total_int_quantizer.py index a1c54603..3846b3be 100644 --- a/mqbench/custom_quantizer/total_int_quantizer.py +++ b/mqbench/custom_quantizer/total_int_quantizer.py @@ -32,6 +32,7 @@ def _passed_func_type(self): @property def _passed_module_type(self): return ( + torch.nn.Dropout2d, torch.nn.ReLU, torch.nn.ReLU6 ) @@ -50,9 +51,11 @@ def _find_act_quants(self, model: GraphModule) -> list: ((node.op == 'call_function' or node.op == 'call_method') and node.target in self.function_type_to_quant_input): for next_node in node.users: - if not ((next_node.op == 'call_function' and next_node.target in self._passed_func_type) or + if ((next_node.op == 'call_function' and next_node.target in self._passed_func_type) or (next_node.op == 'call_module' and isinstance(modules[next_node.target], self._passed_module_type))): - node_need_to_quantize_output.append(node) - else: node_need_to_quantize_output.append(next_node) - return node_need_to_quantize_output \ No newline at end of file + elif self._is_implicit_merge(modules, (next_node, node)): + continue + else: + node_need_to_quantize_output.append(node) + return node_need_to_quantize_output diff --git a/mqbench/deploy/common.py b/mqbench/deploy/common.py index 31363fc7..6a3382e1 100644 --- a/mqbench/deploy/common.py +++ b/mqbench/deploy/common.py @@ -225,6 +225,13 @@ def prepare_initializer(graph): return named_initializer +def insert_initializer(graph, new_init): + for init in graph.initializer: + if init.name == new_init.name: + graph.initializer.remove(init) + graph.initializer.append(new_init) + + def parse_attrs(node_attrs): attrs = {} for attr in node_attrs: diff --git a/mqbench/deploy/deploy_linear.py b/mqbench/deploy/deploy_linear.py index b3c9fd50..ef7e8063 100644 --- a/mqbench/deploy/deploy_linear.py +++ b/mqbench/deploy/deploy_linear.py @@ -75,7 +75,6 @@ def deal_with_activation_fakequant(self, node, inp2node): next_nodes = inp2node[node.output[0]] for next_node, idx in next_nodes: next_node.input[idx] = node.input[0] - return def parse_qparams(self, node, name2data): tensor_name, scale, zero_point = node.input[:3] @@ -119,13 +118,13 @@ def post_process_clip_ranges(self, clip_ranges, graph, inp2node): def find_the_closest_clip_range(node): if node.input[0] in clip_ranges: return node.input[0] - elif node.op_type in ['Flatten', 'Resize'] and node.output[0] in inp2node: + elif node.op_type in ['Flatten', 'Resize', 'Reshape'] and node.output[0] in inp2node: return find_the_closest_clip_range(inp2node[node.output[0]][0][0]) else: return None for node in graph.node: - if node.op_type in ['Flatten', 'Resize']: + if node.op_type in ['Flatten', 'Resize', 'Reshape']: tensor_name = find_the_closest_clip_range(node) if tensor_name: clip_ranges[node.input[0]] = clip_ranges[tensor_name] diff --git a/mqbench/deploy/deploy_stpu.py b/mqbench/deploy/deploy_stpu.py index a16b7833..e44d975a 100644 --- a/mqbench/deploy/deploy_stpu.py +++ b/mqbench/deploy/deploy_stpu.py @@ -3,9 +3,10 @@ from collections import OrderedDict import onnx +from onnx import numpy_helper from mqbench.deploy.common import (get_constant_inputs, prepare_data, - prepare_initializer, + prepare_initializer, insert_initializer, update_inp2node_out2node) from mqbench.deploy.deploy_linear import (PERTENSOR_FAKEQUANTIZER, LinearQuantizer_process) @@ -17,10 +18,8 @@ class STPU_process(LinearQuantizer_process): def remove_fakequantize_and_collect_params(self, onnx_path, model_name): model = onnx.load(onnx_path) graph = model.graph - out2node, inp2node = update_inp2node_out2node(graph) name2data = prepare_data(graph) named_initializer = prepare_initializer(graph) - out2node, inp2node = update_inp2node_out2node(graph) quant_params = OrderedDict() @@ -57,6 +56,35 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): "min": -127 * scale, "max": 127 * scale } + # Merge Conv + mul + for conv_node in graph.node: + # Newwork output. + if conv_node.output[0] not in inp2node or len(inp2node[conv_node.output[0]]) < 1: + continue + mul_node = inp2node[conv_node.output[0]][0][0] + if conv_node.op_type == 'Conv' and mul_node.op_type == 'Mul': + mul_scale = numpy_helper.to_array(out2node[mul_node.input[1]].attribute[0].t) + weight_name = named_initializer[conv_node.input[1]].name + bias_name = named_initializer[conv_node.input[2]].name + weight = numpy_helper.to_array(named_initializer[conv_node.input[1]]) + bias = numpy_helper.to_array(named_initializer[conv_node.input[2]]) + new_weight = numpy_helper.from_array(weight * mul_scale) + new_bias = numpy_helper.from_array(bias * mul_scale) + new_weight.name = weight_name + new_bias.name = bias_name + insert_initializer(graph, new_weight) + insert_initializer(graph, new_bias) + quant_params[conv_node.name + '_weights']['min'] *= mul_scale + quant_params[conv_node.name + '_weights']['max'] *= mul_scale + # Delete mul node. + nodes_to_be_removed.append(mul_node) + conv_node.output[0] = mul_node.output[0] + # Pass concat + for node in graph.node: + if node.op_type == 'Concat' and node.output[0] in quant_params: + for node_input in node.input: + quant_params[node_input] = quant_params[node.output[0]] + logger.info(f'Pass {node.output[0]} range to {node.name} input {node_input}.') # Update bias scale = input scale * weight scale for node in graph.node: if node.op_type in ['Gemm', 'Conv'] and len(node.input) == 3: @@ -74,6 +102,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name): } quant_params = self.post_process_clip_ranges(quant_params, graph, inp2node) self.merge_relu_layer(graph, quant_params, out2node) + # Update emin. for node in graph.node: self.update_emin(node, quant_params, named_initializer) # Delete node and init. @@ -131,7 +160,7 @@ def find_conv_emin(i_vmax, w_vmax, o_vmax, n, r): if node.op_type in ['Upsample', 'DynamicUpsample']: emin = find_interp_emin(quant_params[node.output[0]]['max'], 2) - quant_params[node.output[0]]['emin'] = emin + quant_params[node.output[0]]['emin'] = emin if node.op_type in ['Conv', 'ConvTranspose']: weight_shape = named_initializer[node.input[1]].dims n = weight_shape[1] * weight_shape[2] * weight_shape[3] diff --git a/mqbench/fuser_method_mappings.py b/mqbench/fuser_method_mappings.py index 5535f2c5..d4baf662 100644 --- a/mqbench/fuser_method_mappings.py +++ b/mqbench/fuser_method_mappings.py @@ -1,5 +1,3 @@ -from typing import Optional, Type - import torch import torch.nn as nn from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion, ModuleReLUFusion @@ -13,7 +11,7 @@ from mqbench.nn.modules import FrozenBatchNorm2d -class ConvFreezebnReLUFusion(ConvBNReLUFusion): +class ConvExtendBnReLUFusion(ConvBNReLUFusion): def __init__(self, quantizer: QuantizerCls, node: Node): super(ConvBNReLUFusion, self).__init__(quantizer, node) self.relu_node = None @@ -87,39 +85,27 @@ def fuse_deconv_bn_relu(deconv, bn, relu): def fuse_conv_freezebn(conv, bn): assert bn.training is False, "Freezebn must be eval." - fused_module_class_map = { - nn.Conv2d: qnni.ConvFreezebn2d, - } - if conv.training: assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' - fused_module_class = fused_module_class_map.get((type(conv)), None) - return fused_module_class(conv, bn) + return qnni.ConvFreezebn2d(conv, bn) else: return nn.utils.fuse_conv_bn_eval(conv, bn) def fuse_conv_freezebn_relu(conv, bn, relu): - assert conv.training == relu.training and bn.training is False, "Conv and relu both must be in the same mode (train or eval) and bn must be eval." - fused_module : Optional[Type[nn.Sequential]] = None + assert conv.training == relu.training and bn.training is False, \ + "Conv and relu both must be in the same mode (train or eval) and bn must be eval." + if conv.training: - map_to_fused_module_train = { - nn.Conv2d: qnni.ConvFreezebnReLU2d, - } assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' assert bn.affine, 'Only support fusing BatchNorm with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' - fused_module = map_to_fused_module_train.get(type(conv), None) - return fused_module(conv, bn, relu) + return qnni.ConvFreezebnReLU2d(conv, bn, relu) else: - map_to_fused_module_eval = { - nn.Conv2d: nn.intrinsic.ConvReLU2d, - } - fused_module = map_to_fused_module_eval.get(type(conv), None) - fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) - return fused_module(fused_conv, relu) + fused_conv = nn.utils.fuse_conv_bn_eval(conv, bn) + return nn.intrinsic.ConvReLU2d(fused_conv, relu) def fuse_deconv_freezebn(deconv, bn): @@ -135,7 +121,8 @@ def fuse_deconv_freezebn(deconv, bn): def fuse_deconv_freezebn_relu(deconv, bn, relu): - assert deconv.training == relu.training and bn.training is False, "Conv and relu both must be in the same mode (train or eval) and bn must be eval." + assert deconv.training == relu.training and bn.training is False, \ + "Conv and relu both must be in the same mode (train or eval) and bn must be eval." if deconv.training: assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d' @@ -171,13 +158,13 @@ def fuse_deconv_freezebn_relu(deconv, bn, relu): (torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)): ConvBNReLUFusion, (torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.Conv2d)): - ConvFreezebnReLUFusion, + ConvExtendBnReLUFusion, (FrozenBatchNorm2d, torch.nn.Conv2d): - ConvFreezebnReLUFusion, + ConvExtendBnReLUFusion, (torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.ConvTranspose2d)): - ConvFreezebnReLUFusion, + ConvExtendBnReLUFusion, (FrozenBatchNorm2d, torch.nn.ConvTranspose2d): - ConvFreezebnReLUFusion, + ConvExtendBnReLUFusion, }, "additional_qat_module_mappings": { nn.ConvTranspose2d: qnn.qat.ConvTranspose2d, diff --git a/mqbench/prepare_by_platform.py b/mqbench/prepare_by_platform.py index f1eab9c3..ccd0d5be 100644 --- a/mqbench/prepare_by_platform.py +++ b/mqbench/prepare_by_platform.py @@ -123,8 +123,8 @@ class BackendType(Enum): default_weight_observer=MinMaxObserver, default_act_observer=EMAMinMaxObserver), BackendType.STPU: dict(qtype="affine", - w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8), - a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8), + w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True), + a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True), default_weight_quantize=FixedFakeQuantize, default_act_quantize=FixedFakeQuantize, default_weight_observer=MinMaxObserver,