# MNIST with Virtual Branching

In [1]:
import tensorflow as tf
import numpy as np
import os
from scipy.special import softmax
import matplotlib.pyplot as plt
import time
from sklearn.manifold import TSNE

In [2]:
import vbranch as vb

In [3]:
save = False
model_id = 1
architecture = 'cnn'

## Load Data

In [4]:
input_dim = 784
num_classes = 10

In [5]:
(X_train, y_train_one_hot), (X_test, y_test_one_hot) = vb.datasets.mnist.load_data(format=architecture)

## Train

### Build Model

In [6]:
BATCH_SIZE = 32
EPOCHS = 10
STEPS_PER_EPOCH = 100
NUM_BRANCHES = 3
SHARED_FRAC = 1
model_path = os.path.join('models', 'vb-mnist-{}-B{:d}-S{:.2f}_{:d}'.format(architecture,
    NUM_BRANCHES, SHARED_FRAC, model_id))

In [7]:
model_path

'models/vb-mnist-cnn-B3-S1.00_1'

In [8]:
tf.reset_default_graph()

train_data = (X_train.astype('float32'), y_train_one_hot)
test_data = (X_test.astype('float32'), y_test_one_hot)

batch_size = tf.placeholder('int64', name='batch_size')

train_datasets = []
test_datasets = []
inputs = [None] * NUM_BRANCHES
labels_one_hot = [None] * NUM_BRANCHES
train_init_ops = []
test_init_ops = []

for i in range(NUM_BRANCHES):
    train_datasets.append(tf.data.Dataset.from_tensor_slices(train_data).\
        batch(batch_size).repeat().\
        shuffle(buffer_size=4*BATCH_SIZE))

    test_datasets.append(tf.data.Dataset.from_tensor_slices(test_data).\
        batch(batch_size))
    
    iterator = tf.data.Iterator.from_structure(train_datasets[i].output_types, 
                                           train_datasets[i].output_shapes)
    inputs[i], labels_one_hot[i] = iterator.get_next(name='input_'+str(i+1))    

    train_init_ops.append(iterator.make_initializer(train_datasets[i]))
    test_init_ops.append(iterator.make_initializer(test_datasets[i], 
                                                name='test_init_op_'+str(i+1)))

In [9]:
if architecture == 'fcn':
    model = vb.vbranch_simple_fcn(inputs,
        ([128]*NUM_BRANCHES, int(128*SHARED_FRAC)), ([10]*NUM_BRANCHES, int(10*SHARED_FRAC)),
        branches=NUM_BRANCHES, name='model_' + str(model_id))
elif architecture == 'cnn':
    model = vb.vbranch_simple_cnn(inputs, (num_classes, 0),
        ([16]*NUM_BRANCHES, int(16*SHARED_FRAC)), ([32]*NUM_BRANCHES, int(32*SHARED_FRAC)),
        branches=NUM_BRANCHES, name='model_' + str(model_id))

In [10]:
model.summary()

i   Layer name         Output shape         Num param  Inbound            
--------------------------------------------------------------------------
    Input              [None,28,28,1]                                     
--------------------------------------------------------------------------
    Input              [None,28,28,1]                                     
--------------------------------------------------------------------------
    Input              [None,28,28,1]                                     
--------------------------------------------------------------------------
0   conv2d_1_1         [None,26,26,16] []   160        input              
                       [None,26,26,16] []                                 
                       [None,26,26,16] []                                 
--------------------------------------------------------------------------
1   bn_1_1             [None,26,26,16] []   32         conv2d_1_1         
                       [N

In [11]:
tf.global_variables()

[<tf.Variable 'model_1/conv2d_1_1_shared_to_shared_f:0' shape=(3, 3, 1, 16) dtype=float32_ref>,
 <tf.Variable 'model_1/conv2d_1_1_shared_to_shared_b:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn_1_1_shared_to_shared_scale:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn_1_1_shared_to_shared_beta:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'model_1/conv2d_1_2_shared_to_shared_f:0' shape=(3, 3, 16, 16) dtype=float32_ref>,
 <tf.Variable 'model_1/conv2d_1_2_shared_to_shared_b:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn_1_2_shared_to_shared_scale:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn_1_2_shared_to_shared_beta:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'model_1/conv2d_2_1_shared_to_shared_f:0' shape=(3, 3, 16, 32) dtype=float32_ref>,
 <tf.Variable 'model_1/conv2d_2_1_shared_to_shared_b:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn_2_1_shared_to_shared_scale:0' shape=(32,) dtype=float32_ref>,
 <tf.Var

In [12]:
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

# Get training ops
model.compile(optimizer, 'softmax_cross_entropy_with_logits', labels_one_hot=labels_one_hot)

### Run Ops

In [13]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, EPOCHS))
        progbar = tf.keras.utils.Progbar(STEPS_PER_EPOCH)
        
        sess.run(train_init_ops, feed_dict={batch_size: BATCH_SIZE})

        for i in range(STEPS_PER_EPOCH):
            _, train_losses, train_accs = sess.run([model.train_ops, model.losses, 
                                                    model.train_accs])
            
            prog_vals = [('loss_'+str(b+1),train_losses[b]) for b in range(NUM_BRANCHES)]
            prog_vals += [('acc_'+str(b+1),train_accs[b]) for b in range(NUM_BRANCHES)]
            
            if i == STEPS_PER_EPOCH - 1:
                sess.run(test_init_ops, feed_dict={batch_size: len(X_test)})
                val_losses, val_acc, indiv_accs = sess.run([model.losses, model.test_acc, 
                                                            model.train_accs])
                
                prog_vals += [("val_loss", np.mean(val_losses)), ("val_acc", val_acc)] + \
                    [('ind_acc_'+str(b+1), indiv_accs[b]) for b in range(NUM_BRANCHES)]
            
            progbar.update(i+1, values=prog_vals)
    
    if save:
        saver = tf.train.Saver()
        path = os.path.join(model_path, 'ckpt')
        saver.save(sess, path)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## Load Model

In [None]:
test_init_ops = ['test_init_op_'+str(i+1) for i in range(NUM_BRANCHES)]
losses = ['loss_'+str(i+1)+':0' for i in range(NUM_BRANCHES)]
train_acc_ops = ['train_acc_'+str(i+1)+':0' for i in range(NUM_BRANCHES)]

inputs = ['input_{}:0'.format(i+1) for i in range(NUM_BRANCHES)]
labels_one_hot = ['input_{}:1'.format(i+1) for i in range(NUM_BRANCHES)]
outputs = ['model_{}/output_vb{}:0'.format(model_id, i+1) for i in range(NUM_BRANCHES)]

In [None]:
with tf.Session() as sess:
    model_path = os.path.join('models', 'vb-mnist-{}-B{:d}-S{:.2f}_{:d}'.format(architecture,
        NUM_BRANCHES, SHARED_FRAC, model_id))
    meta_path = os.path.join(model_path, 'ckpt.meta')
    ckpt = tf.train.get_checkpoint_state(model_path)

    imported_graph = tf.train.import_meta_graph(meta_path)
    imported_graph.restore(sess, ckpt.model_checkpoint_path)

    sess.run(test_init_ops, feed_dict={'batch_size:0': len(X_test)})
    val_losses, val_acc, indiv_accs = sess.run([losses, 'test_acc:0', train_acc_ops])
    
    sample_size = 250
    sess.run(test_init_ops, feed_dict={'batch_size:0':sample_size})
    X_test_samples, y_test_samples, features = sess.run([inputs, labels_one_hot, outputs])

In [None]:
print('Loss:', np.mean(val_losses))
print('Acc:', val_acc)
print('Indiv accs:', indiv_accs)

## Feature Visualization

In [None]:
mean_features = np.mean(features, axis=0)
print(mean_features.shape)

In [None]:
start = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(mean_features)

print('t-SNE done! Time elapsed: {} seconds'.format(time.time()-start))

In [None]:
labels = np.argmax(y_test_samples[0], axis=-1)

In [None]:
plt.scatter(tsne_results[:,0], tsne_results[:,1], c=labels, cmap=plt.cm.jet)
plt.colorbar()
plt.show()