# MNIST VIB Example

Here I demonstrate the Variational Information Bottleneck method on the MNIST dataset.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf

In [None]:
tf.reset_default_graph()

# Turn on xla optimization
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.InteractiveSession(config=config)

In [3]:
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('/tmp/mnistdata', validation_size=0)

Extracting /tmp/mnistdata/train-images-idx3-ubyte.gz
Extracting /tmp/mnistdata/train-labels-idx1-ubyte.gz
Extracting /tmp/mnistdata/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnistdata/t10k-labels-idx1-ubyte.gz


In [4]:
images = tf.placeholder(tf.float32, [None, 784], 'images')
labels = tf.placeholder(tf.int64, [None], 'labels')
one_hot_labels = tf.one_hot(labels, 10)

In [5]:
layers = tf.contrib.layers
ds = tf.contrib.distributions

def encoder(images):
    net = layers.relu(2*images-1, 1024)
    net = layers.relu(net, 1024)
    params = layers.linear(net, 512)
    mu, rho = params[:, :256], params[:, 256:]
    encoding = ds.NormalWithSoftplusScale(mu, rho - 5.0)
    return encoding


def decoder(encoding_sample):
    net = layers.linear(encoding_sample, 10)
    return net

prior = ds.Normal(0.0, 1.0)

In [6]:
import math

with tf.variable_scope('encoder'):
    encoding = encoder(images)
    
with tf.variable_scope('decoder'):
    logits = decoder(encoding.sample())
    
with tf.variable_scope('decoder', reuse=True):
    many_logits = decoder(encoding.sample(12))

class_loss = tf.losses.softmax_cross_entropy(
    logits=logits, onehot_labels=one_hot_labels) / math.log(2)

BETA = 1e-3

info_loss = tf.reduce_sum(tf.reduce_mean(
    ds.kl_divergence(encoding, prior), 0)) / math.log(2)

total_loss = class_loss + BETA * info_loss

In [7]:
accuracy = tf.reduce_mean(tf.cast(tf.equal(
    tf.argmax(logits, 1), labels), tf.float32))
avg_accuracy = tf.reduce_mean(tf.cast(tf.equal(
    tf.argmax(tf.reduce_mean(tf.nn.softmax(many_logits), 0), 1), labels), tf.float32))
IZY_bound = math.log(10, 2) - class_loss
IZX_bound = info_loss 

In [8]:
batch_size = 100
steps_per_batch = int(mnist_data.train.num_examples / batch_size)

In [9]:
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(1e-4, global_step,
                                           decay_steps=2*steps_per_batch,
                                           decay_rate=0.97, staircase=True)
opt = tf.train.AdamOptimizer(learning_rate, 0.5)

ma = tf.train.ExponentialMovingAverage(0.999, zero_debias=True)
ma_update = ma.apply(tf.model_variables())

saver = tf.train.Saver()
saver_polyak = tf.train.Saver(ma.variables_to_restore())

train_tensor = tf.contrib.training.create_train_op(total_loss, opt,
                                                   global_step,
                                                   update_ops=[ma_update])

In [10]:
tf.global_variables_initializer().run()

In [11]:
def evaluate():
    IZY, IZX, acc, avg_acc = sess.run([IZY_bound, IZX_bound, accuracy, avg_accuracy],
                             feed_dict={images: mnist_data.test.images, labels: mnist_data.test.labels})
    return IZY, IZX, acc, avg_acc, 1-acc, 1-avg_acc

In [12]:
import sys

for epoch in range(200):
    for step in range(steps_per_batch):
        im, ls = mnist_data.train.next_batch(batch_size)
        sess.run(train_tensor, feed_dict={images: im, labels: ls})
    print "{}: IZY={:.2f}\tIZX={:.2f}\tacc={:.4f}\tavg_acc={:.4f}\terr={:.4f}\tavg_err={:.4f}".format(
        epoch, *evaluate())
    sys.stdout.flush()
    
savepth = saver.save(sess, '/tmp/mnistvib', global_step)

0: IZY=2.98	IZX=134.31	acc=0.9232	avg_acc=0.9379	err=0.0768	avg_err=0.0621
1: IZY=3.09	IZX=93.38	acc=0.9489	avg_acc=0.9604	err=0.0511	avg_err=0.0396
2: IZY=3.14	IZX=81.95	acc=0.9600	avg_acc=0.9708	err=0.0400	avg_err=0.0292
3: IZY=3.15	IZX=85.19	acc=0.9636	avg_acc=0.9730	err=0.0364	avg_err=0.0270
4: IZY=3.19	IZX=80.84	acc=0.9724	avg_acc=0.9769	err=0.0276	avg_err=0.0231
5: IZY=3.17	IZX=86.21	acc=0.9694	avg_acc=0.9733	err=0.0306	avg_err=0.0267
6: IZY=3.19	IZX=71.26	acc=0.9718	avg_acc=0.9782	err=0.0282	avg_err=0.0218
7: IZY=3.19	IZX=67.85	acc=0.9733	avg_acc=0.9780	err=0.0267	avg_err=0.0220
8: IZY=3.21	IZX=69.18	acc=0.9746	avg_acc=0.9793	err=0.0254	avg_err=0.0207
9: IZY=3.21	IZX=63.90	acc=0.9756	avg_acc=0.9813	err=0.0244	avg_err=0.0187
10: IZY=3.21	IZX=68.74	acc=0.9776	avg_acc=0.9815	err=0.0224	avg_err=0.0185
11: IZY=3.21	IZX=58.74	acc=0.9747	avg_acc=0.9812	err=0.0253	avg_err=0.0188
12: IZY=3.21	IZX=60.24	acc=0.9767	avg_acc=0.9806	err=0.0233	avg_err=0.0194
13: IZY=3.21	IZX=64.56	acc=0.9761	

110: IZY=3.21	IZX=25.31	acc=0.9815	avg_acc=0.9867	err=0.0185	avg_err=0.0133
111: IZY=3.21	IZX=24.20	acc=0.9820	avg_acc=0.9866	err=0.0180	avg_err=0.0134
112: IZY=3.21	IZX=24.25	acc=0.9818	avg_acc=0.9860	err=0.0182	avg_err=0.0140
113: IZY=3.21	IZX=26.17	acc=0.9821	avg_acc=0.9871	err=0.0179	avg_err=0.0129
114: IZY=3.22	IZX=24.54	acc=0.9818	avg_acc=0.9866	err=0.0182	avg_err=0.0134
115: IZY=3.21	IZX=24.21	acc=0.9811	avg_acc=0.9874	err=0.0189	avg_err=0.0126
116: IZY=3.20	IZX=25.15	acc=0.9814	avg_acc=0.9869	err=0.0186	avg_err=0.0131
117: IZY=3.21	IZX=23.58	acc=0.9822	avg_acc=0.9880	err=0.0178	avg_err=0.0120
118: IZY=3.21	IZX=24.28	acc=0.9833	avg_acc=0.9865	err=0.0167	avg_err=0.0135
119: IZY=3.22	IZX=25.01	acc=0.9831	avg_acc=0.9867	err=0.0169	avg_err=0.0133
120: IZY=3.21	IZX=23.58	acc=0.9829	avg_acc=0.9865	err=0.0171	avg_err=0.0135
121: IZY=3.21	IZX=24.13	acc=0.9811	avg_acc=0.9868	err=0.0189	avg_err=0.0132
122: IZY=3.21	IZX=24.14	acc=0.9829	avg_acc=0.9871	err=0.0171	avg_err=0.0129
123: IZY=3.1

In [13]:
saver_polyak.restore(sess, savepth)
evaluate()

INFO:tensorflow:Restoring parameters from /tmp/mnistvib-120000


(3.2056963,
 22.165133,
 0.98189998,
 0.98789996,
 0.01810002326965332,
 0.012100040912628174)

In [14]:
saver.restore(sess, savepth)
evaluate()

INFO:tensorflow:Restoring parameters from /tmp/mnistvib-120000


(3.2040105,
 22.224438,
 0.98229998,
 0.98579997,
 0.017700016498565674,
 0.014200031757354736)