# Bayes By Backprop

* [**Paper**](https://arxiv.org/abs/1505.05424)

In this paper, the authors describe a method for weight regularisation which provides direct information on the uncertainty of the weights. In essence, we try to approximate the posterior $p(w | D)$ of weights $w$, given the data $D$.

From Bayes' theorem, we know:

$$
p(w | D) = \frac{p(D | w)p(w)}{p(D)}
$$

However, estimating $p(w | D)$ directly is intractible due to $p(D)$ (the evidence). So instead we can approximate $p(w | D)$ with the variational posterior $q(w | \theta)$. The approximation can be done by selecting parameters $\theta$ such that the Kullback-Leibler divergence between the two distributions is minimised:

$$
\theta^* = \text{argmin}_{\theta} KL[q(w | \theta) || p(w | D)]
$$

The divergence can be re-written into:

$$\begin{align}
KL[q(w | \theta) || p(w | D)] &= \int q(w | \theta) \log \frac{q(w | \theta)}{p(w | D)} \\
&= \int q(w | \theta) \log \frac{q(w | \theta) p(D)}{p(D | w) p(w)} \\
&= \int q(w | \theta) \log \frac{q(w | \theta)}{p(D | w) p(w)} + \mathbb{E}_{q(w | \theta)}[\log p(D)] \\
&= \int q(w | \theta) \log \frac{q(w | \theta)}{p(D | w) p(w)} + \log p(D)
\end{align}$$

Since $p(D)$ is constant in this expression and does not depend on $\theta$, we can ignore it in the minimisation. This reduces the expression to:

$$\begin{align}
\text{argmin}_{\theta} KL[q(w | \theta) || p(w | D)] &= \text{argmin}_{\theta} \int q(w | \theta) \log \frac{q(w | \theta)}{p(D | w) p(w)} \\
&= \text{argmin}_{\theta} \int q(w | \theta) \log \frac{q(w | \theta)}{p(w)} - \int q(w | \theta) \log p(D | w) \\
&= \text{argmin}_{\theta} KL[q(w | \theta) || p(w)] - \mathbb{E}_{q(w | \theta)}[\log p(D | w)]
\end{align}$$

In [1]:
from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt
import tensorflow as tf
import tqdm
from sklearn import metrics

In [2]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [3]:
X_train = jnp.array(X_train.reshape(len(X_train), -1) / 126)
X_test = jnp.array(X_test.reshape(len(X_test), -1) / 126)

y_train = jax.nn.one_hot(y_train, 10)
y_test = jax.nn.one_hot(y_test, 10)

In [4]:
def kaiming_sigma(n):
    return 2 / n


def inv_t(sigma):
    return jnp.log(jnp.exp(sigma) - 1)


def init_mu(shape, rng):
    return 0.1 * jax.random.normal(rng, shape)


def init_rho(shape, rng):
    return inv_t(kaiming_sigma(shape[-1])) + 0.1 * jax.random.normal(rng, shape)


def init_theta(shape, rng):
    a, b = jax.random.split(rng)
    return (init_mu(shape, a), init_rho(shape, b))


def init_Wb(shape, rng):
    a, b = jax.random.split(rng)
    return (init_theta(shape, a), init_theta(shape[-1:], b))


In [5]:
Theta = Tuple[jnp.ndarray, jnp.ndarray]
Params = Tuple[Theta, ...]


def sample_w(mu: jnp.ndarray, rho: jnp.ndarray, rng_key: jnp.ndarray) -> jnp.ndarray:
    eps: jnp.ndarray = jax.random.normal(rng_key, mu.shape)
    w = mu + jnp.log(1 + jnp.exp(rho)) * eps
    return w


def bbb_mlp(params: Params, X: jnp.ndarray, rng_key: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]:
    theta_W0, theta_b0, theta_W1, theta_b1, theta_W2, theta_b2 = params
    k0, k1, k2, k3, k4, k5 = jax.random.split(rng_key, 6)

    W0 = sample_w(*theta_W0, k0)
    b0 = sample_w(*theta_b0, k1)
    
    W1 = sample_w(*theta_W1, k2)
    b1 = sample_w(*theta_b1, k3)
    
    W2 = sample_w(*theta_W2, k4)
    b2 = sample_w(*theta_b2, k5)
    return nn.relu(nn.relu(X @ W0 + b0) @ W1 + b1) @ W2 + b2, (W0, W1, W2)


def kl_div(p: Params) -> jnp.ndarray:
    kl = jnp.array(0)
    
    mu_p, sigma_p = jnp.array(0), jnp.exp(-2)
    for (mu_q, rho_q) in p:
        sigma_q = jnp.log(1 + jnp.exp(rho_q))
        kl += jnp.sum(
            2 * jnp.log(sigma_p / sigma_q)
            - 1 + (sigma_q / sigma_p) ** 2
            + ((mu_p - mu_q) / sigma_p) ** 2
        )
    return 0.5 * kl


@jax.jit
def train_step(
    params: Params,
    X: jnp.ndarray,
    y: jnp.ndarray,
    rng_key: jnp.ndarray,
    n_posterior_samples: int = 10,
    eta: float = 1e-3,
    beta: float = 0.05,
) -> Tuple[Params, jnp.ndarray, jnp.ndarray]:
    def loss_fn(p: Params, k: jnp.ndarray) -> jnp.ndarray:
        y_hat, _ = bbb_mlp(p, X, k)
        loss = (
            # log q(w | theta) / p(w)
            beta * kl_div(p)
            # log p(D | theta)
            - jnp.mean(
                jnp.sum(y * nn.log_softmax(y_hat, axis=-1), axis=-1)
            )
        )
        return loss

    G = jax.tree_map(lambda _: jnp.zeros_like(_), params)
    f = jax.value_and_grad(loss_fn)
    L = jnp.array(0)
    for i in range(n_posterior_samples):
        rng_key, key = jax.random.split(rng_key)
        l, g = f(params, key)
        L += l
        G = jax.tree_map(lambda c, k: c + k, g, G)

    # even though we are trying to approximate the expectation of the gradient
    # there is no point in normalising by the number of samples, as it would be
    # equivalent to using a lower learning rate
    # G = jax.tree_map(lambda g: g / n_posterior_samples, G)
    return update_params(params, G, eta=eta), L / n_posterior_samples, rng_key


def update_params(params: Params, gradients: Params, eta: float) -> Params:
    return jax.tree_map(lambda w, g: w - eta * g, params, gradients)


In [6]:
def get_batch_indices(rng: jnp.ndarray, dataset_size: int, batch_size: int) -> jnp.ndarray:
    steps_per_epoch = dataset_size // batch_size

    perms = jax.random.permutation(rng, dataset_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    return perms

In [7]:
rng = jax.random.PRNGKey(0)

In [8]:
p0 = ()
for l in [
    (28 ** 2, 512),
    (512, 256),
    (256, 10),
]:
    p0 = p0 + init_Wb(l, rng)

In [9]:
def with_beta(I):
    M = len(I)
    for ix, i in enumerate(I, 1):
        yield (2 ** (M - ix)) / (2 ** M - 1), i


In [10]:
r = rng
p = p0
for e in range(50):
    r, r0 = jax.random.split(r, 2)
    ix = get_batch_indices(r0, len(X_train), 128)
    I = tqdm.tqdm(ix, desc=f"Epoch {e}")
    for beta, i in with_beta(I):
        p, l, r = train_step(p, X_train[i], y_train[i], r, eta=1e-3, beta=beta)
        I.set_description(f"Epoch {e} (loss={l.item():.3f})")
        I.refresh()

Epoch 0 (loss=0.982): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:45<00:00, 10.28it/s]
Epoch 1 (loss=0.794): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:29<00:00, 15.70it/s]
Epoch 2 (loss=0.588): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:29<00:00, 16.11it/s]
Epoch 3 (loss=0.664): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:28<00:00, 16.32it/s]
Epoch 4 (loss=0.805): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:29<00:00, 15.83it/s]
Epoch 5 (loss=0.755): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:28<00:00, 16.31it/s]
Epoch 6 (loss=0.723): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/

In [11]:
y_hat, _ = bbb_mlp(p, X_test, r)

In [12]:
print(
    metrics.classification_report(
        y_test.argmax(axis=1),
        y_hat.argmax(axis=1)
    )
)

              precision    recall  f1-score   support

           0       0.82      0.98      0.90       980
           1       0.84      0.99      0.91      1135
           2       0.87      0.78      0.82      1032
           3       0.74      0.87      0.80      1010
           4       0.86      0.80      0.83       982
           5       0.87      0.49      0.63       892
           6       0.81      0.93      0.87       958
           7       0.93      0.80      0.86      1028
           8       0.74      0.74      0.74       974
           9       0.77      0.78      0.77      1009

    accuracy                           0.82     10000
   macro avg       0.83      0.82      0.81     10000
weighted avg       0.83      0.82      0.82     10000



In [13]:
metrics.roc_auc_score(y_test, nn.softmax(y_hat, axis=-1))

0.9792882105162253

## Thoughts

While theoretically neat, this method is both involved in terms of setup, and seems _very_ sensitive to hyper-parameter settings and intial parameters $\theta_{t = 0}$. It is clear that the authors of this paper have spent quite a bit of time tuning these hyper-parameters.


## Appendix

Here I show some additional working.

### KL-divergence of two univariate Gaussians

We start by defining the classic KL divergence between two continuous random distributions $p(x)$ and $q(x)$:

$$
KL[p(x) || q(x)] = \int p(x) \log \frac{p(x)}{q(x)}
$$

and the univariate Gaussian probability (density) function:

$$\begin{align}
p(x) &= \frac{1}{\sqrt{2 \pi} \sigma} \exp \left(- \frac{z^2}{2}\right) \\
z &= \frac{x - \mu}{\sigma}
\end{align}$$

Taking the logarithm of the Gaussian:

$$
\log p(x) = - \frac{1}{2} \left( \log 2 \pi + \log \sigma^2 + z^2 \right)
$$

Expanding the logarithm, we get:

$$
KL[p(x) || q(x)] = \int p(x) \log p(x) - \int p(x) \log q(x)
$$

Let us first focus on the Shannon entropy (first term of RHS):

$$\begin{align}
\int p(x) \log p(x) &= \int - \frac{1}{2} p(x) \left( \log 2 \pi + \log \sigma^2 + z^2 \right) \\
&= - \frac{1}{2} \int p(x) \left( \log 2 \pi + \log \sigma^2 + z^2 \right) \\
&= - \frac{1}{2} \left( \log 2 \pi + \log \sigma^2 + \int p(x) z^2 \right)\\
&= - \frac{1}{2} \left( \log 2 \pi + \log \sigma^2 + \mathbb{E}[z^2] \right)
\end{align}$$

Now, since
$$\begin{align}
\mathbb{V}[x] = \mathbb{E}[x^2] - \mathbb{E}[x]^2 \iff \mathbb{E}[x^2] = \mathbb{V}[x] + \mathbb{E}[x]^2
\end{align}$$

and we know that $\mathbb{E}[z] = 0$ and $\mathbb{V}[z] = 1$, we hence get:

$$
\int p(x) \log p(x) = - \frac{1}{2} \left( \log 2 \pi + \log \sigma^2 + 1 \right)
$$

Next, we focus on the cross-entropy part of the KL (second term of RHS). Here, I will start introducing subscripts, such that the distinctions between $p(x)$ and $q(x)$ (and their moments) are made clear:

$$\begin{align}
\int p(x) \log q(x) &= \int - \frac{1}{2} p(x) \left( \log 2 \pi + \log \sigma_q^2 + z_q^2 \right) \\
&= - \frac{1}{2} \left( \log 2 \pi + \log \sigma_q^2 + \mathbb{E}_p[z_q^2] \right)
\end{align}$$

Focusing just on $\mathbb{E}_p[z_q^2]$:

$$\begin{align}
\mathbb{E}_p[z_q^2] &= \int p(x) \left( \frac{x - \mu_q}{\sigma_q} \right)^2 \\
&= \int p(x) \frac{x^2 - 2x\mu_q + \mu_q^2}{\sigma_q^2} \\
&= \frac{1}{\sigma_q^2} \left( \int p(x) x^2  - \int p(x) 2 x \mu_q  + \int p(x) \mu_q^2 \right) \\
&= \frac{1}{\sigma_q^2} \left( \mathbb{E}[x^2] - 2 \mu_p \mu_q + \mu_q^2 \right) \\
&= \frac{1}{\sigma_q^2} \left( \sigma_p^2 + \mu_p^2 - 2 \mu_p \mu_q + \mu_q^2 \right) \\
&= \frac{1}{\sigma_q^2} \left( \sigma_p^2 + (\mu_p - \mu_q)^2 \right)
\end{align}$$