# Train several different types of neural nets to solve MNIST problem

In [1]:
import numpy as np
import tensorflow as tf
import os
from datetime import datetime

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./tmp/MNIST_data/", one_hot=True)

Extracting ./tmp/MNIST_data/train-images-idx3-ubyte.gz
Extracting ./tmp/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./tmp/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./tmp/MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
# Verify parts of MNIST data
for s in 'train validation test'.split():
    tmp = getattr(mnist, s)
    print(tmp.images.shape)
    print(tmp.labels.shape)
    
img_size = mnist.train.images.shape[1]
n_labels = mnist.train.labels.shape[1]
n_train = mnist.train.images.shape[0]

(55000, 784)
(55000, 10)
(5000, 784)
(5000, 10)
(10000, 784)
(10000, 10)


In [4]:
# Function to get a random batch
def random_batch(batch_size):
    indices = np.random.randint(n_train, size=batch_size)
    X_ = mnist.train.images[indices, :]
    y_ = mnist.train.labels[indices]
    return X_, y_

In [5]:
[x.shape for x in random_batch(5)]

[(5, 784), (5, 10)]

In [6]:
mnist.train.labels

array([[ 0.,  0.,  0., ...,  1.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  1.,  0.]])

# Vanilla Net
Make a dense neural net with several layers

In [7]:
# Network architecture params
n_h1 = 300
n_h2 = 150
n_h3 = 35
tf.reset_default_graph()

# Training params
n_epochs = 150
batch_size = 128
n_batches = n_train // batch_size
print("Number of Batches per epoch: {}".format(n_batches))

# Saving data
name = "mnist_vanilla_net"
model_dir = os.path.join('.', 'models', name)
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
log_dir = os.path.join('.', 'tf_logs', name+now)
for d in [model_dir, log_dir]:
    if not os.path.isdir(d):
        os.mkdir(d)

checkpoint_path = os.path.join(model_dir, "model_ckpt.ckpt")
checkpoint_epoch_path = checkpoint_path + ".epoch"
final_model_path = os.path.join(model_dir, "final_model.nn")

# Make the actual neural net
with tf.name_scope("dnn"):
    # Inputs and outputs
    X = tf.placeholder(tf.float32, shape=(None, img_size), name="X")
    y = tf.placeholder(tf.int32, shape=(None, n_labels), name="y")

    # Hidden layers
    h1_pre = tf.layers.dense(X, n_h1, activation=tf.nn.elu, name="h1_pre")
    h1 = tf.nn.dropout(h1_pre, keep_prob=0.9, name="h1")
    h2_pre = tf.layers.dense(h1, n_h2, activation=tf.nn.elu, name="h2_pre")
    h2 = tf.nn.dropout(h2_pre, keep_prob=0.95, name="h2")
    h3 = tf.layers.dense(h2, n_h3, activation=tf.nn.elu, name="h3")
    
    # Output
    logits = tf.layers.dense(h3, n_labels, name="outputs")
    
with tf.name_scope("loss"):
    xentropy =\
        tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy)
    loss_summary = tf.summary.scalar('log_loss', loss)
    optimizer = tf.train.MomentumOptimizer(learning_rate=2e-3, momentum=0.9,
                                          use_nesterov=True)
    training_op = optimizer.minimize(loss)
    
# For evaluation
with tf.name_scope("eval"):
    output_class = tf.argmax(y, 1)
    correct = tf.nn.in_top_k(logits, tf.argmax(y, 1), 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    accuracy_summary = tf.summary.scalar('accuracy', accuracy)
#     correct = tf.nn.in_top_k(logits, y, 1)
#     accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

with tf.name_scope("admin"):
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    file_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

# Run the training
with tf.Session() as sess:
    if os.path.isfile(checkpoint_epoch_path):
        # if the checkpoint file exists, restore the model and load the epoch number
        with open(checkpoint_epoch_path, "rb") as f:
            start_epoch = int(f.read())
        print("Training was interrupted. Continuing at epoch", start_epoch)
        saver.restore(sess, checkpoint_path)
    else:
        start_epoch = 0
        sess.run(init)   
    
    for epoch in range(start_epoch, n_epochs):
        
        # Run batches
        for batch_index in range(n_batches):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        
        # Run validation
        val_feed_dict = {X: mnist.validation.images, 
                         y: mnist.validation.labels}        
        loss_val, acc_val, summary_str = sess.run([loss, accuracy, loss_summary], 
                                         feed_dict=val_feed_dict)

        file_writer.add_summary(summary_str, epoch)
        
        
        if epoch % 1 == 0:
            print("Epoch:", epoch, "\tLoss:", loss_val, "\tAccuracy:", acc_val)
            saver.save(sess, checkpoint_path)
            with open(checkpoint_epoch_path, "wb") as f:
                f.write(b"%d" % (epoch + 1))

    saver.save(sess, final_model_path)

Number of Batches per epoch: 429
Epoch: 0 	Loss: 0.371223 	Accuracy: 0.8956
Epoch: 1 	Loss: 0.305884 	Accuracy: 0.9154
Epoch: 2 	Loss: 0.274899 	Accuracy: 0.9232
Epoch: 3 	Loss: 0.253709 	Accuracy: 0.9276
Epoch: 4 	Loss: 0.237029 	Accuracy: 0.9334
Epoch: 5 	Loss: 0.223717 	Accuracy: 0.9336
Epoch: 6 	Loss: 0.202572 	Accuracy: 0.9432
Epoch: 7 	Loss: 0.194844 	Accuracy: 0.9442
Epoch: 8 	Loss: 0.186339 	Accuracy: 0.948
Epoch: 9 	Loss: 0.176538 	Accuracy: 0.9512
Epoch: 10 	Loss: 0.173116 	Accuracy: 0.9498
Epoch: 11 	Loss: 0.161841 	Accuracy: 0.9536
Epoch: 12 	Loss: 0.160125 	Accuracy: 0.9544
Epoch: 13 	Loss: 0.153496 	Accuracy: 0.9562
Epoch: 14 	Loss: 0.145711 	Accuracy: 0.9574
Epoch: 15 	Loss: 0.145038 	Accuracy: 0.957
Epoch: 16 	Loss: 0.135227 	Accuracy: 0.9608
Epoch: 17 	Loss: 0.137797 	Accuracy: 0.9602
Epoch: 18 	Loss: 0.129534 	Accuracy: 0.9616
Epoch: 19 	Loss: 0.125262 	Accuracy: 0.9618
Epoch: 20 	Loss: 0.127384 	Accuracy: 0.9616
Epoch: 21 	Loss: 0.125475 	Accuracy: 0.9636
Epoch: 22 	

In [8]:
# Test accuracy
with tf.Session() as sess:
    # if the checkpoint file exists, restore the model and load the epoch number
    print("Loading model")
    saver.restore(sess, final_model_path)

    acc_val = sess.run(accuracy, feed_dict={
        X: mnist.test.images, 
        y: mnist.test.labels
    })
    print("Accuracy: {}".format(acc_val))

Loading model
INFO:tensorflow:Restoring parameters from .\models\mnist_vanilla_net\final_model.nn
Accuracy: 0.9776999950408936


# Deep learning using He initialization, training only on digits 0 to 4, followed by transfer learning for the rest of the variables
This will be used in transfer learning later

In [9]:
# Make 0 to 4 training data
def filter_04(X, y):
    indices = np.argmax(y, 1) < 5
    return X[indices], y[indices]

def next_batch04(size):
    return filter_04(*mnist.train.next_batch(size))

In [10]:
del model_dir
del log_dir

In [11]:
# Network architecture params
tf.reset_default_graph()

# Training params
n_epochs = 5
batch_size = 128
n_batches = n_train // batch_size
print("Number of Batches per epoch: {}".format(n_batches))

# Saving data
name = "mnist_deep_net1"
model_dir04 = os.path.join('.', 'models', name)
log_dir04 = os.path.join('.', 'tf_logs', name)
for d in [model_dir04, log_dir04]:
    if not os.path.isdir(d):
        os.mkdir(d)

checkpoint_path = os.path.join(model_dir04, "model_ckpt.ckpt")
checkpoint_epoch_path = checkpoint_path + ".epoch"
checkpoint_path_best = os.path.join(model_dir04, "best_model_ckpt.ckpt")
final_model_path = os.path.join(model_dir04, "final_model.nn")

# Make the actual neural net
with tf.name_scope("dnn"):
    # Inputs and outputs
    X = tf.placeholder(tf.float32, shape=(None, img_size), name="X")
    y = tf.placeholder(tf.int32, shape=(None, n_labels), name="y")

    # Hidden layers
    he_init = tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG")
    kwargs = {"activation": tf.nn.elu, "kernel_initializer": he_init}
    h1 = tf.layers.dense(X, 100, name="h1", **kwargs)
    h2 = tf.layers.dense(h1, 100, name="h2", **kwargs)
    h3 = tf.layers.dense(h2, 100, name="h3", **kwargs)
    h4 = tf.layers.dense(h3, 100, name="h4", **kwargs)
    h5 = tf.layers.dense(h4, 100, name="h5", **kwargs)
    
    # Output
    logits = tf.layers.dense(h5, n_labels, name="logits")
    
with tf.name_scope("loss"):
    Y_proba = tf.nn.softmax(logits, name="Y_proba")
    xentropy =\
        tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy, name="loss")
    loss_summary = tf.summary.scalar('log_loss', loss)
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)
    
# For evaluation
with tf.name_scope("eval"):
    output_class = tf.argmax(y, 1)
    correct = tf.nn.in_top_k(logits, tf.argmax(y, 1), 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    accuracy_summary = tf.summary.scalar('accuracy', accuracy)

with tf.name_scope("admin"):
    init = tf.global_variables_initializer()
    file_writer = tf.summary.FileWriter(log_dir04, tf.get_default_graph())

# Run the training
highest_acc = 0.
epochs_with_no_improvement = 0
max_epochs_with_no_improvement = 10
with tf.Session() as sess:
    saver = tf.train.Saver()
    if os.path.isfile(checkpoint_epoch_path):
        # if the checkpoint file exists, restore the model and load the epoch number
        with open(checkpoint_epoch_path, "rb") as f:
            start_epoch = int(f.read())
        print("Training was interrupted. Continuing at epoch", start_epoch)
        saver.restore(sess, checkpoint_path)
    else:
        start_epoch = 0
        sess.run(init)   
    
    for epoch in range(start_epoch, n_epochs):
        
        # Run batches
        for batch_index in range(n_batches):
            X_batch, y_batch = next_batch04(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        
        if epoch % 1 == 0:
            
            # Run validation
            X_val_04, y_val_04 = filter_04(mnist.validation.images, 
                                           mnist.validation.labels)
            val_feed_dict = {X: X_val_04, 
                             y: y_val_04}        
            loss_val, acc_val, summary_str = sess.run([loss, accuracy, loss_summary], 
                                             feed_dict=val_feed_dict)

            file_writer.add_summary(summary_str, epoch)
            print("Epoch:", epoch, "\tLoss:", loss_val, "\tAccuracy:", acc_val)
            saver.save(sess, checkpoint_path)
            with open(checkpoint_epoch_path, "wb") as f:
                f.write(b"%d" % (epoch + 1))

            # Early stopping
            if acc_val > highest_acc:
                highest_acc = acc_val
                epochs_with_no_improvement = 0
                saver.save(sess, checkpoint_path_best)
            else:
                epochs_with_no_improvement += 1
            if epochs_with_no_improvement > max_epochs_with_no_improvement:
                print("Early stopping due to no improvement")

    saver.save(sess, final_model_path)

Number of Batches per epoch: 429
Epoch: 0 	Loss: 0.0640176 	Accuracy: 0.980844
Epoch: 1 	Loss: 0.0538823 	Accuracy: 0.984754
Epoch: 2 	Loss: 0.0376537 	Accuracy: 0.988272
Epoch: 3 	Loss: 0.0700493 	Accuracy: 0.97889
Epoch: 4 	Loss: 0.0464074 	Accuracy: 0.988272


In [12]:
# Test accuracy
with tf.Session() as sess:
    # if the checkpoint file exists, restore the model and load the epoch number
    print("Loading model")
    saver.restore(sess, final_model_path)
    X_test_04, y_test_04 = filter_04(mnist.test.images, mnist.test.labels)
    acc_val = sess.run(accuracy, feed_dict={
        X: X_test_04, 
        y: y_test_04
    })
    print("Accuracy on 0-4: {}".format(acc_val))

    acc_val = sess.run(accuracy, feed_dict={
        X: mnist.test.images, 
        y: mnist.test.labels
    })
    print("Accuracy on all: {}".format(acc_val))

Loading model
INFO:tensorflow:Restoring parameters from .\models\mnist_deep_net1\final_model.nn
Accuracy on 0-4: 0.9900758862495422
Accuracy on all: 0.5088000297546387


## Transfer learning for the rest of the digits

In [13]:
tf.reset_default_graph()

restore_saver = tf.train.import_meta_graph(r".\models\mnist_deep_net1\final_model.nn.meta")

# Get the inputs and outputs
X = tf.get_default_graph().get_tensor_by_name("dnn/X:0")
y = tf.get_default_graph().get_tensor_by_name("dnn/y:0")
loss = tf.get_default_graph().get_tensor_by_name("loss/loss:0")
Y_proba = tf.get_default_graph().get_tensor_by_name("loss/Y_proba:0")
logits = Y_proba.op.inputs[0]
accuracy = tf.get_default_graph().get_tensor_by_name("eval/accuracy:0")

output_layer_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope="logits")
optimizer = tf.train.AdamOptimizer(1e-2, name="Adam2")
training_op = optimizer.minimize(loss, var_list=output_layer_vars)

output_class = tf.argmax(y, 1)
correct = tf.nn.in_top_k(logits, tf.argmax(y, 1), 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")
accuracy_summary = tf.summary.scalar('accuracy', accuracy)

init = tf.global_variables_initializer()
five_frozen_saver = tf.train.Saver()

# Getting 5-9 data
def filter_59(X, y):
    indices = np.argmax(y, 1) >= 5
    return X[indices], y[indices]

def next_batch59(size):
    return filter_59(*mnist.train.next_batch(size))

# Training params
n_epochs = 3
batch_size = 128
n_batches = n_train // batch_size
print("Number of Batches per epoch: {}".format(n_batches))

# Run the training
highest_acc = 0.
epochs_with_no_improvement = 0
max_epochs_with_no_improvement = 10
with tf.Session() as sess:
    restore_saver.restore(sess, './models/mnist_deep_net1/final_model.nn')
    start_epoch = 0
    sess.run(init)
    for var in output_layer_vars:
        var.initializer.run()
    
    for epoch in range(start_epoch, n_epochs):
        
        # Run batches
        for batch_index in range(n_batches):
            X_batch, y_batch = next_batch59(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        
        if epoch % 1 == 0:
            
            # Run validation
            X_val_59, y_val_59 = filter_59(mnist.validation.images, 
                                           mnist.validation.labels)
            val_feed_dict = {X: X_val_59, 
                             y: y_val_59}        
            loss_val, acc_val = sess.run([loss, accuracy], 
                                             feed_dict=val_feed_dict)

            file_writer.add_summary(summary_str, epoch)
            print("Epoch:", epoch, "\tLoss:", loss_val, "\tAccuracy:", acc_val)
            saver.save(sess, checkpoint_path)
            with open(checkpoint_epoch_path, "wb") as f:
                f.write(b"%d" % (epoch + 1))

            # Early stopping
            if acc_val > highest_acc:
                highest_acc = acc_val
                epochs_with_no_improvement = 0
                saver.save(sess, checkpoint_path_best)
            else:
                epochs_with_no_improvement += 1
            if epochs_with_no_improvement > max_epochs_with_no_improvement:
                print("Early stopping due to no improvement")
                
    
    # Testing
    X_test_59, y_test_59 = filter_59(mnist.test.images, mnist.test.labels)
    acc_val = sess.run(accuracy, feed_dict={
        X: X_test_59, 
        y: y_test_59
    })
    print("Accuracy on 0-4: {}".format(acc_val))

    acc_val = sess.run(accuracy, feed_dict={
        X: mnist.test.images, 
        y: mnist.test.labels
    })
    print("Accuracy on all: {}".format(acc_val))

Number of Batches per epoch: 429
INFO:tensorflow:Restoring parameters from ./models/mnist_deep_net1/final_model.nn
Epoch: 0 	Loss: 0.365271 	Accuracy: 0.88534
Epoch: 1 	Loss: 0.339092 	Accuracy: 0.889025
Epoch: 2 	Loss: 0.34577 	Accuracy: 0.886568
Accuracy on 0-4: 0.8679283857345581
Accuracy on all: 0.4219000041484833


## Understandably, just training the last layer on 5-9 destroys its accuracy on 0-4