In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import batch_norm, dropout
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

# to make this notebook's output stable across runs
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [19]:
def he_normal_initialisation(n_inputs, n_outputs):
    stddev = np.power(2 / (n_inputs + n_outputs), 1 / np.sqrt(2))
    # truncated normal distributions limit the size of the weights, speeding up the training time.
    return tf.truncated_normal((n_inputs, n_outputs), stddev=stddev)

def he_uniform_initialisation(n_inputs, n_outputs):
    r = np.power(6 / (n_inputs + n_outputs), 1 / np.sqrt(2))
    # truncated normal distributions limit the size of the weights, speeding up the training time.
    return tf.random_uniform((n_inputs, n_outputs), -r, r)

def neuron_layer(X, n_neurons, name):
    with tf.name_scope(name):
        #print(X.get_shape()[1])
        n_inputs = int(X.get_shape()[1])
        W = tf.Variable(he_normal_initialisation(n_inputs, n_neurons), name="weights")
        b = tf.Variable(tf.zeros([n_neurons]), name="biases")
        z = tf.matmul(X, W) + b
        return tf.nn.elu(z)

def cnn_layer(X, patch_size, n_input_filters, n_filters, name, initialised_weights_stddev = 0.1):
    with tf.name_scope(name):
        w = tf.Variable(tf.truncated_normal(
            [patch_size, patch_size, n_input_filters, n_filters], stddev=initialised_weights_stddev))
        b = tf.Variable(tf.zeros([n_filters]))
        return tf.nn.elu(tf.nn.conv2d(X, w, strides=[1, 2, 2, 1], padding="SAME") + b)

input_spatial_size = 28
input_channels = 1
n_filters_per_layer = [10, 10]
patch_size = 3
n_output = 10
batch_size = 200
    
with tf.device("/gpu:0"):
    x = tf.placeholder(tf.float32, shape=(batch_size, input_spatial_size ** 2), name="input")
    reshaped_x = tf.reshape(x, (tf.shape(x)[0], input_spatial_size, input_spatial_size, 1))
    y = tf.placeholder(tf.int64, shape=(batch_size), name="y")
    
    with tf.name_scope("dnn"):
        input_tensor = reshaped_x
        n_input_filters = input_channels
        for i in range(len(n_filters_per_layer)):
            input_tensor = cnn_layer(input_tensor, patch_size, n_input_filters, n_filters_per_layer[i], "hidden" + str(i + 1))
            n_input_filters = n_filters_per_layer[i]
        shape = input_tensor.get_shape().as_list()
        #print(shape)
        reshape = tf.reshape(input_tensor, [shape[0], shape[1] * shape[2] * shape[3]])
        #print(reshape.get_shape())
        logits = neuron_layer(reshape, n_output, "output")
        evaluation = tf.nn.softmax(logits)

    with tf.name_scope("loss"):
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
        loss = tf.reduce_mean(cross_entropy, name="loss")

    with tf.name_scope("training"):
        optimizer = tf.train.AdamOptimizer()
        training_op = optimizer.minimize(loss)

with tf.name_scope("eval"):
    k = 1
    correctness = tf.nn.in_top_k(logits, y, k)
    accuracy = tf.reduce_mean(tf.cast(correctness, tf.float32)) * 100

In [20]:
init = tf.global_variables_initializer()
saver = tf.train.Saver()

interim_checkpoint_path = "./checkpoints/mnist_cnn_model.ckpt"
early_stopping_checkpoint_path = "./checkpoints/mnist_cnn_model_early_stopping.ckpt"

from datetime import datetime

now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
root_logdir = "tf_logs"
log_dir = "{}/run-{}/".format(root_logdir, now)

loss_summary = tf.summary.scalar('loss', loss)
accuracy_summary = tf.summary.scalar("accuracy", accuracy)
summary_op = tf.summary.merge([loss_summary, accuracy_summary])
file_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

In [None]:
def eval_predictions(session, data):
    dataset_size = data.shape[0]
    #print "dataset_size: ", dataset_size, " batch_size: ", batch_size
    if dataset_size % batch_size != 0:
        raise "batch_size must be a multiple of dataset_size."
    predictions = np.ndarray(shape=(dataset_size, n_output), dtype=np.float32)
    steps = dataset_size // batch_size
    #print "steps: ", steps
    for step in range(steps):
        offset = (step * batch_size)
        #print "offset ", offset
        batch_data = data[offset:(offset + batch_size), :]
        feed_dict = {
            x: batch_data
        }
        predictions[offset:offset+batch_size, :] = session.run(evaluation, feed_dict=feed_dict)
    return predictions

def prediction_accuracy(predictions, labels):
  return (100.0 * np.sum(np.argmax(predictions, 1) == labels)
          / predictions.shape[0])

epochs = 20
n_batches = int(np.ceil(mnist.train.num_examples // batch_size))

early_stopping_check_frequency = batch_size // 4
early_stopping_check_limit = batch_size * 2

with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init)
    #saver.restore(sess, interim_checkpoint_path)
    
    best_validation_acc = 0.0
    best_validation_step = 0
    for epoch in range(epochs):
        print("epoch", epoch)
        for batch_index in range(n_batches):
            step = epoch * n_batches + batch_index
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            if batch_index % 10 == 0:
                summary_str = summary_op.eval(feed_dict={x: X_batch, y: y_batch})
                file_writer.add_summary(summary_str, step)
            t, l, a = sess.run([training_op, loss, accuracy], feed_dict={x: X_batch, y: y_batch})
            if batch_index % 10 == 0: print("loss:", l, "train accuracy:", a)
            # Early stopping check
            if batch_index % early_stopping_check_frequency == 0:
                predictions = eval_predictions(sess, mnist.validation.images)
                validation_acc = prediction_accuracy(predictions, mnist.validation.labels)
                print("validation accuracy", validation_acc)
                if validation_acc > best_validation_acc:
                    saver.save(sess, early_stopping_checkpoint_path)
                    best_validation_acc = validation_acc
                    best_validation_step = step
                elif step >= (best_validation_step + early_stopping_check_limit):
                    print("Stopping early during epoch", epoch)
                    break
        else:
            continue
        break
    save_path = saver.save(sess, interim_checkpoint_path)
    test_predictions = eval_predictions(sess, mnist.test.images)
    test_acc = prediction_accuracy(predictions, mnist.test.labels)
    print(">>>>>>>>>> test dataset accuracy:", test_acc)

    save_path = saver.save(sess, "./checkpoints/mnist_cnn_model_final.ckpt")

epoch 0
loss: 2.30844 train accuracy: 3.5
validation accuracy 9.6
loss: 2.22732 train accuracy: 51.0
loss: 2.0181 train accuracy: 71.0
loss: 1.69386 train accuracy: 73.0
loss: 1.26418 train accuracy: 77.0
loss: 0.95412 train accuracy: 77.0
validation accuracy 80.82
loss: 0.656738 train accuracy: 84.0
loss: 0.576777 train accuracy: 85.0
loss: 0.499835 train accuracy: 85.0
loss: 0.407467 train accuracy: 88.0
loss: 0.385032 train accuracy: 90.0
validation accuracy 89.02
loss: 0.427316 train accuracy: 89.5
loss: 0.363884 train accuracy: 87.5
loss: 0.44623 train accuracy: 87.0
loss: 0.342677 train accuracy: 92.5
loss: 0.345295 train accuracy: 88.0
validation accuracy 90.42
loss: 0.363489 train accuracy: 88.5
loss: 0.312713 train accuracy: 91.0
loss: 0.3171 train accuracy: 91.5
loss: 0.250527 train accuracy: 93.0
loss: 0.368013 train accuracy: 88.5
validation accuracy 90.84
loss: 0.4191 train accuracy: 91.5
loss: 0.348346 train accuracy: 92.5
loss: 0.46249 train accuracy: 85.5
loss: 0.28662 