# MNIST FCN with Virtual Branching

In [1]:
import tensorflow as tf
import numpy as np
import os
from scipy.special import softmax
import matplotlib.pyplot as plt

In [2]:
import vbranch as vb

In [3]:
save = True
model_id = 1

## Load Data

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
input_dim = 784
num_classes = 10

In [6]:
X_train_flat = X_train.reshape([-1, input_dim])
X_test_flat = X_test.reshape([-1, input_dim])

y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes)

## Build Model

In [7]:
BATCH_SIZE = 32
EPOCHS = 10
STEPS_PER_EPOCH = 100
model_path = './models/vb-mnist_' + str(model_id)
NUM_BRANCHES = 2

In [8]:
tf.reset_default_graph()

train_data = (X_train_flat.astype('float32'), y_train_one_hot)
test_data = (X_test_flat.astype('float32'), y_test_one_hot)

batch_size = tf.placeholder('int64', name='batch_size')

train_datasets = []
test_datasets = []
inputs = [None] * NUM_BRANCHES
labels_one_hot = [None] * NUM_BRANCHES
train_init_ops = []
test_init_ops = []

for i in range(NUM_BRANCHES):
    train_datasets.append(tf.data.Dataset.from_tensor_slices(train_data).\
        batch(batch_size).repeat().\
        shuffle(buffer_size=4*BATCH_SIZE))

    test_datasets.append(tf.data.Dataset.from_tensor_slices(test_data).\
        batch(batch_size))
    
    iterator = tf.data.Iterator.from_structure(train_datasets[i].output_types, 
                                           train_datasets[i].output_shapes)
    inputs[i], labels_one_hot[i] = iterator.get_next()    

    train_init_ops.append(iterator.make_initializer(train_datasets[i]))
    test_init_ops.append(iterator.make_initializer(test_datasets[i], 
                                                name='test_init_op_'+str(i+1)))

In [9]:
with tf.variable_scope('model_' + str(model_id), reuse=tf.AUTO_REUSE):
    model = vb.models.vbranch_fcn(inputs, ([128,128], 0), ([10,10], 0), 
                                  branches=NUM_BRANCHES)

In [10]:
model.summary()

i   Layer name          Output shapes       Num param 
------------------------------------------------------
    Input               [None,784]                    
------------------------------------------------------
    Input               [None,784]                    
------------------------------------------------------
0   fc1                 [None,128]          200960    
                        [None,128]                    
------------------------------------------------------
1   bn1                 [None,128]          512       
                        [None,128]                    
------------------------------------------------------
2   relu1               [None,128]          0         
                        [None,128]                    
------------------------------------------------------
3   fc2                 [None,10]           2580      
                        [None,10]                     
------------------------------------------------------
4   bn2   

In [11]:
tf.global_variables()

[<tf.Variable 'model_1/fc1_vb1_w:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb1_b:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb2_w:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb2_b:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb1_scale:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb1_beta:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb2_scale:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb2_beta:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/fc2_vb1_w:0' shape=(128, 10) dtype=float32_ref>,
 <tf.Variable 'model_1/fc2_vb1_b:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'model_1/fc2_vb2_w:0' shape=(128, 10) dtype=float32_ref>,
 <tf.Variable 'model_1/fc2_vb2_b:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn2_vb1_scale:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn2_vb1_beta:0' shape=(10,) dtype=float32_ref>,
 <t

In [12]:
model.input

[<tf.Tensor 'IteratorGetNext:0' shape=(?, 784) dtype=float32>,
 <tf.Tensor 'IteratorGetNext_1:0' shape=(?, 784) dtype=float32>]

In [13]:
# Single output
# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_one_hot, 
#                                                                  logits=model.output), name='loss')

# Multi output
losses = []
for i in range(len(model.output)):
    losses.append(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_one_hot[i], 
                                                                 logits=model.output[i])))
loss = tf.reduce_sum(losses, name='loss')
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

In [14]:
# pred = tf.nn.softmax(model.output, name='pred')

# Train accuracies
train_acc_ops = []
for i in range(NUM_BRANCHES):
    pred_max = tf.one_hot(tf.argmax(tf.nn.softmax(model.output[i]), axis=-1), 
                          num_classes)
    train_acc_ops.append(tf.reduce_mean(tf.reduce_sum(labels_one_hot[i]*pred_max, [1])))

# Test accuracy
pred = tf.nn.softmax(tf.reduce_mean(model.output, [0]))
pred_max = tf.one_hot(tf.argmax(pred, axis=-1), num_classes)
test_acc_op = tf.reduce_mean(tf.reduce_sum(labels_one_hot[0]*pred_max, [1]), name='acc')

In [15]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, EPOCHS))
        progbar = tf.keras.utils.Progbar(STEPS_PER_EPOCH)
        
        sess.run(train_init_ops, feed_dict={batch_size: BATCH_SIZE})

        for i in range(STEPS_PER_EPOCH):
            _, loss_value, train_accs = sess.run([train_op, loss, train_acc_ops])
            
#             plt.subplot(121)
#             plt.imshow(train_vals[-2][0].reshape((28,28)))
#             plt.title('1 e{}i{}'.format(e, i))
#             plt.subplot(122)
#             plt.imshow(train_vals[-1][0].reshape((28,28)))
#             plt.title('2 e{}i{}'.format(e, i))
#             plt.show()
            
            progbar_vals = [("loss", loss_value),] + \
                [('acc_'+str(b+1), train_accs[b]) for b in range(NUM_BRANCHES)]
            
            if i == STEPS_PER_EPOCH - 1:
                sess.run(test_init_ops, feed_dict={batch_size: len(X_test_flat)})
                val_loss, val_acc, indiv_accs = sess.run([loss, test_acc_op, train_acc_ops])
                
#                 plt.subplot(121)
#                 plt.imshow(test_im[0][0].reshape((28,28)))
#                 plt.title('1 e{}i{}'.format(e, i))
#                 plt.subplot(122)
#                 plt.imshow(test_im[1][0].reshape((28,28)))
#                 plt.title('2 e{}i{}'.format(e, i))
#                 plt.show()
                
                progbar_vals += [("val_loss", val_loss), ("val_acc", val_acc)] + \
                    [('ind_acc_'+str(b+1), indiv_accs[b]) for b in range(NUM_BRANCHES)]
            
            progbar.update(i+1, values=progbar_vals)
    
    if save:
        saver = tf.train.Saver()
        path = os.path.join(model_path, 'ckpt')
        saver.save(sess, path)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
