In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [2]:
#  Create a test graph
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
c = tf.matmul(a, b)
# Creates a session with log_device_placement set to True.
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Runs the op.
print(sess.run(c))

[[ 22.  28.]
 [ 49.  64.]]


In [19]:
number_of_inputs = 28 * 28
n_hidden_per_layer = [100, 100, 100, 100, 100]
n_output = 10

def he_normal_initialisation(n_inputs, n_outputs):
    stddev = np.power(2 / (n_inputs + n_outputs), 1 / np.sqrt(2))
    # truncated normal distributions limit the size of the weights, speeding up the training time.
    return tf.truncated_normal((n_inputs, n_outputs), stddev=stddev)

def he_uniform_initialisation(n_inputs, n_outputs):
    r = np.power(6 / (n_inputs + n_outputs), 1 / np.sqrt(2))
    # truncated normal distributions limit the size of the weights, speeding up the training time.
    return tf.random_uniform((n_inputs, n_outputs), -r, r)

def neuron_layer(X, n_neurons, name):
    with tf.name_scope(name):
        n_inputs = int(X.get_shape()[1])
        W = tf.Variable(he_normal_initialisation(n_inputs, n_neurons), name="weights")
        b = tf.Variable(tf.zeros([n_neurons]), name="biases")
        z = tf.matmul(X, W) + b
        return tf.nn.elu(z)

with tf.device("/gpu:0"):
    x = tf.placeholder(tf.float32, shape=(None, number_of_inputs), name="input")
    y = tf.placeholder(tf.int64, shape=(None), name="y")

    with tf.name_scope("dnn"):
        input_tensor = x
        for i in range(len(n_hidden_per_layer)):
            input_tensor = neuron_layer(input_tensor, n_hidden_per_layer[i], "hidden" + str(i + 1))
        logits = neuron_layer(input_tensor, n_output, "output")

    with tf.name_scope("loss"):
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
        loss = tf.reduce_mean(cross_entropy, name="loss")

    with tf.name_scope("training"):
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
        training_op = optimizer.minimize(loss)

with tf.name_scope("eval"):
    k = 1
    correctness = tf.nn.in_top_k(logits, y, k)
    accuracy = tf.reduce_mean(tf.cast(correctness, tf.float32)) * 100

In [20]:
init = tf.global_variables_initializer()
saver = tf.train.Saver()

interim_checkpoint_path = "./checkpoints/mnist_model.ckpt"

from datetime import datetime

now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
root_logdir = "tf_logs"
log_dir = "{}/run-{}/".format(root_logdir, now)

loss_summary = tf.summary.scalar('loss', loss)
accuracy_summary = tf.summary.scalar("accuracy", accuracy)
summary_op = tf.summary.merge([loss_summary, accuracy_summary])
file_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

In [None]:
epochs = 15
batch_size = 100
n_batches = int(np.ceil(mnist.train.num_examples // batch_size))

with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init)
    #saver.restore(sess, interim_checkpoint_path)
    
    for epoch in range(epochs):
        print("epoch", epoch)
        for batch_index in range(n_batches):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            if batch_index % 10 == 0:
                summary_str = summary_op.eval(feed_dict={x: X_batch, y: y_batch})
                step = epoch * n_batches + batch_index
                file_writer.add_summary(summary_str, step)
            t, l, a = sess.run([training_op, loss, accuracy], feed_dict={x: X_batch, y: y_batch})
            
            if batch_index % 10000: print("loss:", l, "accuracy:", a)
    save_path = saver.save(sess, interim_checkpoint_path)
    test_acc = accuracy.eval(feed_dict={x: mnist.test.images, y: mnist.test.labels})
    print(">>>>>>>>>> test dataset accuracy:", test_acc)

    save_path = saver.save(sess, "./checkpoints/mnist_model_final.ckpt")

epoch 0
loss: 2.30251 accuracy: 8.0
loss: 2.30269 accuracy: 6.0
loss: 2.3025 accuracy: 11.0
loss: 2.30195 accuracy: 9.0
loss: 2.3027 accuracy: 7.0
loss: 2.30182 accuracy: 12.0
loss: 2.30049 accuracy: 14.0
loss: 2.30121 accuracy: 9.0
loss: 2.30358 accuracy: 9.0
loss: 2.30344 accuracy: 5.0
loss: 2.30097 accuracy: 17.0
loss: 2.30379 accuracy: 8.0
loss: 2.30156 accuracy: 11.0
loss: 2.30132 accuracy: 11.0
loss: 2.30303 accuracy: 4.0
loss: 2.30413 accuracy: 10.0
loss: 2.30222 accuracy: 13.0
loss: 2.3008 accuracy: 14.0
loss: 2.30152 accuracy: 13.0
loss: 2.29999 accuracy: 14.0
loss: 2.30085 accuracy: 13.0
loss: 2.30523 accuracy: 7.0
loss: 2.30345 accuracy: 14.0
loss: 2.2994 accuracy: 15.0
loss: 2.30276 accuracy: 10.0
loss: 2.3013 accuracy: 14.0
loss: 2.29428 accuracy: 17.0
loss: 2.30139 accuracy: 11.0
loss: 2.30102 accuracy: 9.0
loss: 2.30092 accuracy: 12.0
loss: 2.29768 accuracy: 12.0
loss: 2.30196 accuracy: 9.0
loss: 2.30756 accuracy: 7.0
loss: 2.29712 accuracy: 17.0
loss: 2.29816 accuracy: 

loss: 2.29923 accuracy: 6.0
loss: 2.29368 accuracy: 14.0
loss: 2.30963 accuracy: 2.0
loss: 2.29714 accuracy: 8.0
loss: 2.29992 accuracy: 7.0
loss: 2.28978 accuracy: 15.0
loss: 2.29998 accuracy: 10.0
loss: 2.28847 accuracy: 13.0
loss: 2.3026 accuracy: 7.0
loss: 2.29745 accuracy: 13.0
loss: 2.30063 accuracy: 9.0
loss: 2.29459 accuracy: 14.0
loss: 2.2983 accuracy: 12.0
loss: 2.29393 accuracy: 13.0
loss: 2.29566 accuracy: 11.0
loss: 2.29562 accuracy: 12.0
loss: 2.30321 accuracy: 3.0
loss: 2.30063 accuracy: 12.0
loss: 2.29527 accuracy: 13.0
loss: 2.29568 accuracy: 7.0
loss: 2.29299 accuracy: 12.0
loss: 2.29768 accuracy: 6.0
loss: 2.29007 accuracy: 12.0
loss: 2.29215 accuracy: 12.0
loss: 2.29971 accuracy: 8.0
loss: 2.29043 accuracy: 17.0
loss: 2.29398 accuracy: 11.0
loss: 2.29074 accuracy: 16.0
loss: 2.2907 accuracy: 11.0
loss: 2.29413 accuracy: 16.0
loss: 2.2927 accuracy: 15.0
loss: 2.29016 accuracy: 12.0
loss: 2.29335 accuracy: 16.0
loss: 2.29248 accuracy: 14.0
loss: 2.29567 accuracy: 9.0


loss: 1.27723 accuracy: 56.0
loss: 1.1999 accuracy: 57.0
loss: 1.2636 accuracy: 48.0
loss: 1.41318 accuracy: 50.0
loss: 1.15579 accuracy: 55.0
loss: 1.13439 accuracy: 58.0
loss: 1.1615 accuracy: 55.0
loss: 1.22973 accuracy: 53.0
loss: 1.07883 accuracy: 58.0
loss: 1.0867 accuracy: 59.0
loss: 1.34135 accuracy: 53.0
loss: 1.41001 accuracy: 49.0
loss: 1.40076 accuracy: 40.0
loss: 1.23199 accuracy: 48.0
loss: 1.24943 accuracy: 50.0
loss: 1.11413 accuracy: 59.0
loss: 1.15062 accuracy: 61.0
loss: 1.11444 accuracy: 57.0
loss: 1.07061 accuracy: 59.0
loss: 1.2485 accuracy: 53.0
loss: 1.08311 accuracy: 60.0
loss: 1.17301 accuracy: 55.0
loss: 1.20767 accuracy: 58.0
loss: 1.24957 accuracy: 57.0
loss: 1.09283 accuracy: 59.0
loss: 1.02179 accuracy: 61.0
loss: 1.1732 accuracy: 56.0
loss: 1.04197 accuracy: 66.0
loss: 1.04237 accuracy: 65.0
loss: 1.13006 accuracy: 64.0
loss: 1.12813 accuracy: 60.0
loss: 0.962546 accuracy: 65.0
loss: 0.985602 accuracy: 65.0
loss: 1.05792 accuracy: 62.0
loss: 1.00818 accu

loss: 0.646149 accuracy: 82.0
loss: 0.621902 accuracy: 81.0
loss: 0.574681 accuracy: 86.0
loss: 0.44973 accuracy: 88.0
loss: 0.548429 accuracy: 82.0
loss: 0.503128 accuracy: 83.0
loss: 0.518677 accuracy: 82.0
loss: 0.49524 accuracy: 83.0
loss: 0.641269 accuracy: 79.0
loss: 0.436432 accuracy: 88.0
loss: 0.708863 accuracy: 82.0
loss: 0.610228 accuracy: 80.0
loss: 0.650899 accuracy: 84.0
loss: 0.823725 accuracy: 75.0
loss: 0.813352 accuracy: 71.0
loss: 0.659655 accuracy: 79.0
loss: 0.514451 accuracy: 85.0
loss: 0.534129 accuracy: 86.0
loss: 0.670923 accuracy: 79.0
loss: 0.674923 accuracy: 78.0
loss: 0.723838 accuracy: 83.0
loss: 0.478418 accuracy: 84.0
loss: 0.643194 accuracy: 78.0
loss: 0.667668 accuracy: 80.0
loss: 0.597882 accuracy: 78.0
loss: 0.471797 accuracy: 84.0
loss: 0.510707 accuracy: 84.0
loss: 0.466958 accuracy: 83.0
loss: 0.573412 accuracy: 81.0
loss: 0.527393 accuracy: 85.0
loss: 0.556579 accuracy: 82.0
loss: 0.486233 accuracy: 87.0
loss: 0.649318 accuracy: 85.0
loss: 0.5507

loss: 0.339096 accuracy: 91.0
loss: 0.627132 accuracy: 81.0
loss: 0.272423 accuracy: 93.0
loss: 0.316135 accuracy: 90.0
loss: 0.500695 accuracy: 86.0
loss: 0.481726 accuracy: 86.0
loss: 0.573603 accuracy: 84.0
loss: 0.396408 accuracy: 87.0
loss: 0.27981 accuracy: 91.0
loss: 0.550941 accuracy: 86.0
loss: 0.364517 accuracy: 90.0
loss: 0.48782 accuracy: 88.0
loss: 0.295309 accuracy: 91.0
loss: 0.266669 accuracy: 94.0
loss: 0.467092 accuracy: 89.0
loss: 0.236238 accuracy: 92.0
loss: 0.509121 accuracy: 87.0
loss: 0.233575 accuracy: 94.0
loss: 0.283318 accuracy: 92.0
loss: 0.415919 accuracy: 90.0
loss: 0.369294 accuracy: 88.0
loss: 0.642046 accuracy: 85.0
loss: 0.370543 accuracy: 88.0
loss: 0.594221 accuracy: 81.0
loss: 0.445222 accuracy: 85.0
loss: 0.479537 accuracy: 89.0
loss: 0.284906 accuracy: 93.0
loss: 0.459336 accuracy: 88.0
loss: 0.326922 accuracy: 88.0
loss: 0.30149 accuracy: 94.0
loss: 0.484342 accuracy: 85.0
loss: 0.436363 accuracy: 88.0
loss: 0.245489 accuracy: 93.0
loss: 0.66292

loss: 0.369572 accuracy: 89.0
loss: 0.305478 accuracy: 89.0
loss: 0.224194 accuracy: 93.0
loss: 0.334292 accuracy: 91.0
loss: 0.413194 accuracy: 90.0
loss: 0.388757 accuracy: 84.0
loss: 0.262424 accuracy: 90.0
loss: 0.302153 accuracy: 93.0
loss: 0.23759 accuracy: 94.0
loss: 0.207405 accuracy: 94.0
loss: 0.394198 accuracy: 91.0
loss: 0.243603 accuracy: 90.0
loss: 0.179547 accuracy: 94.0
loss: 0.449457 accuracy: 88.0
loss: 0.268721 accuracy: 94.0
loss: 0.399767 accuracy: 92.0
loss: 0.262522 accuracy: 91.0
loss: 0.362744 accuracy: 90.0
loss: 0.255853 accuracy: 92.0
loss: 0.338351 accuracy: 89.0
loss: 0.317667 accuracy: 89.0
loss: 0.305344 accuracy: 88.0
loss: 0.211418 accuracy: 93.0
loss: 0.329497 accuracy: 90.0
loss: 0.223952 accuracy: 93.0
loss: 0.424498 accuracy: 91.0
loss: 0.285891 accuracy: 93.0
loss: 0.331538 accuracy: 91.0
loss: 0.314662 accuracy: 89.0
loss: 0.282446 accuracy: 90.0
loss: 0.216539 accuracy: 95.0
loss: 0.228776 accuracy: 93.0
loss: 0.376232 accuracy: 91.0
loss: 0.234

loss: 0.185203 accuracy: 93.0
loss: 0.347196 accuracy: 93.0
loss: 0.303016 accuracy: 89.0
loss: 0.231958 accuracy: 90.0
loss: 0.164293 accuracy: 96.0
loss: 0.133068 accuracy: 96.0
loss: 0.327676 accuracy: 92.0
loss: 0.181713 accuracy: 93.0
loss: 0.171973 accuracy: 95.0
loss: 0.223511 accuracy: 92.0
loss: 0.175575 accuracy: 92.0
loss: 0.298933 accuracy: 92.0
loss: 0.301141 accuracy: 92.0
loss: 0.235369 accuracy: 92.0
loss: 0.425707 accuracy: 87.0
loss: 0.187754 accuracy: 92.0
loss: 0.348821 accuracy: 89.0
loss: 0.273327 accuracy: 94.0
loss: 0.386757 accuracy: 91.0
loss: 0.372553 accuracy: 88.0
loss: 0.388882 accuracy: 89.0
loss: 0.0887185 accuracy: 97.0
loss: 0.277046 accuracy: 89.0
loss: 0.285424 accuracy: 95.0
loss: 0.328951 accuracy: 87.0
loss: 0.223167 accuracy: 94.0
loss: 0.253117 accuracy: 93.0
loss: 0.503045 accuracy: 86.0
loss: 0.229297 accuracy: 93.0
loss: 0.275408 accuracy: 92.0
loss: 0.282477 accuracy: 94.0
loss: 0.308946 accuracy: 93.0
loss: 0.228682 accuracy: 95.0
loss: 0.1

loss: 0.268237 accuracy: 93.0
loss: 0.125575 accuracy: 96.0
loss: 0.20588 accuracy: 92.0
loss: 0.151287 accuracy: 95.0
loss: 0.266746 accuracy: 91.0
loss: 0.241426 accuracy: 89.0
loss: 0.30348 accuracy: 92.0
loss: 0.11276 accuracy: 95.0
loss: 0.141919 accuracy: 93.0
loss: 0.180403 accuracy: 94.0
loss: 0.230531 accuracy: 94.0
loss: 0.209552 accuracy: 90.0
loss: 0.072404 accuracy: 99.0
loss: 0.157503 accuracy: 96.0
loss: 0.121289 accuracy: 96.0
loss: 0.188477 accuracy: 92.0
loss: 0.155161 accuracy: 97.0
loss: 0.105184 accuracy: 96.0
loss: 0.146398 accuracy: 96.0
loss: 0.224761 accuracy: 94.0
loss: 0.103862 accuracy: 95.0
loss: 0.266495 accuracy: 95.0
loss: 0.160077 accuracy: 94.0
loss: 0.0697373 accuracy: 97.0
loss: 0.367384 accuracy: 90.0
loss: 0.166495 accuracy: 94.0
loss: 0.12427 accuracy: 96.0
loss: 0.111302 accuracy: 96.0
loss: 0.31066 accuracy: 90.0
loss: 0.217159 accuracy: 94.0
loss: 0.105715 accuracy: 96.0
loss: 0.129815 accuracy: 95.0
loss: 0.237879 accuracy: 93.0
loss: 0.22215 

loss: 0.227071 accuracy: 93.0
loss: 0.193894 accuracy: 93.0
loss: 0.230257 accuracy: 94.0
loss: 0.181865 accuracy: 93.0
loss: 0.208908 accuracy: 96.0
loss: 0.222406 accuracy: 96.0
loss: 0.123653 accuracy: 96.0
loss: 0.29226 accuracy: 91.0
loss: 0.19157 accuracy: 95.0
loss: 0.155042 accuracy: 94.0
loss: 0.121431 accuracy: 96.0
loss: 0.212196 accuracy: 92.0
loss: 0.0977907 accuracy: 96.0
loss: 0.14259 accuracy: 96.0
loss: 0.14007 accuracy: 96.0
loss: 0.0909549 accuracy: 97.0
loss: 0.220627 accuracy: 91.0
loss: 0.234313 accuracy: 93.0
loss: 0.280373 accuracy: 93.0
loss: 0.0680204 accuracy: 97.0
loss: 0.0950293 accuracy: 97.0
loss: 0.0842573 accuracy: 99.0
loss: 0.3423 accuracy: 91.0
loss: 0.120491 accuracy: 96.0
loss: 0.154549 accuracy: 96.0
loss: 0.0644235 accuracy: 98.0
loss: 0.184977 accuracy: 94.0
loss: 0.0855772 accuracy: 98.0
loss: 0.122057 accuracy: 97.0
loss: 0.194397 accuracy: 93.0
loss: 0.136343 accuracy: 94.0
loss: 0.0760156 accuracy: 97.0
loss: 0.244678 accuracy: 94.0
loss: 0.