Skip to content

Commit

Permalink
Pytorch1.5 update extract method (#876)
Browse files Browse the repository at this point in the history
* change trace function for pytorch 1.5 (#870)

Co-authored-by: XiaoXYe <50827462+XiaoXYe@users.noreply.github.com>

Co-authored-by: nate.river <lvyufeng2007@hotmail.com>
  • Loading branch information
XiaoXYe and lvyufeng committed Jul 31, 2020
1 parent f0a9798 commit e3dbf30
Showing 1 changed file with 86 additions and 53 deletions.
139 changes: 86 additions & 53 deletions mmdnn/conversion/pytorch/pytorch_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,44 @@
import contextlib
from torch.jit import _unique_state_dict

class scope_name_workaround(object):
def __init__(self):
self.backup = None

def __enter__(self):
def _tracing_name(self_, tracing_state):
if not tracing_state._traced_module_stack:
return None
module = tracing_state._traced_module_stack[-1]
for name, child in module.named_children():
if child is self_:
return name
return None

def _slow_forward(self_, *input, **kwargs):
tracing_state = torch._C._get_tracing_state()
if not tracing_state or isinstance(self_.forward, torch._C.ScriptMethod):
return self_.forward(*input, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = _tracing_name(self_, tracing_state)
if name:
tracing_state.push_scope('%s[%s]' % (self_._get_name(), name))
else:
tracing_state.push_scope(self_._get_name())
tracing_state._traced_module_stack.append(self_)
try:
result = self_.forward(*input, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result

self.backup = torch.nn.Module._slow_forward
setattr(torch.nn.Module, '_slow_forward', _slow_forward)

def __exit__(self, type, value, tb):
setattr(torch.nn.Module, '_slow_forward', self.backup)

class PytorchGraphNode(GraphNode):

Expand Down Expand Up @@ -71,54 +108,12 @@ def __init__(self, model):
self.shape_dict = dict()
self.layer_weight_map = dict()


@staticmethod
def _optimize_graph(graph, aten, export_raw_ir=False):
# run dce first to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override

torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)

torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)
if not export_raw_ir:
graph = torch._C._jit_pass_onnx(graph, aten)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
return graph


@staticmethod
def get_node_id(node):
import re
node_id = re.search(r"[\d]+", node.__str__())
return node_id.group(0)

@contextlib.contextmanager
def set_training(self, model, mode):
r"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block. A no-op if
mode is None.
"""
if mode is None:
yield
return
old_mode = model.training
if old_mode != mode:
model.train(mode)
try:
yield
finally:
if old_mode != mode:
model.train(old_mode)


def build(self, shape):
"""
Expand Down Expand Up @@ -180,6 +175,45 @@ def node_connection(self, graph, node, node_name):
def CreateGraphNode(self, node):
return PytorchGraphNode040(node)

@staticmethod
def _optimize_graph(graph, aten, export_raw_ir=False):
# run dce first to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override

torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)

torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)
if not export_raw_ir:
graph = torch._C._jit_pass_onnx(graph, aten)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
return graph

@contextlib.contextmanager
def set_training(self, model, mode):
r"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block. A no-op if
mode is None.
"""
if mode is None:
yield
return
old_mode = model.training
if old_mode != mode:
model.train(mode)
try:
yield
finally:
if old_mode != mode:
model.train(old_mode)

class PytorchGraph151(PytorchGraph):

Expand All @@ -188,22 +222,21 @@ def __init__(self, model):

def extractgraph(self, dummy_input):
import re
import torch.onnx.utils
# connect name and id in nodes with weights
graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input, _retain_param_name=True)
from torch.onnx.utils import OperatorExportTypes
from torch.onnx.utils import _trace

self.model.eval()
with scope_name_workaround():
graph = _trace(self.model, dummy_input, OperatorExportTypes.ONNX)
nodes = list(graph.nodes())

for node in nodes:
# print(node.__str__())
node_id = PytorchGraph.get_node_id(node)
node_name = 'node' + node_id
node_scope_str = re.findall(r'[^()!]+', node.__str__())[-2]
for x in node_scope_str.split(','):
if re.findall(r'%\S+.weight', x):
node_scope = '.'.join(re.findall(r'%\S+.weight', x)[0].replace('%','',1).split('.')[:-1])
self.layer_weight_map[node_name] = node_scope

graph, params_dict, torch_out = torch.onnx.utils._model_to_graph(self.model, dummy_input)
nodes = list(graph.nodes())
self.layer_weight_map[node_name] = '.'.join(
re.findall(r'\[([\w\d.]+)\]', node.scopeName())
)
return graph, nodes

def rename_nodes(self, node, node_id):
Expand Down

0 comments on commit e3dbf30

Please sign in to comment.