Skip to content

Commit

Permalink
[Misc] Update about STPU. (#232)
Browse files Browse the repository at this point in the history
* [Misc] Update about STPU.

---------

Co-authored-by: zhangqi3 <zhangqi3@sensetime.com>
  • Loading branch information
Tracin and zhangqi3 committed Feb 14, 2023
1 parent e2f6d78 commit 48e0634
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 44 deletions.
6 changes: 2 additions & 4 deletions mqbench/custom_quantizer/model_quantizer.py
Expand Up @@ -233,9 +233,7 @@ def _find_act_quants(self, model: GraphModule) -> List:
((node.op == 'call_function' or node.op == 'call_method') and
node.target in self.function_type_to_quant_input) or node.name in self.additional_node_name:
input_node_list = self._flatten_args(node.args)
# Means this is not Tensor + Tensor.
if not all([isinstance(_node, torch.fx.node.Node) for _node in input_node_list]):
continue
input_node_list = [_node for _node in input_node_list if isinstance(_node, torch.fx.node.Node)]
for _node in input_node_list:
if self._is_implicit_merge(modules, (node, _node)):
logger.info("Implicit merge: {} + {}".format(_node.name, node.name))
Expand Down Expand Up @@ -272,4 +270,4 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
for key, value in reassign.items():
module._modules[key] = value

return module
return module
11 changes: 7 additions & 4 deletions mqbench/custom_quantizer/total_int_quantizer.py
Expand Up @@ -32,6 +32,7 @@ def _passed_func_type(self):
@property
def _passed_module_type(self):
return (
torch.nn.Dropout2d,
torch.nn.ReLU,
torch.nn.ReLU6
)
Expand All @@ -50,9 +51,11 @@ def _find_act_quants(self, model: GraphModule) -> list:
((node.op == 'call_function' or node.op == 'call_method') and
node.target in self.function_type_to_quant_input):
for next_node in node.users:
if not ((next_node.op == 'call_function' and next_node.target in self._passed_func_type) or
if ((next_node.op == 'call_function' and next_node.target in self._passed_func_type) or
(next_node.op == 'call_module' and isinstance(modules[next_node.target], self._passed_module_type))):
node_need_to_quantize_output.append(node)
else:
node_need_to_quantize_output.append(next_node)
return node_need_to_quantize_output
elif self._is_implicit_merge(modules, (next_node, node)):
continue
else:
node_need_to_quantize_output.append(node)
return node_need_to_quantize_output
7 changes: 7 additions & 0 deletions mqbench/deploy/common.py
Expand Up @@ -225,6 +225,13 @@ def prepare_initializer(graph):
return named_initializer


def insert_initializer(graph, new_init):
for init in graph.initializer:
if init.name == new_init.name:
graph.initializer.remove(init)
graph.initializer.append(new_init)


def parse_attrs(node_attrs):
attrs = {}
for attr in node_attrs:
Expand Down
5 changes: 2 additions & 3 deletions mqbench/deploy/deploy_linear.py
Expand Up @@ -75,7 +75,6 @@ def deal_with_activation_fakequant(self, node, inp2node):
next_nodes = inp2node[node.output[0]]
for next_node, idx in next_nodes:
next_node.input[idx] = node.input[0]
return

def parse_qparams(self, node, name2data):
tensor_name, scale, zero_point = node.input[:3]
Expand Down Expand Up @@ -119,13 +118,13 @@ def post_process_clip_ranges(self, clip_ranges, graph, inp2node):
def find_the_closest_clip_range(node):
if node.input[0] in clip_ranges:
return node.input[0]
elif node.op_type in ['Flatten', 'Resize'] and node.output[0] in inp2node:
elif node.op_type in ['Flatten', 'Resize', 'Reshape'] and node.output[0] in inp2node:
return find_the_closest_clip_range(inp2node[node.output[0]][0][0])
else:
return None

for node in graph.node:
if node.op_type in ['Flatten', 'Resize']:
if node.op_type in ['Flatten', 'Resize', 'Reshape']:
tensor_name = find_the_closest_clip_range(node)
if tensor_name:
clip_ranges[node.input[0]] = clip_ranges[tensor_name]
Expand Down
37 changes: 33 additions & 4 deletions mqbench/deploy/deploy_stpu.py
Expand Up @@ -3,9 +3,10 @@
from collections import OrderedDict

import onnx
from onnx import numpy_helper

from mqbench.deploy.common import (get_constant_inputs, prepare_data,
prepare_initializer,
prepare_initializer, insert_initializer,
update_inp2node_out2node)
from mqbench.deploy.deploy_linear import (PERTENSOR_FAKEQUANTIZER,
LinearQuantizer_process)
Expand All @@ -17,10 +18,8 @@ class STPU_process(LinearQuantizer_process):
def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
model = onnx.load(onnx_path)
graph = model.graph
out2node, inp2node = update_inp2node_out2node(graph)
name2data = prepare_data(graph)
named_initializer = prepare_initializer(graph)

out2node, inp2node = update_inp2node_out2node(graph)

quant_params = OrderedDict()
Expand Down Expand Up @@ -57,6 +56,35 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
"min": -127 * scale,
"max": 127 * scale
}
# Merge Conv + mul
for conv_node in graph.node:
# Newwork output.
if conv_node.output[0] not in inp2node or len(inp2node[conv_node.output[0]]) < 1:
continue
mul_node = inp2node[conv_node.output[0]][0][0]
if conv_node.op_type == 'Conv' and mul_node.op_type == 'Mul':
mul_scale = numpy_helper.to_array(out2node[mul_node.input[1]].attribute[0].t)
weight_name = named_initializer[conv_node.input[1]].name
bias_name = named_initializer[conv_node.input[2]].name
weight = numpy_helper.to_array(named_initializer[conv_node.input[1]])
bias = numpy_helper.to_array(named_initializer[conv_node.input[2]])
new_weight = numpy_helper.from_array(weight * mul_scale)
new_bias = numpy_helper.from_array(bias * mul_scale)
new_weight.name = weight_name
new_bias.name = bias_name
insert_initializer(graph, new_weight)
insert_initializer(graph, new_bias)
quant_params[conv_node.name + '_weights']['min'] *= mul_scale
quant_params[conv_node.name + '_weights']['max'] *= mul_scale
# Delete mul node.
nodes_to_be_removed.append(mul_node)
conv_node.output[0] = mul_node.output[0]
# Pass concat
for node in graph.node:
if node.op_type == 'Concat' and node.output[0] in quant_params:
for node_input in node.input:
quant_params[node_input] = quant_params[node.output[0]]
logger.info(f'Pass {node.output[0]} range to {node.name} input {node_input}.')
# Update bias scale = input scale * weight scale
for node in graph.node:
if node.op_type in ['Gemm', 'Conv'] and len(node.input) == 3:
Expand All @@ -74,6 +102,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
}
quant_params = self.post_process_clip_ranges(quant_params, graph, inp2node)
self.merge_relu_layer(graph, quant_params, out2node)
# Update emin.
for node in graph.node:
self.update_emin(node, quant_params, named_initializer)
# Delete node and init.
Expand Down Expand Up @@ -131,7 +160,7 @@ def find_conv_emin(i_vmax, w_vmax, o_vmax, n, r):

if node.op_type in ['Upsample', 'DynamicUpsample']:
emin = find_interp_emin(quant_params[node.output[0]]['max'], 2)
quant_params[node.output[0]]['emin'] = emin
quant_params[node.output[0]]['emin'] = emin
if node.op_type in ['Conv', 'ConvTranspose']:
weight_shape = named_initializer[node.input[1]].dims
n = weight_shape[1] * weight_shape[2] * weight_shape[3]
Expand Down
41 changes: 14 additions & 27 deletions mqbench/fuser_method_mappings.py
@@ -1,5 +1,3 @@
from typing import Optional, Type

import torch
import torch.nn as nn
from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion, ModuleReLUFusion
Expand All @@ -13,7 +11,7 @@
from mqbench.nn.modules import FrozenBatchNorm2d


class ConvFreezebnReLUFusion(ConvBNReLUFusion):
class ConvExtendBnReLUFusion(ConvBNReLUFusion):
def __init__(self, quantizer: QuantizerCls, node: Node):
super(ConvBNReLUFusion, self).__init__(quantizer, node)
self.relu_node = None
Expand Down Expand Up @@ -87,39 +85,27 @@ def fuse_deconv_bn_relu(deconv, bn, relu):
def fuse_conv_freezebn(conv, bn):
assert bn.training is False, "Freezebn must be eval."

fused_module_class_map = {
nn.Conv2d: qnni.ConvFreezebn2d,
}

if conv.training:
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
fused_module_class = fused_module_class_map.get((type(conv)), None)
return fused_module_class(conv, bn)
return qnni.ConvFreezebn2d(conv, bn)
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)


def fuse_conv_freezebn_relu(conv, bn, relu):
assert conv.training == relu.training and bn.training is False, "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
fused_module : Optional[Type[nn.Sequential]] = None
assert conv.training == relu.training and bn.training is False, \
"Conv and relu both must be in the same mode (train or eval) and bn must be eval."

if conv.training:
map_to_fused_module_train = {
nn.Conv2d: qnni.ConvFreezebnReLU2d,
}
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
fused_module = map_to_fused_module_train.get(type(conv), None)
return fused_module(conv, bn, relu)
return qnni.ConvFreezebnReLU2d(conv, bn, relu)
else:
map_to_fused_module_eval = {
nn.Conv2d: nn.intrinsic.ConvReLU2d,
}
fused_module = map_to_fused_module_eval.get(type(conv), None)
fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
return fused_module(fused_conv, relu)
fused_conv = nn.utils.fuse_conv_bn_eval(conv, bn)
return nn.intrinsic.ConvReLU2d(fused_conv, relu)


def fuse_deconv_freezebn(deconv, bn):
Expand All @@ -135,7 +121,8 @@ def fuse_deconv_freezebn(deconv, bn):


def fuse_deconv_freezebn_relu(deconv, bn, relu):
assert deconv.training == relu.training and bn.training is False, "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
assert deconv.training == relu.training and bn.training is False, \
"Conv and relu both must be in the same mode (train or eval) and bn must be eval."

if deconv.training:
assert bn.num_features == deconv.out_channels, 'Output channel of ConvTranspose2d must match num_features of BatchNorm2d'
Expand Down Expand Up @@ -171,13 +158,13 @@ def fuse_deconv_freezebn_relu(deconv, bn, relu):
(torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.ConvTranspose2d)):
ConvBNReLUFusion,
(torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.Conv2d)):
ConvFreezebnReLUFusion,
ConvExtendBnReLUFusion,
(FrozenBatchNorm2d, torch.nn.Conv2d):
ConvFreezebnReLUFusion,
ConvExtendBnReLUFusion,
(torch.nn.ReLU, (FrozenBatchNorm2d, torch.nn.ConvTranspose2d)):
ConvFreezebnReLUFusion,
ConvExtendBnReLUFusion,
(FrozenBatchNorm2d, torch.nn.ConvTranspose2d):
ConvFreezebnReLUFusion,
ConvExtendBnReLUFusion,
},
"additional_qat_module_mappings": {
nn.ConvTranspose2d: qnn.qat.ConvTranspose2d,
Expand Down
4 changes: 2 additions & 2 deletions mqbench/prepare_by_platform.py
Expand Up @@ -123,8 +123,8 @@ class BackendType(Enum):
default_weight_observer=MinMaxObserver,
default_act_observer=EMAMinMaxObserver),
BackendType.STPU: dict(qtype="affine",
w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8),
a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8),
w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True),
a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True),
default_weight_quantize=FixedFakeQuantize,
default_act_quantize=FixedFakeQuantize,
default_weight_observer=MinMaxObserver,
Expand Down

0 comments on commit 48e0634

Please sign in to comment.