In [1]:
"""tensorflow - 1.3.0
python - 3.6.3"""

import os
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

from tensorflow.examples.tutorials.mnist import input_data

MODEL_NAME = 'mnist_for_android_convnet'
NUM_STEPS = 3000
BATCH_SIZE = 16
os_path = os.path


# model input 
def model_input(input_node_name, keep_prob_node_name):
    x = tf.placeholder(tf.float32, shape=[None, 28 * 28], name=input_node_name)
    keep_prob = tf.placeholder(tf.float32, name=keep_prob_node_name)
    y_ = tf.placeholder(tf.float32, shape=[None, 10])
    return x, keep_prob, y_

In [2]:
# build the model
def build_model(x, keep_prob, y_, output_node_name):
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    # 28*28*1

    conv1 = tf.layers.conv2d(x_image, 64, 3, 1, 'same', activation=tf.nn.relu)
    # 28*28*64

    pool1 = tf.layers.max_pooling2d(conv1, 2, 2, 'same')
    # 14*14*64

    conv2 = tf.layers.conv2d(pool1, 128, 3, 1, 'same', activation=tf.nn.relu)
    # 14*14*128

    pool2 = tf.layers.max_pooling2d(conv2, 2, 2, 'same')
    # 7*7*128

    conv3 = tf.layers.conv2d(pool2, 256, 3, 1, 'same', activation=tf.nn.relu)
    # 7*7*256

    pool3 = tf.layers.max_pooling2d(conv3, 2, 2, 'same')
    # 4*4*256

    flatten = tf.reshape(pool3, [-1, 4 * 4 * 256])
    fc = tf.layers.dense(flatten, 1024, activation=tf.nn.relu)
    dropout = tf.nn.dropout(fc, keep_prob)
    logits = tf.layers.dense(dropout, 10)
    outputs = tf.nn.softmax(logits, name=output_node_name)

    # loss
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))

    # train step
    train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)

    # accuracy
    correct_prediction = tf.equal(tf.argmax(outputs, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    tf.summary.scalar("loss", loss)
    tf.summary.scalar("accuracy", accuracy)
    merged_summary_op = tf.summary.merge_all()

    return train_step, loss, accuracy, merged_summary_op


In [3]:
# train the model
def train(x, keep_prob, y_, train_step, loss, accuracy, merged_summary_op, saver):
    print('training started...')
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init_op)
        tf.train.write_graph(sess.graph_def, 'out', MODEL_NAME + '.pbtxt', True)

        # op to write logs to Tensorboard
        summary_writer = tf.summary.FileWriter('logs/', graph=tf.get_default_graph())

        for step in range(NUM_STEPS):
            batch = mnist.train.next_batch(BATCH_SIZE)
            if step % 100 == 0:
                train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
                print('step %d, training accuracy %f' % (step, train_accuracy))
            _, summary = sess.run([train_step, merged_summary_op], feed_dict={x: batch[0], y_: batch[1], keep_prob: .5})
            summary_writer.add_summary(summary, step)
        saver.save(sess, 'out/' + MODEL_NAME + '.chkp')
        test_accuracy = accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
        print('test accuracy %g' % test_accuracy)

    print('training finished...!')

In [4]:
# export the model
def export_model(input_node_names, output_node_name):
    freeze_graph.freeze_graph('out/' + MODEL_NAME + '.pbtxt', None, False,
                              'out/' + MODEL_NAME + '.chkp', output_node_name, "save/restore_all",
                              "save/Const:0", 'out/frozen_' + MODEL_NAME + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def, input_node_names, [output_node_name],
        tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")


In [5]:
def main():
    if not os_path.exists('out'):
        os.mkdir('out')

    input_node_name = 'input'
    keep_prob_node_name = 'keep_prob'
    output_node_name = 'output'

    x, keep_prob, y_ = model_input(input_node_name, keep_prob_node_name)

    train_step, loss, accuracy, merged_summary_op = build_model(x, keep_prob, y_, output_node_name)
    saver = tf.train.Saver()

    train(x, keep_prob, y_, train_step, loss, accuracy, merged_summary_op, saver)

    export_model([input_node_name, keep_prob_node_name], output_node_name)


In [6]:
if __name__ == '__main__':
    main()


training started...
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz


Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


step 0, training accuracy 0.125000


step 100, training accuracy 0.937500


step 200, training accuracy 0.750000


step 300, training accuracy 0.812500


step 400, training accuracy 0.937500


step 500, training accuracy 0.937500


step 600, training accuracy 1.000000


step 700, training accuracy 1.000000


step 800, training accuracy 0.875000


step 900, training accuracy 0.937500


step 1000, training accuracy 1.000000


step 1100, training accuracy 0.937500


step 1200, training accuracy 0.875000


step 1300, training accuracy 1.000000


step 1400, training accuracy 1.000000


step 1500, training accuracy 1.000000


step 1600, training accuracy 0.937500


step 1700, training accuracy 1.000000


step 1800, training accuracy 1.000000


step 1900, training accuracy 0.937500


step 2000, training accuracy 1.000000


step 2100, training accuracy 1.000000


step 2200, training accuracy 0.875000


step 2300, training accuracy 1.000000


step 2400, training accuracy 1.000000


step 2500, training accuracy 1.000000


step 2600, training accuracy 1.000000


step 2700, training accuracy 0.937500


step 2800, training accuracy 1.000000


step 2900, training accuracy 1.000000


test accuracy 0.9843
training finished...!


INFO:tensorflow:Restoring parameters from out/mnist_for_android_convnet.chkp


INFO:tensorflow:Froze 10 variables.


Converted 10 variables to const ops.
55 ops in the final graph.


graph saved!
