In [3]:
import tensorflow as tf

In [4]:
def freeze_session(output_node_names, destination, name="frozen_model.pb"):

    tf.keras.backend.set_learning_phase(0) # set inference phase
    session = tf.keras.backend.get_session()
    input_graph_def = session.graph.as_graph_def() # get graph def proto from keras session's graph

    with session.as_default():
        # Convert variables into constants so they will be stored into the graph def
        output_graph_def = tf.graph_util.convert_variables_to_constants(session, input_graph_def, output_node_names=output_node_names)
        tf.train.write_graph(graph_or_graph_def=output_graph_def, logdir=destination, name=name, as_text=False)

    #tf.keras.backend.clear_session()

In [5]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

In [6]:
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype('float32')/255

train_labels = tf.keras.utils.to_categorical(train_labels)

test_images = test_images.reshape((10000, 28, 28, 1))
test_images = test_images.astype('float32')/255

test_labels = tf.keras.utils.to_categorical(test_labels) 

In [7]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')])

Instructions for updating:
Colocations handled automatically by placer.


In [8]:
model.compile(optimizer='rmsprop', 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

In [9]:
model.fit(train_images, train_labels, epochs=5, batch_size=128)

Instructions for updating:
Use tf.cast instead.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x2338e710>

In [10]:
test_loss, test_accuracy = model.evaluate(test_images, test_labels)
print('test_accuracy:', test_accuracy)

test_accuracy: 0.9905


In [11]:
print(model.input.name)

conv2d_input:0


In [12]:
print(model.output.name)

dense_1/Softmax:0


In [13]:
output_layer_name = model.output.name.split(':')[0]
print(output_layer_name)

dense_1/Softmax


In [14]:
freeze_session([output_layer_name], destination='', name="MNIST.pb")

Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
INFO:tensorflow:Froze 10 variables.
INFO:tensorflow:Converted 10 variables to const ops.
