In [1]:
'''
Author: Aishik Chakraborty
Using checkpointing in Tensorflow
'''

import tensorflow as tf
import numpy as np
import random
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

tf.set_random_seed(123)  # reproducibility

In [2]:
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images  # Returns np.array
train_labels = np.eye(10)[np.asarray(mnist.train.labels, dtype=np.int32)]
test_data = mnist.test.images  # Returns np.array
test_labels = np.eye(10)[np.asarray(mnist.test.labels, dtype=np.int32)]
train_data, val_data, train_labels, val_labels = train_test_split(train_data, train_labels, test_size=0.1, random_state=123)

print val_data.shape
learning_rate = 0.01
training_epochs = 50
batch_size = 100

class NN(object):
    def __init__(self):
        self.graph = tf.Graph()
        with self.graph.as_default():
            # input place holders
            self.X = tf.placeholder(tf.float32, [None, 784])
            self.Y = tf.placeholder(tf.float32, [None, 10])
            self.mode = tf.placeholder(tf.bool)
            self.keep_prob = tf.placeholder(tf.float32)

            self.h1 = tf.layers.dense(inputs=self.X, units=500)
            self.h1 = tf.layers.batch_normalization(self.h1, training=self.mode)
            self.h1 = tf.nn.relu(self.h1)
            self.h1 = tf.nn.dropout(self.h1, self.keep_prob)
            self.h2 = tf.layers.dense(inputs=self.h1, units=500)
            self.h2 = tf.layers.batch_normalization(self.h2, training=self.mode)
            self.h2 = tf.nn.relu(self.h2)
            self.h2 = tf.nn.dropout(self.h2, self.keep_prob)

            self.logits = tf.layers.dense(inputs=self.h2, units=10)
            self.logits = tf.layers.batch_normalization(self.logits, training=self.mode)
            self.pred = tf.nn.softmax(self.logits)
            # Test model and check accuracy
            self.correct_prediction = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.Y, 1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))

            tf.summary.scalar('accuracy', self.accuracy)
            # define cost/loss & optimizer
            self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.Y))
            self.global_step = tf.Variable(0, name='global_step', trainable=False)
            tf.summary.scalar('mean_loss', self.cost)
            self.merged = tf.summary.merge_all()

            # When using the batchnormalization layers,
            # it is necessary to manually add the update operations
            # because the moving averages are not included in the graph            
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):                     
                self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.cost, global_step=self.global_step)

        
nn = NN()
# Best validation accuracy seen so far.
best_validation_accuracy = 0.0

# Iteration-number for last improvement to validation accuracy.
last_improvement = 0

# Stop optimization if no improvement found in this many iterations.
patience = 10

# Start session
sv = tf.train.Supervisor(graph=nn.graph,
                         logdir='logs/',
                         summary_op=None,
                         save_model_secs=0)

with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(len(train_data) / batch_size)
        if sv.should_stop(): break
        for i in range(total_batch):
            batch_xs, batch_ys = train_data[(i)*batch_size:(i+1)*batch_size], train_labels[(i)*batch_size:(i+1)*batch_size]
            feed_dict = {nn.X: batch_xs, nn.Y: batch_ys, nn.mode:True, nn.keep_prob:0.8}
            c, _ = sess.run([nn.cost, nn.optimizer], feed_dict=feed_dict)
            avg_cost += c / total_batch
            if i%50:
                
                sv.summary_computed(sess, sess.run(nn.merged, feed_dict))
                gs = sess.run(nn.global_step, feed_dict)
        
        print 'Epoch : ' + str(epoch) + ' Training Loss: ' + str(avg_cost)
        acc = sess.run(nn.accuracy, feed_dict={
                        nn.X: val_data, nn.Y: val_labels, nn.mode:False, nn.keep_prob:1.0})
        print 'Validation Accuracy: ' + str(acc)
        if acc > best_validation_accuracy:
            last_improvement = epoch
            best_validation_accuracy = acc
            sv.saver.save(sess, 'logs' + '/model_gs', global_step=gs)
        if epoch - last_improvement > patience:
            print("Early stopping ...")
            break

Extracting MNIST-data/train-images-idx3-ubyte.gz
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
(5500, 784)
INFO:tensorflow:Starting standard services.
INFO:tensorflow:Starting queue runners.
INFO:tensorflow:global_step/sec: 0
Epoch : 0 Training Loss: 0.271696949472
Validation Accuracy: 0.961091
Epoch : 1 Training Loss: 0.123334100549
Validation Accuracy: 0.972909
Epoch : 2 Training Loss: 0.084519075892
Validation Accuracy: 0.976182
Epoch : 3 Training Loss: 0.0630652308558
Validation Accuracy: 0.976363
Epoch : 4 Training Loss: 0.052724371398
Validation Accuracy: 0.976
Epoch : 5 Training Loss: 0.0418031172408
Validation Accuracy: 0.981091
Epoch : 6 Training Loss: 0.0377528071046
Validation Accuracy: 0.980364
Epoch : 7 Training Loss: 0.0312446936433
Validation Accuracy: 0.981091
Epoch : 8 Training Loss: 0.0264482819275
Validation Accuracy: 0.977636
Epoch : 9 Training Loss: 0.0289179145199
Va

In [3]:
# Load graph
nn = NN()
print("Graph loaded")
with nn.graph.as_default():
    sv = tf.train.Supervisor()
    with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        ## Restore parameters
        sv.saver.restore(sess, tf.train.latest_checkpoint('logs/'))
        print("Restored!")
        acc = sess.run(nn.accuracy, feed_dict={
              nn.X: test_data, nn.Y: test_labels, nn.mode:False, nn.keep_prob:1.0})
        print('Accuracy:', acc)

Graph loaded
INFO:tensorflow:Starting standard services.
INFO:tensorflow:Starting queue runners.
INFO:tensorflow:Restoring parameters from logs/model_gs-6930
Restored!
('Accuracy:', 0.98120016)
