In [1]:
import numpy as np
import tensorflow as tf
import math
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
mnist = read_data_sets("data", one_hot=True, reshape=False, validation_size=5000)

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz


In [2]:
import datetime
def timestamp():
    d = datetime.datetime.now()
    return d.strftime("%Y/%m/%d/%X")

In [3]:
# Parameters
training_epochs = 25
batch_size = 100
display_step = 1
max_learning_rate = 0.0004
min_learning_rate = 0.0001
decay_speed = 10
bnepsilon = 1e-5

logs_path_train = '/tmp/tensorflow_logs/example/train' + timestamp()
logs_path_val = '/tmp/tensorflow_logs/example/val' + timestamp()

In [4]:
# Placeholders
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y_ = tf.placeholder(tf.float32, [None, 10])
# variable learning rate
lr = tf.placeholder(tf.float32)
# test flag for batch norm
tst = tf.placeholder(tf.bool)
iter = tf.placeholder(tf.int32)
# Probability of keeping a node during dropout = 1.0 at test time (no dropout)  and 0.75 at training time
pkeep = tf.placeholder(tf.float32)
pkeep_conv = tf.placeholder(tf.float32)

# Variables
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, 24], stddev=0.1))
b1 = tf.Variable(tf.ones([24]))
W2 = tf.Variable(tf.truncated_normal([5, 5, 24, 48], stddev=0.1))
b2 = tf.Variable(tf.ones([48]))
W3 = tf.Variable(tf.truncated_normal([4, 4, 48, 64], stddev=0.1))
b3 = tf.Variable(tf.ones([64]))
W4 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 200], stddev=0.1))
b4 = tf.Variable(tf.ones([200]))
W5 = tf.Variable(tf.truncated_normal([200, 10], stddev=0.1))
b5 = tf.Variable(tf.ones([10]))

In [5]:
def batchnorm(Ylogits, is_test, iteration, offset, convolutional=False):
    # adding the iteration prevents from averaging across non-existing iterations
    exp_moving_avg = tf.train.ExponentialMovingAverage(0.999, iteration) 
    if convolutional:
        mean, variance = tf.nn.moments(Ylogits, [0, 1, 2])
    else:
        mean, variance = tf.nn.moments(Ylogits, [0])
    update_moving_everages = exp_moving_avg.apply([mean, variance])
    m = tf.cond(is_test, lambda: exp_moving_avg.average(mean), lambda: mean)
    v = tf.cond(is_test, lambda: exp_moving_avg.average(variance), lambda: variance)
    Ybn = tf.nn.batch_normalization(Ylogits, m, v, offset, None, bnepsilon)
    return Ybn, update_moving_everages

def compatible_convolutional_noise_shape(Y):
    noiseshape = tf.shape(Y)
    noiseshape = noiseshape * tf.constant([1,0,0,1]) + tf.constant([0,1,1,0])
    return noiseshape

In [6]:
# The model
# batch norm scaling is not useful with relus
# batch norm offsets are used instead of biases

stride = 1  # output is 28x28
y1l = tf.nn.conv2d(x, W1, strides=[1, stride, stride, 1], padding='SAME')
y1bn, update_ema1 = batchnorm(y1l, tst, iter, b1, convolutional=True)
y1r = tf.nn.relu(y1bn)
y1 = tf.nn.dropout(y1r, pkeep_conv, compatible_convolutional_noise_shape(y1r))

stride = 2  # output is 14x14
y2l = tf.nn.conv2d(y1, W2, strides=[1, stride, stride, 1], padding='SAME')
y2bn, update_ema2 = batchnorm(y2l, tst, iter, b2, convolutional=True)
y2r = tf.nn.relu(y2bn)
y2 = tf.nn.dropout(y2r, pkeep_conv, compatible_convolutional_noise_shape(y2r))

stride = 2  # output is 7x7
y3l = tf.nn.conv2d(y2, W3, strides=[1, stride, stride, 1], padding='SAME')
y3bn, update_ema3 = batchnorm(y3l, tst, iter, b3, convolutional=True)
y3r = tf.nn.relu(y3bn)
y3 = tf.nn.dropout(y3r, pkeep_conv, compatible_convolutional_noise_shape(y3r))

# reshape the output from the third convolution for the fully connected layer
y3_conv = tf.reshape(y3, shape=[-1, 7 * 7 * 64])

y4l = tf.matmul(y3_conv, W4)
y4bn, update_ema4 = batchnorm(y4l, tst, iter, b4)
y4r = tf.nn.relu(y4bn)
y4 = tf.nn.dropout(y4r, pkeep)
y = tf.matmul(y4, W5) + b5

update_ema = tf.group(update_ema1, update_ema2, update_ema3, update_ema4)

In [7]:
with tf.name_scope('Loss'):
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))

with tf.name_scope('Optimizer'):
    optimizer = tf.train.AdamOptimizer(lr).minimize(cross_entropy)
    
with tf.name_scope('Accuracy'):
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
# Initializing the variables
init = tf.global_variables_initializer()

# Create a summary to monitor cost tensor
tf.summary.scalar("loss", cross_entropy)
# Create a summary to monitor accuracy tensor
tf.summary.scalar("accuracy", acc)
# Merge all summaries into a single op
merged_summary_op = tf.summary.merge_all()

In [8]:
# Launch the graph
with tf.Session() as sess:
    sess.run(init)

    # op to write logs to Tensorboard
    summary_writer_train = tf.summary.FileWriter(logs_path_train, 
                                                 graph=tf.get_default_graph())
    summary_writer_val = tf.summary.FileWriter(logs_path_val, 
                                               graph=tf.get_default_graph())

    # Training cycle
    for epoch in range(training_epochs):
        avg_train_loss = 0.
        avg_val_loss = 0.
        total_batch = int(mnist.train.num_examples / batch_size)
        learning_rate = min_learning_rate + (
            max_learning_rate - min_learning_rate) * math.exp(-epoch / decay_speed)  
        # Loop over all batches
        for i in range(total_batch):
            batch_x_train, batch_y_train = mnist.train.next_batch(batch_size)
            batch_x_val, batch_y_val = mnist.validation.next_batch(batch_size)
            
            # Run optimization op (backprop), cost op (to get loss value) and summary nodes
            # and summary nodes
            a, c, summary = sess.run([optimizer, cross_entropy, merged_summary_op], 
                                     feed_dict={x: batch_x_train, 
                                                y_: batch_y_train,
                                                lr: learning_rate, 
                                                tst: False,
                                                pkeep: 0.75, 
                                                pkeep_conv: 0.75})  
            
            sess.run(update_ema, {x: batch_x_train, 
                                  y_: batch_y_train, 
                                  tst: False, 
                                  iter: i, 
                                  pkeep: 0.75, 
                                  pkeep_conv: 0.75})
           
            c_val, summary_val = sess.run([cross_entropy, merged_summary_op], 
                                          feed_dict={x: batch_x_val, 
                                                     y_: batch_y_val,
                                                     lr: learning_rate, 
                                                     tst: False, 
                                                     pkeep: 1.0, 
                                                     pkeep_conv: 1.0}) 
            
            
            # Write logs at every iteration
            summary_writer_train.add_summary(summary, epoch * total_batch + i)
            summary_writer_val.add_summary(summary_val, epoch * total_batch + i)

            # Compute average loss
            avg_train_loss += c / total_batch
            avg_val_loss += c_val / total_batch
            
        # Display logs per epoch step
        if (epoch + 1) % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1),
                  "train loss", "{:.9f}".format(avg_train_loss),
                  "val loss", "{:.9f}".format(avg_val_loss),
                  "learning rate", "{:.9f}".format(learning_rate))

    print("Optimization Finished!")

    # Test model
    # Calculate accuracy    
    total_test_batch = mnist.test.num_examples // batch_size
    print('total_test_batch', total_test_batch)    
    acc_test_lst = []
    for step in range(mnist.test.images.shape[0] // total_test_batch):        
        acc_test = acc.eval(
            {x: mnist.test.images[step * total_test_batch:(step + 1) * total_test_batch, :], 
             y_: mnist.test.labels[step * total_test_batch:(step + 1) * total_test_batch, :],
             lr: learning_rate, 
             tst: False, 
             pkeep: 1.0, 
             pkeep_conv: 1.0}) 
        acc_test_lst.append(acc_test)    
    print("Accuracy:", np.mean(acc_test_lst))

    print("Run the command line:\n" \
          "--> tensorboard --logdir=/tmp/tensorflow_logs " \
          "\nThen open http://0.0.0.0:6006/")

Epoch: 0001 train loss 0.679447803 val loss 0.429449026 learning rate 0.000400000
Epoch: 0002 train loss 0.296709327 val loss 0.187034261 learning rate 0.000371451
Epoch: 0003 train loss 0.205166025 val loss 0.127128779 learning rate 0.000345619
Epoch: 0004 train loss 0.159735332 val loss 0.100147673 learning rate 0.000322245
Epoch: 0005 train loss 0.133731233 val loss 0.083816353 learning rate 0.000301096
Epoch: 0006 train loss 0.115231307 val loss 0.072443589 learning rate 0.000281959
Epoch: 0007 train loss 0.103943510 val loss 0.065619865 learning rate 0.000264643
Epoch: 0008 train loss 0.095510741 val loss 0.061461575 learning rate 0.000248976
Epoch: 0009 train loss 0.086480129 val loss 0.056793766 learning rate 0.000234799
Epoch: 0010 train loss 0.081258942 val loss 0.052968601 learning rate 0.000221971
Epoch: 0011 train loss 0.074817300 val loss 0.049768013 learning rate 0.000210364
Epoch: 0012 train loss 0.070222282 val loss 0.047685462 learning rate 0.000199861
Epoch: 0013 trai