High Accuracy CNN on MNIST
---------------------------

The goal is to train a CNN to classify MNIST data with remarkable accuracy. After a couple of experiments, the architecture descibed below together with dropout and early stopping achieved this goal. Obviously, better performance can still be achieved with more complex architectures.

#### Import required libraries

In [1]:
import tensorflow as tf
import numpy as np

#### Build CNN
Architecture: three conv layers, one pool layer and two fully connected layers. Dropout applied after pool layer and after first fully connected layer

In [2]:
# MNIST image
height = 28
width = 28
channels = 1
n_inputs = height * width

# conv layer 1
c1_n_fmaps = 16
c1_kernel_size = 3
c1_strides = 1
c1_padding = 'SAME'

# conv layer 2
c2_n_fmaps = 32
c2_kernel_size = 3
c2_strides = 1
c2_padding = "SAME"

# conv layer 3
c3_n_fmaps = 64
c3_kernel_size = 3
c3_strides = 1
c3_padding = "SAME"

# pooling layer
p1_kernel_size = 2
p1_strides = 2
p1_padding = "VALID"
p1_n_fmaps = c3_n_fmaps

# fully connected
fc1_units = 128
outputs = 10

# Build graph
tf.reset_default_graph()

with tf.name_scope("inputs"):
    X = tf.placeholder(tf.float32,shape=(None,n_inputs),name='X')
    X_reshaped = tf.reshape(X,shape=(-1,height,width,channels))
    y = tf.placeholder(tf.int32,shape=(None),name='y')
    is_training = tf.placeholder(tf.bool,shape=(None),name='is_training')
    
c1 = tf.layers.conv2d(X_reshaped,filters=c1_n_fmaps,kernel_size=c1_kernel_size,strides=c1_strides,padding=c1_padding,
                     activation=tf.nn.relu,name='C1')
c2 = tf.layers.conv2d(c1,filters=c2_n_fmaps,kernel_size=c2_kernel_size,strides=c2_strides,padding=c2_padding,
                     activation=tf.nn.relu,name='C2')
c3 = tf.layers.conv2d(c2,filters=c3_n_fmaps,kernel_size=c3_kernel_size,strides=c3_strides,padding=c3_padding,
                     activation=tf.nn.relu,name='C3')

with tf.name_scope('pool1'):
    p1 = tf.nn.max_pool(c3,ksize=[1,p1_kernel_size,p1_kernel_size,1],strides=[1,p1_strides,p1_strides,1],padding=p1_padding,
                   name="p1")
    p1_flat = tf.reshape(p1,shape=[-1,p1_n_fmaps*14*14])
    #dropout
    p1_flat_drop = tf.layers.dropout(p1_flat,rate=0.25,training=is_training)
    
    
# fully connected layer fc1
with tf.name_scope("fc1"):
    fc1 = tf.layers.dense(p1_flat_drop,fc1_units,activation=tf.nn.relu,name='fc1')
    fc1_drop = tf.layers.dropout(fc1,rate=0.5,training=is_training)

# fully connected output layer
with tf.name_scope('outputs'):
    logits = tf.layers.dense(fc1_drop,outputs,name='logits')
    y_proba = tf.nn.softmax(logits,name='y_proba')
    
with tf.name_scope('loss'):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=y)
    loss = tf.reduce_mean(xentropy,name='loss')
    
with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)
    
with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits,y,1)
    accuracy = tf.reduce_mean(tf.cast(correct,tf.float32),name='accuracy')
    
with tf.name_scope('init_and_save'):
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

#### Extract MNIST data

In [3]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


#### Train CNN

In [4]:
# helper functions for implementing early stopping

def get_model_params():
    gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    return {gvar.op.name: value for gvar, value in zip(gvars, tf.get_default_session().run(gvars))}

def restore_model_params(model_params):
    gvar_names = list(model_params.keys())
    assign_ops = {gvar_name: tf.get_default_graph().get_operation_by_name(gvar_name + "/Assign")
                  for gvar_name in gvar_names}
    init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
    feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
    tf.get_default_session().run(assign_ops, feed_dict=feed_dict)

In [5]:
n_epochs = 1000
batch_size = 50

best_loss_val = np.infty
check_interval = 500
checks_since_last_progress = 0
max_checks_without_progress = 20
best_model_params = None 

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch,is_training: True})
            if iteration % check_interval == 0:
                loss_val = loss.eval(feed_dict={X: mnist.validation.images,
                                                y: mnist.validation.labels,is_training: False})
                if loss_val < best_loss_val:
                    best_loss_val = loss_val
                    checks_since_last_progress = 0
                    best_model_params = get_model_params()
                else:
                    checks_since_last_progress += 1
        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch,is_training: False})
        acc_val = accuracy.eval(feed_dict={X: mnist.validation.images,
                                           y: mnist.validation.labels,is_training: False})
        print("Epoch {}, train accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}".format(
                  epoch, acc_train * 100, acc_val * 100, best_loss_val))
        if checks_since_last_progress > max_checks_without_progress:
            print("Early stopping!")
            break

    if best_model_params:
        restore_model_params(best_model_params)
    acc_test = accuracy.eval(feed_dict={X: mnist.test.images,
                                        y: mnist.test.labels,is_training: False})
    print("Final accuracy on test set:", acc_test)
    save_path = saver.save(sess, "./mnist_model_final")

Epoch 0, train accuracy: 100.0000%, valid. accuracy: 98.4200%, valid. best loss: 0.058289
Epoch 1, train accuracy: 98.0000%, valid. accuracy: 98.7200%, valid. best loss: 0.040067
Epoch 2, train accuracy: 100.0000%, valid. accuracy: 99.0200%, valid. best loss: 0.036280
Epoch 3, train accuracy: 98.0000%, valid. accuracy: 98.9800%, valid. best loss: 0.031201
Epoch 4, train accuracy: 100.0000%, valid. accuracy: 99.1400%, valid. best loss: 0.031201
Epoch 5, train accuracy: 100.0000%, valid. accuracy: 99.2200%, valid. best loss: 0.030522
Epoch 6, train accuracy: 100.0000%, valid. accuracy: 99.3400%, valid. best loss: 0.030522
Epoch 7, train accuracy: 100.0000%, valid. accuracy: 99.1800%, valid. best loss: 0.028896
Epoch 8, train accuracy: 98.0000%, valid. accuracy: 99.2000%, valid. best loss: 0.028896
Epoch 9, train accuracy: 100.0000%, valid. accuracy: 99.1200%, valid. best loss: 0.028896
Epoch 10, train accuracy: 100.0000%, valid. accuracy: 99.3800%, valid. best loss: 0.027055
Epoch 11, tr