In [1]:
import tensorflow as tf

from graph_utils import Graph_Utils

ckpt_dir = './tf_logs'
name = 'model'

In [2]:
utils = Graph_Utils(name, ckpt_dir)
utils.visualize_graph()

INFO:tensorflow:Restoring parameters from ./tf_logs1\model2


# Inserting a Node

To insert a node C into the path A <-- B, you must first connect node C and node B. Then, you can call the `insert_nodes()` method to perform the insertion.

In [3]:
# I want to insert the 'Abs2' node into the 'Neg' <-- 'Abs' path. First, I must connect 'Abs2' and 'Abs'.

node = utils.get_tensors(['Abs'])[0]
tf.abs(node, name='Abs2')

utils.visualize_graph()

In [4]:
# Then, I can call the insert_nodes() method to complete the insertion.
# The list order must be [node1, node2, node to insert between node1 and node2]

li = [['Neg'], 'Abs', 'Abs2']

utils.reroute([li])
utils.visualize_graph()

INFO:tensorflow:Restoring parameters from ./tf_logs1\model2


# Inserting a Placeholder

To insert a placeholder, you must first create a placeholder and then call the `insert_nodes()` method.

In [5]:
# I want to replace 'strided_slice' node with a placeholder. First, I must create a placeholder to insert.

placeholder = tf.placeholder(tf.float32, shape=[])
utils.visualize_graph()

In [6]:
# Then, I can call the insert_nodes() method to complete the insertion.
# The list order must be [node1, node2, node to insert between node1 and node2]
# Note that the graph is disconnected since placeholders do not take any inputs.

utils.reroute([[['Abs'], 'strided_slice', 'Placeholder']])
utils.visualize_graph()

INFO:tensorflow:Restoring parameters from ./tf_logs1\model2


# Removing Nodes

In [7]:
# Removing nodes is simple. Just call the remove_nodes() method.

utils.remove_nodes(['Neg'])
utils.visualize_graph()

INFO:tensorflow:Restoring parameters from ./tf_logs1\model2


# Saving Graph

In [8]:
# The following files will be created at the ckpt_dir:
# model2.data-00000-of-00001
# model2.index
# model2.meta
# checkpoint

utils.save('model2', ckpt_dir)

# Performing Computations

In [9]:
graph, sess = utils.graph_sess

nodes = utils.get_tensors(['Abs2', 'Placeholder'])

grad = tf.gradients(nodes[0], nodes[1])[0]

sess.run(grad, feed_dict={nodes[1]: 1.})

1.0

# Getting Paths

Given two points A and B in a computation graph, the `get_paths()` method returns the all the nodes between A and B and optionally visualizes the path.

In [10]:
tf.reset_default_graph()
utils = Graph_Utils(name, ckpt_dir)
paths = utils.get_paths('Neg', 'Variable', visualize=True)
print(paths)

INFO:tensorflow:Restoring parameters from ./tf_logs1\model2


['Neg', 'Abs', 'strided_slice', 'Variable/read', 'Variable']
