In [1]:
from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random
import jax.scipy.stats as jst
import tensorflow as tf
import tqdm
from sklearn import metrics

In [42]:
def z_score(X):
    return (X - jnp.mean(X, axis=0, keepdims=True)) / (jnp.std(X, axis=0, keepdims=True, ddof=1) + jnp.finfo(float).eps)


In [43]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [44]:
X_train = z_score(X_train.reshape(len(X_train), -1))
X_test = z_score(X_test.reshape(len(X_test), -1))

y_train = jax.nn.one_hot(y_train, 10)
y_test = jax.nn.one_hot(y_test, 10)

In [52]:
def init_mu(shape, rng):
    return 0.1 * jax.random.normal(rng, shape)


def init_rho(shape, rng):
    return -3 + jax.random.normal(rng, shape)


def init_theta(shape, rng):
    a, b = jax.random.split(rng)
    return (init_mu(shape, a), init_rho(shape, b))


def init_Wb(shape, rng):
    a, b = jax.random.split(rng)
    return (init_theta(shape, a), init_theta(shape[-1:], b))


$$
f(w, \theta) = \log q(w | \theta) - \log p(w) - \log p(D | w)
$$



In [25]:
Theta = Tuple[jnp.ndarray, jnp.ndarray]
Params = Tuple[Theta, Theta, Theta]


def sample_w(mu: jnp.ndarray, rho: jnp.ndarray, rng_key: jnp.ndarray) -> jnp.ndarray:
    eps: jnp.ndarray = jax.random.normal(rng_key, mu.shape)
    w = mu + jnp.log(1 + jnp.exp(rho)) * eps
    return w


def bbb_mlp(params: Params, X: jnp.ndarray, rng_key: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]:
    theta_W0, theta_b0, theta_W1, theta_b1, theta_W2, theta_b2 = params
    k0, k1, k2, k3, k4, k5 = jax.random.split(rng_key, 6)

    W0 = sample_w(*theta_W0, k0)
    b0 = sample_w(*theta_b0, k1)
    
    W1 = sample_w(*theta_W1, k2)
    b1 = sample_w(*theta_b1, k3)
    
    W2 = sample_w(*theta_W2, k4)
    b2 = sample_w(*theta_b2, k5)
    return nn.relu(nn.relu(X @ W0 + b0) @ W1 + b1) @ W2 + b2, (W0, W1, W2)


def z_score(w, mu, sigma):
    return (w - mu) / sigma


def kl_div(p: Params) -> jnp.ndarray:
    kl = jnp.array(0)
    
    mu_p, sigma_p = jnp.array(0), jnp.exp(-2)
    for (mu_q, rho_q) in p:
        sigma_q = jnp.log(1 + jnp.exp(rho_q))
        kl += jnp.sum(2 * jnp.log(sigma_p / sigma_q) - 1 + (sigma_q / sigma_p) ** 2 + ((mu_p - mu_q) / sigma_p) ** 2)
    return 0.5 * kl


@jax.jit
def train_step(
    params: Params,
    X: jnp.ndarray,
    y: jnp.ndarray,
    rng_key: jnp.ndarray,
    n_posterior_samples: int = 10,
    eta: float = 1e-3,
    beta: float = 0.05,
) -> Tuple[Params, jnp.ndarray, jnp.ndarray]:
    def loss_fn(p: Params, k: jnp.ndarray) -> jnp.ndarray:
        y_hat, _ = bbb_mlp(p, X, k)
        loss = (
            # log q(w | theta) / p(w)
            beta * kl_div(p)
            # log p(D | theta)
            - jnp.mean(
                jnp.sum(y * nn.log_softmax(y_hat, axis=-1), axis=-1)
            )
        )
        return loss

    G = jax.tree_map(lambda _: jnp.zeros_like(_), params)
    f = jax.value_and_grad(loss_fn)
    L = jnp.array(0)
    for i in range(n_posterior_samples):
        rng_key, key = jax.random.split(rng_key)
        l, g = f(params, key)
        L += l
        G = jax.tree_map(lambda c, k: c + k, g, G)

    # G = jax.tree_map(lambda g: g / n_posterior_samples, G)
    return update_params(params, G, eta=eta), L / n_posterior_samples, rng_key


def update_params(params: Params, gradients: Params, eta: float) -> Params:
    return jax.tree_map(lambda w, g: w - eta * g, params, gradients)


In [6]:
def get_batch_indices(rng: jnp.ndarray, dataset_size: int, batch_size: int) -> jnp.ndarray:
    steps_per_epoch = dataset_size // batch_size

    perms = jax.random.permutation(rng, dataset_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    return perms

In [53]:
l0 = (28 ** 2, 512)
l1 = (512, 256)
l2 = (256, 10)

p0 = ()
for l in [
    (28 ** 2, 512),
    (512, 256),
    (256, 10),
]:
    p0 = p0 + init_Wb(l, rng)

In [8]:
rng = jax.random.PRNGKey(0)

In [9]:
def with_beta(I):
    M = len(I)
    for ix, i in enumerate(I, 1):
        yield (2 ** (M - ix)) / (2 ** M - 1), i


In [55]:
r = rng
p = p0
# p = (W0, W1)
for e in range(50):
    ix = get_batch_indices(r, len(X_train), 128)
    I = tqdm.tqdm(ix, desc=f"Epoch {e}")
    for beta, i in with_beta(I):
        p, l, r = train_step(p, X_train[i], y_train[i], r, eta=1e-3, beta=beta)
        I.set_description(f"Epoch {e} (loss={l.item():.3f})")
        I.refresh()

Epoch 0 (loss=2.405): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:29<00:00, 15.74it/s]
Epoch 1 (loss=2.266): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:29<00:00, 15.77it/s]
Epoch 2 (loss=2.063): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:33<00:00, 14.17it/s]
Epoch 3 (loss=2.038): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:33<00:00, 13.98it/s]
Epoch 4 (loss=1.907): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:33<00:00, 14.06it/s]
Epoch 5 (loss=1.885): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:33<00:00, 13.98it/s]
Epoch 6 (loss=1.797): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/

KeyboardInterrupt: 

In [56]:
y_hat, _ = bbb_mlp(p, X_test, r)

In [57]:
print(
    metrics.classification_report(
        y_test.argmax(axis=1),
        y_hat.argmax(axis=1)
    )
)

              precision    recall  f1-score   support

           0       0.35      0.89      0.50       980
           1       0.81      0.96      0.88      1135
           2       0.27      0.25      0.26      1032
           3       0.35      0.07      0.12      1010
           4       0.24      0.06      0.09       982
           5       0.14      0.13      0.14       892
           6       0.26      0.07      0.11       958
           7       0.31      0.88      0.45      1028
           8       0.21      0.14      0.17       974
           9       0.25      0.02      0.04      1009

    accuracy                           0.36     10000
   macro avg       0.32      0.35      0.28     10000
weighted avg       0.33      0.36      0.29     10000



In [58]:
metrics.roc_auc_score(y_test, nn.softmax(y_hat, axis=-1))

0.8164770253238245

In [14]:
def calculate_kl(p, w):
    kl = jnp.array(0)
    mu_p, sigma_p = jnp.array(0), jnp.exp(-2)
    for W, (mu_q, rho_q) in zip(w, p):
        sigma_q = jnp.log(1 + jnp.exp(rho_q))
        kl += jnp.sum(2 * jnp.log(sigma_p / sigma_q) - 1 + (sigma_q / sigma_p) ** 2 + ((mu_p - mu_q) / sigma_p) ** 2)
    return 0.5 * kl

In [16]:
_, w = bbb_mlp(p, X_test, rng)

In [18]:
kl_div(p, w)

Array(3495.823, dtype=float32)

In [19]:
calculate_kl(p, w)

Array(3549.7795, dtype=float32)