# Saving and Restoring `tf.Session`

In [1]:
import os.path
import tensorflow as tf
import prettytensor as pt
from tqdm import tqdm

## Load datasets

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('../datasets/MNIST/', one_hot=True)

Extracting ../datasets/MNIST/train-images-idx3-ubyte.gz
Extracting ../datasets/MNIST/train-labels-idx1-ubyte.gz
Extracting ../datasets/MNIST/t10k-images-idx3-ubyte.gz
Extracting ../datasets/MNIST/t10k-labels-idx1-ubyte.gz


In [3]:
print('Training:   = {:,}'.format(data.train.num_examples))
print('Testing:    = {:,}'.format(data.test.num_examples))
print('Validation: =  {:,}'.format(data.validation.num_examples))

Training:   = 55,000
Testing:    = 10,000
Validation: =  5,000


## Hyperparameters

In [4]:
# Network
image_size = 28
num_channels = 1
image_shape = image_size * image_size * num_channels
kernel_size = 5
conv1_depth = 8
conv2_depth = 16
conv3_depth = 32
conv4_depth = 64
conv5_depth = 128
conv6_depth = 256
fc_size = 1024
num_classes = 10

# Training
learning_rate = 1e-2
batch_size = 24
iterations = 0
save_step = 1000
save_path = '../logs/save-restore-convnet/'
best_val_acc = 0.0
last_improvement = 0
improvement_requirement = 1000

## Create Log dir

In [5]:
if not os.path.exists(save_path):
    os.makedirs(save_path)

## Define Model's placeholder variables

In [6]:
X = tf.placeholder(tf.float32, [None, image_shape])
y = tf.placeholder(tf.float32, [None, num_classes])
y_true = tf.argmax(y, axis=1)

## Constructing the Network

In [7]:
X_image = tf.reshape(X, shape=[-1, image_size, image_size, num_channels])
X_pretty = pt.wrap(X_image)

with pt.defaults_scope(activation_fn=tf.nn.relu):
    y_pred, loss = X_pretty.\
                    conv2d(kernel=kernel_size, depth=conv1_depth, name='conv1').\
                    max_pool(kernel=2, stride=2).\
                    conv2d(kernel=kernel_size, depth=conv2_depth, name='conv2').\
                    max_pool(kernel=2, stride=2).\
                    conv2d(kernel=kernel_size, depth=conv3_depth, name='conv1').\
                    max_pool(kernel=2, stride=2).\
                    conv2d(kernel=kernel_size, depth=conv4_depth, name='conv2').\
                    max_pool(kernel=2, stride=2).\
                    conv2d(kernel=kernel_size, depth=conv5_depth, name='conv1').\
                    max_pool(kernel=2, stride=2).\
                    conv2d(kernel=kernel_size, depth=conv6_depth, name='conv2').\
                    max_pool(kernel=2, stride=2).\
                    flatten().\
                    fully_connected(size=fc_size, name='fully_connected').\
                    softmax_classifier(num_classes=num_classes, labels=y)

## Optimize the `loss` from the Network

In [8]:
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_step = optimizer.minimize(loss)

## Define a `tf.train.Saver` object

In [9]:
saver = tf.train.Saver()

## Evaluate Network's accuracy

In [10]:
y_pred_true = tf.argmax(y_pred, axis=1)
correct = tf.equal(y_pred_true, y_true)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

## Running the network
### Define `tf.Session` as the default graph

In [11]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

In [12]:
# Run accuracy for both test and validation sets
def accuracy_eval(validation=False, test=True):
    test_acc = 0.0
    val_acc = 0.0
    if test:
        feed_dict_test = {X:data.test.images, y:data.test.labels}
        test_acc = sess.run(accuracy, feed_dict=feed_dict_test)
    if validation:
        feed_dict_val = {X:data.validation.images, y:data.validation.labels}
        val_acc = sess.run(accuracy, feed_dict=feed_dict_val)
    return test_acc, val_acc


# Display Accuracy
def print_accuracy(validation=False, test=True):
    test_acc, val_acc = accuracy_eval(validation=validation, test=True)
    msg = 'After {:,} iterations:\n'.format(iterations)
    if test:
        msg += '\tTest Accuracy\t\t= {:.2%}\n'.format(test_acc)
    if validation:
        msg += '\tValidation Accuracy\t= {:.2%}\n'.format(val_acc)
    print(msg)


# Run the optimizer
def optimize(num_iter=100):
    global iterations
    global last_improvement
    global best_val_acc
    
    for i in tqdm(range(0, num_iter)):
        # Early stopping
        if iterations - last_improvement > improvement_requirement:
            print('\nStopping optimization @ {:,} iter due to none improvement in accuracy!!!'.format(iterations))
            break
        # Update iterations
        iterations += 1
        # Get training batch
        X_batch, y_batch = data.train.next_batch(batch_size=batch_size)
        feed_dict = {X: X_batch, y: y_batch}
        # Train the network
        sess.run(train_step, feed_dict=feed_dict)
        
        # Log after every `save_step`
        if i != 0 and ((i%save_step) == 0 or i == num_iter - 1):
            _, val_acc = accuracy_eval(validation=True, test=False)
            if val_acc > best_val_acc:
                # Save the session into the saver object
                saver.save(sess=sess, save_path=save_path)
                print('Iteration: {:,}'.format(iterations))
                print('Last validation = {:.02%}\tNew validation: {:.02%}'.format(best_val_acc, val_acc))
                # Update the best_val_acc and last improvement
                last_improvement = i
                best_val_acc = val_acc
    # Log optimization info
    print('Optimization details:')
    print_accuracy(validation=True, test=True)

In [13]:
print_accuracy()

After 0 iterations:
	Test Accuracy		= 8.47%



In [14]:
optimize(num_iter=100)

100%|██████████| 100/100 [00:09<00:00, 10.28it/s]

Iteration: 100
Last validation = 0.00%	New validation: 11.26%
Optimization details:





After 100 iterations:
	Test Accuracy		= 11.35%
	Validation Accuracy	= 11.26%



In [15]:
optimize(num_iter=900)

100%|██████████| 900/900 [00:59<00:00,  2.63it/s]


Optimization details:
After 1,000 iterations:
	Test Accuracy		= 10.09%
	Validation Accuracy	= 9.90%



In [16]:
optimize(num_iter=10000)

  1%|          | 99/10000 [00:06<10:10, 16.23it/s]



Stopping optimization @ 1,100 iter due to none improvement in accuracy!!!
Optimization details:
After 1,100 iterations:
	Test Accuracy		= 11.35%
	Validation Accuracy	= 11.26%



## Restoring the `tf.Session`

In [17]:
# reset global variables
sess.run(init)
# print the accuracy
print_accuracy(validation=True, test=True)

After 1,100 iterations:
	Test Accuracy		= 9.79%
	Validation Accuracy	= 9.42%



In [18]:
# Restore the session
saver.restore(sess=sess, save_path=save_path)
print_accuracy(validation=True, test=True)

INFO:tensorflow:Restoring parameters from log_dir/
After 1,100 iterations:
	Test Accuracy		= 11.35%
	Validation Accuracy	= 11.26%



## Close the `tf.Session`

In [19]:
sess.close()