Skip to content

Commit

Permalink
[Feature] duplicate the reused nodes in the graph. (#73)
Browse files Browse the repository at this point in the history
* [Feature]: duplicate the reused nodes in the graph.
Some nodes could be reused to save model size, but it
will cause reconstruction issues, e.g. a node is recon by
many blocks.

Co-authored-by: fanyunqian <fanyunqian@sensetime.com>
  • Loading branch information
PannenetsF and fanyunqian committed Apr 11, 2022
1 parent 78e6051 commit cc07b13
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion mqbench/prepare_by_platform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from enum import Enum
from typing import Any, Dict

Expand Down Expand Up @@ -290,6 +291,26 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo
return True
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)

def duplicate_reused_nodes(graph: torch.fx.Graph, modules: Dict[str, Any] = {}):
_dup_prefix = '_dup'
target_dict = dict()
dup_modules = dict()
for node in graph.nodes:
if isinstance(node.target, str):
if node.target not in target_dict:
target_dict[node.target] = [node]
else:
target_dict[node.target].append(node)
for key in target_dict:
if len(target_dict[key]) > 1:
for idx, node in enumerate(target_dict[key]):
if idx == 0:
continue
module = deepcopy(modules[node.target])
node.target += _dup_prefix + str(idx)
dup_modules[node.target] = module
graph.lint()
return graph, dup_modules

def prepare_by_platform(
model: torch.nn.Module,
Expand Down Expand Up @@ -341,7 +362,10 @@ def prepare_by_platform(
tracer = custom_tracer
graph = tracer.trace(model, concrete_args)
name = model.__class__.__name__ if isinstance(model, torch.nn.Module) else model.__name__
graph_module = GraphModule(model, graph, name)
modules = dict(model.named_modules())
graph, duplicated_modules = duplicate_reused_nodes(graph, modules)
modules.update(duplicated_modules)
graph_module = GraphModule(modules, graph, name)
# Model fusion.
extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {})
extra_fuse_dict.update(fuse_custom_config_dict)
Expand Down

0 comments on commit cc07b13

Please sign in to comment.