In [1]:
from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random
import jax.scipy.stats as jst
import tensorflow as tf
import tqdm
from sklearn import metrics

In [2]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [3]:
X_train = jnp.array(X_train.reshape(len(X_train), -1) / 255)
X_test = jnp.array(X_test.reshape(len(X_test), -1) / 255)

y_train = jax.nn.one_hot(y_train, 10)
y_test = jax.nn.one_hot(y_test, 10)

In [4]:
def init_mu(shape):
    return jnp.zeros(shape)


def init_rho(shape):
    sigma2 = jnp.ones(shape) * jnp.exp(-2)
    return jnp.log(jnp.exp(sigma2) - 1)

$$
f(w, \theta) = \log q(w | \theta) - \log p(w) - \log p(D | w)
$$



In [5]:

Theta = Tuple[jnp.ndarray, jnp.ndarray]
Params = Tuple[Theta, Theta, Theta]


def sample_w(mu: jnp.ndarray, rho: jnp.ndarray, rng_key: jnp.ndarray) -> jnp.ndarray:
    eps: jnp.ndarray = jax.random.normal(rng_key, mu.shape)
    w = mu + jnp.log(1 + jnp.exp(rho)) * eps
    return w


def bbb_mlp(params: Params, X: jnp.ndarray, rng_key: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]:
    theta_0, theta_1, theta_2 = params
    k0, k1, k2 = jax.random.split(rng_key, 3)

    W0 = sample_w(*theta_0, k0)
    W1 = sample_w(*theta_1, k1)
    W2 = sample_w(*theta_2, k2)
    return nn.relu(nn.relu(X @ W0) @ W1) @ W2, (W0, W1, W2)


def z_score(w, mu, sigma):
    return (w - mu) / sigma


def kl_div(p: Params, w: Tuple[jnp.ndarray, ...]) -> jnp.ndarray:
    kl = jnp.array(0)
    
    mu_p, sigma_p = jnp.array(0), jnp.exp(-2)
    for W, (mu_q, rho_q) in zip(w, p):
        sigma_q = jnp.log(1 + jnp.exp(rho_q))

        z_p = z_score(W, mu_p, sigma_p)
        z_q = z_score(W, mu_q, sigma_q)
        kl += jnp.sum(2 * jnp.log(sigma_p / sigma_q) + z_p ** 2 - z_q ** 2)
    return 0.5 * kl


@jax.jit
def train_step(
    params: Params,
    X: jnp.ndarray,
    y: jnp.ndarray,
    rng_key: jnp.ndarray,
    n_posterior_samples: int = 10,
    eta: float = 1e-3,
    beta: float = 0.05,
) -> Tuple[Params, jnp.ndarray, jnp.ndarray]:
    def loss_fn(p: Params, k: jnp.ndarray) -> jnp.ndarray:
        y_hat, w = bbb_mlp(p, X, k)
        loss = (
            # log q(w | theta) / p(w)
            beta * kl_div(p, w)
            # log p(D | theta)
            - jnp.mean(
                jnp.sum(y * nn.log_softmax(y_hat, axis=-1), axis=-1)
            )
        )
        return loss

    G = jax.tree_map(lambda _: jnp.zeros_like(_), params)
    f = jax.value_and_grad(loss_fn)
    L = jnp.array(0)
    for i in range(n_posterior_samples):
        rng_key, key = jax.random.split(rng_key)
        l, g = f(params, key)
        L += l
        G = jax.tree_map(lambda c, k: c + k, g, G)

    # G = jax.tree_map(lambda g: g / n_posterior_samples, G)
    return update_params(params, G, eta=eta), L / n_posterior_samples, rng_key


def update_params(params: Params, gradients: Params, eta: float) -> Params:
    return jax.tree_map(lambda w, g: w - eta * g, params, gradients)


In [6]:
def get_batch_indices(rng: jnp.ndarray, dataset_size: int, batch_size: int) -> jnp.ndarray:
    steps_per_epoch = dataset_size // batch_size

    perms = jax.random.permutation(rng, dataset_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    return perms

In [7]:
l0 = (28 ** 2, 512)
l1 = (512, 256)
l2 = (256, 10)

t0 = (init_mu(l0), init_rho(l0))
t1 = (init_mu(l1), init_rho(l1))
t2 = (init_mu(l2), init_rho(l2))

p0 = (t0, t1, t2)

In [8]:
rng = jax.random.PRNGKey(0)

In [9]:
def with_beta(I):
    M = len(I)
    for ix, i in enumerate(I, 1):
        yield (2 ** (M - ix)) / (2 ** M - 1), i


In [10]:
r = rng
p = p0
# p = (W0, W1)
for e in range(50):
    ix = get_batch_indices(r, len(X_train), 128)
    I = tqdm.tqdm(ix, desc=f"Epoch {e}")
    for beta, i in with_beta(I):
        p, l, r = train_step(p, X_train[i], y_train[i], r, eta=1e-3, beta=beta)
        I.set_description(f"Epoch {e} (loss={l.item():.3f})")
        I.refresh()

Epoch 0 (loss=3.355): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:52<00:00,  8.95it/s]
Epoch 1 (loss=2.969): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:41<00:00, 11.18it/s]
Epoch 2 (loss=2.967): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:42<00:00, 11.12it/s]
Epoch 3 (loss=2.903): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:42<00:00, 11.05it/s]
Epoch 4 (loss=3.016): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:42<00:00, 11.11it/s]
Epoch 5 (loss=2.928): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:42<00:00, 11.04it/s]
Epoch 6 (loss=2.911): 100%|█████████████████████████████████████████████████████████████████████████████████████| 468/

KeyboardInterrupt: 

In [11]:
y_hat, _ = bbb_mlp(p, X_test, r)

In [12]:
print(
    metrics.classification_report(
        y_test.argmax(axis=1),
        y_hat.argmax(axis=1)
    )
)

              precision    recall  f1-score   support

           0       0.18      0.75      0.29       980
           1       0.00      0.00      0.00      1135
           2       0.05      0.04      0.05      1032
           3       0.13      0.06      0.08      1010
           4       0.02      0.02      0.02       982
           5       0.13      0.14      0.13       892
           6       0.01      0.00      0.00       958
           7       0.35      0.49      0.41      1028
           8       0.07      0.09      0.07       974
           9       0.00      0.00      0.00      1009

    accuracy                           0.16     10000
   macro avg       0.09      0.16      0.11     10000
weighted avg       0.09      0.16      0.10     10000



In [13]:
metrics.roc_auc_score(y_test, nn.softmax(y_hat, axis=-1))

0.5435239434705862

In [None]:
log_posterior_w(sample_w(*p[0], rng), p[0])

Array(104708.195, dtype=float32)

In [None]:
jst.norm.pdf(sample_w(*p[0], rng), p[0][0], jnp.log(1 + jnp.exp(p[0][1])))

Array([[11.716147 ,  5.2126284, 10.182134 , ...,  2.3938844, 13.124783 ,
         7.7830815],
       [13.260883 , 10.124926 , 13.173701 , ..., 13.188732 ,  3.9522772,
         9.572952 ],
       [ 3.792557 , 10.843338 ,  6.59099  , ..., 10.800665 , 12.89541  ,
         4.5610533],
       ...,
       [11.91493  ,  4.299404 , 11.691259 , ..., 12.0596   ,  2.1139517,
        10.35854  ],
       [13.077661 , 12.847717 ,  4.188637 , ...,  9.286474 ,  4.4286294,
        12.834122 ],
       [ 2.8425922,  8.49259  ,  6.8283567, ...,  6.215066 ,  8.3357935,
        13.013833 ]], dtype=float32)

In [None]:
jnp.log(1 + jnp.exp(p0[0][1]))

Array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)

In [None]:
jax.grad(lambda p: - jnp.mean(jnp.sum(y_test * nn.log_softmax(bbb_mlp(p, X_test, rng)[0]), axis=-1)))(p0)

((Array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
  Array([[-0.,  0.,  0., ..., -0.,  0.,  0.],
         [-0., -0.,  0., ..., -0., -0.,  0.],
         [ 0.,  0.,  0., ...,  0.,  0.,  0.],
         ...,
         [-0., -0.,  0., ..., -0., -0., -0.],
         [ 0.,  0.,  0., ...,  0., -0.,  0.],
         [-0.,  0., -0., ...,  0.,  0., -0.]], dtype=float32)),
 (Array([[ 6.22701608e-02, -3.78969833e-02, -3.39822331e-03,
           5.19318581e-02, -7.94331655e-02,  2.86801029e-02,
           2.34241262e-02, -1.71592563e-01,  7.88178816e-02,
           4.71968241e-02],
         [ 2.33181510e-02, -5.22711407e-03, -5.96698746e-02,
           4.36440073e-02, -1.08306244e-01, -3.67969554e-03,
          -9.23269056e-03,  5.04258312e-02,  4.19892445e-02,
           2.67384015e-02],

In [None]:
yhat, (W0, W1) = bbb_mlp(p0, X_test, rng)

In [None]:
nn.relu(X_test @ W0) @ W1

Array([[   0.75118184,   12.745816  ,   39.13999   , ...,  -11.419679  ,
         -27.065598  ,  -11.102991  ],
       [  -3.7900229 ,  -38.241516  ,   75.909615  , ...,    9.601346  ,
        -111.15672   ,   25.467857  ],
       [   3.63172   ,   15.827387  ,    1.7404846 , ...,   26.774538  ,
         -52.632885  ,   22.031975  ],
       ...,
       [ -72.39879   ,   26.60979   ,   67.92412   , ...,  -25.590082  ,
         -77.43488   ,  -71.45912   ],
       [  -0.95404387,    9.030612  ,   33.93918   , ...,   23.07801   ,
         -74.80061   ,  -38.07965   ],
       [ -71.674355  ,  -56.237396  ,   40.2892    , ...,   37.686012  ,
         -81.4109    ,   61.43129   ]], dtype=float32)

In [None]:
yhat

Array([[   0.75118184,   12.745816  ,   39.13999   , ...,  -11.419679  ,
         -27.065598  ,  -11.102991  ],
       [  -3.7900229 ,  -38.241516  ,   75.909615  , ...,    9.601346  ,
        -111.15672   ,   25.467857  ],
       [   3.63172   ,   15.827387  ,    1.7404846 , ...,   26.774538  ,
         -52.632885  ,   22.031975  ],
       ...,
       [ -72.39879   ,   26.60979   ,   67.92412   , ...,  -25.590082  ,
         -77.43488   ,  -71.45912   ],
       [  -0.95404387,    9.030612  ,   33.93918   , ...,   23.07801   ,
         -74.80061   ,  -38.07965   ],
       [ -71.674355  ,  -56.237396  ,   40.2892    , ...,   37.686012  ,
         -81.4109    ,   61.43129   ]], dtype=float32)

In [None]:
def test(p):
    tau = 10
    k0, k1 = jax.random.split(rng)

    W0 = p[0] + jnp.log(1 + jnp.exp(p[1])) * jax.random.normal(k0, p[0].shape)
    # W1 = p[1][0] + jnp.log(1 + jnp.exp(p[1][1])) * jax.random.normal(k1, l1)
    y_hat = X_test @ W0
    L = -jnp.mean(jnp.sum(y_test * nn.log_softmax(y_hat * tau), axis=-1)) #+ 0.5 * (jnp.sum(W0 ** 2) + jnp.sum(W1 ** 2))
    return L

jax.grad(test)((init_mu((28 ** 2, 10)), init_rho((28 ** 2, 10))))

(Array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 Array([[-0.,  0.,  0., ..., -0., -0., -0.],
        [ 0.,  0., -0., ..., -0.,  0.,  0.],
        [ 0.,  0., -0., ...,  0.,  0.,  0.],
        ...,
        [ 0.,  0., -0., ..., -0.,  0., -0.],
        [-0., -0.,  0., ..., -0.,  0.,  0.],
        [-0.,  0., -0., ..., -0.,  0.,  0.]], dtype=float32))

In [None]:
jax.grad(lambda mu: mu + jax.random.normal(rng))(2.0)

Array(1., dtype=float32, weak_type=True)

In [None]:
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))


In [None]:
jax.grad(lambda mu: softmax(jnp.array([5.0, 1.0]) @ (mu + jax.random.normal(rng, mu.shape))))(jnp.array([[2.0, 3.0],[0.4, 1.0]]))

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).

In [None]:
jax.grad(jnp.exp)(3.0)

Array(20.085537, dtype=float32, weak_type=True)

In [None]:
mu = (np.array([[2.0, 3.0],[0.4, 1.0]]) + np.random.randn(2, 2))
x = np.array([5.0, 1.0])
o = x @ mu
z = softmax(o)

In [None]:
jnp.fill_diagonal(-jnp.expand_dims(z, 1) * z, z * (1 - z))

NotImplementedError: Numpy function <function fill_diagonal at 0x7fe10848fee0> not yet implemented

In [None]:
import numpy as np

In [None]:
Z = -np.expand_dims(z, 1) * z
np.fill_diagonal(Z, z * (1 - z))

array([ 5.97289804e-06, -5.97289804e-06])

In [None]:
jnp.exp(-2)

Array(0.13533528, dtype=float32, weak_type=True)

In [None]:
jnp.sum(y_test * jnp.log(nn.softmax(y_hat)), axis=-1)

Array([-2.0488215e-04, -1.1140065e+00, -9.1537694e-03, ...,
       -8.7561179e-03, -2.9766288e+00, -2.9802368e-06], dtype=float32)