In [1]:
import tensorflow as tf

  from ._conv import register_converters as _register_converters


使用 tf.train.Saver() 创建 Saver 来管理模型中的所有变量。例如，以下代码片段展示了如何调用 tf.train.Saver.save 方法以将变量保存到检查点文件中

In [2]:
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Model saved in path: /tmp/model.ckpt


In [4]:
inc_v1

<tf.Tensor 'Assign:0' shape=(3,) dtype=float32_ref>

##### tf.train.Saver 对象不仅将变量保存到检查点文件中，还将恢复变量。复变量时，不必事先将其初始化。以从检查点文件中恢复变量：

In [2]:
tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Model restored.
v1 : [1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]


##### 检查某个检查点的变量

In [5]:
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

In [10]:
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=True)

tensor_name:  v1
[1. 1. 1.]
tensor_name:  v2
[-1. -1. -1. -1. -1.]


In [13]:
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False, all_tensor_names=False)

tensor_name:  v1
[1. 1. 1.]


In [14]:
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False, all_tensor_names=False)

tensor_name:  v2
[-1. -1. -1. -1. -1.]
