# Benchmarks - Hyper-parameters optimisation

**Main considerations when implementing HPs optimisation**
- we made kernels pytrees, so we should be able to compute gradient and optimise for them directly


---
## Setup

In [1]:
# Standard library
import os
from typing import NamedTuple

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
import jax.random as jrd
import optax
import optax.tree_utils as otu
import chex

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel, SEMagmaKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.hyperpost import hyperpost
#from MagmaClustPy.hp_optimisation import optimise_hyperparameters
from MagmaClustPy.likelihoods import magma_neg_likelihood

In [4]:
# Config
key = jrd.PRNGKey(0)
test_db_size = "small"

---
## Data

---
## Current implementation

In [5]:
# Taken from optax doc (https://optax.readthedocs.io/en/latest/_collections/examples/lbfgs.html#l-bfgs-solver)
class InfoState(NamedTuple):
	iter_num: chex.Numeric


def print_info():
	def init_fn(params):
		del params
		return InfoState(iter_num=0)

	def update_fn(updates, state, params, *, value, grad, **extra_args):
		del params, extra_args

		jax.debug.print(
			'Iteration: {i}, Value: {v}, Gradient norm: {e}',
			i=state.iter_num,
			v=value,
			e=otu.tree_l2_norm(grad),
		)
		return updates, InfoState(iter_num=state.iter_num + 1)

	return optax.GradientTransformationExtraArgs(init_fn, update_fn)


# Adapted from optax doc (https://optax.readthedocs.io/en/latest/_collections/examples/lbfgs.html#l-bfgs-solver)
def run_opt(init_params, fun, opt, max_iter, tol):
	value_and_grad_fun = optax.value_and_grad_from_state(fun)

	def step(carry):
		params, state, prev_llh = carry
		value, grad = value_and_grad_fun(params, state=state)
		updates, state = opt.update(grad, state, params, value=value, grad=grad, value_fn=fun)
		params = optax.apply_updates(params, updates)
		return params, state, value

	def continuing_criterion(carry):
		# tol is not computed on the gradients but on the difference between current and previous likelihoods, to
		# prevent overfitting on ill-defined likelihood functions where variance can blow up.
		_, state, prev_llh = carry
		iter_num = otu.tree_get(state, 'count')
		val = otu.tree_get(state, 'value')
		diff = jnp.abs(val - prev_llh)
		return (iter_num == 0) | ((iter_num < max_iter) & (diff >= tol))

	init_carry = (init_params, opt.init(init_params),
	              jnp.array(jnp.inf))  # kernel params, initial state, first iter, previous likelihood
	final_params, final_state, final_llh = jax.lax.while_loop(
		continuing_criterion, step, init_carry
	)
	return final_params, final_state, final_llh


def optimise_hyperparameters(mean_kernel, task_kernel, inputs, outputs, all_inputs, prior_mean, post_mean, post_cov,
                             masks, nugget=jnp.array(1e-10), max_iter=100, tol=1e-3, verbose=False):
	# Optimise mean kernel
	if verbose:
		mean_opt = optax.chain(print_info(), optax.lbfgs())
	else:
		mean_opt = optax.lbfgs()
	mean_fun_wrapper = lambda kern: magma_neg_likelihood(kern, all_inputs, prior_mean, post_mean, post_cov, mask=None,
	                                                     nugget=nugget)

	new_mean_kernel, _, _ = run_opt(mean_kernel, mean_fun_wrapper, mean_opt, max_iter=max_iter, tol=tol)

	# Optimise task kernel
	if verbose:
		task_opt = optax.chain(print_info(), optax.lbfgs())
	else:
		task_opt = optax.lbfgs()
	task_fun_wrapper = lambda kern: magma_neg_likelihood(kern, inputs, outputs, prior_mean, post_cov, mask=masks,
	                                                     nugget=nugget).sum()

	new_task_kernel, _, _ = run_opt(task_kernel, task_fun_wrapper, task_opt, max_iter=max_iter, tol=tol)

	return new_mean_kernel, new_task_kernel


---
## Custom implementation(s)

---
## Comparison

In [6]:
prior_mean = jnp.array(0)
nugget = jnp.array(1e-10)

### Common Input, Common HP

In [7]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_common_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((15,), (20, 15))

In [8]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [9]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [10]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 1088.8059707099683, Gradient norm: 1083.5718317034205
Iteration: 1, Value: 412.13104179487385, Gradient norm: 387.0471430737794
Iteration: 2, Value: 247.97383677138006, Gradient norm: 217.46216412345564
Iteration: 3, Value: 138.26870194194757, Gradient norm: 102.20707751192019
Iteration: 4, Value: 91.42961012748887, Gradient norm: 50.843206830962345
Iteration: 5, Value: 68.80705807098248, Gradient norm: 24.111700148530847
Iteration: 6, Value: 59.19113709675716, Gradient norm: 11.093214281074845
Iteration: 7, Value: 55.52616623403778, Gradient norm: 4.754958469521489
Iteration: 8, Value: 54.89939469111188, Gradient norm: 10.87976967787057
Iteration: 9, Value: 54.25388373605932, Gradient norm: 1.0986602413817192
Iteration: 10, Value: 54.23977621944245, Gradient norm: 0.48360126120269276
Iteration: 11, Value: 54.237009371681104, Gradient norm: 0.01882526795878109


In [11]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=0.9826130454639493, variance=6.151066513607268),
 SEMagmaKernel(length_scale=1.5999490215189538, variance=1.0100972453316084))

In [12]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.72 s ± 25.3 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Common Input, Distinct HP

In [13]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_common_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((15,), (20, 15))

In [14]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [15]:
distinct_length_scales.shape

(20,)

In [16]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [17]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 2180.25408700812, Gradient norm: 2298.8298728512023
Iteration: 1, Value: 776.9481659133234, Gradient norm: 778.3285480218974
Iteration: 2, Value: 467.7553198900272, Gradient norm: 448.609466510178
Iteration: 3, Value: 238.95826450931128, Gradient norm: 210.80130576390275
Iteration: 4, Value: 140.4660045516055, Gradient norm: 106.80064567606223
Iteration: 5, Value: 90.56710491302316, Gradient norm: 51.951200661182305
Iteration: 6, Value: 67.70868190765934, Gradient norm: 25.11142488647024
Iteration: 7, Value: 57.53952924566286, Gradient norm: 11.552518431963833
Iteration: 8, Value: 53.68055712757789, Gradient norm: 7.226308603566132
Iteration: 9, Value: 53.12588641306984, Gradient norm: 9.58137547069042
Iteration: 10, Value: 52.29040622809569, Gradient norm: 1.3656537336425274
Iteration: 11, Value: 52.14878717110657, Gradient norm: 0.26733119477027734
Iteration: 12, Value: 52.14265071871399, Gradient norm: 0.03848019574499035
Iteration: 0, Value: 9.471764371406239e+

In [18]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=1.1856918246895949, variance=6.816049817155252),
 SEMagmaKernel(length_scale=[1.43308798 1.44393616 1.43577344 1.43015607 1.43034229 1.4352805
  1.43638783 1.44390116 1.44705345 1.43860163 1.4090773  1.43410564
  1.4182356  1.4333545  1.42750424 1.42261946 1.44203239 1.43022241
  1.44429227 1.44275739], variance=1.3756287739182211))

In [19]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.78 s ± 31.1 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Distinct Input, Common HP

In [20]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((41,), (20, 41))

In [21]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [22]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [23]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 93883.35546043565, Gradient norm: 1194745.0537562363
Iteration: 1, Value: 2974.744120614934, Gradient norm: 3148.214376785104
Iteration: 2, Value: 2491.7251499679633, Gradient norm: 2609.0515343296424
Iteration: 3, Value: 1116.2495461511437, Gradient norm: 1082.403136794949
Iteration: 4, Value: 764.0029201857294, Gradient norm: 1105.0290129108764
Iteration: 5, Value: 570.921255489159, Gradient norm: 1403.87102684516
Iteration: 6, Value: 356.6130936636686, Gradient norm: 992.0110973415899
Iteration: 7, Value: 205.05306836019986, Gradient norm: 461.03370646928903
Iteration: 8, Value: 142.8731738210606, Gradient norm: 214.79124111175432
Iteration: 9, Value: 115.70298059020287, Gradient norm: 85.1437965271968
Iteration: 10, Value: 106.25762591722116, Gradient norm: 22.86426903099017
Iteration: 11, Value: 103.7650586126552, Gradient norm: 8.736323760266655
Iteration: 12, Value: 102.95825498057987, Gradient norm: 1.301709460177779
Iteration: 13, Value: 102.8991340876219,

In [24]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=-0.0071476759712833186, variance=6.639570750925357),
 SEMagmaKernel(length_scale=0.6, variance=1.0))

In [25]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.26 s ± 87.6 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Distinct Input, Distinct HP

In [26]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((41,), (20, 41))

In [27]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [28]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [29]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 63693.24263182963, Gradient norm: 935204.7971765217
Iteration: 1, Value: 936.9165952740885, Gradient norm: 930.1424022688172
Iteration: 2, Value: 826.1892973711264, Gradient norm: 813.0332957764756
Iteration: 3, Value: 361.23238370298935, Gradient norm: 317.8693494976895
Iteration: 4, Value: 233.44345265802076, Gradient norm: 208.48996186308386
Iteration: 5, Value: 168.2633939143924, Gradient norm: 282.21037786145416
Iteration: 6, Value: 141.5093737836419, Gradient norm: 340.7081490306301
Iteration: 7, Value: 110.34916051831858, Gradient norm: 240.78517808066476
Iteration: 8, Value: 90.96996170282848, Gradient norm: 113.10981083458589
Iteration: 9, Value: 85.10452211994607, Gradient norm: 47.36631658124324
Iteration: 10, Value: 83.81965938398909, Gradient norm: 16.27514127248603
Iteration: 11, Value: 83.69133665039448, Gradient norm: 5.6383183429130925
Iteration: 12, Value: 83.67553022693143, Gradient norm: 1.120123616884129


In [30]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=-0.0567281240906174, variance=5.454397379716924),
 SEMagmaKernel(length_scale=[0.70407369 0.78190355 0.15792831 0.92380933 0.55947948 0.78384972
  0.57594738 0.32603756 0.279596   0.62387592 0.14051311 0.74971683
  0.84952631 0.49412175 0.66092181 0.83649758 0.69980315 0.91777237
  0.63878216 0.40331981], variance=1.0))

In [31]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.22 s ± 42.1 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


---
## Conclusion

---