In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import warnings

# Dependency imports
from absl import flags
import matplotlib
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp

from models.bayesian_resnet import bayesian_resnet
from models.bayesian_vgg import bayesian_vgg

matplotlib.use("Agg")
warnings.simplefilter(action="ignore")
tfd = tfp.distributions

IMAGE_SHAPE = [32, 32, 3]

In [2]:
def build_input_pipeline(x_train, x_test, y_train, y_test,
                         batch_size, valid_size):
  """Build an Iterator switching between train and heldout data."""
  x_train = x_train.astype("float32")
  x_test = x_test.astype("float32")

  x_train /= 255
  x_test /= 255

  y_train = y_train.flatten()
  y_test = y_test.flatten()

  if subtract_pixel_mean:
    x_train_mean = np.mean(x_train, axis=0)
    x_train -= x_train_mean
    x_test -= x_train_mean

  print("x_train shape:" + str(x_train.shape))
  print(str(x_train.shape[0]) + " train samples")
  print(str(x_test.shape[0]) + " test samples")

  # Build an iterator over training batches.
  training_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, np.int32(y_train)))
  training_batches = training_dataset.shuffle(
      50000, reshuffle_each_iteration=True).repeat().batch(batch_size)
  training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)

  # Build a iterator over the heldout set with batch_size=heldout_size,
  # i.e., return the entire heldout set as a constant.
  heldout_dataset = tf.data.Dataset.from_tensor_slices(
      (x_test, np.int32(y_test)))
  heldout_batches = heldout_dataset.repeat().batch(valid_size)
  heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_batches)

  # Combine these into a feedable iterator that can switch between training
  # and validation inputs.
  handle = tf.compat.v1.placeholder(tf.string, shape=[])
  feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
      handle, training_batches.output_types, training_batches.output_shapes)
  images, labels = feedable_iterator.get_next()

  return images, labels, handle, training_iterator, heldout_iterator

In [3]:
def build_fake_data():
  """Build fake CIFAR10-style data for unit testing."""
  num_examples = 10
  x_train = np.random.rand(num_examples, *IMAGE_SHAPE).astype(np.float32)
  y_train = np.random.permutation(np.arange(num_examples)).astype(np.int32)
  x_test = np.random.rand(num_examples, *IMAGE_SHAPE).astype(np.float32)
  y_test = np.random.permutation(np.arange(num_examples)).astype(np.int32)
  return (x_train, y_train), (x_test, y_test)

In [4]:
model_dir = "bnnmodels/"
loss_dir = 'lossesFAKE/'

### Hyperparams

In [6]:
#Generate fake data for now before switching to CIFAR10
fake_data = True
learning_rate = 0.001
data_dir = "data/"
eval_freq = 400
num_monte_carlo = 50
architecture = "resnet" # or "vgg"
kernel_posterior_scale_mean = 0.9
kernel_posterior_scale_constraint = 0.2
kl_annealing = 50
subtract_pixel_mean = True

In [7]:
batch_size = 1

In [8]:
with tf.Session() as sess:
    if fake_data:
        (x_train, y_train), (x_test, y_test) = build_fake_data()
    else:
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    (images, labels, handle,
     training_iterator,
     heldout_iterator) = build_input_pipeline(x_train, x_test, y_train, y_test,
                                              batch_size, 500)

x_train shape:(10, 32, 32, 3)
10 train samples
10 test samples
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(dataset)`.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(dataset)`.


In [9]:
total = x_train.shape[0]

In [10]:
from IPython.core.debugger import set_trace

In [11]:
def run_experiment_misso(algo,fake_data, batch_size, epochs, verbose):
    with tf.Session() as sess:
            
        #instanciate model (bayesian resnet)
        model_fn = bayesian_resnet
        model = model_fn(
            IMAGE_SHAPE,
            num_classes=10,
            kernel_posterior_scale_mean=kernel_posterior_scale_mean,
            kernel_posterior_scale_constraint=kernel_posterior_scale_constraint)
        logits = model(images)
        labels_distribution = tfd.Categorical(logits=logits)
        t = tf.compat.v2.Variable(0.0)
        kl_regularizer = t / (kl_annealing * len(x_train) / batch_size)

        log_likelihood = labels_distribution.log_prob(labels)
        neg_log_likelihood = -tf.reduce_mean(input_tensor=log_likelihood)
        kl = sum(model.losses) / len(x_train) * tf.minimum(1.0, kl_regularizer)
        loss = neg_log_likelihood + kl

        predictions = tf.argmax(input=logits, axis=1)

        with tf.compat.v1.name_scope("train"):
            train_accuracy, train_accuracy_update_op = tf.compat.v1.metrics.accuracy(
                labels=labels, predictions=predictions)
            opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
        
            train_op = opt.minimize(loss)
            update_step_op = tf.compat.v1.assign(t, t + 1)

        with tf.compat.v1.name_scope("valid"):
            valid_accuracy, valid_accuracy_update_op = tf.compat.v1.metrics.accuracy(
                labels=labels, predictions=predictions)

            init_op = tf.group(tf.compat.v1.global_variables_initializer(),
                           tf.compat.v1.local_variables_initializer())

            stream_vars_valid = [v for v in tf.compat.v1.local_variables() if "valid/" in v.name]
            reset_valid_op = tf.compat.v1.variables_initializer(stream_vars_valid)
    
    
        sess.run(init_op)
        indivgrads = []
        indivvar = []
        # Run the training loop
        train_handle = sess.run(training_iterator.string_handle())
        heldout_handle = sess.run(heldout_iterator.string_handle())
        training_steps = int(
          round(epochs * (len(x_train) / batch_size)))
        listloss = []
        listaccuracy = []
        print(training_steps)
        for indiv in range(0,total,batch_size):
            print(indiv)
            gradss = tf.gradients(loss, tf.trainable_variables())
            grads = [x for x in gradss if x is not None]
            #set_trace()
            var_updates = []
            var_list = tf.trainable_variables()
            for grad, var in zip(grads, var_list):
                var_updates.append(var.assign_sub(0.001 * grad))
            train_op = tf.group(*var_updates)
            indivgrads.append(grads)
            indivvar.append(var_list)
#        for step in range(training_steps):
        for epoch in range(epochs):
            for step in range(0,int(total/batch_size)):
                gradss = tf.gradients(loss, tf.trainable_variables())
                grads = [x for x in gradss if x is not None]
                indivgrads[step] = grads
                var_updates = []
                var_list = tf.trainable_variables()
                for gradstemp, varlist in zip(indivgrads, indivvar):
                    for grad, var in zip(gradstemp, varlist):
                        var_updates.append(var.assign_sub(0.001 * grad)) #\theta^{\tau_i^k} - \grad f_{\tau_i^k}
                _ = sess.run([train_op,
                              train_accuracy_update_op,
                              update_step_op],
                              feed_dict={handle: train_handle})
                # Print loss values
                loss_value, accuracy_value, kl_value = sess.run(
                  [loss, train_accuracy, kl], feed_dict={handle: train_handle})
                if step % 100 == 0:
                    print("Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f} KL: {:.3f}".format(
                      step, loss_value, accuracy_value, kl_value))
                listloss.append(loss_value)
                listaccuracy.append(accuracy_value)
                
        sess.run(reset_valid_op)
    return listloss
    

### MISSO

In [12]:
epochs = 60
batch_size = 1
nb_runs = 3
seed0 = 23456

In [None]:
misso_loss = []
misso_kl = []
for _ in range(nb_runs):
    print("Run Number: {:.0f}".format(_))
    tf.random.set_random_seed(_*seed0)
    loss, kl = run_experiment_misso(algo='misso', fake_data=fake_data, batch_size = batch_size, epochs=epochs, verbose= True)
    
    misso_loss.append(loss)
    misso_kl.append(kl)

In [None]:
##SAVE LOSSES
with open(loss_dir+'missoloss', 'wb') as fp: 
    pickle.dump(misso_loss, fp)

In [None]:
with open(loss_dir+'missokl', 'wb') as fp: 
    pickle.dump(misso_kl, fp)

## PLOTS

In [14]:
%matplotlib inline
import matplotlib.pyplot as plt
import pylab

In [15]:
def tsplotseveral(x, y, n=20, percentile_min=1, percentile_max=99, color='r', plot_mean=True, plot_median=False, line_color='k', **kwargs):
    line_colors=['r','b','g','y','black']
    colors=['r','b','g','y','black']
    labels= ['ADAM','ADAGRAD','ADADELTA','RMSPROP','SAG']
    i = 0
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(16, 3.5))
    axes.set_facecolor('white')
    axes.grid(linestyle='-', linewidth='0.2', color='grey')
    axes.spines['bottom'].set_color('black')
    axes.spines['top'].set_color('black') 
    axes.spines['right'].set_color('black')
    axes.spines['left'].set_color('black')
    
    for element in y:
      perc1 = np.percentile(element, np.linspace(percentile_min, 50, num=n, endpoint=False), axis=0)
      perc2 = np.percentile(element, np.linspace(50, percentile_max, num=n+1)[1:], axis=0)


      if 'alpha' in kwargs:
          alpha = kwargs.pop('alpha')
      else:
          alpha = 1/n
      alpha = 0.005
      # fill lower and upper percentile groups
      for p1, p2 in zip(perc1, perc2):
          plt.fill_between(x, p1, p2, alpha=alpha, color=colors[i], edgecolor=None)


      if plot_mean:
          plt.plot(x, np.mean(element, axis=0), color=line_colors[i],label=labels[i])


      if plot_median:
          plt.plot(x, np.median(element, axis=0), color=line_colors[i],label=labels[i])
      i += 1
    leg = plt.legend(fontsize=18,fancybox=True, loc=0,ncol=3)
    leg.get_frame().set_alpha(0.5)
    plt.xlabel('Epoch', fontsize=20)
    plt.ylabel('ELBO', fontsize=20)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.grid(linestyle='dotted',linewidth=2)
    pylab.ticklabel_format(axis='y',style='sci',scilimits=(1,4))
    fig.tight_layout()
    return plt.gca()

In [None]:
iterations = len(misso_loss[0])
iterations
itera = np.linspace(0,iterations,iteratio"ns)

In [None]:
toplotloss = [misso_loss]
toplotkl = [misso_kl]
tsplotseveral(itera,toplotloss, n=100)