# Bayesian computations with Neural Networks

Neural network training is typically done with maximum likelihood estimation. Given the number of parameter invariances in neural network architectures, this often introduces a large number of local minima, making global optimisation very difficult.

JAXNS can easily compute the evidence and posterior of a neural network. The larger number of parameters means that we should be careful with how accurate we want this to be. With high precision, we'll need more likelihood evaluations.

## What we'll do in this notebook
1. Define a neural network model
2. Find its maximum likelihood parameters using Global Optimisation
3. Redefine it as a Bayesian neural network allowing biases to be sampled from a prior.
4. Compute the evidence and posterior of the parameters.


## Data

We'll use the N-bit majority problem as our data. This is a binary classification problem where the input is a sequence of bits and the output is 1 if the majority of the bits are 1, and 0 otherwise.

In [1]:
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=6"

try:
    import haiku as hk
except ImportError:
    print("You must `pip install dm-haiku` first.")
    raise

from itertools import product

import jax
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp
from jax import vmap

from jaxns import resample

tfpd = tfp.distributions


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [2]:
# Generate data

def n_bit_majority(x):
    return jnp.sum(x) > (x.size / 2)


num_variables = 7
options = [True, False]
x = jnp.asarray(list(product(options, repeat=num_variables)))  #N,2
y = vmap(n_bit_majority)(x)[:, None]  #N, 1
x = x.astype(jnp.float32)
print("Data:")

for input, output in zip(x, y):
    print(f"{input} -> {output}")


Data:
[1. 1. 1. 1. 1. 1. 1.] -> [ True]
[1. 1. 1. 1. 1. 1. 0.] -> [ True]
[1. 1. 1. 1. 1. 0. 1.] -> [ True]
[1. 1. 1. 1. 1. 0. 0.] -> [ True]
[1. 1. 1. 1. 0. 1. 1.] -> [ True]
[1. 1. 1. 1. 0. 1. 0.] -> [ True]
[1. 1. 1. 1. 0. 0. 1.] -> [ True]
[1. 1. 1. 1. 0. 0. 0.] -> [ True]
[1. 1. 1. 0. 1. 1. 1.] -> [ True]
[1. 1. 1. 0. 1. 1. 0.] -> [ True]
[1. 1. 1. 0. 1. 0. 1.] -> [ True]
[1. 1. 1. 0. 1. 0. 0.] -> [ True]
[1. 1. 1. 0. 0. 1. 1.] -> [ True]
[1. 1. 1. 0. 0. 1. 0.] -> [ True]
[1. 1. 1. 0. 0. 0. 1.] -> [ True]
[1. 1. 1. 0. 0. 0. 0.] -> [False]
[1. 1. 0. 1. 1. 1. 1.] -> [ True]
[1. 1. 0. 1. 1. 1. 0.] -> [ True]
[1. 1. 0. 1. 1. 0. 1.] -> [ True]
[1. 1. 0. 1. 1. 0. 0.] -> [ True]
[1. 1. 0. 1. 0. 1. 1.] -> [ True]
[1. 1. 0. 1. 0. 1. 0.] -> [ True]
[1. 1. 0. 1. 0. 0. 1.] -> [ True]
[1. 1. 0. 1. 0. 0. 0.] -> [False]
[1. 1. 0. 0. 1. 1. 1.] -> [ True]
[1. 1. 0. 0. 1. 1. 0.] -> [ True]
[1. 1. 0. 0. 1. 0. 1.] -> [ True]
[1. 1. 0. 0. 1. 0. 0.] -> [False]
[1. 1. 0. 0. 0. 1. 1.] -> [ True]
[1. 1. 0

In [3]:
from jaxns.internals.maps import pytree_unravel
from jaxns import Prior, Model
import jaxns.framework.context as ctx


def prior_model():
    def compute_logits(x):
        mlp = hk.Sequential([hk.Linear(1),
                             jax.nn.sigmoid,
                             hk.Linear(1)])
        return mlp(x)

    init, apply = hk.transform(compute_logits)
    init_params = init(random.PRNGKey(0), x)
    # Convert haiku to jaxns params
    ctx_params = ctx.convert_external_params(init_params, prefix='haiku_model')
    # Flatten, model, then unflatten to use
    ravel_fn, unravel_fn = pytree_unravel(ctx_params)
    ndims = ravel_fn(init_params).size
    flat_params = yield Prior(tfpd.Uniform(-10. * jnp.ones(ndims), 10. * jnp.ones(ndims)), name='flat_params')
    params = unravel_fn(flat_params)
    logits = apply(params, jax.random.PRNGKey(0), x)
    return logits


def log_likelihood(logits):
    # Classification probelm, so we use a Bernoulli likelihood
    return tfpd.Bernoulli(logits=logits).log_prob(y).mean()


model = Model(prior_model=prior_model, log_likelihood=log_likelihood)




In [4]:

model.sanity_check(random.PRNGKey(0), S=100)


INFO:jaxns:Sanity check...
INFO:jaxns:Sanity check passed


In [5]:
from jaxns.experimental import DefaultGlobalOptimisation, GlobalOptimisationTerminationCondition

go = DefaultGlobalOptimisation(model=model,
                               num_search_chains=model.U_ndims * 10,
                               num_parallel_workers=len(jax.devices()))

results = go(
    random.PRNGKey(0),
    GlobalOptimisationTerminationCondition(log_likelihood_contour=-0.01,
                                           max_likelihood_evaluations=2e6)
)
go.summary(results)


INFO:jaxns:Using 6 parallel workers, each running identical samplers.


--------
Termination Conditions:
Replica 0:
On plateau (possibly local minimum, or due to numerical issues)
Replica 1:
On plateau (possibly local minimum, or due to numerical issues)
Replica 2:
On plateau (possibly local minimum, or due to numerical issues)
Replica 3:
On plateau (possibly local minimum, or due to numerical issues)
Replica 4:
On plateau (possibly local minimum, or due to numerical issues)
Replica 5:
Reached goal log-likelihood contour
On plateau (possibly local minimum, or due to numerical issues)
--------
likelihood evals: 504890
samples: 20100
likelihood evals / sample: 25.1
--------
max(log_L)=-0.028
relative spread: 6.1e-05
absolute spread: 1.7e-06
--------
flat_params[#]: max(L) est.
flat_params[0]: 10.0
flat_params[1]: -2.91
flat_params[2]: -2.91
flat_params[3]: -2.91
flat_params[4]: -2.91
flat_params[5]: -2.91
flat_params[6]: -2.91
flat_params[7]: -2.91
flat_params[8]: 4.74
flat_params[9]: -10.0
--------


In [6]:
logits = model.prepare_input(results.U_solution)[0]
predictions = jax.nn.sigmoid(logits)
for i in range(len(y)):
    pred = predictions[i] > 0.5
    print(f"{i}: {x[i]} -> {y[i]} | pred: {pred} {'✓' if pred == y[i] else '✗'}")

accuracy = jnp.mean((predictions > 0.5) == y)
print(f"Accuracy: {accuracy * 100:.1f}%")

0: [1. 1. 1. 1. 1. 1. 1.] -> [ True] | pred: [ True] ✓
1: [1. 1. 1. 1. 1. 1. 0.] -> [ True] | pred: [ True] ✓
2: [1. 1. 1. 1. 1. 0. 1.] -> [ True] | pred: [ True] ✓
3: [1. 1. 1. 1. 1. 0. 0.] -> [ True] | pred: [ True] ✓
4: [1. 1. 1. 1. 0. 1. 1.] -> [ True] | pred: [ True] ✓
5: [1. 1. 1. 1. 0. 1. 0.] -> [ True] | pred: [ True] ✓
6: [1. 1. 1. 1. 0. 0. 1.] -> [ True] | pred: [ True] ✓
7: [1. 1. 1. 1. 0. 0. 0.] -> [ True] | pred: [ True] ✓
8: [1. 1. 1. 0. 1. 1. 1.] -> [ True] | pred: [ True] ✓
9: [1. 1. 1. 0. 1. 1. 0.] -> [ True] | pred: [ True] ✓
10: [1. 1. 1. 0. 1. 0. 1.] -> [ True] | pred: [ True] ✓
11: [1. 1. 1. 0. 1. 0. 0.] -> [ True] | pred: [ True] ✓
12: [1. 1. 1. 0. 0. 1. 1.] -> [ True] | pred: [ True] ✓
13: [1. 1. 1. 0. 0. 1. 0.] -> [ True] | pred: [ True] ✓
14: [1. 1. 1. 0. 0. 0. 1.] -> [ True] | pred: [ True] ✓
15: [1. 1. 1. 0. 0. 0. 0.] -> [False] | pred: [False] ✓
16: [1. 1. 0. 1. 1. 1. 1.] -> [ True] | pred: [ True] ✓
17: [1. 1. 0. 1. 1. 1. 0.] -> [ True] | pred: [ True] ✓
18