In [1]:
# TensorFlow 保存和导入计算图中的部分节点
#保存计算图中的节点
import numpy as np
import tensorflow as tf
from tensorflow.compat.v1 import graph_util
from tensorflow.python.platform import gfile

#建立测试用图
tf.reset_default_graph()

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')

res = tf.add(v1, v2, name='add_res')
res2 = tf.add(res, v2, name='add_res1')

graph_def = tf.get_default_graph().as_graph_def()
for node in graph_def.node:
    print(node.name)
print('show graph\n')

#保存图
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #获取到达add_res1的图    
    #convert_variables_to_constants：通过这个函数可以将计算图中的变量及其取值通过常量保存。
    #add_res1 没有“:0”，表示这是计算节点（operator），而“add_res:0” 表示节点计算后的输出张量。
    #这里的节点存在依赖关系，会自动查找所依赖的节点，并将相关的节点一起保存
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add_res1'])
    
for node in output_graph_def.node:
    print(node.name)
print('graph_def convert\n')
    
#文件句柄，保存文件。
with gfile.GFile('save/combined_model.pb', 'wb') as f:
    f.write(output_graph_def.SerializeToString())
#文件句柄，保存文件。
# tf.train.write_graph(graph_def, export_dir, 'expert-graph.pb', as_text=False)

Instructions for updating:
Colocations handled automatically by placer.
Const
v1
v1/Assign
v1/read
Const_1
v2
v2/Assign
v2/read
add_res
add_res1
show graph

Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
v1
v1/read
v2
v2/read
add_res
add_res1
graph_def convert



In [2]:
#清空当前图
tf.reset_default_graph()

#导入已保存的节点
model_filename = 'save/combined_model.pb'
with gfile.GFile(model_filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    
#这里导入了之前保存的节点，并且“:0”表示该节点的第一次输出的结果
#注意，每次使用该函数，就导入一次图中的节点，故对于关联的节点，必须同时导入才能相关，否则属于不同的组
result1 = tf.import_graph_def(graph_def, return_elements=['v1:0']) #导入节点
result2 = tf.import_graph_def(graph_def, return_elements=['add_res1:0'])#重复导入节点，导致冗余
graph_def = tf.get_default_graph().as_graph_def()
for node in graph_def.node:
    print(node.name)
print('show graph import\n')

with tf.Session() as sess:
    print(sess.run(result1))
    print(sess.run(result2))

import/v1
import/v1/read
import/v2
import/v2/read
import/add_res
import/add_res1
import_1/v1
import_1/v1/read
import_1/v2
import_1/v2/read
import_1/add_res
import_1/add_res1
show graph import

[array([1.], dtype=float32)]
[array([5.], dtype=float32)]


In [3]:
#导入节点，并修改对节点进行修改
tf.reset_default_graph()

model_filename = 'save/combined_model.pb'
with gfile.GFile(model_filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

for node in graph_def.node:
    print(node.name)
print('graph_def read\n')

#修改图节点
input_node_names_list = ['v2','add_res']
inputs_replaced_graph_def = tf.GraphDef()
for node in graph_def.node:
    if node.name in input_node_names_list:
        placeholder_node = tf.NodeDef()
        placeholder_node.op = "Placeholder"
        placeholder_node.name = node.name
        placeholder_node.attr["dtype"].CopyFrom(
            tf.AttrValue(type=tf.float32.as_datatype_enum))
        inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
        old_node = tf.NodeDef()
        old_node.CopyFrom(node)
        inputs_replaced_graph_def.node.extend([old_node])

for node in inputs_replaced_graph_def.node:
    print(node.name)
print('inputs_replaced_graph_def read\n')

#抽取子图
output_graph_def = graph_util.extract_sub_graph(
        inputs_replaced_graph_def, ['add_res1'])

for node in output_graph_def.node:
    print(node.name)
print('output_graph_def read\n')

v1
v1/read
v2
v2/read
add_res
add_res1
graph_def read

v1
v1/read
v2
v2/read
add_res
add_res1
inputs_replaced_graph_def read

v2
v2/read
add_res
add_res1
output_graph_def read



In [4]:
code1 = np.array([[5,],[4,]], dtype = np.float32)
# code2 = np.array([5,], dtype = np.float32)
code2 = code1

In [5]:
# 方法一
tf.reset_default_graph()

#导入tensor和graph_def
input_x,input_y,result = tf.import_graph_def(output_graph_def, return_elements=["v2:0","add_res:0","add_res1:0"])

graph_def = tf.get_default_graph().as_graph_def()
for node in graph_def.node:
    print(node.name)
print('show graph\n')

with tf.Session() as sess:
    print (input_x),print (input_y),print (result)
    print(sess.run(result,feed_dict={input_x:code1,input_y:code2}))

import/v2
import/v2/read
import/add_res
import/add_res1
show graph

Tensor("import/v2:0", dtype=float32)
Tensor("import/add_res:0", dtype=float32)
Tensor("import/add_res1:0", dtype=float32)
[[10.]
 [ 8.]]


In [6]:
#方法二
tf.reset_default_graph()

tf.import_graph_def(output_graph_def, name='') 

graph_def = tf.get_default_graph().as_graph_def()
for node in graph_def.node:
    print(node.name)
print('show graph\n')

with tf.Session() as sess:
    input_x = sess.graph.get_tensor_by_name("v2:0")
    input_y = sess.graph.get_tensor_by_name("add_res:0")
    result = sess.graph.get_tensor_by_name("add_res1:0")
    
    print (input_x),print (input_y),print (result)
    print(sess.run(result,feed_dict={input_x:code1,input_y:code2}))

v2
v2/read
add_res
add_res1
show graph

Tensor("v2:0", dtype=float32)
Tensor("add_res:0", dtype=float32)
Tensor("add_res1:0", dtype=float32)
[[10.]
 [ 8.]]


In [7]:
#方法三
tf.reset_default_graph()

result, = tf.import_graph_def(output_graph_def, return_elements=["add_res1:0"])#返回的是一个列表，取首元

graph_def = tf.get_default_graph().as_graph_def()
for node in graph_def.node:
    print(node.name)
print('show graph\n')

with tf.Session() as sess:
    input_x = sess.graph.get_tensor_by_name("import/v2:0")
    input_y = sess.graph.get_tensor_by_name("import/add_res:0")
    
    print (input_x),print (input_y),print (result)
    print(sess.run(result,feed_dict={input_x:code1,input_y:code2}))

import/v2
import/v2/read
import/add_res
import/add_res1
show graph

Tensor("import/v2:0", dtype=float32)
Tensor("import/add_res:0", dtype=float32)
Tensor("import/add_res1:0", dtype=float32)
[[10.]
 [ 8.]]


In [8]:
#其他实例
import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile

data = np.arange(10, dtype=np.int32)
with tf.Session() as sess:
    print("# build graph and run")
    input1 = tf.placeholder(tf.int32, [10], name="input")
    output1 = tf.add(input1, tf.constant(100, dtype=tf.int32),
                     name="output")  #  data depends on the input data
    saved_result = tf.Variable(data, name="saved_result")
    do_save = tf.assign(saved_result, output1)
    tf.initialize_all_variables()
    os.system("rm -rf /tmp/load")
    tf.train.write_graph(sess.graph_def, "/tmp/load", "test.pb", False)  #proto
    # now set the data:
    result, _ = sess.run(
        [output1, do_save],
        {input1: data})  # calculate output1 and assign to 'saved_result'
    saver = tf.train.Saver(tf.all_variables())
    saver.save(sess, "checkpoint.data")

with tf.Session() as persisted_sess:
    print("load graph")
    with gfile.FastGFile("/tmp/load/test.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        persisted_sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
    print("map variables")
    persisted_result = persisted_sess.graph.get_tensor_by_name(
        "saved_result:0")
    tf.add_to_collection(tf.GraphKeys.VARIABLES, persisted_result)
    try:
        saver = tf.train.Saver(
            tf.all_variables())  # 'Saver' misnomer! Better: Persister!
    except:
        pass
    print("load data")
    saver.restore(persisted_sess, "checkpoint.data")  # now OK
    print(persisted_result.eval())
    print("DONE")

# build graph and run
Instructions for updating:
Use `tf.global_variables_initializer` instead.
Object was never used (type <class 'tensorflow.python.framework.ops.Operation'>):
<tf.Operation 'init' type=NoOp>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "D:\MyApp\Anaconda3\envs\tf1\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)  File "D:\MyApp\Anaconda3\envs\tf1\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)  File "D:\MyApp\Anaconda3\envs\tf1\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()  File "D:\MyApp\Anaconda3\envs\tf1\lib\site-packages\traitlets\config\application.py", line 664, in launch_instance
    app.start()  File "D:\MyApp\Anaconda3\envs\tf1\lib\site-packages\ipykernel\kernelapp.py", line 563, in start
    self.io_loop.start()  File "D:\MyApp\Anaconda3\envs\tf1\lib\site-packages\tornado\platform\asyncio.py", line 148, in start
  