# Benchmarks - Hyper-parameters optimisation

**Main considerations when implementing HPs optimisation**
- we made kernels pytrees, so we should be able to compute gradient and optimise for them directly


---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
import jax.random as jrd
from jax.scipy.linalg import cho_factor, cho_solve
from jax.scipy.stats.multivariate_normal import logpdf
from jax import lax, tree
import optax
import optax.tree_utils as otu

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel, SEMagmaKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.hyperpost import hyperpost

In [4]:
# Config
key = jrd.PRNGKey(0)
test_db_size = "medium"

---
## Data

---
## Current implementation

In [5]:
@jit
def solve_right_cholesky(A, B, nugget):
	""" Solves for X in X @ A = B """
	# For X @ A = B, we can transpose both sides: A.T @ X.T = B.T
	# As A and B are symmetric, this simplifies to A @ X.T = B
	# Then solve for X.T and transpose the result
	return cho_solve(cho_factor(A + nugget), B).T

@jit
def magma_neg_likelihood_on_cov(covar, inputs, outputs, mean, mean_process_cov, mask=None, nugget=jnp.array(1e-10)):
	nugget_matrix = jnp.eye(inputs.shape[0]) * nugget

	if mask is not None:
		# Mask the covariance matrix and outputs
		mask_2D = mask[:, None] & mask[None, :]
		covar = jnp.where(mask_2D, covar, jnp.eye(inputs.shape[0]))
		outputs = jnp.where(mask, outputs, 0)

	# Compute log-likelihood
	multiv_log_lik = logpdf(outputs, mean, covar + nugget_matrix)

	# Compute correction term
	correction = 0.5 * jnp.trace(solve_right_cholesky(covar, mean_process_cov, nugget))

	if mask is not None:
		# Correct log-likelihood for padding
		# The logpdf is computed as:
		# -0.5 * (N * log(2 * pi) + log(det(cov)) + (outputs - mean).T @ inv(cov) @ (outputs - mean))
		# det(cov) and the Mahalanobis distance are not affected by our padding
		# We only have to correct for the -0.5 * N * log(2 * pi) term, as N is bigger with padding
		multiv_log_lik += 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(~mask, axis=0)

		# We also need to correct the correction term, as padding adds 1s to the diagonal and hence 1 to the trace
		correction -= 0.5 * jnp.sum(~mask, axis=0)
	return - (multiv_log_lik - correction)


@jit
def magma_neg_likelihood(kernel, inputs, outputs: jnp.array, mean: jnp.array, mean_process_cov: jnp.array, mask=None, nugget=jnp.array(1e-10)):
	"""
	Computes the MAGMA log-likelihood.

	:param kernel: the kernel containing HPs to optimise. This kernel is used to compute the covariance (matrix `S`)
	:param inputs: inputs on which to compute the covariance matrix (shape (N, ))
	:param mask: boolean masks indicating which inputs and outputs to consider (shape (N, ))
	:param outputs: the observed values (shape (N, ))
	:param mean: the mean over the inputs (scalar or vector of shape (N, ))
	:param mean_process_cov: the hypper-posterior mean process covariance (matrix K^t)
	:param nugget: the nugget, for numerical stability

	:return: the negative log-likelihood (scalar)
	"""
	covar = kernel(inputs)

	# check if we need to vmap
	if inputs.ndim == 1:
		return magma_neg_likelihood_on_cov(covar, inputs, outputs, mean, mean_process_cov, mask, nugget)
	elif inputs.ndim == 2:
		return vmap(magma_neg_likelihood_on_cov, in_axes=(0, 0, 0, None, None, 0, None))(covar, inputs, outputs, mean, mean_process_cov, mask, nugget)

---
## Custom implementation(s)

In [6]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [7]:
nugget = jnp.array(1e-6)
prior_mean = jnp.zeros_like(all_inputs)

### Baseline Adam

In [8]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [9]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [10]:
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
opt_state = optimizer.init(mean_kern)

In [11]:
previous_likelihood = jnp.inf
conv_threshold = 1e-5

# A simple update loop.
for i in range(10000):
	likelihood, grads = jax.value_and_grad(magma_neg_likelihood)(mean_kern, all_inputs, prior_mean, post_mean, post_cov, nugget=nugget)
	updates, opt_state = optimizer.update(grads, opt_state)
	mean_kern = optax.apply_updates(mean_kern, updates)

	if i % 100 == 0:
		print(f"Iter:{i:5}\t Neg Log-Likelihood: {likelihood:12.4f}\t HPs: {mean_kern}")

	if jnp.abs(previous_likelihood - likelihood) < conv_threshold:
		print(f"Converged at iteration {i}\t Neg Log-Likelihood: {likelihood:12.4f}\t HPs: {mean_kern}")
		break

	previous_likelihood = likelihood

Iter:    0	 Neg Log-Likelihood:   17848.5853	 HPs: SEMagmaKernel(length_scale=0.3999999999998523, variance=1.0999999999999455)
Iter:  100	 Neg Log-Likelihood:     341.8618	 HPs: SEMagmaKernel(length_scale=0.4441807841214679, variance=4.602767161536733)
Iter:  200	 Neg Log-Likelihood:     187.6302	 HPs: SEMagmaKernel(length_scale=0.46715963691364437, variance=5.268742081389386)
Iter:  300	 Neg Log-Likelihood:     159.2269	 HPs: SEMagmaKernel(length_scale=0.4790159152494407, variance=5.599543044545862)
Iter:  400	 Neg Log-Likelihood:     152.8114	 HPs: SEMagmaKernel(length_scale=0.4852799265802575, variance=5.773519589027679)
Iter:  500	 Neg Log-Likelihood:     151.3450	 HPs: SEMagmaKernel(length_scale=0.4879188533021306, variance=5.862980168672663)
Iter:  600	 Neg Log-Likelihood:     151.1004	 HPs: SEMagmaKernel(length_scale=0.48972171819176563, variance=5.9062426337739025)
Converged at iteration 679	 Neg Log-Likelihood:     150.9648	 HPs: SEMagmaKernel(length_scale=0.4907364932686867, 

### L-BFGS

In [12]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [13]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [14]:
optimizer = optax.lbfgs()
opt_state = optimizer.init(mean_kern)

In [15]:
previous_likelihood = jnp.inf
conv_threshold = 1e-5

# A simple update loop.
for i in range(100):
	likelihood, grads = jax.value_and_grad(magma_neg_likelihood)(mean_kern, all_inputs, prior_mean, post_mean, post_cov, nugget=nugget)
	fun_wrapper = lambda kern: magma_neg_likelihood(kern, all_inputs, prior_mean, post_mean, post_cov, nugget=nugget)
	updates, opt_state = optimizer.update(grads, opt_state, mean_kern, value=likelihood, grad=grads, value_fn=fun_wrapper)
	mean_kern = optax.apply_updates(mean_kern, updates)

	print(f"Iter:{i:5}\t Neg Log-Likelihood: {likelihood:12.4f}\t HPs: {mean_kern}")

	if jnp.abs(previous_likelihood - likelihood) < conv_threshold:
		print(f"Converged at iteration {i}\t Neg Log-Likelihood: {likelihood:12.4f}\t HPs: {mean_kern}")
		break

	previous_likelihood = likelihood

Iter:    0	 Neg Log-Likelihood:   17848.5853	 HPs: SEMagmaKernel(length_scale=0.47330374343774223, variance=1.469005130580109)
Iter:    1	 Neg Log-Likelihood:   12393.8630	 HPs: SEMagmaKernel(length_scale=0.39845433448555545, variance=1.6737251378483244)
Iter:    2	 Neg Log-Likelihood:    8725.7159	 HPs: SEMagmaKernel(length_scale=0.39778276962505754, variance=2.21774073616172)
Iter:    3	 Neg Log-Likelihood:    4893.0008	 HPs: SEMagmaKernel(length_scale=0.39968225809419816, variance=2.9485555681072912)
Iter:    4	 Neg Log-Likelihood:    2200.6780	 HPs: SEMagmaKernel(length_scale=0.4042968755076525, variance=3.587388084849157)
Iter:    5	 Neg Log-Likelihood:    1070.0281	 HPs: SEMagmaKernel(length_scale=0.41352815552162486, variance=4.223307060217768)
Iter:    6	 Neg Log-Likelihood:     519.3376	 HPs: SEMagmaKernel(length_scale=0.42947694242168744, variance=4.795398487784609)
Iter:    7	 Neg Log-Likelihood:     283.6355	 HPs: SEMagmaKernel(length_scale=0.45532617881530535, variance=5.2

### L-BFGS with lax while_loop

In [16]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [17]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [18]:
optimizer = optax.lbfgs()
opt_state = optimizer.init(mean_kern)

In [19]:
# Taken from optax doc (https://optax.readthedocs.io/en/latest/_collections/examples/lbfgs.html#l-bfgs-solver)
def run_opt(init_params, fun, opt, max_iter, tol):
  value_and_grad_fun = optax.value_and_grad_from_state(fun)

  def step(carry):
    params, state = carry
    value, grad = value_and_grad_fun(params, state=state)
    updates, state = opt.update(grad, state, params, value=value, grad=grad, value_fn=fun)
    params = optax.apply_updates(params, updates)
    return params, state

  def continuing_criterion(carry):
    _, state = carry
    iter_num = otu.tree_get(state, 'count')
    grad = otu.tree_get(state, 'grad')
    err = otu.tree_l2_norm(grad)
    return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

  init_carry = (init_params, opt.init(init_params))
  final_params, final_state = jax.lax.while_loop(
      continuing_criterion, step, init_carry
  )
  return final_params, final_state

In [20]:
def optimise_hyperparameters(mean_kernel, task_kernel, inputs, outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=jnp.array(1e-10)):
	big_nugget = jnp.eye(all_inputs.shape[0]) * nugget

	# Optimise mean kernel
	mean_opt = optax.lbfgs()
	mean_fun_wrapper = lambda kern: magma_neg_likelihood(kern, all_inputs, prior_mean, post_mean, post_cov, mask=None, nugget=nugget)

	new_mean_kernel, _ = run_opt(mean_kernel, mean_fun_wrapper, mean_opt, max_iter=100, tol=1e-3)

	# Optimise task kernel
	task_opt = optax.lbfgs()
	task_fun_wrapper = lambda kern: magma_neg_likelihood(kern, inputs, outputs, prior_mean, post_cov, mask=masks, nugget=nugget).mean()

	new_task_kernel, _ = run_opt(task_kernel, task_fun_wrapper, task_opt, max_iter=100, tol=1e-3)

	return new_mean_kernel, new_task_kernel

In [21]:
fun_wrapper = lambda kern: magma_neg_likelihood(kern, padded_inputs, padded_outputs, prior_mean, post_cov, masks, nugget=nugget).mean()

In [22]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)

In [24]:
new_mean_kern

SEMagmaKernel(length_scale=0.4909671890368215, variance=5.937835806359825)

In [25]:
new_task_kern

RBFKernel(length_scale=1.0132287879863946, variance=693.2209726643256)

---
## Comparison

### Common Input, Common HP

---
## Conclusion

---