In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
import time

In [None]:
# make things wide
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def=None, width=1200, height=800, max_const_size=32, ungroup_gradients=False):
    if not graph_def:
        graph_def = tf.get_default_graph().as_graph_def()
        
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    data = str(strip_def)
    if ungroup_gradients:
        data = data.replace('"gradients/', '"b_')
        #print(data)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(data), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:{}px;height:{}px;border:0" srcdoc="{}"></iframe>
    """.format(width, height, code.replace('"', '&quot;'))
    display(HTML(iframe))

In [None]:
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("batch_size", 10, "The batch size to train")
flags.DEFINE_integer("epoch_number", 10, "Number of epochs to run trainer")
flags.DEFINE_integer("steps_to_validate", 1,
                     "Steps to validate and print loss")
flags.DEFINE_string("checkpoint_dir", "./checkpoint/",
                    "indicates the checkpoint dirctory")
flags.DEFINE_string("model_path", "./model/", "The export path of the model")
flags.DEFINE_integer("export_version", 4, "The version number of the model")


In [1]:
def main():
  # Define training data
  x = np.ones(FLAGS.batch_size)
  y = np.ones(FLAGS.batch_size)

  # Define the model
  X = tf.placeholder(tf.float32, shape=[None], name="X")
  Y = tf.placeholder(tf.float32, shape=[None], name="yhat")
  w = tf.Variable(1.0, name="weight")
  b = tf.Variable(1.0, name="bias")
  loss = tf.square(Y - tf.mul(X, w) - b)
  train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
  predict_op  = tf.mul(X, w) + b

  saver = tf.train.Saver()
  checkpoint_dir = FLAGS.checkpoint_dir
  checkpoint_file = checkpoint_dir + "/checkpoint.ckpt"
  if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
    
  # Start the session
  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      print("Continue training from the model {}".format(ckpt.model_checkpoint_path))
      saver.restore(sess, ckpt.model_checkpoint_path)

    saver_def = saver.as_saver_def()
    print(saver_def.filename_tensor_name)
    print(saver_def.restore_op_name)

    # Start training
    start_time = time.time()
    for epoch in range(FLAGS.epoch_number):
      sess.run(train_op, feed_dict={X: x, Y: y})

      # Start validating
      if epoch % FLAGS.steps_to_validate == 0:
        end_time = time.time()
        print("[{}] Epoch: {}".format(end_time - start_time, epoch))

        saver.save(sess, checkpoint_file)
        tf.train.write_graph(sess.graph_def, checkpoint_dir, 'trained_model.pb', as_text=False)
        tf.train.write_graph(sess.graph_def, checkpoint_dir, 'trained_model.txt', as_text=True)

        start_time = end_time

    # Print model variables
    w_value, b_value = sess.run([w, b])
    print("The model of w: {}, b: {}".format(w_value, b_value))

    # Export the model
    print("Exporting trained model to {}".format(FLAGS.model_path))
    model_exporter = exporter.Exporter(saver)
    model_exporter.init(
      sess.graph.as_graph_def(),
      named_graph_signatures={
        'inputs': exporter.generic_signature({"features": X}),
        'outputs': exporter.generic_signature({"prediction": predict_op})
      })
    model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess)
    print('Done exporting!')

if __name__ == "__main__":
  main()

NameError: name 'np' is not defined

In [None]:
show_graph()

In [None]:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts checkpoint variables into Const ops in a standalone GraphDef file.

This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
variable values stored in a checkpoint file, and output a GraphDef with all of
the variable ops converted into const ops containing the values of the
variables.

It's useful to do this when we need to load a single file in C++, especially in
environments like mobile or embedded where we may not have access to the
RestoreTensor ops and file loading calls that they rely on.

An example of command-line usage is:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax

You can also look at freeze_graph_test.py for an example of how to use it.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from google.protobuf import text_format
from tensorflow.python.framework import graph_util


FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string("input_graph", "/root/pipeline/jupyterhub.ml/notebooks/Spark/ML/checkpoint/trained_model.pb",
                           """TensorFlow 'GraphDef' file to load.""")
tf.app.flags.DEFINE_string("input_saver", "",
                           """TensorFlow saver file to load.""")
tf.app.flags.DEFINE_string("input_checkpoint", "",
                           """TensorFlow variables file to load.""")
tf.app.flags.DEFINE_string("output_graph", "",
                           """Output 'GraphDef' file name.""")
tf.app.flags.DEFINE_boolean("input_binary", False,
                            """Whether the input files are in binary format.""")
tf.app.flags.DEFINE_string("output_node_names", "",
                           """The name of the output nodes, comma separated.""")
tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all",
                           """The name of the master restore operator.""")
tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
                           """The name of the tensor holding the save path.""")
tf.app.flags.DEFINE_boolean("clear_devices", True,
                            """Whether to remove device specifications.""")
tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
                           "initializer nodes to run before freezing.")
tf.app.flags.DEFINE_string("variable_names_blacklist", "", "comma separated "
                           "list of variables to skip converting to constants ")


def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not tf.train.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read().decode("utf-8"), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
                                FLAGS.variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))


def main(unused_args):
  freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
               FLAGS.input_checkpoint, FLAGS.output_node_names,
               FLAGS.restore_op_name, FLAGS.filename_tensor_name,
               FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)

if __name__ == "__main__":
  tf.app.run()

In [None]:
# pylint: disable=g-bad-file-header
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Removes parts of a graph that are only needed for training.

There are several common transformations that can be applied to GraphDefs
created to train a model, that help reduce the amount of computation needed when
the network is used only for inference. These include:

 - Removing training-only operations like checkpoint saving.

 - Stripping out parts of the graph that are never reached.

 - Removing debug operations like CheckNumerics.

 - Folding batch normalization ops into the pre-calculated weights.

 - Fusing common operations into unified versions.

This script takes either a frozen binary GraphDef file (where the weight
variables have been converted into constants by the freeze_graph script), or a
text GraphDef proto file (the weight variables are stored in a separate
checkpoint file), and outputs a new GraphDef with the optimizations applied.

If the input graph is a text graph file, make sure to include the node that
restores the variable weights in output_names. That node is usually named
"restore_all".

An example of command-line usage is:

bazel build tensorflow/python/tools:optimize_for_inference && \
bazel-bin/tensorflow/python/tools/optimize_for_inference \
--input=frozen_inception_graph.pb \
--output=optimized_inception_graph.pb \
--frozen_graph=True \
--input_names=Mul \
--output_names=softmax


"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import tensorflow as tf

from google.protobuf import text_format

from tensorflow.python.tools import optimize_for_inference_lib

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("output", "", """File to save the output graph to.""")
flags.DEFINE_string("input_names", "", """Input node names, comma separated.""")
flags.DEFINE_string("output_names", "",
                    """Output node names, comma separated.""")
flags.DEFINE_boolean("frozen_graph", True,
                     """If true, the input graph is a binary frozen GraphDef
                     file; if false, it is a text GraphDef proto file.""")
flags.DEFINE_integer("placeholder_type_enum", tf.float32.as_datatype_enum,
                     """The AttrValue enum to use for placeholders.""")


def main(unused_args):
  if not tf.gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = tf.GraphDef()
  with tf.gfile.Open(FLAGS.input, "r") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = tf.gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    tf.train.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0

if __name__ == "__main__":
  tf.app.run()

In [None]:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A simple script for inspect checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("file_name", "", "Checkpoint filename")
tf.app.flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
tf.app.flags.DEFINE_bool("all_tensors", "True",
                         "If True, print the values of all the tensors.")


def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    reader = tf.train.NewCheckpointReader(file_name)
    if FLAGS.all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key))
    elif not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")


def main(unused_argv):
  if not FLAGS.file_name:
    print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
          "[--tensor_name=tensor_to_print]")
    sys.exit(1)
  else:
    print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)

if __name__ == "__main__":
  tf.app.run()