# Tensorflow Surgery

The goal of this notebook is to explore how to remove and alter nodes in tensorflow models. Pre-trained models are great ... until you find some fatal flaw in them that prevents you from using them as you'd like (I'm looking at you 'DecodeJpeg' op that can't be run in android). This is all about taking one of those frozen or checkpointed models and picking it apart.

Though this is most useful for a model you haven't made yourself (or one you simply don't want to waste the time training again), we're going to set up a basic model to use so everything that's happening will be clearer. First we'll save it, reload it after it's been saved, freeze it and reload it after it's been frozen in case you need help with those issues as well. Tensorflow doesn't exactly make them simple.

Then we will load the frozen model, strip an operation out of it, save the result and run it to see the difference.

Finally we'll optimize our model for inference so it is ready for production.

# Set up basic model and save it

In [34]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph

First let's put two empty folders in the working directory -- one labeled "checkpoints" and the other "new_checkpoints".

In [35]:
if not os.path.exists('checkpoints'):
    os.mkdir('checkpoints')
if not os.path.exists('new_checkpoints'):
    os.mkdir('new_checkpoints')

In [36]:
# This is important when using interactive shells like iPython notebooks
# If you don't reset the default graph, then tensorflow will create new
# variables and number them like this -- "a_1" -- each time you run a block
# in the notebook over again.
tf.reset_default_graph()

In [37]:
# Placeholder variables
a = tf.placeholder(tf.int16, name="a")
b = tf.placeholder(tf.int16, name="b")

In [38]:
# A trainable variable
var = tf.Variable(initial_value=16.2, name="var")

In [39]:
# build the model
add = tf.add(a, b, name="add_ab")
var = tf.cast(var, tf.int16, name="cast_var")
var = tf.add(var, add, name="add_var")
mul = tf.multiply(a, var, name="mul_ab")

Here are our operations. Right now the variable "var" (which I have entered as a float and cast into an int just to keep things interesting) gets the sum of placeholders a and b added to it and is then multiplied by a. Our goal will be to remove that "add_var" operation while still retaining a working graph.

So the operation looks like this ...

answer = a * (a + b + 16)

And we want to make it this ...

answer = a * 16

*NOTE: "var" is 16 in the equation and not 16.2 because we cast it to int.

In [40]:
init = tf.global_variables_initializer()
saver = tf.train.Saver()
model_path = './checkpoints/test_model'
input_graph_path = './checkpoints/input_graph.pb'

For the purposes of this notebook we'll use a=20 and b=10. So the result of our equation as written should be ...

answer = 20 * (20 + 10 + 16) = 920

And after we alter our graph, it should be ...

answer = 20 * 16 = 320

In [41]:
with tf.Session() as sess:
    
    sess.run(init)
    output = sess.run(mul, feed_dict={a: 20, b: 10})
    
    # Save the variables in a checkpoint and the graph in a meta graph
    saver.save(sess, model_path, write_meta_graph=True, global_step=1)
    
    # Save the graph def as well.
    # This is a bit redundant because the graph def is already contained
    # in the meta graph we just saved but I want to use it to show you
    # an alternative method of freezing the graph later on.
    with open(input_graph_path, 'wb') as f:
        f.write(sess.graph_def.SerializeToString())
        
    print ("output:", output)

output: 920


Tensorflow saves models in three different ways:

**CHECKPOINTs** store the values of your variables. The file labeled 'test_model-1.data-00000-of-00001' is the checkpoint. The other files (the plain 'checkpoint' file and the 'test_model-1.index' files) are for keeping track of which saved checkpoint is the most up-to-date.

**GRAPH_DEFs** store graph operations (add, multiply, matmul, etc...) and constant values. They are saved in binary form as .pb or text form (which takes forever to load by the way for a normal-sized model) as .pbtxt. But those are just naming conventions, you decide the name yourself. In this case, the graph def is 'input_graph.pb'. Graph defs are not enough to restore a model on their own for continued training (you need variables for that) but after your model has been trained, you will convert its variables to constants and then it can be stored in a graph def. So frozen or optimized for inference models are traded around as graph def files.

**META_GRAPHs** store the graph defs as well as other additional information you need to restore the model (like the SaverDef). Our Meta file appears in the checkpoints folder as 'test_model-1.meta'.

When you're *training* -- you use the checkpoint and meta graph to restore the full model.

For *production* -- you use a graph def with the variables frozen into constants to perform inference.

# Load and use model

Now let's load our saved model and use it again. tf.reset_default_graph() is equivalent to us opening up a new notebook and calling these functions from there. There's no carryover from the last section (but the files we saved are still in our folder.

In [68]:
tf.reset_default_graph()

In [69]:
checkpoint_path = tf.train.latest_checkpoint("./checkpoints/")

In [70]:
graph = tf.Graph()
with graph.as_default():
    saver = tf.train.import_meta_graph(checkpoint_path + '.meta')

In [71]:
session = tf.Session(graph=graph)
saver.restore(session, checkpoint_path)

In [72]:
# Some tensorflow operations require that you append ":0" to the
# variables to signify that they are tensors. If you're troubleshooting
# this is probably a good thing to try
a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
out = graph.get_tensor_by_name('mul_ab:0')

Now let's take a look at the nodes in our graph_def file. This is Google's protobuf format, which serves a similar purpose to JSON. It's Google's way of efficiently representing data so programs in different languages can access it and use it easily

In [83]:
graph_def = graph.as_graph_def()
for node in graph_def.node:
    print(node)

Take a look at "var" -- it is labeled with the op "VariableV2". See if you can match the operations that are happening in protobuf to the operations in the equation.

In [74]:
output = session.run(out, feed_dict={a: 20, b: 10})

In [75]:
print("output:", output)

output: 920


In [76]:
session.close()

# Load and freeze model

**TECHNIQUE #1**

In [54]:
tf.reset_default_graph()

In [55]:
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.tools import inspect_checkpoint

In [56]:
checkpoint_path = tf.train.latest_checkpoint("./checkpoints/")
inspect_checkpoint.print_tensors_in_checkpoint_file(checkpoint_path, tensor_name='', all_tensors=True)
# Change all_tensors to False if you want it to print out just the names in the checkpoint and not
# all the values. Typically it's better just to print the names as the values are huge matrices of numbers.

tensor_name:  var
16.2


As "var" is our only graph variable, that's the one we'll want to freeze. Fortunately though tensorflow's utilities will figure out what variables to freeze and how on its own so we don't have to specify.

In [57]:
# Import our saved meta graph into the current graph we're using
saver = tf.train.import_meta_graph(checkpoint_path + '.meta', import_scope=None)
# If you have multiple output nodes, they should be stored as a list of strings
# where each string is a name of one of the output nodes.
output_node_names = "mul_ab"

with tf.Session() as sess:
    
    # Restore the variable values
    saver.restore(sess, checkpoint_path)
    # Get the graph def from our current graph
    graph_def = tf.get_default_graph().as_graph_def()
    # Turn all variables into constants
    frozen_graph_def = convert_variables_to_constants(sess, graph_def, output_node_names.split(","))
    
    # Save our new graph def
    with tf.gfile.GFile("./checkpoints/" + "frozen.pb", "wb") as f:
        f.write(frozen_graph_def.SerializeToString())

INFO:tensorflow:Froze 1 variables.
Converted 1 variables to const ops.


It should report that 1 variable has been frozen/converted to const ops.

Now let's look at the nodes again. The op for our node "var" should now be changed to "Const".

In [84]:
for node in frozen_graph_def.node:
    print(node)

**TECHNIQUE #2**

This does the exact same thing but uses tensorflow's freeze_graph utility.

In [59]:
tf.reset_default_graph()

In [60]:
from tensorflow.python.tools import freeze_graph
import os

In [61]:
save_dir = './checkpoints/'
checkpoint_path = tf.train.latest_checkpoint(save_dir)
input_graph_path = os.path.join(save_dir, 'input_graph.pb')
meta_path = checkpoint_path + '.meta'
output_frozen_graph_name = os.path.join(save_dir, 'frozen2.pb')
input_saver_def_path = ""
input_binary = True
output_node_names = "mul_ab"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
clear_devices = True

In [62]:
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                         input_binary, checkpoint_path, output_node_names,
                         restore_op_name, filename_tensor_name,
                         output_frozen_graph_name, clear_devices, "")

INFO:tensorflow:Froze 1 variables.
Converted 1 variables to const ops.
8 ops in the final graph.


# A useful function for loading graphs

In [77]:
def load_graph(my_path):
    # Load the pb file and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(my_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    # Now import the graph_def to our default graph.
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            # If you put name=None here instead of ''
            # it will relabel all your ops as "import/original_name"
            name='',
            op_dict=None, 
            producer_op_list=None
        )
    # Return the loaded graph
    return graph

# Load and use frozen model

In [78]:
tf.reset_default_graph()

In [79]:
filename = "./checkpoints/" + "frozen.pb"

graph = load_graph(filename)

a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
out = graph.get_tensor_by_name('mul_ab:0')

with tf.Session(graph=graph) as sess:
    
    # Now we don't need to initialize any variables because
    # there aren't any, only constants.
    
    output = sess.run(out, feed_dict={a: 20, b: 10})
    print ("output: ", output)

output:  920


# Load graph and alter node

In [80]:
tf.reset_default_graph()

In [81]:
# Note that sometimes you need the './' before the address
# and sometimes you don't. Tensorflow can be a bit fickle
# on this point (like with adding ":0" to tensor names)
# so if you get an error reporting that it can't find your 
# file, this should be one of the first things you try.
frozen_path = './checkpoints/frozen.pb'

# Load our frozen graph
graph = load_graph(frozen_path)

In [85]:
graph_def = graph.as_graph_def()
for node in graph_def.node:
    print(node)

Now before we get rid of "add_var", let's change the node that takes "add_var" as input and make its input the node before "add_var" instead, which in this case is "cast_var". This way we cut "add_var" out of the step-by-step operations of the graph entirely.

In [35]:
for node in graph_def.node:
    # Find the node that we noticed above takes "add_var" as input
    if node.name == "mul_ab":
        # Change its input to the node before "add_var"
        node.input[1] = "cast_var"
        print(node)

name: "mul_ab"
op: "Mul"
input: "a"
input: "cast_var"
attr {
  key: "T"
  value {
    type: DT_INT16
  }
}



Now let's make "add_var" into an identity function (an op that returns exactly what went into it). Also get rid of the second input because an identity function should only have one input.

This is a workaround because in Tensorflow, you can't simply delete the node itself.

There is another option. You could use export_sub_graph to recreate the graph with every node except the one you don't want. I found the approach taken in this notebook to be easier though because every time we take out a node we need to alter another one (change the input) through this process anyway whether we're using export_sub_graph or not. So we might as well finish the job with this approach.

Also, assuming you will be optimizing the model for inference eventually (more info below), that process will find nodes that we've neutered and cut them out of the graph, so from a memory/final product perspective the result will be the same.

In [36]:
for node in graph_def.node:
    # Find the node we want to get rid of
    if node.name == "add_var":
        # Change its op to Identity
        node.op = "Identity"
        # Delete its second input
        del node.input[1]
        print(node)

name: "add_var"
op: "Identity"
input: "cast_var"
attr {
  key: "T"
  value {
    type: DT_INT16
  }
}



**Note**: If the node didn't have an attribute "T" already, we would have to add one as that is also a requirement of an "Identity" op.

In [86]:
for node in graph_def.node:
    print(node)

Here we can see that "add_var" and "mul_ab" have been successfully changed.

So now we will save our new, altered frozen file in the new_checkpoints folder

In [38]:
tf.train.write_graph(graph_def, "./new_checkpoints", "altered_frozen.pb", False)

'./new_checkpoints/altered_frozen.pb'

**Note**: I've included more info about different alterations to nodes at the bottom.

# Check if it worked

In [39]:
tf.reset_default_graph()

In [40]:
graph_def_path = './new_checkpoints/altered_frozen.pb'

In [41]:
# Load our new and improved graph def
graph = load_graph(graph_def_path)

In [42]:
session = tf.Session(graph=graph)

In [43]:
a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
out = graph.get_tensor_by_name('mul_ab:0')

In [46]:
output = session.run(out, feed_dict={a: 20, b: 10})

In [47]:
print("output:", output)

320


In the beginning we said that if our surgery is successful then the resulting answer will be 320.

This means that we have successfully excised the "add_var" node!

In [48]:
session.close()

# Optimize for Inference

This is strictly if we want to make the model as lean as possible (useful if you want it to run on an app for instance). We could also use the altered_frozen model as is and it would work fine.

In [83]:
tf.reset_default_graph()

In [84]:
from tensorflow.python.tools import optimize_for_inference_lib

In [85]:
graph_def_path = './new_checkpoints/altered_frozen.pb'
optimized_model = './new_checkpoints/optimized.pb'

In [86]:
# Load the graph we just froze
input_graph_def = tf.GraphDef()
with tf.gfile.Open(graph_def_path, "rb") as f:
    data = f.read()
    input_graph_def.ParseFromString(data)

output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        # A list of input nodes
        ["a","b"],
        # A list of output nodes
        ["mul_ab"],
        tf.int16.as_datatype_enum)

# Save the optimized graph
f = tf.gfile.FastGFile(optimized_model, "wb")
f.write(output_graph_def.SerializeToString())
f.close()

In [87]:
for node in output_graph_def.node:
    print(node)

That got rid of unnecessary nodes like "var/read" and "add_var" (which did nothing since we turned it into an identity function.) It also removed input node "b" because it noticed that we are no longer using it now that our equation is just 'answer = a * var'. Thanks to the optimize_for_inference_lib we are down from a bloated 8-node graph to a lean, mean 4-node inference machine.

# Load and use optimized model

In [94]:
tf.reset_default_graph()

In [95]:
optimized_model = './new_checkpoints/optimized.pb'

In [96]:
graph = load_graph(optimized_model)

In [97]:
a = graph.get_tensor_by_name('a:0')
out = graph.get_tensor_by_name('mul_ab:0')

In [98]:
with tf.Session(graph=graph) as sess:
    output = sess.run(out, feed_dict={a: 20})
    print("output:", output)

320


sweet.

# Other useful alterations

Another common edit you might need to make to a node is to change its variable type.

Let's take this node as an example:

##################

    name: "input_feed"
    op: "Placeholder"
    attr {
      key: "dtype"
      value {
        type: DT_INT64
        }
    }
##################

In this node, "input_feed", the placeholder takes a 64-bit integer. Let's say we want to change that to a 32-bit integer.

First you need to know that tensorflow protobuf files can handle 20 different types and each one has an integer label. Here they are:

# Variable Types
0: DT_INVALID

1: DT_FLOAT

2: DT_DOUBLE

3: DT_INT32

4: DT_UINT8

5: DT_INT16

6: DT_INT8

7: DT_STRING

8: DT_COMPLEX64

9: DT_INT64

10: DT_BOOL

11: DT_QINT8

12: DT_QUINT8

13: DT_QINT32

14: DT_BFLOAT16

15: DT_QINT16

16: DT_QUINT16

17: DT_UINT16

18: DT_COMPLEX128

19: DT_HALF

20: DT_RESOURCE


And here's some info on what those types represent:

https://www.tensorflow.org/programmers_guide/dims_types

To make the change, you just set the node's "attr['dtype'].type" equal to the integer label of whichever variable type you want, like so ...

In [None]:
# for node in my_graph_def.node:
#     if node.name == "input_feed":
#         # This will make it a 32-bit integer instead of 64-bit
#         node.attr['dtype'].type = 3

You can make other changes to the attributes of a node the same way. And you can delete attributes like this ...

In [None]:
# for node in my_graph_def.node:
#     if node.name == "input_feed":
#         if 'acceptable_fraction' in node.attr: del node.attr['acceptable_fraction']
#         if 'channels' in node.attr: del node.attr['channels']
#         if 'fancy_upscaling' in node.attr: del node.attr['fancy_upscaling']
#         if 'ratio' in node.attr: del node.attr['ratio']
#         if 'try_recover_truncated' in node.attr: del node.attr['try_recover_truncated']

If you made a lot of alterations to the graph, it might be a good idea to check if the graph is still valid as well (meaning every op that needs an input receives one and so on).

Conveniently, Tensorflow has a function for that!

In [None]:
# from tensorflow.python.tools import optimize_for_inference_lib

# optimize_for_inference_lib.ensure_graph_is_valid(my_graph_def)

That function will raise an error if the graph isn't valid. If it runs and there is no response then all is well.

There are other types of errors it won't catch however that you'll only see when you run the inference operation (like a missing attribute for example).

Good luck!