### Training BNN using SAEM

In [16]:
import warnings
warnings.simplefilter(action="ignore")

In [17]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import warnings

# Dependency imports
from absl import app
from absl import flags
import matplotlib
matplotlib.use('Agg')
from matplotlib import figure  # pylint: disable=g-import-not-at-top
from matplotlib.backends import backend_agg
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

tf.enable_v2_behavior()

In [18]:
try:
    import seaborn as sns  # pylint: disable=g-import-not-at-top
    HAS_SEABORN = True
except ImportError:
    HAS_SEABORN = False

tfd = tfp.distributions

IMAGE_SHAPE = [28, 28, 1]
NUM_TRAIN_EXAMPLES = 60000
NUM_HELDOUT_EXAMPLES = 10000
NUM_CLASSES = 10

In [24]:
def plot_weight_posteriors(names, qm_vals, qs_vals, fname):
  """Save a PNG plot with histograms of weight means and stddevs.

  Args:
    names: A Python `iterable` of `str` variable names.
      qm_vals: A Python `iterable`, the same length as `names`,
      whose elements are Numpy `array`s, of any shape, containing
      posterior means of weight varibles.
    qs_vals: A Python `iterable`, the same length as `names`,
      whose elements are Numpy `array`s, of any shape, containing
      posterior standard deviations of weight varibles.
    fname: Python `str` filename to save the plot to.
  """
  fig = figure.Figure(figsize=(6, 3))
  canvas = backend_agg.FigureCanvasAgg(fig)

  ax = fig.add_subplot(1, 2, 1)
  for n, qm in zip(names, qm_vals):
    sns.distplot(tf.reshape(qm, shape=[-1]), ax=ax, label=n)
  ax.set_title('weight means')
  ax.set_xlim([-1.5, 1.5])
  ax.legend()

  ax = fig.add_subplot(1, 2, 2)
  for n, qs in zip(names, qs_vals):
    sns.distplot(tf.reshape(qs, shape=[-1]), ax=ax)
  ax.set_title('weight stddevs')
  ax.set_xlim([0, 1.])

  fig.tight_layout()
  canvas.print_figure(fname, format='png')
  print('saved {}'.format(fname))


def plot_heldout_prediction(input_vals, probs,
                            fname, n=10, title=''):
  """Save a PNG plot visualizing posterior uncertainty on heldout data.

  Args:
    input_vals: A `float`-like Numpy `array` of shape
      `[num_heldout] + IMAGE_SHAPE`, containing heldout input images.
    probs: A `float`-like Numpy array of shape `[num_monte_carlo,
      num_heldout, num_classes]` containing Monte Carlo samples of
      class probabilities for each heldout sample.
    fname: Python `str` filename to save the plot to.
    n: Python `int` number of datapoints to vizualize.
    title: Python `str` title for the plot.
  """
  fig = figure.Figure(figsize=(9, 3*n))
  canvas = backend_agg.FigureCanvasAgg(fig)
  for i in range(n):
    ax = fig.add_subplot(n, 3, 3*i + 1)
    ax.imshow(input_vals[i, :].reshape(IMAGE_SHAPE[:-1]), interpolation='None')

    ax = fig.add_subplot(n, 3, 3*i + 2)
    for prob_sample in probs:
      sns.barplot(np.arange(10), prob_sample[i, :], alpha=0.1, ax=ax)
      ax.set_ylim([0, 1])
    ax.set_title('posterior samples')

    ax = fig.add_subplot(n, 3, 3*i + 3)
    sns.barplot(np.arange(10), tf.reduce_mean(probs[:, i, :], axis=0), ax=ax)
    ax.set_ylim([0, 1])
    ax.set_title('predictive probs')
  fig.suptitle(title)
  fig.tight_layout()

  canvas.print_figure(fname, format='png')
  print('saved {}'.format(fname))


def create_model():
  """Creates a Keras model using the LeNet-5 architecture.

  Returns:
      model: Compiled Keras model.
  """
  # KL divergence weighted by the number of training samples, using
  # lambda function to pass as input to the kernel_divergence_fn on
  # flipout layers.
  kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) /  # pylint: disable=g-long-lambda
                            tf.cast(NUM_TRAIN_EXAMPLES, dtype=tf.float32))

  # Define a LeNet-5 model using three convolutional (with max pooling)
  # and two fully connected dense layers. We use the Flipout
  # Monte Carlo estimator for these layers, which enables lower variance
  # stochastic gradients than naive reparameterization.
  model = tf.keras.models.Sequential([
      tfp.layers.Convolution2DFlipout(
          6, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(
          pool_size=[2, 2], strides=[2, 2],
          padding='SAME'),
      tfp.layers.Convolution2DFlipout(
          16, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(
          pool_size=[2, 2], strides=[2, 2],
          padding='SAME'),
      tfp.layers.Convolution2DFlipout(
          120, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.Flatten(),
      tfp.layers.DenseFlipout(
          84, kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tfp.layers.DenseFlipout(
          NUM_CLASSES, kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.softmax)
  ])

  # Model compilation.
  optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
  # We use the categorical_crossentropy loss since the MNIST dataset contains
  # ten labels. The Keras API will then automatically add the
  # Kullback-Leibler divergence (contained on the individual layers of
  # the model), to the cross entropy loss, effectively
  # calcuating the (negated) Evidence Lower Bound Loss (ELBO)
  model.compile(optimizer, loss='categorical_crossentropy',
                metrics=['accuracy'], experimental_run_tf_function=False)
  return model

In [25]:
from mnist import *

In [26]:
batch_size = 128
model_dir = 'bayesian_neural_network/'
learning_rate = 0.001

In [27]:
tf.io.gfile.makedirs(model_dir)
train_set, heldout_set = tf.keras.datasets.mnist.load_data()
train_seq = MNISTSequence(data=train_set, batch_size=batch_size)
heldout_seq = MNISTSequence(data=heldout_set, batch_size=batch_size)

model = create_model()
# TODO(b/149259388): understand why Keras does not automatically build the
# model correctly.
model.build(input_shape=[None, 28, 28, 1])

Instructions for updating:
Please use `layer.add_weight` method instead.


In [29]:
num_epochs= 2
num_monte_carlo=10
viz_steps=100

In [30]:
for epoch in range(num_epochs):
  epoch_accuracy, epoch_loss = [], []
  for step, (batch_x, batch_y) in enumerate(train_seq):
    batch_loss, batch_accuracy = model.train_on_batch(
        batch_x, batch_y)
    epoch_accuracy.append(batch_accuracy)
    epoch_loss.append(batch_loss)

    if step % 100 == 0:
      print('Epoch: {}, Batch index: {}, '
            'Loss: {:.3f}, Accuracy: {:.3f}'.format(
                epoch, step,
                tf.reduce_mean(epoch_loss),
                tf.reduce_mean(epoch_accuracy)))

    if (step+1) % viz_steps == 0:
      # Compute log prob of heldout set by averaging draws from the model:
      # p(heldout | train) = int_model p(heldout|model) p(model|train)
      #                   ~= 1/n * sum_{i=1}^n p(heldout | model_i)
      # where model_i is a draw from the posterior p(model|train).
      print(' ... Running monte carlo inference')
      probs = tf.stack([model.predict(heldout_seq, verbose=1)
                        for _ in range(num_monte_carlo)], axis=0)
      mean_probs = tf.reduce_mean(probs, axis=0)
      heldout_log_prob = tf.reduce_mean(tf.math.log(mean_probs))
      print(' ... Held-out nats: {:.3f}'.format(heldout_log_prob))

      if HAS_SEABORN:
        names = [layer.name for layer in model.layers
                 if 'flipout' in layer.name]
        qm_vals = [layer.kernel_posterior.mean()
                   for layer in model.layers
                   if 'flipout' in layer.name]
        qs_vals = [layer.kernel_posterior.stddev()
                   for layer in model.layers
                   if 'flipout' in layer.name]
        plot_weight_posteriors(names, qm_vals, qs_vals,
                               fname=os.path.join(
                                   model_dir,
                                   'epoch{}_step{:05d}_weights.png'.format(
                                       epoch, step)))
        plot_heldout_prediction(heldout_seq.images, probs,
                                fname=os.path.join(
                                    model_dir,
                                    'epoch{}_step{}_pred.png'.format(
                                        epoch, step)),
                                title='mean heldout logprob {:.2f}'
                                .format(heldout_log_prob))

Epoch: 0, Batch index: 0, Loss: 27.690, Accuracy: 0.109
 ... Running monte carlo inference
 ... Held-out nats: -5.500
saved bayesian_neural_network/epoch0_step00099_weights.png
saved bayesian_neural_network/epoch0_step99_pred.png
Epoch: 0, Batch index: 100, Loss: 22.153, Accuracy: 0.537
 ... Running monte carlo inference
 ... Held-out nats: -6.769
saved bayesian_neural_network/epoch0_step00199_weights.png
saved bayesian_neural_network/epoch0_step199_pred.png
Epoch: 0, Batch index: 200, Loss: 21.465, Accuracy: 0.692
 ... Running monte carlo inference

KeyboardInterrupt: 