# Load Variables from Existing Checkpoints
#### without recreating original computational graph

In [2]:
import tensorflow as tf
import os

##  CheckpointReader
tf.train.NewCheckpointReader is a nifty  method that creates a CheckpointReader object. CheckpointReader has several very useful methods:
* get_variable_to_shape_map() - provides a dictionary with variable names and shapes
    * debug_string() - provides a string containing all variables stored in the checkpoint 
* has_tensor(var_name) - allows to check whether the variable exists in the checkpoint
* get_tensor(var_name) - returns a tensor for the variable name

For illustration, I will define a function that will check the validity of the path and will load the checkpoint reader for you. 

In [3]:
def load_reader(path):
    assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
    return tf.train.NewCheckpointReader(path)

In [34]:
your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

### - reader.debug_string() returns a sting containing the following:
* variable name
* data type
* tensor shape

The elements are separated by a space(' '). You can use this debu string to create a list of varible names like this:

In [53]:
all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print var_names[:4]
print var_shapes[:4]

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']


### However, a far better method for the same job is reader.get_variable_to_shape_map()
### - reader.get_variable_to_shape_map()  returns a dictionary containing names of all variables and shapes
Variables provided as dictionary keys and shapes as values

In [66]:
saved_shapes = reader.get_variable_to_shape_map()
print 'fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels']

fire9/squeeze1x1/kernels: [1, 1, 512, 64]


### - reader.has_tensor(var_name) returns bool
It is a convinience method that allows you to check whether the variable in question exists in the ckeckpoint

In [51]:
names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
    print key+':', names_that_exit[key]

fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True


### - reader.get_tensor(tensor_name): returns a NumPy array containing the the tensor values from the checkpoint.
The normal use would be to recover a tensor first and then initialize your own variable with the recovered tensor afterwards:

In [60]:
def recover_var(reader, var_name):
    recovered_var = 'var to be recovered'
    try:
        recovered_var = reader.get_tensor(var_name)
    except:
        assert reader.has_tensor(var_name),\
        "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
    return recovered_var    

In [67]:
checkpoint_var = recover_var(reader, 'conv1/kernels')
print "Recovered variable has the following shape: \n", checkpoint_var.shape
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print "New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape()

Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)
