In [12]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import math

mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

half_test_images = mnist.test.images.shape[0] / 2

x_train, y_train = mnist.train.images, mnist.train.labels
x_valid, y_valid = mnist.test.images[:half_test_images], mnist.test.labels[:half_test_images]
x_test, y_test = mnist.test.images[half_test_images:], mnist.test.labels[half_test_images:]

print("x_train.shape: {}, y_train.shape: {}".format(x_train.shape, y_train.shape))
print("x_valid.shape: {}, y_valid.shape: {}".format(x_valid.shape, y_valid.shape))
print("x_test.shape: {}, y_test.shape: {}".format(x_test.shape, y_test.shape))

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
x_train.shape: (55000, 784), y_train.shape: (55000, 10)
x_valid.shape: (5000, 784), y_valid.shape: (5000, 10)
x_test.shape: (5000, 784), y_test.shape: (5000, 10)


In [19]:
learning_rate = 1e-4
epochs = 10
batch_size = 50

x = tf.placeholder(tf.float32, shape=[None, 784])
x_shaped = tf.reshape(x, [-1, 28, 28, 1])
y = tf.placeholder(tf.float32, shape=[None, 10])

def create_conv2d(input_data, num_input_channels, num_filters, filter_shape, pool_shape, name):
    conv_filter_shape = [filter_shape[0], filter_shape[1], num_input_channels, num_filters]
    
    weights = tf.Variable(tf.truncated_normal(conv_filter_shape, stddev=0.03), name=name+"_W")
    bias = tf.Variable(tf.truncated_normal([num_filters]), name=name+"_b")
    
    out_layer = tf.nn.conv2d(input_data, weights, (1, 1, 1, 1), padding="SAME")
    out_layer += bias
    out_layer = tf.nn.relu(out_layer)
    out_layer = tf.nn.max_pool(out_layer, ksize=(1, pool_shape[0], pool_shape[1], 1), strides=(1, 2, 2, 1), padding="SAME")
    return out_layer

layer1 = create_conv2d(x_shaped, 1, 32, (5, 5), (2, 2), name="layer1")
layer2 = create_conv2d(layer1, 32, 64, (5, 5), (2, 2), name="layer2")
flattened = tf.reshape(layer2, (-1, 7 * 7 * 64))

wd1 = tf.Variable(tf.truncated_normal((7 * 7 * 64, 1000), stddev=0.03), name="wd1")
bd1 = tf.Variable(tf.truncated_normal([1000], stddev=0.01), name="bd1")
dense_layer1 = tf.add(tf.matmul(flattened, wd1), bd1)
dense_layer1 = tf.nn.relu(dense_layer1)

wd2 = tf.Variable(tf.truncated_normal((1000, 10), stddev=0.03), name="wd2")
bd2 = tf.Variable(tf.truncated_normal([10], stddev=0.01), name="bd2")
dense_layer2 = tf.add(tf.matmul(dense_layer1, wd2), bd2)
y_ = tf.nn.softmax(dense_layer2)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

iteration = 0
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    batch_count = int(math.ceil(x_train.shape[0] / float(batch_size)))
    for e in range(epochs):
        for batch_i in range(batch_count):
            batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
            _, loss = sess.run([optimizer, cost], feed_dict={x: batch_x, y: batch_y})
            
            if batch_i % 20 == 0:
                print("Epoch: {}/{}".format(e+1, epochs), 
                      "Iteration: {}".format(iteration), 
                      "Training loss: {:.5f}".format(loss))
            iteration += 1
            
            if iteration % batch_size == 0:
                valid_acc = sess.run(accuracy, feed_dict={x: x_valid, y: y_valid})
                print("Epoch: {}/{}".format(e, epochs),
                      "Iteration: {}".format(iteration),
                      "Validation Accuracy: {:.5f}".format(valid_acc))

    saver.save(sess, "checkpoints/mnist_cnn_tf.ckpt")

('Epoch: 1/10', 'Iteration: 0', 'Training loss: 2.32924')
('Epoch: 1/10', 'Iteration: 20', 'Training loss: 2.30679')
('Epoch: 1/10', 'Iteration: 40', 'Training loss: 2.28476')
('Epoch: 0/10', 'Iteration: 50', 'Validation Accuracy: 0.11420')
('Epoch: 1/10', 'Iteration: 60', 'Training loss: 2.30400')
('Epoch: 1/10', 'Iteration: 80', 'Training loss: 2.31607')
('Epoch: 0/10', 'Iteration: 100', 'Validation Accuracy: 0.09240')
('Epoch: 1/10', 'Iteration: 100', 'Training loss: 2.30153')
('Epoch: 1/10', 'Iteration: 120', 'Training loss: 2.27674')
('Epoch: 1/10', 'Iteration: 140', 'Training loss: 2.26295')
('Epoch: 0/10', 'Iteration: 150', 'Validation Accuracy: 0.19260')
('Epoch: 1/10', 'Iteration: 160', 'Training loss: 2.27241')
('Epoch: 1/10', 'Iteration: 180', 'Training loss: 2.13509')
('Epoch: 0/10', 'Iteration: 200', 'Validation Accuracy: 0.39960')
('Epoch: 1/10', 'Iteration: 200', 'Training loss: 2.11183')
('Epoch: 1/10', 'Iteration: 220', 'Training loss: 2.01921')
('Epoch: 1/10', 'Iterat

('Epoch: 1/10', 'Iteration: 1900', 'Validation Accuracy: 0.92660')
('Epoch: 2/10', 'Iteration: 1900', 'Training loss: 1.55960')
('Epoch: 2/10', 'Iteration: 1920', 'Training loss: 1.51721')
('Epoch: 2/10', 'Iteration: 1940', 'Training loss: 1.54494')
('Epoch: 1/10', 'Iteration: 1950', 'Validation Accuracy: 0.92580')
('Epoch: 2/10', 'Iteration: 1960', 'Training loss: 1.52806')
('Epoch: 2/10', 'Iteration: 1980', 'Training loss: 1.53460')
('Epoch: 1/10', 'Iteration: 2000', 'Validation Accuracy: 0.92660')
('Epoch: 2/10', 'Iteration: 2000', 'Training loss: 1.51763')
('Epoch: 2/10', 'Iteration: 2020', 'Training loss: 1.50209')
('Epoch: 2/10', 'Iteration: 2040', 'Training loss: 1.50146')
('Epoch: 1/10', 'Iteration: 2050', 'Validation Accuracy: 0.93060')
('Epoch: 2/10', 'Iteration: 2060', 'Training loss: 1.48412')
('Epoch: 2/10', 'Iteration: 2080', 'Training loss: 1.55136')
('Epoch: 1/10', 'Iteration: 2100', 'Validation Accuracy: 0.92580')
('Epoch: 2/10', 'Iteration: 2100', 'Training loss: 1.55

('Epoch: 4/10', 'Iteration: 3760', 'Training loss: 1.48085')
('Epoch: 4/10', 'Iteration: 3780', 'Training loss: 1.46663')
('Epoch: 3/10', 'Iteration: 3800', 'Validation Accuracy: 0.95780')
('Epoch: 4/10', 'Iteration: 3800', 'Training loss: 1.46960')
('Epoch: 4/10', 'Iteration: 3820', 'Training loss: 1.49716')
('Epoch: 4/10', 'Iteration: 3840', 'Training loss: 1.54839')
('Epoch: 3/10', 'Iteration: 3850', 'Validation Accuracy: 0.95740')
('Epoch: 4/10', 'Iteration: 3860', 'Training loss: 1.50936')
('Epoch: 4/10', 'Iteration: 3880', 'Training loss: 1.50019')
('Epoch: 3/10', 'Iteration: 3900', 'Validation Accuracy: 0.95740')
('Epoch: 4/10', 'Iteration: 3900', 'Training loss: 1.51012')
('Epoch: 4/10', 'Iteration: 3920', 'Training loss: 1.46336')
('Epoch: 4/10', 'Iteration: 3940', 'Training loss: 1.55328')
('Epoch: 3/10', 'Iteration: 3950', 'Validation Accuracy: 0.96420')
('Epoch: 4/10', 'Iteration: 3960', 'Training loss: 1.48221')
('Epoch: 4/10', 'Iteration: 3980', 'Training loss: 1.47506')


('Epoch: 6/10', 'Iteration: 5640', 'Training loss: 1.49307')
('Epoch: 5/10', 'Iteration: 5650', 'Validation Accuracy: 0.97060')
('Epoch: 6/10', 'Iteration: 5660', 'Training loss: 1.49561')
('Epoch: 6/10', 'Iteration: 5680', 'Training loss: 1.46590')
('Epoch: 5/10', 'Iteration: 5700', 'Validation Accuracy: 0.96880')
('Epoch: 6/10', 'Iteration: 5700', 'Training loss: 1.48898')
('Epoch: 6/10', 'Iteration: 5720', 'Training loss: 1.46825')
('Epoch: 6/10', 'Iteration: 5740', 'Training loss: 1.48273')
('Epoch: 5/10', 'Iteration: 5750', 'Validation Accuracy: 0.97100')
('Epoch: 6/10', 'Iteration: 5760', 'Training loss: 1.49462')
('Epoch: 6/10', 'Iteration: 5780', 'Training loss: 1.49847')
('Epoch: 5/10', 'Iteration: 5800', 'Validation Accuracy: 0.97220')
('Epoch: 6/10', 'Iteration: 5800', 'Training loss: 1.46815')
('Epoch: 6/10', 'Iteration: 5820', 'Training loss: 1.47542')
('Epoch: 6/10', 'Iteration: 5840', 'Training loss: 1.47016')
('Epoch: 5/10', 'Iteration: 5850', 'Validation Accuracy: 0.96

('Epoch: 7/10', 'Iteration: 7520', 'Training loss: 1.49520')
('Epoch: 7/10', 'Iteration: 7540', 'Training loss: 1.48791')
('Epoch: 6/10', 'Iteration: 7550', 'Validation Accuracy: 0.97540')
('Epoch: 7/10', 'Iteration: 7560', 'Training loss: 1.46854')
('Epoch: 7/10', 'Iteration: 7580', 'Training loss: 1.48774')
('Epoch: 6/10', 'Iteration: 7600', 'Validation Accuracy: 0.97240')
('Epoch: 7/10', 'Iteration: 7600', 'Training loss: 1.51465')
('Epoch: 7/10', 'Iteration: 7620', 'Training loss: 1.47204')
('Epoch: 7/10', 'Iteration: 7640', 'Training loss: 1.48266')
('Epoch: 6/10', 'Iteration: 7650', 'Validation Accuracy: 0.97960')
('Epoch: 7/10', 'Iteration: 7660', 'Training loss: 1.48114')
('Epoch: 7/10', 'Iteration: 7680', 'Training loss: 1.46453')
('Epoch: 6/10', 'Iteration: 7700', 'Validation Accuracy: 0.97700')
('Epoch: 8/10', 'Iteration: 7700', 'Training loss: 1.46239')
('Epoch: 8/10', 'Iteration: 7720', 'Training loss: 1.47838')
('Epoch: 8/10', 'Iteration: 7740', 'Training loss: 1.46170')


('Epoch: 8/10', 'Iteration: 9400', 'Validation Accuracy: 0.97620')
('Epoch: 9/10', 'Iteration: 9400', 'Training loss: 1.49897')
('Epoch: 9/10', 'Iteration: 9420', 'Training loss: 1.46127')
('Epoch: 9/10', 'Iteration: 9440', 'Training loss: 1.48471')
('Epoch: 8/10', 'Iteration: 9450', 'Validation Accuracy: 0.97800')
('Epoch: 9/10', 'Iteration: 9460', 'Training loss: 1.47302')
('Epoch: 9/10', 'Iteration: 9480', 'Training loss: 1.46614')
('Epoch: 8/10', 'Iteration: 9500', 'Validation Accuracy: 0.97760')
('Epoch: 9/10', 'Iteration: 9500', 'Training loss: 1.46119')
('Epoch: 9/10', 'Iteration: 9520', 'Training loss: 1.46131')
('Epoch: 9/10', 'Iteration: 9540', 'Training loss: 1.46764')
('Epoch: 8/10', 'Iteration: 9550', 'Validation Accuracy: 0.97760')
('Epoch: 9/10', 'Iteration: 9560', 'Training loss: 1.46839')
('Epoch: 9/10', 'Iteration: 9580', 'Training loss: 1.46406')
('Epoch: 8/10', 'Iteration: 9600', 'Validation Accuracy: 0.98000')
('Epoch: 9/10', 'Iteration: 9600', 'Training loss: 1.46

In [20]:
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints/'))
    
    test_acc = sess.run(accuracy, feed_dict={x: x_test, y: y_test})
    print("test accuracy: {:.5f}".format(test_acc))

INFO:tensorflow:Restoring parameters from checkpoints/mnist_cnn_tf.ckpt
test accuracy: 0.99180
