Skip to content
Permalink
Browse files

make tf meta file temporary (#621)

* make tf meta file temporary
  • Loading branch information...
rainLiuplus committed Apr 2, 2019
1 parent ac1ade8 commit 56275b65128dbe1fab8460c5bde47084c47819c6
@@ -254,7 +254,7 @@ def emit_Pool(self, IR_node):
ceil_mode = self.is_ceil_mode(IR_node.get_attr('pads'))

# input_node = self._defuse_padding(IR_node, exstr)
code = "{:<15} = F.{}({}, kernel_size={}, stride={}, padding={}, ceil_mode={})".format(
code = "{:<15} = F.{}({}, kernel_size={}, stride={}, padding={}, ceil_mode={}, count_include_pad=False)".format(
IR_node.variable_name,
pool_name,
self.parent_variable_name(IR_node),
@@ -8,6 +8,9 @@
from mmdnn.conversion.common.utils import *
from mmdnn.conversion.common.DataStructure.parser import Parser
from distutils.version import LooseVersion
import tempfile
import os
import shutil


class TensorflowParser2(Parser):
@@ -120,11 +123,14 @@ def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes):
output_node_names = dest_nodes,
placeholder_type_enum = dtypes.float32.as_datatype_enum)
# Save it to an output file
frozen_model_file = './frozen.pb'
tempdir = tempfile.mkdtemp()
frozen_model_file = os.path.join(tempdir, 'frozen.pb')
with gfile.GFile(frozen_model_file, "wb") as f:
f.write(original_gdef.SerializeToString())
with open(frozen_model_file, 'rb') as f:
serialized = f.read()
shutil.rmtree(tempdir)

tensorflow.reset_default_graph()
model = tensorflow.GraphDef()
model.ParseFromString(serialized)
@@ -149,14 +155,15 @@ def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes):
x = tensorflow.placeholder(dtype)

input_map[in_nodes[i] + ':0'] = x

tensorflow.import_graph_def(model, name='', input_map=input_map)

with tensorflow.Session(graph = g) as sess:

meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
tempdir = tempfile.mkdtemp()
meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
model = meta_graph_def.graph_def

shutil.rmtree((tempdir))

self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
@@ -346,7 +353,7 @@ def _add_constant_node(self, source_node):
parent_node = self.tf_graph.get_node(s)
if parent_node.type == 'Const':
self._rename_Const(parent_node)

def _rename_Const(self, source_node):
IR_node = self._convert_identity_operation(source_node, end_idx=0, new_op='Constant') # Constant
value = source_node.get_attr('value')
@@ -369,13 +376,13 @@ def gen_IR(self):
continue

node_type = current_node.type

if hasattr(self, "rename_" + node_type):

func = getattr(self, "rename_" + node_type)
func(current_node)
else:

self.rename_UNKNOWN(current_node)


@@ -800,9 +807,9 @@ def rename_Gather(self, source_node):
assign_IRnode_values(IR_node, kwargs)

return IR_node

def rename_GatherV2(self, source_node):

IR_node = self.rename_Gather(source_node)

kwargs = {}
@@ -1013,7 +1020,7 @@ def rename_Rank(self, source_node):

def rename_Transpose(self, source_node):
IR_node = self._convert_identity_operation(source_node, end_idx=1, new_op = 'Transpose')

input_node_perm = self.get_parent(source_node.name, [1])
# input_node_perm = self.check_const(self.get_parent(source_node.name, [1], True))
tensor_content = input_node_perm.get_attr('value')
@@ -1142,6 +1149,6 @@ def rename_Tanh(self, source_node):
kwargs['shape'] = self.tensor_shape_to_list(input_node.get_attr('_output_shapes'))[0]

assign_IRnode_values(IR_node, kwargs)

def rename_Log(self, source_node):
IR_node = self._convert_identity_operation(source_node, new_op = 'Log')
@@ -14,6 +14,9 @@
from mmdnn.conversion.common.DataStructure.parser import Parser
from tensorflow.tools.graph_transforms import TransformGraph
from mmdnn.conversion.rewriter.utils import *
import tempfile
import os
import shutil


class TensorflowParser(Parser):
@@ -308,9 +311,10 @@ def __init__(self, meta_file, checkpoint_file, dest_nodes, inputShape = None, in
tensorflow.import_graph_def(transformed_graph_def, name='', input_map=input_map)

with tensorflow.Session(graph = g) as sess:

meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
tempdir = tempfile.mkdtemp()
meta_graph_def = tensorflow.train.export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
model = meta_graph_def.graph_def
shutil.rmtree(tempdir)

self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
@@ -691,12 +691,6 @@ def onnx_emit(original_framework, architecture_name, architecture_path, weight_p

predict = tf_rep.run(input_data)[0]

return predict

except ImportError:
print('Please install Onnx! Or Onnx is not supported in your platform.', file=sys.stderr)

finally:
del prepare
del model_converted
del tf_rep
@@ -705,6 +699,10 @@ def onnx_emit(original_framework, architecture_name, architecture_path, weight_p
os.remove(converted_file + '.py')
os.remove(converted_file + '.npy')

return predict

except ImportError:
print('Please install Onnx! Or Onnx is not supported in your platform.', file=sys.stderr)


# In case of odd number add the extra padding at the end for SAME_UPPER(eg. pads:[0, 2, 2, 0, 0, 3, 3, 0]) and at the beginning for SAME_LOWER(eg. pads:[0, 3, 3, 0, 0, 2, 2, 0])

0 comments on commit 56275b6

Please sign in to comment.
You can’t perform that action at this time.