# Gradient-free Global Maximum Likelihood Finding: Neural Networks

Neural network training is typically done with maximum likelihood estimation. Given the number of parameter invariances, and symmetries in neural network architectures, this often introduces a large number of local minima, making global optimisation very difficult.

JAXNS uses slice sampling as a gradient-free way to sample from hard-likelihood constraints, starting from small likelihoods and strictly increasing towards a maximum likelihood. This actually means that JAXNS is performing global maximisation of the likelihood. The prior can be seen as a measure which guides where JAXNS looks first. An attractive idea is to think about the prior as a guide for efficient global maximisation with JAXNS, but that's for another tutorial ;).

## Overview

In this tutorial we'll cover:
1. How to build a JAXNS model of a neural network using [Haiku](https://github.com/deepmind/dm-haiku)
2. How to do global likelihood maximisation with JAXNS

In [1]:
try:
    import haiku as hk
except ImportError:
    print("You must `pip install dm-haiku` first.")

from jax import numpy as jnp, random, jit, vmap
import jax
from jax.flatten_util import ravel_pytree

from jaxns.prior_transforms import UniformPrior, PriorChain
from jaxns.nested_sampler.nested_sampling import NestedSampler
from jaxns.plotting import plot_diagnostics
from jaxns.utils import summary
from jax.scipy.optimize import minimize
from itertools import product


In [2]:
# Generate data

def xor_reduce(x):
    """
    Computes the XOR reduction on a sequence of bits.

    Examples:
        100 -> xor(xor(1,0),0) = 1
        001 -> xor(xor(0,0),1) = 1
        110 -> xor(xor(1,1),0) = 0
        011 -> xor(xor(0,1),1) = 0

    Args:
        x: boolean vector of bits.

    Returns:
        bool
    """
    output = x[0]
    for i in range(1, x.shape[-1]):
        output = jnp.logical_xor(output, x[i])
    return output


num_variables = 7
options = [True, False]
x = jnp.asarray(list(product(options, repeat=num_variables)))#N,2
y = vmap(xor_reduce)(x)[:, None]#N, 1
x = x.astype(jnp.float32)
print("Data:")

for input, output in zip(x,y):
    print(f"{input} -> {output}")


INFO[2022-03-04 15:01:40,804]: Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO[2022-03-04 15:01:40,804]: Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO[2022-03-04 15:01:40,805]: Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


Data:
[1. 1. 1. 1. 1. 1. 1.] -> [ True]
[1. 1. 1. 1. 1. 1. 0.] -> [False]
[1. 1. 1. 1. 1. 0. 1.] -> [False]
[1. 1. 1. 1. 1. 0. 0.] -> [ True]
[1. 1. 1. 1. 0. 1. 1.] -> [False]
[1. 1. 1. 1. 0. 1. 0.] -> [ True]
[1. 1. 1. 1. 0. 0. 1.] -> [ True]
[1. 1. 1. 1. 0. 0. 0.] -> [False]
[1. 1. 1. 0. 1. 1. 1.] -> [False]
[1. 1. 1. 0. 1. 1. 0.] -> [ True]
[1. 1. 1. 0. 1. 0. 1.] -> [ True]
[1. 1. 1. 0. 1. 0. 0.] -> [False]
[1. 1. 1. 0. 0. 1. 1.] -> [ True]
[1. 1. 1. 0. 0. 1. 0.] -> [False]
[1. 1. 1. 0. 0. 0. 1.] -> [False]
[1. 1. 1. 0. 0. 0. 0.] -> [ True]
[1. 1. 0. 1. 1. 1. 1.] -> [False]
[1. 1. 0. 1. 1. 1. 0.] -> [ True]
[1. 1. 0. 1. 1. 0. 1.] -> [ True]
[1. 1. 0. 1. 1. 0. 0.] -> [False]
[1. 1. 0. 1. 0. 1. 1.] -> [ True]
[1. 1. 0. 1. 0. 1. 0.] -> [False]
[1. 1. 0. 1. 0. 0. 1.] -> [False]
[1. 1. 0. 1. 0. 0. 0.] -> [ True]
[1. 1. 0. 0. 1. 1. 1.] -> [ True]
[1. 1. 0. 0. 1. 1. 0.] -> [False]
[1. 1. 0. 0. 1. 0. 1.] -> [False]
[1. 1. 0. 0. 1. 0. 0.] -> [ True]
[1. 1. 0. 0. 0. 1. 1.] -> [False]
[1. 1. 0

In [3]:
# Define the likelihood, using Haiku as our framework for neural networks

def model(x, is_training=False):
    mlp = hk.Sequential([hk.Linear(4),
                         jax.nn.sigmoid,
                         hk.Linear(1)])
    return mlp(x)

model = hk.without_apply_rng(hk.transform(model))
# We must call the model once to get the params shape and type as a big pytree
# We then use ravel_pytree to flatten and get the unflatten function.
init_params = model.init(random.PRNGKey(2345), x)
init_params_flat, unravel_func = ravel_pytree(init_params)
n_dims = init_params_flat.size
print("Number of parameters:", n_dims)

def softplus(x):
    return jnp.log1p(jnp.exp(x))

def log_likelihood(params, **kwargs):
    """
    log(P(y|p))
    p = exp(logits)/1 - exp(logits)
    = log(p) * y + log(1-p) * (1-y)
    = logits * y1 - log(exp(-logits)/(exp(-logits) - 1)) * y0
    """
    params_dict = unravel_func(params)
    logits = model.apply(params_dict, x)
    log_prob0, log_prob1 = -softplus(logits), -softplus(-logits)
    #log(p) * y + log(1-p) * (1-y)
    log_prob = jnp.mean(jnp.where(y, log_prob1, log_prob0))
    return jnp.asarray(log_prob, jnp.float64)


Number of parameters: 37


In [4]:
# Let us compare the results of nested sampling to optimisation done with BFGS

params_bfgs = minimize(lambda p: -log_likelihood(p),
                   random.normal(random.PRNGKey(2435), shape=(n_dims,)),
                   method='BFGS').x
print(f"BFGS maximum likelihood solution: log(L) = {log_likelihood(params_bfgs)}")

  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")


BFGS maximum likelihood solution: log(L) = -0.6931479573249817


  lax._check_user_dtype_supported(dtype, "asarray")


In [5]:
# Build the model

with PriorChain() as prior_chain:
    # we'll effectively place no prior on the parameters, other than requiring them to be within [-10,10]
    UniformPrior('params', -10.*jnp.ones(n_dims), 10.*jnp.ones(n_dims))

# We'll do some strange things here.
# num_slices -> low: We'll make the sampler do very few slices. This will lead to large correlation between samples, and poor estimate of the evidence.
# This is alright, because we'll be looking for the maximum likelihood solution.
ns = NestedSampler(loglikelihood=log_likelihood, prior_chain=prior_chain, dynamic=True, sampler_kwargs=dict(num_slices=prior_chain.U_ndims))


  lax._check_user_dtype_supported(dtype, "zeros")


In [6]:
# Let's test the model with a small sanity check.
prior_chain.test_prior(random.PRNGKey(42), 10, log_likelihood)

  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,769]: Log-likelihood: -2.900855541229248
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,781]: Log-likelihood: -4.551949501037598
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,791]: Log-likelihood: -2.183030605316162
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,809]: Log-likelihood: -8.703441619873047
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,825]: Log-likelihood: -8.483926773071289
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,841]: Log-likelihood: -2.484107494354248
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,857]: Log-likelihood: -3.687208890914917
  lax._check_user_dtype_supported(dtype, "asarray")
INFO[2022-03-04 15:01:46,872]: Log-likelihood: -5.963511943817139
  lax._check_user_dtype_supported(dtype, "asarray")
INFO

In [7]:
# We do another strange thing here: we set the number of live points really low, which is okay because we don't care about evidence (yet)
ns = jit(ns, static_argnames='maximise_likelihood')
results = ns(random.PRNGKey(42),
                  num_live_points=prior_chain.U_ndims*10,
                  delta_num_live_points=prior_chain.U_ndims*10,
                  termination_likelihood_frac_increase=0.1,
             maximise_likelihood=True)

  lax._check_user_dtype_supported(dtype, "zeros")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "astype")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "zeros")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "astype")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "zeros")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "full")
  lax._check_user_dtype_supported(dtype, "zeros")
  lax._check_user_dtype_supported(dty

KeyboardInterrupt: 

In [None]:
# The maximum likelihood solution from nested sampling
i_max = jnp.argmax(results.log_L_samples)
params_max = results.samples['params'][i_max]
print("log L_max(L)", log_likelihood(params_max))

In [None]:
summary(results)
plot_diagnostics(results)