## Auxiliary Learning 

Train a tensorflow network to learn to recognize if two MNIST digits are the same. Then retrain on minimal data to recognize all nine digits!

In [1]:
import tensorflow as tf
import numpy as np

In [55]:
# Load the data

mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

p = np.random.permutation(len(y_train))

x_reg, y_reg = x_train[p][:5000], y_train[p][:5000]
x_aux, y_aux = x_train[p][5000:], y_train[p][5000:]

x_reg, x_aux, x_test = x_reg/255.0, x_aux/255.0, x_test/255.0
x_reg, x_aux, x_test = x_reg.reshape(-1,28*28), x_aux.reshape(-1,28*28), x_test.reshape(-1,28*28)

### Auxiliary Graph
Build and train the auxiliary graph. Use its layers to build a new model!

In [35]:
# Function for getting mini batches of matching and non matching images

def aux_mini_batch(x_aux, y_aux, size):
    p1 = np.random.permutation(len(y_aux))
    p2 = np.random.permutation(len(y_aux))
    
    s_left = x_aux[p1][y_aux[p1] == y_aux[p2]][:size]
    s_right = x_aux[p2][y_aux[p1] == y_aux[p2]][:size]
    same = np.ones(size)
    
    d_left = x_aux[p1][y_aux[p1] != y_aux[p2]][:size]
    d_right = x_aux[p2][y_aux[p1] != y_aux[p2]][:size]
    diff = np.zeros(size)

    left = np.concatenate((s_left, d_left), axis=0)
    right = np.concatenate((s_right, d_right), axis=0)
    labels = np.concatenate((same, diff))
    
    pf = np.random.permutation(len(labels))
    
    return left[pf], right[pf], labels[pf]

In [38]:
# Aux computational graph
# Two networks, each one takes an image and matches i

from tensorflow.contrib.layers import fully_connected 
from tensorflow.contrib.layers import batch_norm
from tensorflow.contrib.layers import dropout

tf.reset_default_graph()

is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

# Inputs for training
left = tf.placeholder(tf.float32, shape=(None,28*28), name='X')
right = tf.placeholder(tf.float32, shape=(None,28*28), name='X')
y = tf.placeholder(tf.int32, shape=(None), name='y')

left_drop = dropout(left,.5, is_training=is_training)
right_drop = dropout(right,.5, is_training=is_training)

# Nueral Network layers
with tf.name_scope('network'):
    he_init = tf.contrib.layers.variance_scaling_initializer()
    bn_params = {'is_training':is_training, 'decay':0.99, 'updates_collections':None}
    
    # Left side of NN
    with tf.name_scope('left_network'):
        with tf.contrib.framework.arg_scope([fully_connected], weights_initializer=he_init, activation_fn=tf.nn.elu, 
                                        normalizer_fn=batch_norm, normalizer_params=bn_params):
            l1 = dropout(fully_connected(left_drop, 100, scope='l1'))
            l2 = dropout(fully_connected(l1, 100, scope='l2'))
            l3 = dropout(fully_connected(l2, 100, scope='l3'))
            l4 = dropout(fully_connected(l3, 100, scope='l4'))
            l5 = dropout(fully_connected(l4, 100, scope='l5'))
    
    # Right side of NN
    with tf.name_scope('right_network'):
        with tf.contrib.framework.arg_scope([fully_connected], weights_initializer=he_init, activation_fn=tf.nn.elu, 
                                        normalizer_fn=batch_norm, normalizer_params=bn_params):
            r1 = dropout(fully_connected(right_drop, 100, scope='r1'))
            r2 = dropout(fully_connected(r1, 100, scope='r2'))
            r3 = dropout(fully_connected(r2, 100, scope='r3'))
            r4 = dropout(fully_connected(r3, 100, scope='r4'))
            r5 = dropout(fully_connected(r4, 100, scope='r5'))
    
    top = tf.concat([l5,r5], axis=1, name='top')
    output = fully_connected(top, 1, scope='output', activation_fn=tf.nn.sigmoid)

# Loss from Network
# Equivalent to performance
with tf.name_scope('loss'):
    loss = tf.losses.mean_squared_error(labels=y, predictions=output)

# SGD
with tf.name_scope('train'):
    optimizer = tf.train.AdamOptimizer()
    train = optimizer.minimize(loss)
    
init = tf.global_variables_initializer()

In [42]:
# Train the model

max_loss = np.inf
epochs = 0

# Save model
saver = tf.train.Saver()

# Log files
import os
from datetime import datetime
now = datetime.utcnow().strftime('%Y%m%d%H%M%S')
log_dir = os.path.join(os.getcwd(), 'tensorflow/logs/11-aux-learning-{}/'.format(now))
mse_summary = tf.summary.scalar('11_MSE_aux_learning',loss)
writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

with tf.Session() as sess:
    init.run()
    
    # SGD Updates
    for index, batch in enumerate(range(2000)):
        l_batch, r_batch, y_batch = aux_mini_batch(x_aux, y_aux, 1000)
        sess.run(train, feed_dict={left: l_batch, right:r_batch, y:y_batch, is_training:True})
        
        # Early stopping Logging and Checkpoint Saving
        if index % 2 == 0:
            saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_aux_learning.ckpt'))
            log_str = mse_summary.eval(feed_dict={left: l_batch, right:r_batch, y:y_batch, is_training:False})
            writer.add_summary(log_str, index)
            
            cur_loss = loss.eval(feed_dict={left: l_batch, right:r_batch, y:y_batch, is_training:False})
            print(cur_loss)
            if cur_loss < max_loss:
                max_loss = cur_loss
                epochs = 0
            else:
                epochs = epochs + 1
                if epochs == 5:
                    saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_aux_learning.ckpt'))
                    break

    # Save final model
    saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_aux_learning.ckpt'))

0.41190735
0.40104935
0.39424938
0.39006904
0.3798521
0.3758725
0.3690762
0.3643566
0.35875094
0.35261756
0.3475057
0.3392602
0.33750314
0.33381855
0.32999775
0.3236753
0.32369298
0.32036403
0.3175712
0.31400186
0.31197724
0.30757162
0.30443674
0.29822594
0.29896343
0.29569244
0.2943005
0.29203767
0.2900126
0.28767797
0.2850445
0.28298783
0.28146452
0.27855822
0.27762994
0.27572697
0.27355176
0.27378914
0.27139482
0.2715755
0.26874173
0.26764333
0.26627398
0.2660837
0.26478156
0.26381898
0.26254287
0.2617081
0.26137707
0.26064038
0.2598513
0.25861531
0.25847474
0.2579613
0.25740445
0.2572683
0.25644597
0.2558217
0.25560394
0.25504237
0.25468007
0.25469816
0.25420633
0.2538021
0.2535938
0.25332958
0.25309005
0.25283602
0.25272945
0.25236842
0.2522171
0.25215274
0.25187445
0.25169408
0.251594
0.25154728
0.25136232
0.2512448
0.25118205
0.25109467
0.25102684
0.2509364
0.2508367
0.25070387
0.25066844
0.25063768
0.2506633
0.2505549
0.25049782
0.25046346
0.2504297
0.250388
0.25033915
0.250336

### Supported Graph
Reuse layers of trained auxiliary network to quickly train our own MNIST classifier

In [152]:
# Load the data

mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

#Filter 0-4 for auxilary training in later notebook
x_train, y_train = x_train[y_train <= 4], y_train[y_train <= 4]
x_test, y_test = x_test[y_test <= 4], y_test[y_test <= 4]

x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test = x_train.reshape(-1,28*28), x_test.reshape(-1,28*28)

In [153]:
# Create method for getting batches for training

class mini_batches:
    
    def __init__(self, x, y, size):
        self.x = x
        self.y = y
        self.size = size
        self.index = 0
    
    def next_batch(self):
        if self.index + self.size >= len(self.x):            
            batch_x = self.x[self.index:]
            batch_y = self.y[self.index:]
            self.index = 0
            return batch_x, batch_y
        
        batch_x = self.x[self.index:self.index + self.size]
        batch_y = self.y[self.index:self.index + self.size]
        self.index = self.index + self.size
        return batch_x, batch_y

In [154]:
# Build the computational graph

from tensorflow.contrib.layers import fully_connected 
from tensorflow.contrib.layers import batch_norm
from tensorflow.contrib.layers import dropout

tf.reset_default_graph()


is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

# Inputs for training
X = tf.placeholder(tf.float32, shape=(None,28*28), name='X')
y = tf.placeholder(tf.int32, shape=(None), name='y')
X_drop = dropout(X,.5, is_training=is_training)

# Nueral Network layers
with tf.name_scope('network'):
    he_init = tf.contrib.layers.variance_scaling_initializer()
    bn_params = {'is_training':is_training, 'decay':0.99, 'updates_collections':None}
    
    with tf.contrib.framework.arg_scope([fully_connected], weights_initializer=he_init, activation_fn=tf.nn.elu, 
                                        normalizer_fn=batch_norm, normalizer_params=bn_params):
        h1 = dropout(fully_connected(X_drop, 100, scope='l1'))
        h2 = dropout(fully_connected(h1, 100, scope='l2'))
        h3 = dropout(fully_connected(h2, 100, scope='l3'))
        h4 = dropout(fully_connected(h3, 100, scope='l4'))
        h5 = dropout(fully_connected(h4, 100, scope='l5'))
        output = fully_connected(h5, 10, scope='output', activation_fn=None)

# Loss from Network
with tf.name_scope('loss'):
    x_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=output)
    loss = tf.reduce_mean(x_entropy, name='loss')

# SGD
with tf.name_scope('train'):
    optimizer = tf.train.AdamOptimizer()
    train = optimizer.minimize(loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='l[345]|output'))
    
# Evaluation of performance
with tf.name_scope('eval'):
    correct = tf.nn.in_top_k(output, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    
init = tf.global_variables_initializer()

In [155]:
#Map weights to current variables
value_list = []
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='l[12]'))

og_saver = tf.train.Saver(value_list)

value_list

[<tf.Variable 'l1/weights:0' shape=(784, 100) dtype=float32_ref>,
 <tf.Variable 'l1/BatchNorm/beta:0' shape=(100,) dtype=float32_ref>,
 <tf.Variable 'l1/BatchNorm/moving_mean:0' shape=(100,) dtype=float32_ref>,
 <tf.Variable 'l1/BatchNorm/moving_variance:0' shape=(100,) dtype=float32_ref>,
 <tf.Variable 'l2/weights:0' shape=(100, 100) dtype=float32_ref>,
 <tf.Variable 'l2/BatchNorm/beta:0' shape=(100,) dtype=float32_ref>,
 <tf.Variable 'l2/BatchNorm/moving_mean:0' shape=(100,) dtype=float32_ref>,
 <tf.Variable 'l2/BatchNorm/moving_variance:0' shape=(100,) dtype=float32_ref>]

In [156]:
# Train the model

# Mini batches
batches = mini_batches(x_train, y_train, 1000)
max_acc = 0
epochs = 0

# Save model
saver = tf.train.Saver()

# Log files
import os
from datetime import datetime
now = datetime.utcnow().strftime('%Y%m%d%H%M%S')
log_dir = os.path.join(os.getcwd(), 'tensorflow/logs/11-axu-reuse-learning-{}/'.format(now))
acc_summary = tf.summary.scalar('11_aux_reuse_acc',accuracy)
writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

with tf.Session() as sess:
    init.run()
    
    #Restore Weights from old model
    og_saver.restore(sess,'./tensorflow/models/11_aux_learning.ckpt')
    
    # SGD Updates
    for index, batch in enumerate(range(50000)):
        batch_x, batch_y = batches.next_batch()
        sess.run(train, feed_dict={X: batch_x, y:batch_y, is_training:True})
        
        # Early stopping and Checkpoint logging
        if index % 1000 == 0:
            saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_aux_reuse_learning.ckpt'))
            log_str = acc_summary.eval(feed_dict={X: x_test, y:y_test, is_training:False})
            writer.add_summary(log_str, index)
            
            cur_acc = accuracy.eval(feed_dict={X: x_test, y:y_test, is_training:False})
            print(cur_acc)
            if cur_acc > max_acc:
                max_acc = cur_acc
                epochs = 0
            else:
                epochs = epochs + 1
                if epochs > 3:
                    saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_aux_reuse_learning.ckpt'))
                    break

    # Save final model
    saver.save(sess, os.path.join(os.getcwd(), 'tensorflow/models/11_aux_reuse_learning.ckpt'))

INFO:tensorflow:Restoring parameters from ./tensorflow/models/11_aux_learning.ckpt
0.103911266
0.769605
0.7742751
0.78361547
0.7853668
0.7853668
0.7838101
0.78906405
0.7931504
0.78886944
0.7915937
0.7929558
0.7896478
