# JAX Implementation of Deep Variational Information Bottleneck

This notebook serves as a modern JAX remake of the code that powered the [Deep Variational Information Bottleneck](https://arxiv.org/abs/1612.00410) paper.

## Setup

I'd recommend using a GPU kernel.

In [221]:
#@title requirements
!pip install flax



In [222]:
#@title imports
import functools

import jax
import jax.numpy as np
from jax import grad, vmap, jit, random
from typing import Any
from pprint import pprint

import flax
import flax.linen as nn

import matplotlib.pyplot as plt
plt.style.use('default')

import tensorflow_datasets as tfds

print("JAX Devices: ", jax.devices())

JAX Devices:  [GpuDevice(id=0)]


In [223]:
#@title data
dataset = tfds.load('mnist', split='train').batch(60_000).cache()
data = jax.device_put(next(dataset.as_numpy_iterator()))
batch = jax.tree_map(lambda x: x[:100], data)

eval_dataset = tfds.load('mnist', split='test').batch(10_000).cache()
eval_ds = jax.device_put(next(eval_dataset.as_numpy_iterator()))

## Code

In [224]:
#@title Distributions

@flax.struct.dataclass
class MultivariateNormalDiag():
  locs: np.ndarray
  scales: np.ndarray

  def log_prob(self, x):
    return jax.scipy.stats.norm.logpdf(
        x, loc=self.locs, scale=self.scales).sum(-1)
  
  def sample(self, rng, shape=()):
    return self.locs + self.scales * random.normal(
        rng, shape + self.locs.shape)

@flax.struct.dataclass
class Categorical():
  logits: np.ndarray

  def log_prob(self, x):
    @functools.partial(np.vectorize, signature='(k),()->()')
    def f(logits, x):
      logits = jax.nn.log_softmax(logits, axis=-1)
      return logits[x]
    return f(self.logits, x)

  def sample(self, rng, shape=()):
    return random.categorical(
        rng, self.logits, axis=-1, shape=shape)

In [225]:
#@title Model

kernel_init = jax.nn.initializers.xavier_uniform()
bias_init = jax.nn.initializers.zeros


class Encoder(nn.Module):
  embedding_width: int = 256
  nonlinearity: Any = jax.nn.relu

  @nn.compact
  def __call__(self, x):
    # rescale
    x = x / 128.0 - 1.0
    x = x.reshape((-1, 28 * 28))
    x = self.nonlinearity(
        nn.Dense(1024,
                 kernel_init=kernel_init,
                 bias_init=bias_init)(x))
    x = self.nonlinearity(
        nn.Dense(1024,
                 kernel_init=kernel_init,
                 bias_init=bias_init)(x))
    means = nn.Dense(
        self.embedding_width,
        kernel_init=kernel_init,
        bias_init=bias_init)(x)
    rhos = nn.Dense(
        self.embedding_width,
        kernel_init=kernel_init,
        bias_init=bias_init)(x)

    return MultivariateNormalDiag(
        means, jax.nn.softplus(rhos - 5.0))


class Decoder(nn.Module):
  classes: int = 10

  @nn.compact
  def __call__(self, z):
    logits = nn.Dense(
        self.classes,
        kernel_init=kernel_init,
        bias_init=bias_init)(z)
    return Categorical(logits=logits)

bits = np.log(2)

class VIB(nn.Module):
  width: int = 256
  num_samples: int = 16
  num_classes: int = 10
  beta: float = 1e-3

  def setup(self):
    self.encoder = Encoder(self.width)
    self.decoder = Decoder(self.num_classes)
    self.prior = MultivariateNormalDiag(
        np.zeros(self.width), np.ones(self.width))

  def __call__(self, batch, rng):
    image = batch['image']
    z_dist = self.encoder(image)
    z_samples = z_dist.sample(rng, (self.num_samples,))
    pred_dist = self.decoder(z_samples)

    class_loss = -pred_dist.log_prob(batch['label']) / bits
    rate = (z_dist.log_prob(z_samples) -
            self.prior.log_prob(z_samples)) / bits
    loss = class_loss + self.beta * rate

    # metrics 
    err = 1-(pred_dist.logits.argmax(-1) == batch['label']).mean()
    avg_logits = jax.nn.logsumexp(
        jax.nn.log_softmax(pred_dist.logits, axis=-1), axis=0)
    avg_err = 1-(avg_logits.argmax(-1) == batch['label']).mean()
    avg_loss = -Categorical(logits=avg_logits).log_prob(
        batch['label']) / bits

    return loss.mean(), {
        'c': class_loss,
        'r': rate,
        'loss': loss,
        'err': err,
        'avg_err': avg_err,
        'avg_loss': avg_loss
    }

In [226]:
#@title Training

@jax.jit
def train_step(optimizer, params_ema, batch, rng, learning_rate):
  """Train for a single step."""

  def loss_fn(params):
    return vib.apply(params, batch, rng)

  (loss, aux), grad = jax.value_and_grad(
      loss_fn, has_aux=True)(
          optimizer.target)
  optimizer = optimizer.apply_gradient(
      grad, learning_rate=learning_rate)
  params_ema = jax.tree_multimap(
      lambda p_ema, p: p_ema * 0.999 + p * 0.001,
      params_ema, optimizer.target)
  return optimizer, params_ema, aux

reshaped_eval_ds = {
    'image': eval_ds['image'].reshape(
        (100, -1, 28, 28, 1)),
    'label': eval_ds['label'].reshape(
        (100, -1))
}


@jax.jit
def evaluate(params, dataset, rng):

  def f(part):
    _, aux = vib.clone(num_samples=1024).apply(
        params, part, rng)
    return aux

  return jax.tree_map(np.mean, jax.lax.map(f, dataset))

def epoch(optimizer,
          params_ema,
          data,
          rng,
          learning_rate,
          eval_rng=None,
          batch_size=100):
  if eval_rng is None:
    eval_rng = random.PRNGKey(0)
  steps_per_epoch = len(data['image']) // batch_size

  rng, spl = random.split(rng)
  perms = random.permutation(spl, len(data['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  @jax.jit
  def segment(state, perm):
    (optimizer, params_ema, rng) = state
    rng, spl = random.split(rng)
    batch = {k: v[perm, ...] for k, v in data.items()}
    optimizer, params_ema, aux = train_step(
        optimizer, params_ema, batch, spl, learning_rate)
    return (optimizer, params_ema,
            rng), jax.tree_map(np.mean, aux)

  (optimizer, params_ema,
   rng), batch_stats = jax.lax.scan(
       segment, (optimizer, params_ema, rng), perms)

  eval_aux = evaluate(params_ema,
                      reshaped_eval_ds,
                      eval_rng)

  return (optimizer, params_ema, spl,
          batch_stats, eval_aux)

## Main

In [227]:
#@title Init
seed = 32828
rng = random.PRNGKey(seed)
rng, init_rng, eval_rng = random.split(rng, 3)
vib = VIB(num_samples=12, beta=1e-3)

params = vib.init(init_rng, batch, rng)
params_ema = vib.init(init_rng, batch, rng)

optimizer_def = flax.optim.Adam()
optimizer = optimizer_def.create(params)
learning_rate = 2e-4

In [228]:
#@title run
prep = lambda x: f"{float(np.mean(x)):.4}"

counter = 0
eval_stats = []
for i in range(25):
  optimizer, params_ema, rng, batch_stats, eval_aux = epoch(
      optimizer, params_ema, data, rng, learning_rate, eval_rng)
  counter += 1
  if (counter > 0) and (counter % 2 == 0):
    learning_rate *= 0.95 # 0.97
  print(counter, flush=True)
  eval_stats.append(jax.tree_map(lambda x: float(np.mean(x)), eval_aux))
  print("TRAIN:", jax.tree_map(prep, batch_stats), flush=True)
  print("EVAL: ", jax.tree_map(prep, eval_aux), flush=True)

1
TRAIN: {'avg_err': '0.0807', 'avg_loss': '0.3949', 'c': '0.4444', 'err': '0.0946', 'loss': '0.6215', 'r': '177.1'}
EVAL:  {'avg_err': '0.1439', 'avg_loss': '0.9107', 'c': '0.9112', 'err': '0.1441', 'loss': '2.041', 'r': '1.13e+03'}
2
TRAIN: {'avg_err': '0.03227', 'avg_loss': '0.165', 'c': '0.1937', 'err': '0.04085', 'loss': '0.2817', 'r': '88.05'}
EVAL:  {'avg_err': '0.0464', 'avg_loss': '0.2713', 'c': '0.2731', 'err': '0.0478', 'loss': '0.8839', 'r': '610.8'}
3
TRAIN: {'avg_err': '0.02038', 'avg_loss': '0.1096', 'c': '0.1291', 'err': '0.02718', 'loss': '0.2082', 'r': '79.1'}
EVAL:  {'avg_err': '0.0282', 'avg_loss': '0.1419', 'c': '0.1465', 'err': '0.03004', 'loss': '0.4569', 'r': '310.4'}
4
TRAIN: {'avg_err': '0.01583', 'avg_loss': '0.08485', 'c': '0.1014', 'err': '0.02144', 'loss': '0.1741', 'r': '72.67'}
EVAL:  {'avg_err': '0.0208', 'avg_loss': '0.1014', 'c': '0.109', 'err': '0.02357', 'loss': '0.2834', 'r': '174.3'}
5
TRAIN: {'avg_err': '0.01018', 'avg_loss': '0.06078', 'c': '0.0

In [229]:
# final evaluation
result = evaluate(params_ema, eval_ds, eval_rng)
print(result)
print(f"{result['avg_err']:.2%}\t{result['avg_loss']:.4}")

{'avg_err': DeviceArray(0.0118, dtype=float32), 'avg_loss': DeviceArray(0.05816593, dtype=float32), 'c': DeviceArray(0.09675445, dtype=float32), 'err': DeviceArray(0.01501162, dtype=float32), 'loss': DeviceArray(0.12789465, dtype=float32), 'r': DeviceArray(31.140131, dtype=float32)}
1.18%	0.05817
