## Auxiliary Learning 

Train a tensorflow network to learn to recognize if two MNIST digits are the same. Then retrain on minimal data to recognize all nine digits!

In [104]:
import tensorflow as tf
import numpy as np

In [105]:
# Load the data

mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

p = np.random.permutation(len(y_train))

x_reg, y_reg = x_train[p][:5000], y_train[p][:5000]
x_aux, y_aux = x_train[p][5000:], y_train[p][5000:]

x_reg, x_aux, x_test = x_reg/255.0, x_aux/255.0, x_test/255.0
x_reg, x_aux, x_test = x_reg.reshape(-1,28*28), x_aux.reshape(-1,28*28), x_test.reshape(-1,28*28)

In [108]:
# Function for getting mini batches of matching and non matching images

def aux_mini_batch(x_aux, y_aux, size):
    p1 = np.random.permutation(len(y_aux))
    p2 = np.random.permutation(len(y_aux))
    
    s_left = x_aux[p1][y_aux[p1] == y_aux[p2]][:size]
    s_right = x_aux[p2][y_aux[p1] == y_aux[p2]][:size]
    same = np.ones(size)
    
    d_left = x_aux[p1][y_aux[p1] != y_aux[p2]][:size]
    d_right = x_aux[p2][y_aux[p1] != y_aux[p2]][:size]
    diff = np.zeros(size)

    left = np.concatenate((s_left, d_left), axis=0)
    right = np.concatenate((s_right, d_right), axis=0)
    labels = np.concatenate((same, diff))
    
    pf = np.random.permutation(len(labels))
    
    return left[pf], right[pf], labels[pf]

In [111]:
left, right, same = aux_mini_batch(x_aux, y_aux, 100)
same

array([0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0.,
       0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 0.,
       1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1.,
       0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
       0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1., 0.,
       1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1.,
       0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1.,
       1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1.,
       1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.,
       1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 0.,
       0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0.])

In [None]:
# Build the computational graph
# Two networks, each one takes an image and matches i

from tensorflow.contrib.layers import fully_connected 
from tensorflow.contrib.layers import batch_norm
from tensorflow.contrib.layers import dropout

tf.reset_default_graph()


is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

# Inputs for training
X = tf.placeholder(tf.float32, shape=(None,28*28), name='X')
y = tf.placeholder(tf.int32, shape=(None), name='y')
X_drop = dropout(X,.5, is_training=is_training)

# Nueral Network layers
with tf.name_scope('network'):
    he_init = tf.contrib.layers.variance_scaling_initializer()
    bn_params = {'is_training':is_training, 'decay':0.99, 'updates_collections':None}
    
    with tf.contrib.framework.arg_scope([fully_connected], weights_initializer=he_init, activation_fn=tf.nn.elu, 
                                        normalizer_fn=batch_norm, normalizer_params=bn_params):
        h1 = dropout(fully_connected(X_drop, 100, scope='h1'))
        h2 = dropout(fully_connected(h1, 100, scope='h2'))
        h3 = dropout(fully_connected(h2, 100, scope='h3'))
        h4 = dropout(fully_connected(h3, 100, scope='h4'))
        h5 = dropout(fully_connected(h4, 100, scope='h5'))
        output = fully_connected(h5, 5, scope='output', activation_fn=None)

# Loss from Network
with tf.name_scope('loss'):
    x_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=output)
    loss = tf.reduce_mean(x_entropy, name='loss')

# SGD
with tf.name_scope('train'):
    optimizer = tf.train.AdamOptimizer()
    train = optimizer.minimize(loss)
    
# Evaluation of performance
with tf.name_scope('eval'):
    correct = tf.nn.in_top_k(output, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    
init = tf.global_variables_initializer()