Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
178 lines (155 sloc) 6.73 KB
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TF: Tensorflow parser"""
from __future__ import absolute_import as _abs
from __future__ import print_function
import os
from tvm.contrib import util
class TFParser(object):
"""
A Wrapper to handle tensorflow models parsing, TensorFlow is needed
Parameters
----------
model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints.
Examples
--------
.. code-block:: python
parser = TFParser(model_dir)
graphdef = parser.parse()
"""
def __init__(self, model_dir):
from tensorflow.core.framework import graph_pb2
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()
def _set_graph(self, graph):
"""Set Graph"""
self._graph = graph
def _get_graph(self):
"""Get Graph"""
return self._graph
def _load_pb_file(self):
"""Load single pb file"""
graph = self._get_graph()
with open(self._model_dir, "rb") as f:
graph.ParseFromString(f.read())
return graph
def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]
def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")
tags = self._get_tag_set()
output_names = set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
tags,
self._model_dir)
for sig_def in meta_graph_def.signature_def.values():
for output_tensor in sig_def.outputs.values():
output_names.add(output_tensor.name.replace(":0", ""))
tf.reset_default_graph()
return ",".join(output_names)
def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
from tensorflow.core.framework import graph_pb2
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")
saved_model_dir = self._model_dir
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names()
input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = True
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = ",".join(self._get_tag_set())
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_filename, clear_devices, "", "", "",
input_meta_graph, input_saved_model_dir,
saved_model_tags)
with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_filename, "rb") as f:
output_graph_def.ParseFromString(f.read())
output_graph_def = graph_util.remove_training_nodes(output_graph_def)
return output_graph_def
def _load_ckpt(self):
"""TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.")
def parse(self):
"""
Parse tensorflow models: checkpoints, saved models, and single frozen pb file.
Returns
-------
GraphDef of the passed model
"""
graph = None
if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt):
if not os.path.isdir(os.path.join(self._model_dir, "variables")):
raise RuntimeError("InputConfiguration: Invalid model path.")
graph = self._load_saved_model()
else:
graph = self._load_ckpt()
elif os.path.isfile(self._model_dir):
# Only .pb or .pbtxt is a valid suffix name.
if self._model_dir.endswith(".pb") or \
self._model_dir.endswith(".pbtxt"):
cur_dir = os.path.dirname(self._model_dir)
else:
raise RuntimeError("InputConfiguration: Invalid model format.")
# It is a saved model if `variables` directory is present at the
# same directory with the pb or pbtxt file.
if os.path.isdir(os.path.join(cur_dir, "variables")):
self._model_dir = cur_dir
graph = self._load_saved_model()
else:
graph = self._load_pb_file()
else:
raise RuntimeError("InputConfiguration: Unrecognized model "
"file or path.")
self._set_graph(graph)
return graph
You can’t perform that action at this time.