## Reusing Deep Learning Models

Leverage TensorFlow tools to reuse lower layers of a deep nueral network and train a model on MNIST data (5-9) using minimal amounts of data.

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
# Load the data

mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

sample_size = 500
p = np.random.permutation(5 * sample_size)

#Filter 5-9 for auxilary training
x_train = np.concatenate((
    x_train[y_train == 5][:sample_size],
    x_train[y_train == 6][:sample_size],
    x_train[y_train == 7][:sample_size],
    x_train[y_train == 8][:sample_size],
    x_train[y_train == 9][:sample_size]
), axis=0)[p]

y_train = np.concatenate((
    y_train[y_train == 5][:sample_size],
    y_train[y_train == 6][:sample_size],
    y_train[y_train == 7][:sample_size],
    y_train[y_train == 8][:sample_size],
    y_train[y_train == 9][:sample_size]
), axis=0)[p]

x_test, y_test = x_test[y_test > 4], y_test[y_test > 4]

#Clean
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test = x_train.reshape(-1,28*28), x_test.reshape(-1,28*28)
y_train, y_test = y_train - 5, y_test - 5

In [3]:
# Create method for getting batches for training

class mini_batches:
    
    def __init__(self, x, y, size):
        self.x = x
        self.y = y
        self.size = size
        self.index = 0
    
    def next_batch(self):
        if self.index + self.size >= len(self.x):            
            batch_x = self.x[self.index:]
            batch_y = self.y[self.index:]
            self.index = 0
            return batch_x, batch_y
        
        batch_x = self.x[self.index:self.index + self.size]
        batch_y = self.y[self.index:self.index + self.size]
        self.index = self.index + self.size
        return batch_x, batch_y

In [4]:
# Build the computational graph

from tensorflow.contrib.layers import fully_connected 
from tensorflow.contrib.layers import batch_norm
from tensorflow.contrib.layers import dropout

tf.reset_default_graph()


is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

# Inputs for training
X = tf.placeholder(tf.float32, shape=(None,28*28), name='X')
y = tf.placeholder(tf.int32, shape=(None), name='y')
X_drop = dropout(X,.5, is_training=is_training)

# Nueral Network layers
with tf.name_scope('network'):
    he_init = tf.contrib.layers.variance_scaling_initializer()
    bn_params = {'is_training':is_training, 'decay':0.99, 'updates_collections':None}
    
    with tf.contrib.framework.arg_scope([fully_connected], weights_initializer=he_init, activation_fn=tf.nn.elu, 
                                        normalizer_fn=batch_norm, normalizer_params=bn_params):
        h1 = dropout(fully_connected(X_drop, 100, scope='h1'))
        h2 = dropout(fully_connected(h1, 100, scope='h2'))
        h3 = dropout(fully_connected(h2, 100, scope='h3'))
        h4 = dropout(fully_connected(h3, 100, scope='h4'))
        h5 = dropout(fully_connected(h4, 100, scope='h5'))
        output = fully_connected(h5, 5, scope='output', activation_fn=None)

# Loss from Network
with tf.name_scope('loss'):
    x_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=output)
    loss = tf.reduce_mean(x_entropy, name='loss')

# SGD
# NOTE: freeze lower layers at optimizer call
with tf.name_scope('train'):
    optimizer = tf.train.AdamOptimizer()
    train = optimizer.minimize(loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='h[345]|output'))
    
# Evaluation of performance
with tf.name_scope('eval'):
    correct = tf.nn.in_top_k(output, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

In [5]:
#Map weights to current variables
init = tf.global_variables_initializer()

value_list = []
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='h[12]'))

og_saver = tf.train.Saver(value_list)

In [6]:
# Train the model

# Mini batches
batches = mini_batches(x_train, y_train, sample_size)
max_acc = 0
epochs = 0

# For saving model
saver = tf.train.Saver()

# Log files
import os
from datetime import datetime
now = datetime.utcnow().strftime('%Y%m%d%H%M%S')
log_dir = os.path.join(os.getcwd(), 'tensorflow/logs/11-reuse-model-{}/'.format(now))
acc_summary = tf.summary.scalar('11_reuse_model_accuracy',accuracy)
writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

with tf.Session() as sess:
    init.run()
    
    #Restore Weights from old model
    og_saver.restore(sess,'./tensorflow/models/11_deep_learning.ckpt')
    
    # SGD Updates
    for index, batch in enumerate(range(sample_size)):
        batch_x, batch_y = batches.next_batch()
        sess.run(train, feed_dict={X: batch_x, y:batch_y, is_training:True})
        
        # Early stopping and Checkpoint logging
        if index % 50 == 0:
            saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_reuse_model.ckpt'))
            log_str = acc_summary.eval(feed_dict={X: x_test, y:y_test, is_training:False})
            writer.add_summary(log_str, index)
            
            cur_acc = accuracy.eval(feed_dict={X: x_test, y:y_test, is_training:False})
            print(cur_acc)
            if cur_acc > max_acc:
                max_acc = cur_acc
                epochs = 0
            else:
                epochs = epochs + 1
                if epochs > 3:
                    saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_reuse_model.ckpt'))
                    break

    # Save final model
    saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_reuse_model.ckpt'))

INFO:tensorflow:Restoring parameters from ./tensorflow/models/11_deep_learning.ckpt
0.21826784
0.4929027
0.5671672
0.6056367
0.6218885
0.6295001
0.6336145
0.6406089
0.6533635
0.64945483


Pretty meh... how do we get the network to abstract better at lower layers..