In [1]:
# chap 5
# MNIST
from tensorflow.examples.tutorials.mnist import input_datat_data

In [7]:
mnist = input_data.read_data_sets('E:\\project\\TensorFlow\\MNIST', one_hot=True)
print('train data size: ', mnist.train.num_examples)
print('validating data size: ', mnist.validation.num_examples)
print('testing train data: ', mnist.test.num_examples)
print('example training data: ', mnist.train.images[0][:4])
print('example train data label: ', mnist.train.labels[0])

Extracting E:\project\TensorFlow\MNIST\train-images-idx3-ubyte.gz
Extracting E:\project\TensorFlow\MNIST\train-labels-idx1-ubyte.gz
Extracting E:\project\TensorFlow\MNIST\t10k-images-idx3-ubyte.gz
Extracting E:\project\TensorFlow\MNIST\t10k-labels-idx1-ubyte.gz
train data size:  55000
validating data size:  5000
testing train data:  10000
example training data:  [0. 0. 0. 0.]
example train data label:  [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]


In [9]:
batch_size = 100
xs, ys = mnist.train.next_batch(batch_size)
print('xs shape: ', xs.shape)
print('ys shape: ', ys.shape)

xs shape:  (100, 784)
ys shape:  (100, 10)


In [23]:
# 在MNIST数据集中使用神经网络模型实现对手写体数字进行识别
# 本代码中使用了如下优化方法：带指数衰减的学习率设置、使用正则化避免过拟合、
# 使用滑动平均模型
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

input_node = 784    # 输入节点个数等于图像像素个数
output_node = 10    # 0到9的数字，类别个数
layer1_node = 500   # 隐层节点个数
batch_size  = 100

learning_rate_base = 0.8    # 基础学习率
learning_rate_decay = 0.99  # 学习率衰减率
regularization_rate = 0.0001  # 描述模型复杂度的正则化项在损失函数中的系数
train_steps = 30000 
moving_average_decay = 0.99   # 滑动平均衰减率

# 用于计算前向传播
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
    if not avg_class:
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
        return tf.matmul(layer1, weights2) + biases2
    else:
        layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))
        return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)
    
def train(mnist):
    x = tf.placeholder(tf.float32, shape=(None, input_node), name = 'x-input')
    y_= tf.placeholder(tf.float32, shape=(None, output_node), name= 'y_input')
    
    weights1 = tf.Variable(tf.truncated_normal([input_node, layer1_node], stddev=0.1))
    biases1 = tf.Variable(tf.constant(0.1, shape=[layer1_node]))
    
    weights2 = tf.Variable(tf.truncated_normal([layer1_node, output_node], stddev=0.1))
    biases2 = tf.Variable(tf.constant(0.1, shape=[output_node]))
    
    # 计算前向传播
    y = inference(x, None, weights1, biases1, weights2, biases2)
    
    # 训练神经网络的时候，一般将代表训练轮数的变量指定为不可训练的参数
    global_step = tf.Variable(0, trainable=False)
    
    # 实例化滑动平均类
    variable_averages = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
    # 意思是给神经网络图中的所有参数使用滑动平均，其中trainable=False不包括在内
    variable_averages_op = variable_averages.apply(tf.trainable_variables())    
    average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)
    
    #在分类问题中，当结果只有一个正确答案时，可使用如下函数计算交叉熵
    #由于标准答案是一个数组，而如下计算交叉熵的函数需要提供一个数字，因此要使用tf.argmax函数得到答案对应的类别编号
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))   # 计算batch中每个样例的交叉熵
    # 计算当前batch中所有样例的交叉熵
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    regularizer = tf.contrib.layers.l2_regularizer(regularization_rate)
    # 计算模型的正则化损失，一般只计算边的权重
    regularization = regularizer(weights1) + regularizer(weights2)
    # 模型的总损失为交叉熵和正则化损失之和
    loss = cross_entropy_mean + regularization
    # 设置指数衰减学习率
    learning_rate = tf.train.exponential_decay(learning_rate_base, 
                                              global_step,   # 当前迭代轮数
                                              mnist.train.num_examples / batch_size,    # 过完所有数据需要的迭代次数
                                              learning_rate_decay)    # 学习率衰减速度
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    # 以下两种实现方法都可通过反向传播更新网络参数和参数的滑动平均值
    #train_op = tf.group(train_step, variable_averages_op)
    with tf.control_dependencies([train_step, variable_averages_op]):
        train_op = tf.no_op(name='train')
    
    # 验证滑动平均的神经网络模型前向传播是否正确
    # average_y是一个二维数组，共有batch_size行，10列， tf.argmax是将第一个参数中的每行中的最大值的下标取出形成一个列表
    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))     # 比较结果是bool型变量
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    # 将bool型变量转换成实数，并计算均值
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}        
        test_feed = {x: mnist.test.images, y_: mnist.test.labels}
        print('weitht1: ', weights1.eval())
        print('weitht2: ', weights2.eval())
        for i in range(train_steps):
            if not (i % 1000):
                validate_acc = sess.run(accuracy, feed_dict=validate_feed)
                print('after %d training steps, validation accuracy using average model is %g' % (i, validate_acc))
            
                
            xs, ys = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={x: xs, y_: ys})
            
        test_acc = sess.run(accuracy, feed_dict=test_feed)
        print ('finall accaracy using average model is: ', test_acc)
        print('weitht1: ', weights1.eval())
        print('weitht2: ', weights2.eval())
        
def main(argv=None):
    mnist = input_data.read_data_sets('E:\\project\\TensorFlow\\MNIST', one_hot=True)
    train(mnist)
    
if __name__ == '__main__':
    tf.app.run()

Extracting E:\project\TensorFlow\MNIST\train-images-idx3-ubyte.gz
Extracting E:\project\TensorFlow\MNIST\train-labels-idx1-ubyte.gz
Extracting E:\project\TensorFlow\MNIST\t10k-images-idx3-ubyte.gz
Extracting E:\project\TensorFlow\MNIST\t10k-labels-idx1-ubyte.gz
weitht1:  [[-0.14306453 -0.18635322  0.1099795  ... -0.09052034 -0.08416008
  -0.09471335]
 [ 0.08705756  0.03762971  0.04566057 ... -0.05304244 -0.05849532
   0.04415581]
 [ 0.09139423 -0.12043717  0.02521717 ...  0.11786151 -0.05820546
   0.08825582]
 ...
 [ 0.03652653  0.0123915   0.03455595 ... -0.04015323 -0.00131873
  -0.00719508]
 [ 0.15182127  0.07061632 -0.04262719 ... -0.0628676  -0.10799395
  -0.07982754]
 [-0.06275754  0.03779794  0.0351049  ... -0.00591473 -0.01883379
   0.02244921]]
weitht2:  [[ 0.17716314  0.14353247  0.0877187  ... -0.04124888 -0.12901127
  -0.02800767]
 [ 0.05315333  0.04188048 -0.00701133 ... -0.05572772  0.08595856
  -0.16677698]
 [-0.01312683 -0.03935822  0.06776911 ... -0.10528753  0.0299221

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [33]:
# 变量管理: variable_scope  get_variable
with tf.variable_scope('foo'):
    v = tf.get_variable('v', [1], initializer=tf.constant_initializer(1.0))
  

In [37]:

with tf.variable_scope('foo'):
    v = tf.get_variable('v', [1])   # 将报错,因此在foo这个上下文空间中已经创建了v变量
  

In [36]:

with tf.variable_scope('foo', reuse=True): 
    v = tf.get_variable('v', [1])    # 不会报错，因为使用了reuse参数


In [None]:
# 使用变量管理的方法重写以上神经网络层的定义
def inference(input_tensor, reuse=False):
    # 第一次构建网络时需要创建新的变量，之后每次调用这个函数都可以直接使用。
    with tf.variable_scope('layer1', reuse=reuse):
        weights = tf.get_variable('weights', [input_node, layer1_node],initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable('biases', [layer1_node], initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
        
    with tf.variable_scopei('layer2', reuse=reuse):
        weights = tf.get_variable('weights', [layer1_node, output_node], initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable('biases', [output_node], initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases
        
    return layer2     # 神经网络最后的前向传播结果

x = tf.placeholder(tf.float32, [None, input_node], name='x-input')
y = inference(x)

newx = ...
new_y = inference(newx, True)

In [42]:
# 持久化，将模型保存到文件中
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, 'E:\\project\\TensorFlow\\save_model\\model.ckpt')

In [43]:
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, 'E:\\project\\TensorFlow\\save_model\\model.ckpt')
    print(sess.run(result))

INFO:tensorflow:Restoring parameters from E:\project\TensorFlow\save_model\model.ckpt


NotFoundError: Key v1_4 not found in checkpoint
	 [[Node: save_4/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_4/Const_0_0, save_4/RestoreV2/tensor_names, save_4/RestoreV2/shape_and_slices)]]

Caused by op 'save_4/RestoreV2', defined at:
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2802, in run_ast_nodes
    if self.run_code(code, result):
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-43-37de42c1a475>", line 5, in <module>
    saver = tf.train.Saver()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 1311, in __init__
    self.build()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 1320, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 1357, in _build
    build_save=build_save, build_restore=build_restore)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 809, in _build_internal
    restore_sequentially, reshape)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 448, in _AddRestoreOps
    restore_sequentially)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 860, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_io_ops.py", line 1541, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
    op_def=op_def)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Key v1_4 not found in checkpoint
	 [[Node: save_4/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_4/Const_0_0, save_4/RestoreV2/tensor_names, save_4/RestoreV2/shape_and_slices)]]


In [45]:
saver = tf.train.import_meta_graph('E:\\project\\TensorFlow\\save_model\\model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess, 'E:\\project\\TensorFlow\\save_model\\model.ckpt')
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

INFO:tensorflow:Restoring parameters from E:\project\TensorFlow\save_model\model.ckpt


InvalidArgumentError: You must feed a value for placeholder tensor 'x-input' with dtype float and shape [?,784]
	 [[Node: x-input = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'x-input', defined at:
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2802, in run_ast_nodes
    if self.run_code(code, result):
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-12-67766be83910>", line 95, in <module>
    tf.app.run()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 126, in run
    _sys.exit(main(argv))
  File "<ipython-input-12-67766be83910>", line 92, in main
    train(mnist)
  File "<ipython-input-12-67766be83910>", line 28, in train
    x = tf.placeholder(tf.float32, shape=(None, input_node), name = 'x-input')
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1777, in placeholder
    return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 5496, in placeholder
    "Placeholder", dtype=dtype, shape=shape, name=name)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
    op_def=op_def)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'x-input' with dtype float and shape [?,784]
	 [[Node: x-input = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]


In [47]:
#利用convert_to_constants函数将计算图中的变量及其取值通过常量的方式保存，
#这种方法可将计算图统一保存到一个文件中
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])   # 只保存相加的那个节点(每个计算就是图中的一个节点)
    with tf.gfile.GFile('E:\\project\\TensorFlow\\save_model\\one_file\\combined_model.pb','wb') as f:
        f.write(output_graph_def.SerializeToString())
    

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.


In [48]:
# 将以下代码放到一个新文件中即可运行成功。
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = 'E:\\project\\TensorFlow\\save_model\\one_file\\combined_model.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    result = tf.import_graph_def(graph_def, return_elements=['add:0'])
    print(sess.run(result))

InvalidArgumentError: You must feed a value for placeholder tensor 'import/x-input' with dtype float and shape [?,784]
	 [[Node: import/x-input = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'import/x-input', defined at:
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2802, in run_ast_nodes
    if self.run_code(code, result):
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-48-62a8923480a1>", line 9, in <module>
    result = tf.import_graph_def(graph_def, return_elements=['add:0'])
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
    return func(*args, **kwargs)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\importer.py", line 577, in import_graph_def
    op_def=op_def)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
    op_def=op_def)
  File "E:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'import/x-input' with dtype float and shape [?,784]
	 [[Node: import/x-input = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
