In [None]:
import numpy as np
import gzip
import sys

import mxfusion as mf
import mxnet as mx

import logging
logging.getLogger().setLevel(logging.DEBUG)  # logging to stdout

# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()

In [None]:
class SplitMnistGenerator:
    def __init__(self, data, batch_size):
        self.data = data
        self.batch_size = batch_size

    def __iter__(self):
        for i in range(5):
            idx_train_0 = np.where(self.data['train_label'] == i * 2)[0]
            idx_train_1 = np.where(self.data['train_label'] == i * 2 + 1)[0]
            idx_test_0 = np.where(self.data['test_label'] == i * 2)[0]
            idx_test_1 = np.where(self.data['test_label'] == i * 2 + 1)[0]
            
            x_train = np.vstack((self.data['train_data'][idx_train_0], self.data['train_data'][idx_train_1]))
            y_train = np.vstack((np.ones((idx_train_0.shape[0], 1)), -np.ones((idx_train_1.shape[0], 1))))

            x_test = np.vstack((self.data['test_data'][idx_test_0], self.data['test_data'][idx_test_1]))
            y_test = np.vstack((np.ones((idx_test_0.shape[0], 1)), -np.ones((idx_test_1.shape[0], 1))))
            
            batch_size = x_train.shape[0] if self.batch_size is None else self.batch_size            
            train_iter = mx.io.NDArrayIter(x_train, y_train, batch_size, shuffle=True)

            batch_size = x_test.shape[0] if self.batch_size is None else self.batch_size            
            test_iter = mx.io.NDArrayIter(x_test, y_test, batch_size)
            
            yield train_iter, test_iter
        return

mnist = mx.test_utils.get_mnist()
in_dim = np.prod(mnist['train_data'][0].shape)

gen = SplitMnistGenerator(mnist, batch_size=None)
for task_id, (train, test) in enumerate(gen):
    print("Task", task_id)
    print("Train data shape" ,train.data[0][1].shape)
    print("Train label shape" ,train.label[0][1].shape)
    print("Test data shape" ,test.data[0][1].shape)
    print("Test label shape" ,test.label[0][1].shape)
    print()

In [None]:
def rand_from_batch(x_coreset, y_coreset, x_train, y_train, coreset_size):
    """ Random coreset selection """
    # Randomly select from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)
    idx = np.random.choice(x_train.shape[0], coreset_size, False)
    x_coreset.append(x_train[idx,:])
    y_coreset.append(y_train[idx,:])
    x_train = np.delete(x_train, idx, axis=0)
    y_train = np.delete(y_train, idx, axis=0)
    return x_coreset, y_coreset, x_train, y_train    

def k_center(x_coreset, y_coreset, x_train, y_train, coreset_size):
    """ K-center coreset selection """
    # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)
    dists = np.full(x_train.shape[0], np.inf)
    current_id = 0
    dists = update_distance(dists, x_train, current_id)
    idx = [ current_id ]

    for i in range(1, coreset_size):
        current_id = np.argmax(dists)
        dists = update_distance(dists, x_train, current_id)
        idx.append(current_id)

    x_coreset.append(x_train[idx,:])
    y_coreset.append(y_train[idx,:])
    x_train = np.delete(x_train, idx, axis=0)
    y_train = np.delete(y_train, idx, axis=0)
    return x_coreset, y_coreset, x_train, y_train

def update_distance(dists, x_train, current_id):
    for i in range(x_train.shape[0]):
        current_dist = np.linalg.norm(x_train[i,:]-x_train[current_id,:])
        dists[i] = np.minimum(current_dist, dists[i])
    return dists

In [None]:
def run_vcl(network_shape, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True):
    x_coresets, y_coresets = [], []
    x_testsets, y_testsets = [], []

    all_acc = np.array([])

    for task_id, (train, test) in enumerate(data_gen):
        x_testsets.append(test.data[0][1])
        y_testsets.append(test.label[0][1])

        # Set the readout head to train
        head = 0 if single_head else task_id
        # bsize = x_train.shape[0] if (batch_size is None) else batch_size

        # Train network with maximum likelihood to initialize first model
        if task_id == 0:
            ml_model = VanillaNN(network_shape)
            ml_model.train(x_train, y_train, task_id, no_epochs, bsize)
            mf_weights = ml_model.get_weights()
            mf_variances = None
            ml_model.close_session()

        # Select coreset if needed
        if coreset_size > 0:
            x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size)

        # Train on non-coreset data
        mf_model = MFVINN(network_shape, prev_means=mf_weights, prev_log_variances=mf_variances)
        mf_model.train(x_train, y_train, head, no_epochs, bsize)
        mf_weights, mf_variances = mf_model.get_weights()

        # Incorporate coreset data and make prediction
        acc = utils.get_scores(mf_model, x_testsets, y_testsets, x_coresets, y_coresets, hidden_size, no_epochs, single_head, batch_size)
        all_acc = utils.concatenate_results(acc, all_acc)

        mf_model.close_session()

    return all_acc

In [None]:
class BaseNN:
    def __init__(self, network_shape):
        # input and output placeholders
        self.task_idx = mx.sym.Variable(name='task_idx', dtype=np.float32)
        self.net = None
        
    def train(self, train_iter, val_iter, ctx):
        #         data = mx.sym.var('data')
        # Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
#         data = mx.sym.flatten(data=data)
        
        # create a trainable module on compute context
        self.model = mx.mod.Module(symbol=self.net, context=ctx)
        self.model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
        init = mx.init.Xavier(factor_type="in", magnitude=2.34)
        self.model.init_params(initializer=init, force_init=True)
        self.model.fit(train_iter,  # train data
            eval_data=val_iter,  # validation data
            optimizer='adam',  # use SGD to train
            optimizer_params={'learning_rate': 0.001},  # use fixed learning rate
            eval_metric='acc',  # report accuracy during training
            batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
            num_epoch=10)  # train for at most 50 dataset passes
        # predict accuracy of mlp
        acc = mx.metric.Accuracy()
        self.model.score(test_iter, acc)
        return acc

    def prediction_prob(self, test_iter, task_idx):
        # task_idx??
        prob = self.model.predict(test_iter)
        return prob

def log_loss(output, y):
    yhat = logistic(output)
    return  - nd.nansum(y * nd.log(yhat) + (1 - y) * nd.log(1 - yhat))
    
class VanillaNN(BaseNN):
    def __init__(self, network_shape, prev_weights=None, learning_rate=0.001):
        super(VanillaNN, self).__init__(network_shape)

        # Create net
        net = mx.gluon.nn.HybridSequential(prefix='vanilla_')
        with net.name_scope():
            for layer in network_shape[1:-1]:
                net.add(mx.gluon.nn.Dense(layer, activation="relu"))
            # Last layer for classification
            net.add(mx.gluon.nn.Dense(network_shape[-1], flatten=True, in_units=network_shape[-2]))
        
        net.initialize(mx.init.Xavier(magnitude=2.34))
        
#         for layer in network_shape[1:-1]:
#             fc  = mx.sym.FullyConnected(data=data, num_hidden=layer)
#             act = mx.sym.Activation(data=fc, act_type="relu")
#             data = act

#         # Last layer is 1D for binary classifiers
#         fc = mx.sym.FullyConnected(data=act, num_hidden=network_shape[-1])
#         act = mx.sym.Activation(data=fc, act_type="relu")
#         # Log loss
#         self.net = log_loss(act)

In [None]:
# Hyperparameters
network_shape = (in_dim, 256, 256, 2)  # binary classification
batch_size = None
no_epochs = 120
single_head = False

In [None]:
# Run vanilla VCL
mx.random.seed(42)
np.random.seed(42)

coreset_size = 0
data_gen = SplitMnistGenerator(mnist, batch_size)
vcl_result = run_vcl(network_shape, no_epochs, data_gen, rand_from_batch, coreset_size, batch_size, single_head)
print(vcl_result)

In [None]:
# Run random coreset VCL
mx.random.seed(42)
np.random.seed(42)

coreset_size = 40
data_gen = SplitMnistGenerator(mnist, batch_size)
rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 
    coreset.rand_from_batch, coreset_size, batch_size, single_head)
print(rand_vcl_result)

In [None]:
# Run k-center coreset VCL
mx.random.seed(42)
np.random.seed(42)

data_gen = SplitMnistGenerator(mnist, batch_size)
kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 
    coreset.k_center, coreset_size, batch_size, single_head)
print(kcen_vcl_result)

In [None]:
# Plot average accuracy
vcl_avg = np.nanmean(vcl_result, 1)
rand_vcl_avg = np.nanmean(rand_vcl_result, 1)
kcen_vcl_avg = np.nanmean(kcen_vcl_result, 1)
utils.plot('results/split.jpg', vcl_avg, rand_vcl_avg, kcen_vcl_avg)