# Classification with Virtual Branching

In [1]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import time
from sklearn.manifold import TSNE

In [2]:
import vbranch as vb
from vbranch.utils.training_utils import get_data, bag_samples, get_data_iterator
from vbranch.utils import TFSessionGrow, restore_sess
from vbranch.callbacks import classification_acc
from vbranch.applications.fcn import SimpleFCNv1, SimpleFCNv2
from vbranch.applications.cnn import SimpleCNNSmall

In [3]:
SAVE = False
MODEL_ID = 1
ARCHITECTURE = 'fcn2'
DATASET = 'mnist'
NUM_CLASSES = 10
NUM_FEATURES = 784
SAMPLES_PER_CLASS = 100
BAGGING_SAMPLES = 1

NUM_BRANCHES = 1
SHARED_FRAC = 0.25
BATCH_SIZE = 32
EPOCHS = 15
STEPS_PER_EPOCH = 100

## Data

In [4]:
(X_train, y_train), (X_test, y_test) = get_data(DATASET, ARCHITECTURE, NUM_CLASSES,
                                                NUM_FEATURES, SAMPLES_PER_CLASS)
x_shape = (None,) + X_train.shape[1:]
y_shape = (None, NUM_CLASSES)

In [5]:
# Perform bagging
x_train_list, y_train_list = bag_samples(X_train, y_train, NUM_BRANCHES, 
                                         max_samples=BAGGING_SAMPLES)

## Train

### Build

In [6]:
if not os.path.isdir('models'):
    os.system('mkdir models')

if NUM_BRANCHES == 1:
    model_name = '{}-{}_{:d}'.format(DATASET, ARCHITECTURE, MODEL_ID)
else:
    model_name = 'vb-{}-{}-B{:d}-S{:.2f}_{:d}'.format(DATASET, ARCHITECTURE,
                                                      NUM_BRANCHES, SHARED_FRAC, MODEL_ID)
model_path = os.path.join('models', model_name)
print(model_path)

models/mnist-fcn2_1


In [7]:
tf.reset_default_graph()

inputs, labels_one_hot, train_init_op, test_init_op = get_data_iterator(x_shape, y_shape, batch_size=BATCH_SIZE, 
                                                      n=NUM_BRANCHES, share_xy=BAGGING_SAMPLES == 1)

Instructions for updating:
Colocations handled automatically by placer.


In [8]:
test_init_op

<tf.Operation 'test_init_op' type=MakeIterator>

In [9]:
def build_model(architecture,inputs,labels, num_classes,num_branches,model_id, shared_frac):
    if num_branches > 1 and isinstance(inputs, tf.Tensor):
        inputs = [inputs] * num_branches
        
    name = 'model'
    
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        if architecture == 'fcn':
            model = SimpleFCNv1(inputs, num_classes, name=name, shared_frac=shared_frac)
        elif architecture == 'fcn2':
            model = SimpleFCNv2(inputs, num_classes, name=name, shared_frac=shared_frac)
        elif architecture == 'cnn':
            model = SimpleCNNSmall(inputs, num_classes, name=name, shared_frac=shared_frac)
        else:
            raise ValueError('invalid model')

        if type(labels) is list or num_branches == 1:
            labels_list = labels
        else:
            labels_list = [labels] * num_branches

        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        model.compile(optimizer, 'softmax_cross_entropy_with_logits', 
                      train_init_op, test_init_op, 
                      labels_one_hot=labels_list, 
                      callbacks={'acc':classification_acc(NUM_BRANCHES)})

    return model

In [10]:
model = build_model(ARCHITECTURE, inputs, labels_one_hot, NUM_CLASSES,
                    NUM_BRANCHES, MODEL_ID, SHARED_FRAC)
model.summary()

Instructions for updating:
Use tf.cast instead.
i  Layer name                Output shape  Parameters       Num param  Inbound  
--------------------------------------------------------------------------------
   Input                     [None,784]                                         
--------------------------------------------------------------------------------
0  fc1 (Dense)               [None,512]    [784,512] [512]  401920     input:0  
--------------------------------------------------------------------------------
1  bn1 (BatchNormalization)  [None,512]    [512] [512]      1024       fc1      
--------------------------------------------------------------------------------
2  relu1 (Activation)        [None,512]                     0          bn1      
--------------------------------------------------------------------------------
3  fc2 (Dense)               [None,256]    [512,256] [256]  131328     relu1    
-------------------------------------------------------------

In [11]:
print(model.output)

Tensor("model/output/output:0", shape=(?, 10), dtype=float32)


### Fit

In [12]:
if NUM_BRANCHES == 1 or BAGGING_SAMPLES == 1:
    train_dict = {'x:0': X_train, 'y:0': y_train, 'batch_size:0': BATCH_SIZE}
else:
    train_dict = {}
    for i in range(NUM_BRANCHES):
        train_dict['vb{}_x:0'.format(i+1)] = x_train_list[i]
        train_dict['vb{}_y:0'.format(i+1)] = y_train_list[i]
    train_dict['batch_size:0'] = BATCH_SIZE
    
val_dict = {'x:0': X_test, 'y:0': y_test, 'batch_size:0': len(X_test)}

history = model.fit(train_dict, EPOCHS, STEPS_PER_EPOCH,
                    val_dict=val_dict, log_path=model_path if SAVE else None)

Epoch 1/15
 - 1s - loss: 0.4878 - acc: 0.9156 - val_loss: 0.2815 - val_acc: 0.9140
Epoch 2/15
 - 1s - loss: 0.1792 - acc: 0.9307 - val_loss: 0.2297 - val_acc: 0.9264
Epoch 3/15
 - 1s - loss: 0.1050 - acc: 0.9373 - val_loss: 0.2113 - val_acc: 0.9321
Epoch 4/15
 - 1s - loss: 0.0940 - acc: 0.9385 - val_loss: 0.2027 - val_acc: 0.9346
Epoch 5/15
 - 1s - loss: 0.0767 - acc: 0.9426 - val_loss: 0.1954 - val_acc: 0.9399
Epoch 6/15
 - 1s - loss: 0.0500 - acc: 0.9496 - val_loss: 0.1780 - val_acc: 0.9427
Epoch 7/15
 - 1s - loss: 0.0448 - acc: 0.9470 - val_loss: 0.1826 - val_acc: 0.9437
Epoch 8/15
 - 1s - loss: 0.0331 - acc: 0.9511 - val_loss: 0.1750 - val_acc: 0.9455
Epoch 9/15
 - 1s - loss: 0.0276 - acc: 0.9523 - val_loss: 0.1675 - val_acc: 0.9490
Epoch 10/15
 - 1s - loss: 0.0277 - acc: 0.9510 - val_loss: 0.1712 - val_acc: 0.9488
Epoch 11/15
 - 1s - loss: 0.0234 - acc: 0.9458 - val_loss: 0.1791 - val_acc: 0.9459
Epoch 12/15
 - 1s - loss: 0.0201 - acc: 0.9482 - val_loss: 0.1855 - val_acc: 0.9436
E

## Evaluation

### Baseline

In [13]:
assert NUM_BRANCHES == 1

model_id_list = [1]
baseline_acc_list = []

for model_id in model_id_list:
    tf.reset_default_graph()
    model_name = '{}-{}_{:d}'.format(DATASET, ARCHITECTURE, model_id)
    model_path = os.path.join('models', model_name)
    
    with TFSessionGrow() as sess:
        restore_sess(sess, model_path)
        acc = baseline_classification(sess, X_test, y_test)
        print('Model {} acc:'.format(model_id), acc)
        baseline_acc_list.append(acc)
        
print('Mean acc:', np.mean(baseline_acc_list), ', std:', np.std(baseline_acc_list))

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from models/mnist-fcn_1/ckpt


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/gong/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-13-cf44fd73e05c>", line 14, in <module>
    model_name='model_'+str(model_id)+'_1')
TypeError: baseline_classification() got multiple values for argument 'model_name'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/gong/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 1863, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'TypeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/gong/anaconda3/lib/python3.5/site-packages/IPython/core/ultratb.py", line 1095, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_

TypeError: baseline_classification() got multiple values for argument 'model_name'

### Ensemble

In [None]:
test_outputs = []
test_losses = []
test_accs = []

num_models = 4
graphs = [tf.Graph() for _ in range(5)]
sessions = [tf.Session(graph=g) for g in graphs]

for i in np.random.choice(5, num_models, replace=False):
    with graphs[i].as_default():
        model_path = 'models/mnist-{}_{}'.format(architecture, i + 1)
        meta_path = os.path.join(model_path, 'ckpt.meta')
        ckpt = tf.train.get_checkpoint_state(model_path)
        
        imported_graph = tf.train.import_meta_graph(meta_path)
        imported_graph.restore(sessions[i], ckpt.model_checkpoint_path)
                
        sessions[i].run('test_init_op', feed_dict={'batch_size:0': len(X_test)})
        
        output, loss, acc = sessions[i].run(['model_%d'%(i+1)+'/'+'output:0', 
                                             'loss:0', 'acc:0'])
        test_outputs.append(output)
        test_losses.append(loss)
        test_accs.append(acc)

### Virtual Branching

In [None]:
model_id_list = [1]
vbranch_acc_list = []

for model_id in model_id_list:
    tf.reset_default_graph()
    model_name = 'vb-{}-{}-B{:d}-S{:.2f}_{:d}'.format(DATASET, ARCHITECTURE,
                                            NUM_BRANCHES, SHARED_FRAC, model_id)
    model_path = os.path.join('models', model_name)
    
    with TFSessionGrow() as sess:
        restore_sess(sess, model_path)
        acc, branch_acc = vbranch_classification(sess, X_test, y_test, 
                                     model_name='model_'+str(model_id)+'_1', 
                                     num_classes=NUM_CLASSES, 
                                     n_branches=NUM_BRANCHES)
        print('Model {} acc:'.format(model_id), acc, branch_acc)
        vbranch_acc_list.append(acc)
        
print('Mean acc:', np.mean(vbranch_acc_list), ', std:', np.std(vbranch_acc_list))

### Feature Visualization

In [None]:
def get_tsne(features):
    start = time.time()
    tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
    tsne_results = tsne.fit_transform(features)

    print('t-SNE done! Time elapsed: {} seconds'.format(time.time()-start))
    return tsne_results

In [None]:
X_sample, y_sample = bag_samples(X_train, y_train, 1, max_samples=250)
print(X_sample.shape, y_sample.shape)

In [None]:
tf.reset_default_graph()
model_name = '{}-{}_{:d}'.format(DATASET, ARCHITECTURE, MODEL_ID)
model_path = os.path.join('models', model_name)

with TFSessionGrow() as sess:
    restore_sess(sess, model_path)
    baseline_features = sess.run('model_{}_1/output/output:0'.format(MODEL_ID), 
                                 feed_dict={'x_test:0':X_sample})

In [None]:
tf.reset_default_graph()
model_name = 'vb-{}-{}-B{:d}-S{:.2f}_{:d}'.format(DATASET, ARCHITECTURE,
                                            NUM_BRANCHES, SHARED_FRAC, MODEL_ID)
model_path = os.path.join('models', model_name)
    
with TFSessionGrow() as sess:
    restore_sess(sess, model_path)
    
    outputs = []
    for i in range(NUM_BRANCHES):
        name = os.path.join('model_{}_1/output/vb{}/output:0'.format(MODEL_ID, i+1))
        outputs.append(name)
        
    vbranch_features = sess.run(outputs, feed_dict={'x_test:0':X_sample})
    mean_vbranch_features = np.mean(vbranch_features, axis=0)

In [None]:
baseline_tsne = get_tsne(baseline_features)
vbranch_tsne = get_tsne(mean_vbranch_features)
sample_labels = np.argmax(y_sample, axis=-1)

plt.figure(figsize=(15,5))

plt.subplot(1, 2, 1)
plt.scatter(baseline_tsne[:,0], baseline_tsne[:,1], c=sample_labels, cmap=plt.cm.jet)
plt.colorbar()
plt.title('Baseline')

plt.subplot(1, 2, 2)
plt.scatter(vbranch_tsne[:,0], vbranch_tsne[:,1], c=sample_labels, cmap=plt.cm.jet)
plt.colorbar()
plt.title('Virtual Branching')

plt.show()

### Correlation and Strength

In [None]:
from vbranch.utils.generic_utils import get_model_path, get_vb_model_path
from vbranch.utils.test_utils import compute_correlation_strength, compute_acc_from_logits

In [None]:
y_labels = np.argmax(y_test, axis=-1)

model_id_list = range(1, 9)
output_list = []
acc_list = []
pred_list = []

for model_id in model_id_list:
    tf.reset_default_graph()

    with tf.Session() as sess:
        model_path = get_model_path(DATASET, ARCHITECTURE, NUM_CLASSES, 
                                    SAMPLES_PER_CLASS, model_id)
        restore_sess(sess, model_path)
        output = sess.run('model_{}_1/output:0'.format(model_id),
            feed_dict={'x_test:0':X_test})

    output_list.append(output)
    acc_list.append(compute_acc_from_logits(output, y_test, NUM_CLASSES))
    pred_list.append(np.argmax(output, axis=1))
    
model_preds = np.array(pred_list).transpose(1,0)
baseline_corr, baseline_strength = compute_correlation_strength(model_preds, y_labels, 
                                                                NUM_CLASSES, 
                                                                len(model_id_list))

print('Mean correlation:', baseline_corr)
print('Strength:' , baseline_strength)

In [None]:
shared_frac_list = [0, 0.25, 0.5, 0.75, 1]
shared_correlation_list = []
shared_strength_list = []

for shared in shared_frac_list:
    mean_correlation_list = []
    strength_list = []
    
    for model_id in range(1, 5):
        model_path = get_vb_model_path(DATASET, ARCHITECTURE, NUM_BRANCHES, shared, 
                                       NUM_CLASSES, SAMPLES_PER_CLASS, model_id)

        tensors = []
        for i in range(NUM_BRANCHES):
            if shared == 0:
                t = 'model_{}_1/output_vb{}:0'.format(model_id, i+1)
            else:
                t = 'model_{}_1/output_{}:0'.format(model_id, i+1)
            tensors.append(t)

        tf.reset_default_graph()

        with tf.Session() as sess:
            restore_sess(sess, model_path)
            feed_dict = feed_dict={'x_test:0': X_test, 'y_test:0': y_test}
            outputs, acc = sess.run([tensors, 'acc_ensemble_1:0'], feed_dict=feed_dict)
            
        preds = np.array([np.argmax(x, axis=1) for x in outputs]).transpose(1,0)
        mean_correlation, strength = compute_correlation_strength(preds, y_labels, NUM_CLASSES, NUM_BRANCHES)
        
        mean_correlation_list.append(mean_correlation)
        strength_list.append(strength)
        
    shared_correlation_list.append([np.mean(mean_correlation_list), np.std(mean_correlation_list)])
    shared_strength_list.append([np.mean(strength_list), np.std(strength_list)])

In [None]:
plt.plot(shared_frac_list, np.array(shared_correlation_list)[:, 0], label='correlation')
plt.plot(shared_frac_list, np.array(shared_strength_list)[:, 0], label='strength')

# Baseline
plt.plot(shared_frac_list, [baseline_corr]* len(shared_correlation_list), 
         label='baseline correlation', linestyle='--')
plt.plot(shared_frac_list, [baseline_strength]* len(shared_correlation_list), 
         label='baseline strength', linestyle='--')

plt.xlabel('shared frac')
plt.title('Correlation and Strength')
plt.legend()

plt.savefig('figs/correlation-strength.png')
plt.show()