# Gradient-free Global Maximum Likelihood Finding: Neural Networks

Neural network training is typically done with maximum likelihood estimation. Given the number of parameter invariances, and symmetries in neural network architectures, this often introduces a large number of local minima, making global optimisation very difficult.

JAXNS uses slice sampling as a gradient-free way to sample from hard-likelihood constraints, starting from small likelihoods and strictly increasing towards a maximum likelihood. This actually means that JAXNS is performing global maximisation of the likelihood. The prior can be seen as a measure which guides where JAXNS looks first. An attractive idea is to think about the prior as a guide for efficient global maximisation with JAXNS, but that's for another tutorial ;).

## Overview

In this tutorial we'll cover:
1. How to build a JAXNS model of a neural network using [Haiku](https://github.com/deepmind/dm-haiku)
2. How to do global likelihood maximisation with JAXNS

In [1]:
try:
    import haiku as hk
except ImportError:
    raise ImportError("You must `pip install dm-haiku` first.")

try:
    from sklearn.metrics import roc_curve
except:
    raise ImportError("You must `pip install scikit-learn`")

from jax import numpy as jnp, random, vmap
import jax
from jax.flatten_util import ravel_pytree

from jaxns.prior_transforms import UniformPrior, PriorChain
from jaxns import GlobalOptimiser
from jax.scipy.optimize import minimize
from itertools import product
import pylab as plt

# for parallel sampling
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"



In [2]:
# Generate data

def xor_reduce(x):
    """
    Computes the XOR reduction on a sequence of bits.

    Examples:
        100 -> xor(xor(1,0),0) = 1
        001 -> xor(xor(0,0),1) = 1
        110 -> xor(xor(1,1),0) = 0
        011 -> xor(xor(0,1),1) = 0

    Args:
        x: boolean vector of bits.

    Returns:
        bool, scalar
    """
    output = x[0]
    for i in range(1, x.shape[-1]):
        output = jnp.logical_xor(output, x[i])
    return output


num_variables = 8
options = [True, False]
x = jnp.asarray(list(product(options, repeat=num_variables)))#N,2
y = vmap(xor_reduce)(x)[:, None]#N, 1
x = x.astype(jnp.float32)
print("Data:")

for input, output in zip(x,y):
    print(f"{input} -> {output}")




Data:
[1. 1. 1. 1. 1. 1. 1. 1.] -> [False]
[1. 1. 1. 1. 1. 1. 1. 0.] -> [ True]
[1. 1. 1. 1. 1. 1. 0. 1.] -> [ True]
[1. 1. 1. 1. 1. 1. 0. 0.] -> [False]
[1. 1. 1. 1. 1. 0. 1. 1.] -> [ True]
[1. 1. 1. 1. 1. 0. 1. 0.] -> [False]
[1. 1. 1. 1. 1. 0. 0. 1.] -> [False]
[1. 1. 1. 1. 1. 0. 0. 0.] -> [ True]
[1. 1. 1. 1. 0. 1. 1. 1.] -> [ True]
[1. 1. 1. 1. 0. 1. 1. 0.] -> [False]
[1. 1. 1. 1. 0. 1. 0. 1.] -> [False]
[1. 1. 1. 1. 0. 1. 0. 0.] -> [ True]
[1. 1. 1. 1. 0. 0. 1. 1.] -> [False]
[1. 1. 1. 1. 0. 0. 1. 0.] -> [ True]
[1. 1. 1. 1. 0. 0. 0. 1.] -> [ True]
[1. 1. 1. 1. 0. 0. 0. 0.] -> [False]
[1. 1. 1. 0. 1. 1. 1. 1.] -> [ True]
[1. 1. 1. 0. 1. 1. 1. 0.] -> [False]
[1. 1. 1. 0. 1. 1. 0. 1.] -> [False]
[1. 1. 1. 0. 1. 1. 0. 0.] -> [ True]
[1. 1. 1. 0. 1. 0. 1. 1.] -> [False]
[1. 1. 1. 0. 1. 0. 1. 0.] -> [ True]
[1. 1. 1. 0. 1. 0. 0. 1.] -> [ True]
[1. 1. 1. 0. 1. 0. 0. 0.] -> [False]
[1. 1. 1. 0. 0. 1. 1. 1.] -> [False]
[1. 1. 1. 0. 0. 1. 1. 0.] -> [ True]
[1. 1. 1. 0. 0. 1. 0. 1.] -> [ T

In [3]:
# Define the likelihood, using Haiku as our framework for neural networks

def model(x, is_training=False):
    mlp1 = hk.Sequential([hk.Linear(3),
                         jax.nn.sigmoid,
                          hk.Linear(1)])

    return mlp1(x)

model = hk.without_apply_rng(hk.transform(model))
# We must call the model once to get the params shape and type as a big pytree
# We then use ravel_pytree to flatten and get the unflatten function.
init_params = model.init(random.PRNGKey(2345), x)
init_params_flat, unravel_func = ravel_pytree(init_params)
n_dims = init_params_flat.size
print("Number of parameters:", n_dims)

def softplus(x):
    return jnp.log1p(jnp.exp(x))

def log_likelihood(params, **kwargs):
    """
    log(P(y|p))
    p = exp(logits)/1 - exp(logits)
    = log(p) * y + log(1-p) * (1-y)
    = logits * y1 - log(exp(-logits)/(exp(-logits) - 1)) * y0
    """
    params_dict = unravel_func(params)
    logits = model.apply(params_dict, x)
    log_prob0, log_prob1 = -softplus(logits), -softplus(-logits)
    #log(p) * y + log(1-p) * (1-y)
    log_prob = jnp.mean(jnp.where(y, log_prob1, log_prob0))
    return jnp.asarray(log_prob, jnp.float64)


Number of parameters: 31


In [None]:
# Let us compare the results of nested sampling to optimisation done with BFGS
num_random_init = 100
init_keys = random.split(random.PRNGKey(42), num_random_init)
params_bfgs = vmap(lambda key: minimize(lambda p: -log_likelihood(p),
                   random.normal(key, shape=(n_dims,)),
                   method='BFGS').x)(init_keys)
log_L_bfgs = vmap(log_likelihood)(params_bfgs)
idx_max = jnp.argmax(log_L_bfgs)
log_L_bfgs_max = log_L_bfgs[idx_max]
params_bfgs_max = params_bfgs[idx_max]
print(f"BFGS maximum likelihood solution of {num_random_init} tries: log(L) = {log_L_bfgs_max}")

  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")
  lax._check_user_dtype_supported(dtype, "asarray")


In [None]:
# Build the model

with PriorChain() as prior_chain:
    # we'll effectively place no prior on the parameters, other than requiring them to be within [-10,10]
    UniformPrior('params', -10.*jnp.ones(n_dims), 10.*jnp.ones(n_dims))

# We'll do some strange things here.
# num_slices -> low: We'll make the sampler do very few slices. This will lead to large correlation between samples, and poor estimate of the evidence.
# This is alright, because we'll be looking for the maximum likelihood solution.
go = GlobalOptimiser(loglikelihood=log_likelihood, prior_chain=prior_chain,
                   num_parallel_samplers=4, sampler_kwargs=dict(gradient_boost=False))


In [None]:
# Let's test the model with a small sanity check.
prior_chain.test_prior(random.PRNGKey(42), 10, log_likelihood)

In [None]:
results = go(random.PRNGKey(42),termination_frac_likelihood_improvement=1e-3,termination_patience=3,
             termination_max_num_likelihood_evaluations=10e6)

In [None]:
# The maximum likelihood solution from nested sampling
params_max = results.sample_L_max['params']
log_L_max = results.log_L_max
print("log L_max(L)", log_L_max)

In [None]:
go.summary(results)

In [None]:

def predict(params):
    params_dict = unravel_func(params)
    logits = model.apply(params_dict, x)
    return jax.nn.sigmoid(logits)[:,0]



In [None]:
predictions = predict(params_max)
print("Predictions of globally optimised NN:")
for i in range(len(y)):
    print(f"{i}: {x[i]} -> {y[i]} | pred: {predictions[i]}")

fpr, tpr, thresholds = roc_curve(y_true=y[:,0],y_score=predictions, pos_label=1)
metric = jnp.abs(1-tpr) + jnp.abs(fpr)
idx = plt.argmin(metric)
plt.plot(fpr, tpr)
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.scatter(fpr[idx], tpr[idx], label=f'optimal threshold {thresholds[idx]}')
plt.legend()
plt.show()

In [None]:
optimal_thresh = thresholds[idx]
classifications = (predictions > optimal_thresh)
accuracy = jnp.mean(classifications == y[:,0])
print(f"accuracy of globally optimised NN with optimal threshold: {accuracy}")

In [None]:
predictions = predict(params_bfgs_max)
print("Predictions of BFGS optimised NN:")
for i in range(len(y)):
    print(f"{i}: {x[i]} -> {y[i]} | pred: {predictions[i]}")

fpr, tpr, thresholds = roc_curve(y_true=y[:,0],y_score=predictions, pos_label=1)
metric = jnp.abs(1-tpr) + jnp.abs(fpr)
idx = plt.argmin(metric)
plt.plot(fpr, tpr)
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.scatter(fpr[idx], tpr[idx], label=f'optimal threshold {thresholds[idx]}')
plt.legend()
plt.show()

In [None]:
optimal_thresh = thresholds[idx]
classifications = (predictions > optimal_thresh)
accuracy = jnp.mean(classifications == y[:,0])
print(f"accuracy of BFGS optimised NN with optimal threshold: {accuracy}")