In [37]:
import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10
HIDDEN_NODE = 500

def get_weight(shape, regularizer):
    weight = tf.get_variable(name="weight", shape=shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
    if not regularizer:
        tf.add_to_collection("losses", regularizer(weight))
    return weight

def get_bias(shape):
    return tf.get_variable(name="bias", shape=shape, initializer=tf.zeros_initializer())
    
def inference(input_tensor, regularizer):
    with tf.variable_scope("layer1"):
        weight = get_weight([INPUT_NODE, HIDDEN_NODE], regularizer)
        bias = get_bias([HIDDEN_NODE])
        a1 = tf.nn.relu(tf.matmul(input_tensor, weight) + bias)
    with tf.variable_scope("layer2"):
        weight = get_weight([HIDDEN_NODE, OUTPUT_NODE], regularizer)
        bias = get_bias([OUTPUT_NODE])
        y = tf.matmul(a1, weight) + bias
    return y

In [4]:
import tensorflow as tf
import mnist_inference
import os
from tensorflow.examples.tutorials.mnist import input_data

# network setting
EPOCH = 3000
BATCH_SIZE = 128
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
MOVEING_AVERAGE_DECAY = 0.99
REGULARAZTION_RATE = 0.01
MODEL_SAVE_PATH = 'model_save'
MODEL_SAVE_NAME = 'fully_connected_model'

def train(mnist):
    x = tf.placeholder(dtype=tf.float32, shape=[None, mnist_inference.INPUT_NODE], name="x-input")
    y_ = tf.placeholder(dtype=tf.float32, shape=[None, mnist_inference.OUTPUT_NODE], name="y-output")
    
    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
    y = inference.inference(x, regularizer)
    
    cem_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
    loss = cem_loss + tf.add_n(tf.get_collection("losses"))
    
    global_step = tf.Variable(0, trainable=False)
    
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.examples/BATCH_SIZE,LEARNING_RATE_DECAY, True)
    train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step)
    
    ema = tf.train.ExponentialMovingAverage(MOVEING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.global_variables())
    
    train_op = tf.group([train, ema_op])
    
    init_op = tf.global_variables_initializer()
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        for i in range(EPOCH):
            xs, ys = mnist.next_batches(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
            if i % 1000 == 0:
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_SAVE_NAME), global_step=step)
                print("after %i steps, the loss is %f." % (i, loss_value))
                
def main(argv=None):
    mnist = input_data.read_data_sets("../mnist/", one_hot = True)
    train(mnist)
    
if __name__ == "__main__":
    rf.app.run()

In [6]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
import time

EVAL_INTERVAL_SECS = 20

def evaluate(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(dtype=tf.float32, shape=[None, mnist_inference.INPUT_NODE], name="x-input")
        y_ = tf.placeholder(dtype=tf.float32, shape=[None, mnist_inference.OUTPUT_NODE], name="y-input")

        y = inference.inference(x, None)

        ema = tf.train.ExponentialMovingAverage(mnist_train.MOVEING_AVERAGE_DECAY)
        ema_restore = ema.variables_ro_restore()
        saver = tf.train.Saver(ema_restore)
        
        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32))
        
        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(mnist_train.MMODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
                    validation_acc = sess.run(accuracy, feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
                    print("After %d steps, the accuracy on validation set is %g." % (global_step, validation_acc))
                else:
                    print("checkpoint model not found!")
            time.sleep(EVAL_INTERVAL_SECS)

def main(argv=None):
    mnist = input_data.read_data_sets("../mnist/", one_hot=True)
    evaluate(mnist)
    
if __name__ == "__main__":
    tf.app.run()