# Benchmarks - Hyper-parameters optimisation

**Main considerations when implementing HPs optimisation**
- we made kernels pytrees, so we should be able to compute gradient and optimise for them directly


---
## Setup

In [1]:
# Standard library
import os
from typing import NamedTuple

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
import jax.random as jrd
from jax.tree_util import tree_flatten
import optax
import optax.tree_utils as otu
import chex

import numpy as np
import pandas as pd

In [3]:
# Local
from Kernax import SEMagmaKernel, NoisySEMagmaKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.hyperpost import hyperpost

In [4]:
# Config
key = jrd.PRNGKey(0)
test_db_size = "small"

---
## Data

---
## Current implementation

In [5]:
import MagmaClustPy
optimise_hyperparameters_old = MagmaClustPy.hp_optimisation.optimise_hyperparameters

---
## Custom implementation(s)

In [6]:
from typing import NamedTuple

import chex
import jax
import jax.numpy as jnp
import optax
import optax.tree_utils as otu

from MagmaClustPy.likelihoods import magma_neg_likelihood


# Taken from optax doc (https://optax.readthedocs.io/en/latest/_collections/examples/lbfgs.html#l-bfgs-solver)
class InfoState(NamedTuple):
	iter_num: chex.Numeric


def print_info():
	def init_fn(params):
		del params
		return InfoState(iter_num=0)

	def update_fn(updates, state, params, *, value, grad, **extra_args):
		del params, extra_args

		jax.debug.print(
			'Iteration: {i}, Value: {v}, Gradient norm: {e}',
			i=state.iter_num,
			v=value,
			e=otu.tree_norm(grad),
		)
		return updates, InfoState(iter_num=state.iter_num + 1)

	return optax.GradientTransformationExtraArgs(init_fn, update_fn)


# Adapted from optax doc (https://optax.readthedocs.io/en/latest/_collections/examples/lbfgs.html#l-bfgs-solver)
def run_opt(init_params, fun, opt, max_iter, tol):
	value_and_grad_fun = optax.value_and_grad_from_state(fun)

	def step(carry):
		params, state, prev_llh = carry
		value, grad = value_and_grad_fun(params, state=state)
		updates, state = opt.update(grad, state, params, value=value, grad=grad, value_fn=fun)
		params = optax.apply_updates(params, updates)
		return params, state, value

	def continuing_criterion(carry):
		# tol is not computed on the gradients but on the difference between current and previous likelihoods, to
		# prevent overfitting on ill-defined likelihood functions where variance can blow up.
		_, state, prev_llh = carry
		iter_num = otu.tree_get(state, 'count')
		val = otu.tree_get(state, 'value')
		diff = jnp.abs(val - prev_llh)
		return (iter_num == 0) | ((iter_num < max_iter) & (diff >= tol))

	init_carry = (init_params, opt.init(init_params),
	              jnp.array(jnp.inf))  # kernel params, initial state, first iter, previous likelihood
	final_params, final_state, final_llh = jax.lax.while_loop(
		continuing_criterion, step, init_carry
	)
	return final_params, final_state, final_llh


def optimise_hyperparameters(mean_kernel, task_kernel, inputs, outputs, all_inputs, prior_mean, post_mean, post_cov,
                             mappings, nugget=jnp.array(1e-10), max_iter=100, tol=1e-3, verbose=False):
	# Optimise mean kernel
	if verbose:
		mean_opt = optax.chain(print_info(), optax.lbfgs())
	else:
		mean_opt = optax.lbfgs()

	def mean_fun_wrapper(kern):
		res = magma_neg_likelihood(kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
		return res

	new_mean_kernel, _, mean_llh = run_opt(mean_kernel, mean_fun_wrapper, mean_opt, max_iter=max_iter, tol=tol)

	# Optimise task kernel
	if verbose:
		task_opt = optax.chain(print_info(), optax.lbfgs())
	else:
		task_opt = optax.lbfgs()

	def task_fun_wrapper(kern):
		res = magma_neg_likelihood(kern, inputs, outputs, post_mean, post_cov, mappings, nugget=nugget).sum()
		return res

	new_task_kernel, _, task_llh = run_opt(task_kernel, task_fun_wrapper, task_opt, max_iter=max_iter, tol=tol)

	return new_mean_kernel, new_task_kernel, mean_llh, task_llh

In [7]:
optimise_hyperparameters_new = optimise_hyperparameters

---
## Comparison

In [8]:
nugget = jnp.array(1e-10)

### shared Input, shared HP

In [9]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_shared_hp.csv")
padded_inputs, padded_outputs, masks, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((15, 1), (20, 15, 1))

In [10]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)
task_kern = NoisySEMagmaKernel(length_scale=.6, variance=1., noise=-2.5)

In [11]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((15,), (15, 15))

In [12]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 1084.5687500717868, Gradient norm: 1080.7032171959315
Iteration: 1, Value: 409.95587984144805, Gradient norm: 385.5030801587251
Iteration: 2, Value: 246.86250594196835, Gradient norm: 216.6894138247774
Iteration: 3, Value: 137.6512580287872, Gradient norm: 101.8502177865421
Iteration: 4, Value: 91.00434593361065, Gradient norm: 50.67620059246727
Iteration: 5, Value: 68.46835251070897, Gradient norm: 24.038302169580163
Iteration: 6, Value: 58.88841157718085, Gradient norm: 11.058114972926345
Iteration: 7, Value: 55.240766286669405, Gradient norm: 4.785243082447486
Iteration: 8, Value: 54.852822463406476, Gradient norm: 14.295815102158798
Iteration: 9, Value: 53.98775673645667, Gradient norm: 0.7948074113548755
Iteration: 10, Value: 53.961940867368675, Gradient norm: 0.4244808202510211
Iteration: 11, Value: 53.957169239573815, Gradient norm: 0.08596669492212126
Iteration: 0, Value: 661.1004876873047, Gradient norm: 193.50584064365333
Iteration: 1, Value: 587.45196070

In [13]:
optimized_mean_kern_old, optimized_task_kern_old

(SEMagmaKernel(length_scale=1.001548464785164, variance=6.1442103965391865),
 NoisySEMagmaKernel(length_scale=0.8098982795863472, variance=2.0507074807095433, noise=-1.746510595053677))

In [14]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 1084.5687500717868, Gradient norm: 1080.7032171959315
Iteration: 1, Value: 409.95587984144805, Gradient norm: 385.5030801587251
Iteration: 2, Value: 246.86250594196835, Gradient norm: 216.6894138247774
Iteration: 3, Value: 137.6512580287872, Gradient norm: 101.8502177865421
Iteration: 4, Value: 91.00434593361065, Gradient norm: 50.67620059246727
Iteration: 5, Value: 68.46835251070897, Gradient norm: 24.038302169580163
Iteration: 6, Value: 58.88841157718085, Gradient norm: 11.058114972926345
Iteration: 7, Value: 55.240766286669405, Gradient norm: 4.785243082447486
Iteration: 8, Value: 54.852822463406476, Gradient norm: 14.295815102158798
Iteration: 9, Value: 53.98775673645667, Gradient norm: 0.7948074113548755
Iteration: 10, Value: 53.961940867368675, Gradient norm: 0.4244808202510211
Iteration: 11, Value: 53.957169239573815, Gradient norm: 0.08596669492212126
Iteration: 0, Value: 661.1004876873047, Gradient norm: 193.50584064365333
Iteration: 1, Value: 587.45196070

In [15]:
optimized_mean_kern_new, optimized_task_kern_new

(SEMagmaKernel(length_scale=1.001548464785164, variance=6.1442103965391865),
 NoisySEMagmaKernel(length_scale=0.8098982795863472, variance=2.0507074807095433, noise=-1.746510595053677))

In [16]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

894 ms ± 26.1 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


In [17]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

868 ms ± 1.89 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### shared Input, Distinct HP

In [18]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_distinct_hp.csv")
padded_inputs, padded_outputs, masks, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((15, 1), (20, 15, 1))

In [19]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = NoisySEMagmaKernel(length_scale=distinct_length_scales, variance=1., noise=-2.5)

In [20]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((15,), (15, 15))

In [21]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 2171.3401026970214, Gradient norm: 2292.932573561129
Iteration: 1, Value: 772.2235778726487, Gradient norm: 774.9208807017385
Iteration: 2, Value: 465.45374367751475, Gradient norm: 446.4539415651758
Iteration: 3, Value: 237.81603938688698, Gradient norm: 209.83364989406658
Iteration: 4, Value: 139.84666615858947, Gradient norm: 106.29635352398901
Iteration: 5, Value: 90.24385621796176, Gradient norm: 51.66078153563649
Iteration: 6, Value: 67.56695471283092, Gradient norm: 24.964188274601558
Iteration: 7, Value: 57.513897692075, Gradient norm: 11.443416048585918
Iteration: 8, Value: 53.90527695746368, Gradient norm: 9.527867511759366
Iteration: 9, Value: 53.018747284850456, Gradient norm: 4.0072953382639485
Iteration: 10, Value: 52.33267227490659, Gradient norm: 3.23567685069672
Iteration: 11, Value: 52.22491924518004, Gradient norm: 1.6050725498024332
Iteration: 12, Value: 52.21527601311167, Gradient norm: 0.15462279751671215
Iteration: 0, Value: 650.5853278758524

In [22]:
optimized_mean_kern_old, optimized_task_kern_old

(SEMagmaKernel(length_scale=1.17433086080464, variance=6.796909630639868),
 NoisySEMagmaKernel(length_scale=[ 1.214734    1.17188189  1.23820641  0.97111138  0.46351171  1.40309787
   0.70215186  1.49056998  0.97386099 -0.05430196  1.08243636  1.0596998
   1.56520916  0.82855608  0.65751824  1.46100638  1.34999146  0.72772361
   1.23905972  1.04273212], variance=2.15882736931959, noise=-2.1751331459541348))

In [23]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 2171.3401026970214, Gradient norm: 2292.932573561129
Iteration: 1, Value: 772.2235778726487, Gradient norm: 774.9208807017385
Iteration: 2, Value: 465.45374367751475, Gradient norm: 446.4539415651758
Iteration: 3, Value: 237.81603938688698, Gradient norm: 209.83364989406658
Iteration: 4, Value: 139.84666615858947, Gradient norm: 106.29635352398901
Iteration: 5, Value: 90.24385621796176, Gradient norm: 51.66078153563649
Iteration: 6, Value: 67.56695471283092, Gradient norm: 24.964188274601558
Iteration: 7, Value: 57.513897692075, Gradient norm: 11.443416048585918
Iteration: 8, Value: 53.90527695746368, Gradient norm: 9.527867511759366
Iteration: 9, Value: 53.018747284850456, Gradient norm: 4.0072953382639485
Iteration: 10, Value: 52.33267227490659, Gradient norm: 3.23567685069672
Iteration: 11, Value: 52.22491924518004, Gradient norm: 1.6050725498024332
Iteration: 12, Value: 52.21527601311167, Gradient norm: 0.15462279751671215
Iteration: 0, Value: 650.5853278758524

In [24]:
optimized_mean_kern_new, optimized_task_kern_new

(SEMagmaKernel(length_scale=1.17433086080464, variance=6.796909630639868),
 NoisySEMagmaKernel(length_scale=[ 1.214734    1.17188189  1.23820641  0.97111138  0.46351171  1.40309787
   0.70215186  1.49056998  0.97386099 -0.05430196  1.08243636  1.0596998
   1.56520916  0.82855608  0.65751824  1.46100638  1.34999146  0.72772361
   1.23905972  1.04273212], variance=2.15882736931959, noise=-2.1751331459541348))

In [25]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

891 ms ± 5.32 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


In [26]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

943 ms ± 29.1 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Distinct Input, shared HP

In [27]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, masks, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((41, 1), (20, 15, 1))

In [28]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)
task_kern = NoisySEMagmaKernel(length_scale=.6, variance=1., noise=2.5)

In [29]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((41,), (41, 41))

In [30]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 6093142.851951117, Gradient norm: 8763936.849021327
Iteration: 1, Value: 2340994.5359735494, Gradient norm: 2344526.408400721
Iteration: 2, Value: 1559915.7975703704, Gradient norm: 1561428.527367833
Iteration: 3, Value: 695187.8203177949, Gradient norm: 695411.7132746985
Iteration: 4, Value: 363608.7626415531, Gradient norm: 363592.6437455191
Iteration: 5, Value: 178880.26675022885, Gradient norm: 178789.94383728367
Iteration: 6, Value: 90135.26700047683, Gradient norm: 90020.03255345317
Iteration: 7, Value: 45037.24029636477, Gradient norm: 44906.257731033584
Iteration: 8, Value: 22623.340091339625, Gradient norm: 22477.855201064976
Iteration: 9, Value: 11393.875287353994, Gradient norm: 11233.68894949135
Iteration: 10, Value: 5790.255162347844, Gradient norm: 5615.32187262639
Iteration: 11, Value: 2994.134381965308, Gradient norm: 2804.6869890194375
Iteration: 12, Value: 1602.377297467958, Gradient norm: 1399.3743934217848
Iteration: 13, Value: 911.7212089540516

In [31]:
optimized_mean_kern_old, optimized_task_kern_old

(SEMagmaKernel(length_scale=0.8870253568076676, variance=15.502081965764084),
 NoisySEMagmaKernel(length_scale=0.7865613276229895, variance=-0.413673332360428, noise=14.103401509480907))

In [32]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 6093142.851951117, Gradient norm: 8763936.849021327
Iteration: 1, Value: 2340994.5359735494, Gradient norm: 2344526.408400721
Iteration: 2, Value: 1559915.7975703704, Gradient norm: 1561428.527367833
Iteration: 3, Value: 695187.8203177949, Gradient norm: 695411.7132746985
Iteration: 4, Value: 363608.7626415531, Gradient norm: 363592.6437455191
Iteration: 5, Value: 178880.26675022885, Gradient norm: 178789.94383728367
Iteration: 6, Value: 90135.26700047683, Gradient norm: 90020.03255345317
Iteration: 7, Value: 45037.24029636477, Gradient norm: 44906.257731033584
Iteration: 8, Value: 22623.340091339625, Gradient norm: 22477.855201064976
Iteration: 9, Value: 11393.875287353994, Gradient norm: 11233.68894949135
Iteration: 10, Value: 5790.255162347844, Gradient norm: 5615.32187262639
Iteration: 11, Value: 2994.134381965308, Gradient norm: 2804.6869890194375
Iteration: 12, Value: 1602.377297467958, Gradient norm: 1399.3743934217848
Iteration: 13, Value: 911.7212089540516

In [33]:
optimized_mean_kern_new, optimized_task_kern_new

(SEMagmaKernel(length_scale=0.8870253568076676, variance=15.502081965764084),
 NoisySEMagmaKernel(length_scale=0.7865613276229895, variance=-0.413673332360428, noise=14.103401509480907))

In [34]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

980 ms ± 3.04 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


In [35]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

964 ms ± 1.83 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Distinct Input, Distinct HP

In [36]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_distinct_hp.csv")
padded_inputs, padded_outputs, masks, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((41, 1), (20, 19, 1))

In [37]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = NoisySEMagmaKernel(length_scale=distinct_length_scales, variance=1., noise=2.5)

In [38]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((41,), (41, 41))

In [39]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 121256.90521232686, Gradient norm: 239445.49750500478
Iteration: 1, Value: 44906.01166080011, Gradient norm: 45107.31180058163
Iteration: 2, Value: 35170.11259415954, Gradient norm: 35221.13430678241
Iteration: 3, Value: 14885.785003174364, Gradient norm: 14812.375319465467
Iteration: 4, Value: 8052.220794125237, Gradient norm: 7966.635421411894
Iteration: 5, Value: 3988.9650448283796, Gradient norm: 3890.4109467181443
Iteration: 6, Value: 2081.5627079409305, Gradient norm: 1968.3249217039634
Iteration: 7, Value: 1109.1017308074431, Gradient norm: 980.4500332356429
Iteration: 8, Value: 633.5674137907808, Gradient norm: 490.6468542759428
Iteration: 9, Value: 400.7458609801983, Gradient norm: 245.74035566461282
Iteration: 10, Value: 289.0195878811828, Gradient norm: 126.2899979534387
Iteration: 11, Value: 234.5989267250031, Gradient norm: 73.60916045824995
Iteration: 12, Value: 128.93569313767566, Gradient norm: 101.90337244275251
Iteration: 13, Value: 126.1212994641

In [40]:
optimized_mean_kern_old, optimized_task_kern_old

(SEMagmaKernel(length_scale=0.6922244696763898, variance=11.242240419368127),
 NoisySEMagmaKernel(length_scale=[0.70265355 0.77952927 0.15812568 0.91845118 0.55242708 0.78221392
  0.56850079 0.32273545 0.27346032 0.61414702 0.13470765 0.74981194
  0.84749205 0.49536833 0.65845461 0.83339771 0.69183551 0.91259354
  0.63182436 0.40323832], variance=0.5340373895271299, noise=9.850625972686874))

In [41]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 121256.90521232686, Gradient norm: 239445.49750500478
Iteration: 1, Value: 44906.01166080011, Gradient norm: 45107.31180058163
Iteration: 2, Value: 35170.11259415954, Gradient norm: 35221.13430678241
Iteration: 3, Value: 14885.785003174364, Gradient norm: 14812.375319465467
Iteration: 4, Value: 8052.220794125237, Gradient norm: 7966.635421411894
Iteration: 5, Value: 3988.9650448283796, Gradient norm: 3890.4109467181443
Iteration: 6, Value: 2081.5627079409305, Gradient norm: 1968.3249217039634
Iteration: 7, Value: 1109.1017308074431, Gradient norm: 980.4500332356429
Iteration: 8, Value: 633.5674137907808, Gradient norm: 490.6468542759428
Iteration: 9, Value: 400.7458609801983, Gradient norm: 245.74035566461282
Iteration: 10, Value: 289.0195878811828, Gradient norm: 126.2899979534387
Iteration: 11, Value: 234.5989267250031, Gradient norm: 73.60916045824995
Iteration: 12, Value: 128.93569313767566, Gradient norm: 101.90337244275251
Iteration: 13, Value: 126.1212994641

In [42]:
optimized_mean_kern_new, optimized_task_kern_new

(SEMagmaKernel(length_scale=0.6922244696763898, variance=11.242240419368127),
 NoisySEMagmaKernel(length_scale=[0.70265355 0.77952927 0.15812568 0.91845118 0.55242708 0.78221392
  0.56850079 0.32273545 0.27346032 0.61414702 0.13470765 0.74981194
  0.84749205 0.49536833 0.65845461 0.83339771 0.69183551 0.91259354
  0.63182436 0.40323832], variance=0.5340373895271299, noise=9.850625972686874))

In [43]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.05 s ± 21.2 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


In [44]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.04 s ± 11.6 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


---
## Conclusion

---