# MNIST FCN with Virtual Branching

In [1]:
import tensorflow as tf
import numpy as np
import os
from scipy.special import softmax
import matplotlib.pyplot as plt

In [2]:
import vbranch as vb

In [3]:
save = True
model_id = 1

## Load Data

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
input_dim = 784
num_classes = 10

In [6]:
X_train_flat = X_train.reshape([-1, input_dim])
X_test_flat = X_test.reshape([-1, input_dim])

y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes)

## Build Model

In [7]:
BATCH_SIZE = 32
EPOCHS = 20
STEPS_PER_EPOCH = 100
NUM_BRANCHES = 3
model_path = './models/vb-mnist-B' + str(NUM_BRANCHES) + '_' + str(model_id)

In [8]:
model_path

'./models/vb-mnist-B3_1'

In [9]:
tf.reset_default_graph()

train_data = (X_train_flat.astype('float32'), y_train_one_hot)
test_data = (X_test_flat.astype('float32'), y_test_one_hot)

batch_size = tf.placeholder('int64', name='batch_size')

train_datasets = []
test_datasets = []
inputs = [None] * NUM_BRANCHES
labels_one_hot = [None] * NUM_BRANCHES
train_init_ops = []
test_init_ops = []

for i in range(NUM_BRANCHES):
    train_datasets.append(tf.data.Dataset.from_tensor_slices(train_data).\
        batch(batch_size).repeat().\
        shuffle(buffer_size=4*BATCH_SIZE))

    test_datasets.append(tf.data.Dataset.from_tensor_slices(test_data).\
        batch(batch_size))
    
    iterator = tf.data.Iterator.from_structure(train_datasets[i].output_types, 
                                           train_datasets[i].output_shapes)
    inputs[i], labels_one_hot[i] = iterator.get_next()    

    train_init_ops.append(iterator.make_initializer(train_datasets[i]))
    test_init_ops.append(iterator.make_initializer(test_datasets[i], 
                                                name='test_init_op_'+str(i+1)))

In [10]:
with tf.variable_scope('model_' + str(model_id), reuse=tf.AUTO_REUSE):
    model = vb.models.vbranch_fcn(inputs, ([128]*NUM_BRANCHES, 0), ([10]*NUM_BRANCHES, 0), 
                                  branches=NUM_BRANCHES)

In [11]:
model.summary()

i   Layer name          Output shapes       Num param 
------------------------------------------------------
    Input               [None,784]                    
------------------------------------------------------
    Input               [None,784]                    
------------------------------------------------------
    Input               [None,784]                    
------------------------------------------------------
0   fc1                 [None,128]          301440    
                        [None,128]                    
                        [None,128]                    
------------------------------------------------------
1   bn1                 [None,128]          768       
                        [None,128]                    
                        [None,128]                    
------------------------------------------------------
2   relu1               [None,128]          0         
                        [None,128]                    
          

In [12]:
tf.global_variables()

[<tf.Variable 'model_1/fc1_vb1_w:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb1_b:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb2_w:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb2_b:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb3_w:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'model_1/fc1_vb3_b:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb1_scale:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb1_beta:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb2_scale:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb2_beta:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb3_scale:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/bn1_vb3_beta:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'model_1/fc2_vb1_w:0' shape=(128, 10) dtype=float32_ref>,
 <tf.Variable 'model_1/fc2_vb1_b:0' shape=(10,) dtype=float32_ref>,

In [13]:
model.input

[<tf.Tensor 'IteratorGetNext:0' shape=(?, 784) dtype=float32>,
 <tf.Tensor 'IteratorGetNext_1:0' shape=(?, 784) dtype=float32>,
 <tf.Tensor 'IteratorGetNext_2:0' shape=(?, 784) dtype=float32>]

In [14]:
# Multi output loss
losses = []
train_ops = []
for i in range(len(model.output)):
    losses.append(vb.losses.softmax_cross_entropy_with_logits(labels=labels_one_hot[i], 
                                                              logits=model.output[i], 
                                                              name='loss_'+str(i+1)))
    train_ops.append(tf.train.AdamOptimizer(learning_rate=0.001).minimize(losses[i]))

In [15]:
# Train accuracies
train_acc_ops = []
for i in range(NUM_BRANCHES):
    pred_max = tf.one_hot(tf.argmax(tf.nn.softmax(model.output[i]), axis=-1), 
                          num_classes)
    train_acc_op = tf.reduce_mean(tf.reduce_sum(labels_one_hot[i]*pred_max, [1]), 
                                  name='train_acc_'+str(i+1))
    train_acc_ops.append(train_acc_op)

# Test accuracy
pred = tf.nn.softmax(tf.reduce_mean(model.output, [0]))
pred_max = tf.one_hot(tf.argmax(pred, axis=-1), num_classes)
test_acc_op = tf.reduce_mean(tf.reduce_sum(labels_one_hot[0]*pred_max, [1]), 
                             name='test_acc')

In [16]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, EPOCHS))
        progbar = tf.keras.utils.Progbar(STEPS_PER_EPOCH)
        
        sess.run(train_init_ops, feed_dict={batch_size: BATCH_SIZE})

        for i in range(STEPS_PER_EPOCH):
            _, train_losses, train_accs = sess.run([train_ops, losses, 
                                                 train_acc_ops])
            
            prog_vals = [('loss_'+str(b+1),train_losses[b]) for b in range(NUM_BRANCHES)]
            prog_vals += [('acc_'+str(b+1),train_accs[b]) for b in range(NUM_BRANCHES)]
            
            if i == STEPS_PER_EPOCH - 1:
                sess.run(test_init_ops, feed_dict={batch_size: len(X_test_flat)})
                val_losses, val_acc, indiv_accs = sess.run([losses, test_acc_op, 
                                                            train_acc_ops])
                
                prog_vals += [("val_loss", np.mean(val_losses)), ("val_acc", val_acc)] + \
                    [('ind_acc_'+str(b+1), indiv_accs[b]) for b in range(NUM_BRANCHES)]
            
            progbar.update(i+1, values=prog_vals)
    
    if save:
        saver = tf.train.Saver()
        path = os.path.join(model_path, 'ckpt')
        saver.save(sess, path)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


## Load Model

In [20]:
test_init_ops = ['test_init_op_'+str(i+1) for i in range(NUM_BRANCHES)]
losses = ['loss_'+str(i+1)+':0' for i in range(NUM_BRANCHES)]
train_acc_ops = ['train_acc_'+str(i+1)+':0' for i in range(NUM_BRANCHES)]

In [21]:
with tf.Session() as sess:
    model_path = './models/vb-mnist-B' + str(NUM_BRANCHES) + '_' + str(model_id)
    meta_path = os.path.join(model_path, 'ckpt.meta')
    ckpt = tf.train.get_checkpoint_state(model_path)

    imported_graph = tf.train.import_meta_graph(meta_path)
    imported_graph.restore(sess, ckpt.model_checkpoint_path)

    sess.run(test_init_ops, feed_dict={'batch_size:0': len(X_test_flat)})

    val_losses, val_acc, indiv_accs = sess.run([losses, 'test_acc:0', train_acc_ops])

INFO:tensorflow:Restoring parameters from ./models/vb-mnist-B3_1/ckpt


In [22]:
print('Loss:', np.mean(val_losses))
print('Acc:', val_acc)
print('Indiv accs:', indiv_accs)

Loss: 0.22113991
Acc: 0.956
Indiv accs: [0.9473, 0.9489, 0.9468]
