In [1]:
import tensorflow as tf
from google.protobuf import text_format

다음을 이용해 protobuf 설치 필요
``` bash
pip install protobuf
```

### Loading GraphDef

Binary File 이용

In [4]:
with tf.gfile.GFile('saved/mnist_graphdef.pb', 'rb') as f:
    graph_def_binary = f.read()
graph_def1 = tf.GraphDef()
graph_def1.ParseFromString(graph_def_binary)

Text file 이용

In [2]:
with tf.gfile.GFile('saved/mnist_graphdef.pbtxt', 'r') as f:
    graph_def_text = f.read()
graph_def2 = tf.GraphDef()
text_format.Merge(graph_def_text, graph_def2)

node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 784
        }
      }
    }
  }
}
node {
  name: "Reshape/shape"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 4
          }
        }
        tensor_content: "\377\377\377\377\034\000\000\000\034\000\000\000\001\000\000\000"
      }
    }
  }
}
node {
  name: "Reshape"
  op: "Reshape"
  input: "Placeholder"
  input: "Reshape/shape"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "Tshape"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Placeholder_1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    

In [5]:
graph_def1 == graph_def2

True

### Importing GraphDef
Using trained model and graph from https://github.com/hunkim/DeepLearningZeroToAll/blob/master/lab-11-2-mnist_deep_cnn.py

In [6]:
g = tf.get_default_graph()

In [7]:
with g.as_default():
    tf.import_graph_def(graph_def1, name='')

In [8]:
g.get_operations()

[<tf.Operation 'Placeholder' type=Placeholder>,
 <tf.Operation 'Reshape/shape' type=Const>,
 <tf.Operation 'Reshape' type=Reshape>,
 <tf.Operation 'Placeholder_1' type=Placeholder>,
 <tf.Operation 'random_normal/shape' type=Const>,
 <tf.Operation 'random_normal/mean' type=Const>,
 <tf.Operation 'random_normal/stddev' type=Const>,
 <tf.Operation 'random_normal/RandomStandardNormal' type=RandomStandardNormal>,
 <tf.Operation 'random_normal/mul' type=Mul>,
 <tf.Operation 'random_normal' type=Add>,
 <tf.Operation 'Variable' type=VariableV2>,
 <tf.Operation 'Variable/Assign' type=Assign>,
 <tf.Operation 'Variable/read' type=Identity>,
 <tf.Operation 'Conv2D' type=Conv2D>,
 <tf.Operation 'Relu' type=Relu>,
 <tf.Operation 'MaxPool' type=MaxPool>,
 <tf.Operation 'random_normal_1/shape' type=Const>,
 <tf.Operation 'random_normal_1/mean' type=Const>,
 <tf.Operation 'random_normal_1/stddev' type=Const>,
 <tf.Operation 'random_normal_1/RandomStandardNormal' type=RandomStandardNormal>,
 <tf.Operati

맞춘 여부를 반환하는 node
```python
# Test model and check accuracy
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))
```

In [9]:
g.get_tensor_by_name('Equal:0')

<tf.Tensor 'Equal:0' shape=(?,) dtype=bool>

In [10]:
[n.name for n in graph_def1.node]

['Placeholder',
 'Reshape/shape',
 'Reshape',
 'Placeholder_1',
 'random_normal/shape',
 'random_normal/mean',
 'random_normal/stddev',
 'random_normal/RandomStandardNormal',
 'random_normal/mul',
 'random_normal',
 'Variable',
 'Variable/Assign',
 'Variable/read',
 'Conv2D',
 'Relu',
 'MaxPool',
 'random_normal_1/shape',
 'random_normal_1/mean',
 'random_normal_1/stddev',
 'random_normal_1/RandomStandardNormal',
 'random_normal_1/mul',
 'random_normal_1',
 'Variable_1',
 'Variable_1/Assign',
 'Variable_1/read',
 'Conv2D_1',
 'Relu_1',
 'MaxPool_1',
 'Reshape_1/shape',
 'Reshape_1',
 'W3/Initializer/random_uniform/shape',
 'W3/Initializer/random_uniform/min',
 'W3/Initializer/random_uniform/max',
 'W3/Initializer/random_uniform/RandomUniform',
 'W3/Initializer/random_uniform/sub',
 'W3/Initializer/random_uniform/mul',
 'W3/Initializer/random_uniform',
 'W3',
 'W3/Assign',
 'W3/read',
 'random_normal_2/shape',
 'random_normal_2/mean',
 'random_normal_2/stddev',
 'random_normal_2/RandomS

`name`을 명시하지 않으면 찾기 어려우므로 잘 명명하는 것이 중요할 수 있음

### `GraphDef`를 이용하여 Node 연결
https://www.tensorflow.org/api_docs/python/tf/import_graph_def
```python
tf.import_graph_def(
    graph_def,
    input_map=None,
    return_elements=None,
    name=None,
    op_dict=None,
    producer_op_list=None
)
```

`input_map`: 이미 존재하는 TF의 Node를 `Dictionary`로 지정하여 Node 대체

`return_elements`: 우리가 원하는 `GrpahDef` 내의 Output node

In [19]:
from tensorflow.examples.tutorials.mnist import input_data
tf.set_random_seed(777)  # reproducibility
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [29]:
img = mnist.test.images[0].reshape([1, 28, 28, 1])
label = mnist.test.labels[0]

In [57]:
new_g = tf.Graph()
with new_g.as_default():
    tf_img = tf.constant(img)
    pred = tf.import_graph_def(graph_def1, input_map={'Reshape:0':tf_img}, return_elements=['ArgMax_3'], name='')
#     pred = g.get_tensor_by_name('ArgMax_3:0')

In [58]:
# 실행하지 마세요.
with tf.Session(graph=new_g) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(pred))

`tf.get_variable`로 불러온 변수들은 `GraphDef`에 저장되지 않음!!

#### 다른 문제점

In [18]:
g.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

[<tf.Variable 'Variable_3:0' shape=() dtype=int32_ref>]

In [14]:
var2 = g.get_operation_by_name('Variable_2')

In [15]:
tf.Variable(1)

<tf.Variable 'Variable_3:0' shape=() dtype=int32_ref>

In [17]:
g.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

[<tf.Variable 'Variable_3:0' shape=() dtype=int32_ref>]

#### Import한 `Graph`의 `Variable`이 `GraphKeys`에 등록이 되지 않음.

https://github.com/tensorflow/tensorflow/issues/696

https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model/33763208#33763208

### Saver와 MetaGraphDef를 이용
https://www.tensorflow.org/api_docs/python/tf/train/import_meta_graph

https://blog.simonszu.de/2018/01/tensorflow---connect-two-graphs/

In [83]:
new_g = tf.Graph()
with new_g.as_default():
    tf_img = tf.constant(img)
    new_saver = tf.train.import_meta_graph('saved/mnist.ckpt.meta', input_map={'Reshape:0':tf_img})

In [84]:
with tf.Session(graph=new_g) as sess:
    new_saver.restore(sess, save_path='mnist.ckpt')
    pred = new_g.get_tensor_by_name('ArgMax_3:0')
    print(sess.run(pred))

INFO:tensorflow:Restoring parameters from mnist.ckpt
[7]
