In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import warnings

# Dependency imports
from absl import flags
import matplotlib
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp

from models.bayesian_resnet import bayesian_resnet
from models.bayesian_vgg import bayesian_vgg

matplotlib.use("Agg")
warnings.simplefilter(action="ignore")
tfd = tfp.distributions

IMAGE_SHAPE = [32, 32, 3]

In [2]:
def build_input_pipeline(x_train, x_test, y_train, y_test,
                         batch_size, valid_size):
  """Build an Iterator switching between train and heldout data."""
  x_train = x_train.astype("float32")
  x_test = x_test.astype("float32")

  x_train /= 255
  x_test /= 255

  y_train = y_train.flatten()
  y_test = y_test.flatten()

  if subtract_pixel_mean:
    x_train_mean = np.mean(x_train, axis=0)
    x_train -= x_train_mean
    x_test -= x_train_mean

  print("x_train shape:" + str(x_train.shape))
  print(str(x_train.shape[0]) + " train samples")
  print(str(x_test.shape[0]) + " test samples")

  # Build an iterator over training batches.
  training_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, np.int32(y_train)))
  training_batches = training_dataset.shuffle(
      50000, reshuffle_each_iteration=True).repeat().batch(batch_size)
  training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)

  # Build a iterator over the heldout set with batch_size=heldout_size,
  # i.e., return the entire heldout set as a constant.
  heldout_dataset = tf.data.Dataset.from_tensor_slices(
      (x_test, np.int32(y_test)))
  heldout_batches = heldout_dataset.repeat().batch(valid_size)
  heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_batches)

  # Combine these into a feedable iterator that can switch between training
  # and validation inputs.
  handle = tf.compat.v1.placeholder(tf.string, shape=[])
  feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
      handle, training_batches.output_types, training_batches.output_shapes)
  images, labels = feedable_iterator.get_next()

  return images, labels, handle, training_iterator, heldout_iterator

In [17]:
def build_fake_data():
  """Build fake CIFAR10-style data for unit testing."""
  num_examples = 10
  x_train = np.random.rand(num_examples, *IMAGE_SHAPE).astype(np.float32)
  y_train = np.random.permutation(np.arange(num_examples)).astype(np.int32)
  x_test = np.random.rand(num_examples, *IMAGE_SHAPE).astype(np.float32)
  y_test = np.random.permutation(np.arange(num_examples)).astype(np.int32)
  return (x_train, y_train), (x_test, y_test)

In [3]:
model_dir = "bnnmodels/"

### Hyperparams

In [49]:
#Generate fake data for now before switching to CIFAR10
fake_data = True
batch_size = 128
learning_rate = 0.0001
epochs = 10
data_dir = "data/"
eval_freq = 400
num_monte_carlo = 50
architecture = "resnet" # or "vgg"
kernel_posterior_scale_mean = 0.9
kernel_posterior_scale_constraint = 0.2
kl_annealing = 50
subtract_pixel_mean = True

In [50]:
nb_runs = 1
seed0 = 23456

In [51]:
with tf.Session() as sess:
    if fake_data:
        (x_train, y_train), (x_test, y_test) = build_fake_data()
    else:
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    (images, labels, handle,
     training_iterator,
     heldout_iterator) = build_input_pipeline(x_train, x_test, y_train, y_test,
                                              batch_size, 500)

x_train shape:(10, 32, 32, 3)
10 train samples
10 test samples


In [56]:
def run_experiment(algo,fake_data, batch_size, epochs, verbose):
    with tf.Session() as sess:
        
        model_fn = bayesian_resnet
        model = model_fn(
            IMAGE_SHAPE,
            num_classes=10,
            kernel_posterior_scale_mean=kernel_posterior_scale_mean,
            kernel_posterior_scale_constraint=kernel_posterior_scale_constraint)
        logits = model(images)
        labels_distribution = tfd.Categorical(logits=logits)
        t = tf.compat.v2.Variable(0.0)
        kl_regularizer = t / (kl_annealing * len(x_train) / batch_size)

        log_likelihood = labels_distribution.log_prob(labels)
        neg_log_likelihood = -tf.reduce_mean(input_tensor=log_likelihood)
        kl = sum(model.losses) / len(x_train) * tf.minimum(1.0, kl_regularizer)
        loss = neg_log_likelihood + kl

        predictions = tf.argmax(input=logits, axis=1)

        with tf.compat.v1.name_scope("train"):
          train_accuracy, train_accuracy_update_op = tf.compat.v1.metrics.accuracy(
              labels=labels, predictions=predictions)
        if algo=="adam":
          opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
        if algo=="bbb":
          opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
        if algo=="misso":
          opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)
        if algo=="momentum":
          opt = tf.compat.v1.train.MomentumOptimizer(learning_rate=learning_rate, momentum=5e-5)

        train_op = opt.minimize(loss)
        update_step_op = tf.compat.v1.assign(t, t + 1)

        with tf.compat.v1.name_scope("valid"):
          valid_accuracy, valid_accuracy_update_op = tf.compat.v1.metrics.accuracy(
              labels=labels, predictions=predictions)

        init_op = tf.group(tf.compat.v1.global_variables_initializer(),
                           tf.compat.v1.local_variables_initializer())

        stream_vars_valid = [
            v for v in tf.compat.v1.local_variables() if "valid/" in v.name
        ]
        reset_valid_op = tf.compat.v1.variables_initializer(stream_vars_valid)
    
    with tf.compat.v1.Session() as sess:
        sess.run(init_op)

        # Run the training loop
        train_handle = sess.run(training_iterator.string_handle())
        heldout_handle = sess.run(heldout_iterator.string_handle())
        training_steps = int(
          round(epochs * (len(x_train) / batch_size)))
        listloss = []
        listaccuracy = []
        print(training_steps)
        for step in range(training_steps):
            _ = sess.run([train_op,
                      train_accuracy_update_op,
                      update_step_op],
                     feed_dict={handle: train_handle})
            # Print loss values
            loss_value, accuracy_value, kl_value = sess.run(
                  [loss, train_accuracy, kl], feed_dict={handle: train_handle})
            print(
                  "Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f} KL: {:.3f}".format(
                      step, loss_value, accuracy_value, kl_value))
            listloss.append(loss_value)
            listaccuracy.append(accuracy_value)

            if (step + 1) % eval_freq == 0:
              # Compute log prob of heldout set by averaging draws from the model:
              # p(heldout | train) = int_model p(heldout|model) p(model|train)
              #                   ~= 1/n * sum_{i=1}^n p(heldout | model_i)
              # where model_i is a draw from the posterior
              # p(model|train).
              probs = np.asarray([sess.run((labels_distribution.probs),
                                           feed_dict={handle: heldout_handle})
                                  for _ in range(num_monte_carlo)])
              mean_probs = np.mean(probs, axis=0)

              _, label_vals = sess.run(
                  (images, labels), feed_dict={handle: heldout_handle})
              heldout_lp = np.mean(np.log(mean_probs[np.arange(mean_probs.shape[0]),
                                                     label_vals.flatten()]))
              print(" ... Held-out nats: {:.3f}".format(heldout_lp))

          # Calculate validation accuracy
          #for _ in range(20):
           # sess.run(
            #    valid_accuracy_update_op, feed_dict={handle: heldout_handle})
          #valid_value = sess.run(
           #   valid_accuracy, feed_dict={handle: heldout_handle})

    #      print(" ... Validation Accuracy: {:.3f}".format(valid_value))
            print(loss_value)
        sess.run(reset_valid_op)
    return listloss
    

### ADAM

In [57]:
adam = []
for run in range(nb_runs):
    tf.random.set_random_seed(_*seed0)
    run = run_experiment(algo='adam', fake_data=fake_data, batch_size = batch_size, epochs=epochs, verbose= True)
    adam.append(run)

1
Step:   0 Loss: 309593.406 Accuracy: 0.117 KL: 153682.656
309593.4


### BBB

In [60]:
bbb = []
for run in range(nb_runs):
    tf.random.set_random_seed(_*seed0)
    run = run_experiment(algo='bbb', fake_data=fake_data, batch_size = batch_size, epochs=epochs, verbose= True)
    bbb.append(run)

1


KeyboardInterrupt: 

### Momentum

In [None]:
momentum = []
for run in range(nb_runs):
    tf.random.set_random_seed(_*seed0)
    run = run_experiment(algo='momentum', fake_data=fake_data, batch_size = batch_size, epochs=epochs, verbose= True)
    momentum.append(run)

In [None]:
##SAVE LOSSES
with open('losses/adam', 'wb') as fp: 
    pickle.dump(adam, fp)
with open('losses/momentum', 'wb') as fp: 
    pickle.dump(momentum, fp)
with open('losses/bbb', 'wb') as fp: 
    pickle.dump(bbb, fp)

In [None]:
##LOAD LOSSES
with open ('losses/bbbloss', 'rb') as fp:
    bbb = pickle.load(fp)
with open ('losses/missoloss', 'rb') as fp:
    misso = pickle.load(fp)
with open ('losses/adamloss', 'rb') as fp:
    adam = pickle.load(fp)