# Benchmarks - Hyper-parameters optimisation

**Main considerations when implementing HPs optimisation**
- we made kernels pytrees, so we should be able to compute gradient and optimise for them directly


---
## Setup

In [None]:
# Standard library
import os
from typing import NamedTuple

os.environ['JAX_ENABLE_X64'] = "True"

In [None]:
# Third party
import jax
import jax.numpy as jnp
import jax.random as jrd
from jax.tree_util import tree_flatten
import optax
import optax.tree_utils as otu
import chex

import numpy as np
import pandas as pd

In [None]:
# Local
from Kernax import SEMagmaKernel, DiagKernel, ExpKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.hyperpost import hyperpost

In [None]:
# Config
key = jrd.PRNGKey(0)
test_db_size = "small"

---
## Data

---
## Current implementation

In [None]:
import MagmaClustPy
optimise_hyperparameters_old = MagmaClustPy.hp_optimisation.optimise_hyperparameters

---
## Custom implementation(s)

*Start by copy-pasting the original function from the MagmaClustPy module, then bring modifications*

In [None]:
optimise_hyperparameters_new = optimise_hyperparameters

---
## Comparison

In [None]:
jitter = jnp.array(1e-10)

### shared Input, shared HP

In [None]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

In [None]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [None]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, mean_kern, task_kern)
post_mean.shape, post_cov.shape

In [None]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_old, optimized_task_kern_old

In [None]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_new, optimized_task_kern_new

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

### shared Input, Distinct HP

In [None]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_distinct_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

In [None]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [None]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, mean_kern, task_kern)
post_mean.shape, post_cov.shape

In [None]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_old, optimized_task_kern_old

In [None]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_new, optimized_task_kern_new

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

### Distinct Input, shared HP

In [None]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

In [None]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [None]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, mean_kern, task_kern)
post_mean.shape, post_cov.shape

In [None]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_old, optimized_task_kern_old

In [None]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_new, optimized_task_kern_new

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

### Distinct Input, Distinct HP

In [None]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_distinct_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

In [None]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [None]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, mean_kern, task_kern)
post_mean.shape, post_cov.shape

In [None]:
optimized_mean_kern_old, optimized_task_kern_old, _, _ = optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_old, optimized_task_kern_old

In [None]:
optimized_mean_kern_new, optimized_task_kern_new, _, _ = optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov, verbose=True)

In [None]:
optimized_mean_kern_new, optimized_task_kern_new

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_old(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

In [None]:
%%timeit -n 3 -r 2
optimise_hyperparameters_new(mean_kern, task_kern, padded_inputs, padded_outputs, mappings, all_inputs, prior_mean, post_mean, post_cov)[0].length_scale.block_until_ready()

---
## Conclusion

---