Skip to content

Commit

Permalink
[optimize] requantize while input and output quantization parameters …
Browse files Browse the repository at this point in the history
…of cat are different
  • Loading branch information
LynnL4 committed May 31, 2023
1 parent d150cf8 commit 6f62020
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2698,6 +2698,43 @@ def cat_split_pass(self):
self.graph.add_operator(op, transform=True)

self.graph.try_restore_edges(mapping)

@class_conditional(lambda self: self.tflite_micro_rewrite)
def cat_requantize(self):
vertices = self.graph.graph.vs.select(functools.partial(is_quantize_cat_node, graph_converter=self.graph.graph))
# If the quantization parameter of the concatenate input is different from the output quantization parameter,
# add a quantization operator after the input, and reconnect the output of the quantization operator to the
# input of the concatenate operator.
for cat in vertices:
op = cat['op']
op_out = cat['op'].outputs[0]
for i, op_in in enumerate(op.inputs):
if op_in.quantization.scale != op_out.quantization.scale or op_in.quantization.zero_point != op_out.quantization.zero_point:

requantized = self.create_transform_tensor(op_in.tensor.copy(), quantization=op_out.quantization)
quantize_op = tfl.QuantizeOperator([op_in], [requantized])
self.graph.add_operator(quantize_op)

# Get the node of the requantized tensor
requantized_node_name = self.graph.tensor_node_map[requantized.name]
requantized_node = self.graph.graph.vs.find(name=requantized_node_name)

# Connect the quantize op to the graph
node_in = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_in.name])
self.graph.replace_next_tensors(node_in, requantized_node, requantized.name, [requantized_node_name])

# Connect the quantize op to the graph
self.graph.replace_operator_input(cat, i, requantized)











def input_transpose_pass(self):
nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
Expand Down Expand Up @@ -3307,9 +3344,13 @@ def optimize(self):
# TFLite micro specific
self.cat_split_pass()
self.split_requantize()

self.cat_requantize()

# Group the same tensors into one
self.group_tensors_pass()




# Final cleanup
self.cleanup_dead_nodes()
Expand Down Expand Up @@ -3441,6 +3482,9 @@ def is_requantize_node(vertex: ig.Vertex, graph_converter: ig.Graph):
def is_large_cat_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return vertex['node_type'] == ExtendedOperator.CONCATENATION and len(vertex['op'].inputs) > 10

def is_quantize_cat_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return vertex['node_type'] == ExtendedOperator.CONCATENATION and vertex['op'].outputs[0].quantization is not None


def is_high_dim_transpose_node(vertex: ig.Vertex, graph_converter: ig.Graph, max_transpose_dims: int):
return vertex['node_type'] == ExtendedOperator.TRANSPOSE and vertex['op'].inputs[1].tensor.size > max_transpose_dims
Expand Down

0 comments on commit 6f62020

Please sign in to comment.