# Batch normalisation on MNIST

https://arxiv.org/abs/1502.03167

Batch normalisation reduces the vanishing and exploding gradients problem, makes networks less sensitive to weight initialisation, reduces training times and acts as a regulariser.  

Before the activation is applied at each layer, the inputs are normalised and zero-centered using the standard deviation and mean of the current mini-batch. The inputs are also shifted and scaled using parameters that are learned.  

In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
mnist = input_data.read_data_sets("/tmp/data/")

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [3]:
n_inputs = 28 * 28 
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10

In [4]:
X = tf.placeholder(tf.float32, shape = (None, n_inputs), name="X") 
y = tf.placeholder(tf.int64, shape = (None), name="y")

In [5]:
training = tf.placeholder_with_default(False, shape=(), name="training") #Training flag for calculating mean and standard deviation (for whole batch during testing)

In [6]:
hidden1 = tf.layers.dense(X, n_hidden1, name="hidden1")
bn1 = tf.layers.batch_normalization(hidden1, training=training, momentum=0.9) 
bn1_act = tf.nn.elu(bn1) #Exponential linear unit https://arxiv.org/abs/1511.07289 (slower than Relu to calculate, but converges faster)

hidden2 = tf.layers.dense(bn1_act, n_hidden2, name="hidden2")
bn2 = tf.layers.batch_normalization(hidden2, training=training, momentum=0.9)
bn2_act = tf.nn.elu(bn2)

logits_before_bn = tf.layers.dense(bn2_act, n_outputs, name="outputs")
logits = tf.layers.batch_normalization(logits_before_bn, training=training, momentum=0.9)

In [7]:
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #extra batch-norm operations that need updating at each step

In [8]:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss = tf.reduce_mean(cross_entropy, name="loss") 

In [9]:
learning_rate = 0.01
optimiser = tf.train.GradientDescentOptimizer(learning_rate)
training_op =optimiser.minimize(loss)

In [10]:
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

In [11]:
init = tf.global_variables_initializer()

In [12]:
n_epochs = 50 
batch_size = 50

In [13]:
with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run([training_op, extra_update_ops], feed_dict = {training:True, X:X_batch, y:y_batch})
        accuracy_val = accuracy.eval(feed_dict = {X: mnist.test.images, y:mnist.test.labels})
        print(epoch, "Test accuracy: ", accuracy_val)

0 Test accuracy:  0.9244
1 Test accuracy:  0.9452
2 Test accuracy:  0.9539
3 Test accuracy:  0.9628
4 Test accuracy:  0.9669
5 Test accuracy:  0.9661
6 Test accuracy:  0.9714
7 Test accuracy:  0.9718
8 Test accuracy:  0.9744
9 Test accuracy:  0.9733
10 Test accuracy:  0.9724
11 Test accuracy:  0.9741
12 Test accuracy:  0.9759
13 Test accuracy:  0.9758
14 Test accuracy:  0.9766
15 Test accuracy:  0.9759
16 Test accuracy:  0.9768
17 Test accuracy:  0.9761
18 Test accuracy:  0.9751
19 Test accuracy:  0.9791
20 Test accuracy:  0.9773
21 Test accuracy:  0.9774
22 Test accuracy:  0.9774
23 Test accuracy:  0.979
24 Test accuracy:  0.9782
25 Test accuracy:  0.9798
26 Test accuracy:  0.9787
27 Test accuracy:  0.9789
28 Test accuracy:  0.977
29 Test accuracy:  0.9781
30 Test accuracy:  0.9783
31 Test accuracy:  0.98
32 Test accuracy:  0.979
33 Test accuracy:  0.9799
34 Test accuracy:  0.98
35 Test accuracy:  0.9801
36 Test accuracy:  0.9776
37 Test accuracy:  0.98
38 Test accuracy:  0.9785
39 Te

~ 98% accuracy, converging after roughly 30 epochs. 

Batch normalisation helps to train much deeper networks.