Skip to content

Commit

Permalink
[STPU] Add STPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqi3 authored and Tracin committed Jan 13, 2023
1 parent c4a84db commit b8e3015
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions mqbench/custom_quantizer/total_int_quantizer.py
Expand Up @@ -6,6 +6,7 @@
from mqbench.custom_quantizer import ModelQuantizer


@register_model_quantizer(BackendType.STPU)
@register_model_quantizer(BackendType.PPLCUDA)
@register_model_quantizer(BackendType.SNPE)
@register_model_quantizer(BackendType.PPLW8A16)
Expand Down
1 change: 1 addition & 0 deletions mqbench/deploy/__init__.py
Expand Up @@ -4,3 +4,4 @@
from .deploy_onnx_qnn import ONNXQNNPass
from .deploy_openvino import replace_fakequantize_and_collect_params_openvino
from .deploy_tengine import remove_fakequantize_and_collect_params_tengine
from .deploy_stpu import remove_fakequantize_and_collect_params_stpu
8 changes: 8 additions & 0 deletions mqbench/deploy/deploy_stpu.py
Expand Up @@ -78,6 +78,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name):
"max": 1
}
quant_params = self.post_process_clip_ranges(quant_params, graph, inp2node)
self.merge_relu_layer(graph, quant_params, out2node)
for node in graph.node:
self.update_emin(node, quant_params, named_initializer)
# Delete node and init.
Expand All @@ -98,6 +99,13 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_name):

logger.info("Finish deploy process.")

def merge_relu_layer(self, graph, quant_params, out2node):
for node in graph.node:
if node.op_type == 'Relu' and node.output[0] in quant_params:
prev_nodes = out2node[node.input[0]]
quant_params[prev_nodes.output[0]] = quant_params[node.output[0]].copy()
logger.info("Merge conv + relu range pass {} to {}".format(node.output[0], prev_nodes.output[0]))

def update_emin(self, node, quant_params, named_initializer):
'''EMIN is some kind of magic number for STPU.
Do not try to understand it.
Expand Down
13 changes: 13 additions & 0 deletions test/backend/test_backend.py
Expand Up @@ -184,3 +184,16 @@ def test_quantize_tengine_u8(self):
loss.backward()
model_prepared.eval()
convert_deploy(model_prepared, BackendType.Tengine_u8, {'x': [1, 3, 224, 224]}, model_name='resnet18')

def test_quantize_stpu(self):
model_to_quantize = torch.hub.load(GITHUB_RES, 'resnet18', pretrained=False)
dummy_input = torch.randn(2, 3, 224, 224, device='cpu')
model_to_quantize.train()
model_prepared = prepare_by_platform(model_to_quantize, BackendType.STPU)
enable_calibration(model_prepared)
model_prepared(dummy_input)
enable_quantization(model_prepared)
loss = model_prepared(dummy_input).sum()
loss.backward()
model_prepared.eval()
convert_deploy(model_prepared, BackendType.STPU, {'x': [1, 3, 224, 224]}, model_name='resnet18')

0 comments on commit b8e3015

Please sign in to comment.