Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] Helper function to add shapes into the graph. …
Browse files Browse the repository at this point in the history
…Use tmp folder for model files and clean it.
  • Loading branch information
srkreddy1238 committed Sep 8, 2018
1 parent ab4946c commit 4bf1345
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
35 changes: 33 additions & 2 deletions nnvm/python/nnvm/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os.path
import collections
import numpy as np
from tvm.contrib import util

# Tensorflow imports
import tensorflow as tf
Expand Down Expand Up @@ -43,6 +44,31 @@ def ProcessGraphDefParam(graph_def):
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def


def AddShapesToGraphDef(out_node):
""" Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Parameters
----------
out_node: String
Final output node of the graph.
Returns
-------
graph_def : Obj
tensorflow graph definition with shapes attribute added to nodes.
"""

with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
)
return graph_def

class NodeLookup(object):
"""Converts integer node ID's to human readable labels."""

Expand Down Expand Up @@ -128,13 +154,18 @@ def get_workload(model_path):
model_url = os.path.join(repo_base, model_path)

from mxnet.gluon.utils import download
download(model_url, model_name)

temp = util.tempdir()
path_model = temp.relpath(model_name)

download(model_url, path_model)

# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
with tf.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
temp.remove()
return graph_def

#######################################################################
Expand Down
7 changes: 2 additions & 5 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
)

final_graph_def = nnvm.testing.tf.AddShapesToGraphDef(out_node)

tf_output = run_tf_graph(sess, in_data, in_name, out_name)
tvm_output = run_tvm_graph(final_graph_def, in_data,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/nnvm/from_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)


######################################################################
# Import model
# ------------
Expand All @@ -76,7 +75,8 @@
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef('softmax')

######################################################################
# Decode image
Expand Down

0 comments on commit 4bf1345

Please sign in to comment.