Skip to content

Commit

Permalink
eliminate sliding window ops with size/stride = 1, that result in a nop
Browse files Browse the repository at this point in the history
  • Loading branch information
Viktor Gyenes committed Oct 11, 2021
1 parent a339121 commit 0295fde
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion nnef_tools/optimization/nnef_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def __call__(self, graph, only_required=False):
self._is_constant(op.inputs[0], 1.0) or self._is_constant(op.inputs[1], 1.0))
changed |= self._remove_identity_ops(graph, 'add', lambda op:
self._is_constant(op.inputs[0], 0.0) or self._is_constant(op.inputs[1], 0.0))
changed |= self._remove_identity_ops(graph, ('box', 'debox', 'avg_pool', 'max_pool'), lambda op:
self._is_uniform(op.attribs['size'], 1) and
self._is_uniform(op.attribs['stride'], 1) and
self._is_uniform(op.attribs['dilation'], 1) and
self._is_uniform(op.attribs['padding'], 0))
changed |= self._remove_identity_ops(graph,
('nearest_downsample', 'area_downsample', 'nearest_upsample', 'multilinear_upsample'), lambda op:
self._is_uniform(op.attribs['factor'], 1))

changed |= self._remove_inverse_ops(graph, 'squeeze', 'unsqueeze', lambda op1, op2:
op1.attribs['axes'] == op2.attribs['axes'])
Expand Down Expand Up @@ -116,10 +124,14 @@ def _insert_copy(tensor, copy=None):
Operation(tensor.graph, type='copy', inputs=tensor, outputs=copy)
return copy

@staticmethod
def _match_op_type(type, types):
return type in types if isinstance(types, tuple) else type == types

def _remove_identity_ops(self, graph, type, cond):
changed = False
for op in graph.operations:
if op.type == type and cond(op) and op.input.quant == op.output.quant:
if self._match_op_type(op.type, type) and cond(op) and op.input.quant == op.output.quant:
changed |= self._bypass_and_remove(graph, op)

return changed
Expand Down Expand Up @@ -428,6 +440,10 @@ def _is_constant(tensor, value):

return (not isinstance(tensor.data, np.ndarray) or data.shape == ()) and data == value

@staticmethod
def _is_uniform(array, value):
return all(item == value for item in array)

@staticmethod
def _merge_transpose_squeeze(transpose, squeeze):
transpose_axes = transpose.attribs['axes']
Expand Down

0 comments on commit 0295fde

Please sign in to comment.