In [1]:
# source: https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

# Saving and restoring tensorflow models

# Components of a tensorflow model:
    - Meta graph
        - A protocol buffer which saves the complete tensorfolow graph with all the variables, operations, collections, etc... The extension is ".meta"
    - Checkpoint file
        - Binary file which contains all the values of the weights, biases, gradients and all the other variables saved. Extension: ".ckpt" (there has been an alterations, now we have two files: .data-0000-of-0001 and .index)
        - .data contains our training variables
        - It also has a file named "checkpoint" which simply keeps a record of latest checkpoint files saved

# Saving a tensorflow model
    - Create an instance of the tf.train.Saver() class
    - Save the model variables from inside the session
    - If we are saving after a certain number of training steps we set
      the parameter: "global_step" to that number 
    - We dont need to save the meta file at every step because the basics parameters will not change across itarations so we can set the parameter in saver to false as shown below
    - We can also specify which variables of the model we want to save
    by setting the variables as elements in a list when instantiating the Saver class.

In [2]:
import tensorflow as tf
import os

tf.reset_default_graph()
w1 = tf.placeholder("float", name = "w1")
w2 = tf.placeholder("float", name = "w2")
b1 = tf.Variable(2.0, name = "bias")
op_add = tf.add(w1,w2, name = "op_add")
op_multiply = tf.multiply(op_add, b1, "op_multiply")
feed_dict = {w1: 3.0, w2: 4.0}
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, "./test_model/test_model")

# To print the list of files in the 
# directory with the elements of the model
print(os.listdir("./test_model/"))

['test_model.data-00000-of-00001', 'test_model.index', 'checkpoint', 'test_model.meta']


In [3]:
# To avoid saving the meta file we do this
saver.save(sess, "./my_test_model", global_step=10, write_meta_graph=False)

'./my_test_model-10'

In [4]:
# T keep only the last 4 models and to set the saving according 
# to a certain number of hours we set the parameters when we instantiate 
# the Saver() class

saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

In [9]:
# To save specific variables we do

saver = tf.train.Saver([w1,w2])

TypeError: Variable to save is not a Variable: Tensor("w2:0", dtype=float32)

# Importing a pre-trained model

- Create the network
- Load the parameters

In [7]:
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph("my_test_model-1000.meta")
    new_saver.restore(sess, tf.train.latest_checkpoint("./"))
    print(sess.run("w1:0"))

# Working with restored models (a template)

In [57]:
#Define the variables of my network

# Prepare to feed input, meaning: feed_dict and placeholders
w1 = tf.placeholder("float", name = "w1")
w2 = tf.placeholder("float", name = "w2")
b1= tf.Variable(2.0, name = "bias")
feed_dict = {w1:4, w2:8}

# Define the test operation that we will restore and the session
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1,name = "op_to_restore")
sess = tf.Session()
var_initializer = tf.global_variables_initializer()
sess.run(var_initializer)

# Create a saver object which will save all the variables

saver = tf.train.Saver()

# Run the operation by feeding the input (will print 24)

print(sess.run(w4, feed_dict))

# Now, save the graph
saver.save(sess, "./my_test_model/my_test_model", global_step = 1000)

24.0


'./my_test_model/my_test_model-1000'

To restore the model we want to also prepare a new feed_dict the will
feed the new training data to the network.
To get the reference to these saved operations and placeholder variables we use "graph.get_tensor_by_name()":

Example_1: w1 = graph.get_tensor_by_name("w1:0")

Example_2: op_restore = graph.get_tensor_by_name("op_to_restore:0")


In [73]:
# Restoring and retraining

import tensorflow

sess=tf.Session()
# get metagraph and restore the weights
saver = tf.train.import_meta_graph("./test_model/test_model.meta")
saver.restore(sess, tf.train.latest_checkpoint("./test_model/"))

# Now access and create placeholders and variables
# and create a new feed dict do feed the new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1:13.0, w2:17.0}
op_to_restore = graph.get_tensor_by_name("op_multiply:0")

#Add more to the current graph

add_on_top = tf.multiply(op_to_restore, 2)

add_on_another_top = tf.multiply(add_on_top, 4)

print(sess.run(add_on_another_top, feed_dict))


INFO:tensorflow:Restoring parameters from ./test_model/test_model
480.0


To get more specific about what you want to restore and add to you can restor part of an old grpah and add on top of that to fine tune using grpah.get_tensor_by_name(), lets look at an example loading from the vgg model. 

WARNING: Do not run this example because we do not have the vgg model data 

In [77]:
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
 
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
 
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
 
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
 
# Now, you run this with fine-tuning data in sess.run()

TypeError: load_vgg() missing 2 required positional arguments: 'sess' and 'vgg_path'