# Benchmarks - Hyper-parameters optimisation

**Main considerations when implementing HPs optimisation**
- we made kernels pytrees, so we should be able to compute gradient and optimise for them directly


---
## Setup

In [1]:
# Standard library
import os
from typing import NamedTuple

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
import jax.random as jrd
from jax.tree_util import tree_flatten
import optax
import optax.tree_utils as otu
import chex

import numpy as np
import pandas as pd

In [3]:
# Local
from Kernax import SEMagmaKernel, NoisySEMagmaKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.hyperpost import hyperpost
from MagmaClustPy.hp_optimisation import optimise_hyperparameters
from MagmaClustPy.likelihoods import magma_neg_likelihood

In [4]:
# Config
key = jrd.PRNGKey(0)
test_db_size = "small"

---
## Data

---
## Current implementation

---
## Custom implementation(s)

---
## Comparison

In [5]:
nugget = jnp.array(1e-10)

### shared Input, shared HP

In [6]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_shared_input_shared_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((15,), (20, 15))

In [7]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = NoisySEMagmaKernel(length_scale=jnp.array(.6), variance=jnp.array(1.), noise=jnp.array(-2.5))

In [8]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [9]:
post_mean

Array([ 35.48962234,  32.96258847,  14.53163018,   7.39134076,
         6.9727986 ,   2.33313928,  -0.79272859,   2.13320892,
        -2.40420656,  -6.34208034, -37.30466657, -36.48716408,
       -39.22947049, -44.80626263, -44.6835497 ], dtype=float64)

In [10]:
np.asarray(post_cov)

array([[ 1.33106999e-01,  1.20663638e-01,  2.28073151e-02,
         4.33130755e-04,  8.57652668e-05,  1.89168920e-06,
        -4.27880897e-06,  6.85963538e-07, -3.88982663e-06,
         1.99620969e-06,  3.04845942e-07, -1.27455522e-08,
        -1.24341070e-08,  2.65360842e-09, -2.72407120e-09],
       [ 1.20663638e-01,  1.33092938e-01,  4.25316404e-02,
         1.48362128e-03,  4.26979313e-04,  5.77622317e-06,
        -3.15875361e-06,  5.24790251e-07, -2.93921597e-06,
         1.50860499e-06,  2.30332115e-07, -9.63157142e-09,
        -9.39626608e-09,  2.00529397e-09, -2.05854155e-09],
       [ 2.28073151e-02,  4.25316404e-02,  1.33002674e-01,
         4.25175890e-02,  2.27987636e-02,  1.48400879e-03,
         1.20509378e-04, -7.93658534e-07,  4.11481840e-06,
        -2.11379794e-06, -3.21586773e-07,  1.34797439e-08,
         1.31515068e-08, -2.80676765e-09,  2.88129400e-09],
       [ 4.33130755e-04,  1.48362128e-03,  4.25175890e-02,
         1.32977462e-01,  1.20598371e-01,  4.25616111

In [11]:
tree_flatten(task_kern)

([Array(0.6, dtype=float64, weak_type=True),
  Array(1., dtype=float64, weak_type=True),
  Array(-2.5, dtype=float64, weak_type=True)],
 PyTreeDef(CustomNode(NoisySEMagmaKernel[None], [*, *, *])))

In [12]:
task_kern.__dict__

{'length_scale': Array(0.6, dtype=float64, weak_type=True),
 'variance': Array(1., dtype=float64, weak_type=True),
 'noise': Array(-2.5, dtype=float64, weak_type=True)}

In [13]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 1084.5687542988799, Gradient norm: 1080.70322094892
Iteration: 1, Value: 409.9558815385026, Gradient norm: 385.50308175436663
Iteration: 2, Value: 246.8625068327388, Gradient norm: 216.68941466303235
Iteration: 3, Value: 137.6512584822056, Gradient norm: 101.85021819665475
Iteration: 4, Value: 91.00434617848073, Gradient norm: 50.67620079745729
Iteration: 5, Value: 68.4683526489665, Gradient norm: 24.038302273306
Iteration: 6, Value: 58.88841165956168, Gradient norm: 11.05811502791275
Iteration: 7, Value: 55.24076634040253, Gradient norm: 4.785243101089721
Iteration: 8, Value: 54.85282246326992, Gradient norm: 14.295814514839375
Iteration: 9, Value: 53.98775677225567, Gradient norm: 0.794807427567745
Iteration: 10, Value: 53.96194090517091, Gradient norm: 0.42448082778252616
Iteration: 11, Value: 53.957169277751184, Gradient norm: 0.08596668675362666
Iteration: 0, Value: 661.1005156615454, Gradient norm: 193.50585235358955
Iteration: 1, Value: 587.4519748282615, Gr

In [14]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=1.00154846620888, variance=6.144210403793922),
 NoisySEMagmaKernel(length_scale=0.8098983971429901, variance=2.0507074907476484, noise=-1.7465099667483082))

In [15]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

3.8 s ± 528 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### shared Input, Distinct HP

In [16]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_shared_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((15,), (20, 15))

In [35]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = NoisySEMagmaKernel(length_scale=distinct_length_scales, variance=jnp.array(1.), noise=jnp.array(-2.5))

In [36]:
distinct_length_scales.shape

(20,)

In [37]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [38]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 626.7167066691293, Gradient norm: 706.0380411925394
Iteration: 1, Value: 451.806311523143, Gradient norm: 543.7230623461476
Iteration: 2, Value: 147.07110873131407, Gradient norm: 232.69181618348313
Iteration: 3, Value: 64.83777429432344, Gradient norm: 113.89407899490469
Iteration: 4, Value: 27.28139245658874, Gradient norm: 225.25012557657044
Iteration: 5, Value: 7.7382040165459545, Gradient norm: 30.940870358614667
Iteration: 6, Value: 1.8335457629512169, Gradient norm: 20.842558323537606
Iteration: 7, Value: 0.13444766232710137, Gradient norm: 2.698937705762504
Iteration: 8, Value: -0.06728720345797257, Gradient norm: 3.296152944075638
Iteration: 9, Value: -0.08372363766263291, Gradient norm: 0.6710941541660963
Iteration: 10, Value: -0.08473203042976607, Gradient norm: 0.6142705295891342
Iteration: 0, Value: 789.7604152108095, Gradient norm: 224.5022628356022
Iteration: 1, Value: 687.9245227152583, Gradient norm: 20.95826515519451


In [39]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=0.5054967309781664, variance=4.907526076005412),
 NoisySEMagmaKernel(length_scale=[ 0.11174793 -0.2183297   0.61641025  0.34626178  0.6645803   0.81149112
   0.5203386   0.25153783  0.73876548  0.0906732   1.35127933  0.37090433
   0.36570312  0.42206205  1.45831781  0.32074166  1.2330959   0.49874442
   0.56897336  0.84063259], variance=1.932440115089407, noise=-2.349786151816508))

In [22]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

3.72 s ± 785 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Distinct Input, shared HP

In [23]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_shared_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((41,), (20, 41))

In [24]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [25]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [26]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 93883.35546043565, Gradient norm: 1194745.0537562363
Iteration: 1, Value: 2974.744120614934, Gradient norm: 3148.214376785104
Iteration: 2, Value: 2491.7251499679633, Gradient norm: 2609.0515343296424
Iteration: 3, Value: 1116.2495461511437, Gradient norm: 1082.403136794949
Iteration: 4, Value: 764.0029201857294, Gradient norm: 1105.0290129108764
Iteration: 5, Value: 570.921255489159, Gradient norm: 1403.87102684516
Iteration: 6, Value: 356.6130936636686, Gradient norm: 992.0110973415899
Iteration: 7, Value: 205.05306836019986, Gradient norm: 461.03370646928903
Iteration: 8, Value: 142.8731738210606, Gradient norm: 214.79124111175432
Iteration: 9, Value: 115.70298059020287, Gradient norm: 85.1437965271968
Iteration: 10, Value: 106.25762591722116, Gradient norm: 22.86426903099017
Iteration: 11, Value: 103.7650586126552, Gradient norm: 8.736323760266655
Iteration: 12, Value: 102.95825498057987, Gradient norm: 1.301709460177779
Iteration: 13, Value: 102.8991340876219,

In [27]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=-0.0071476759712833186, variance=6.639570750925357),
 SEMagmaKernel(length_scale=-0.5745746785013888, variance=2.9418057340078563))

In [28]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.55 s ± 27.7 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)


### Distinct Input, Distinct HP

In [29]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((41,), (20, 41))

In [30]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [31]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [32]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget, verbose=True)

Iteration: 0, Value: 63693.24263182963, Gradient norm: 935204.7971765217
Iteration: 1, Value: 936.9165952740885, Gradient norm: 930.1424022688172
Iteration: 2, Value: 826.1892973711264, Gradient norm: 813.0332957764756
Iteration: 3, Value: 361.23238370298935, Gradient norm: 317.8693494976895
Iteration: 4, Value: 233.44345265802076, Gradient norm: 208.48996186308386
Iteration: 5, Value: 168.2633939143924, Gradient norm: 282.21037786145416
Iteration: 6, Value: 141.5093737836419, Gradient norm: 340.7081490306301
Iteration: 7, Value: 110.34916051831858, Gradient norm: 240.78517808066476
Iteration: 8, Value: 90.96996170282848, Gradient norm: 113.10981083458589
Iteration: 9, Value: 85.10452211994607, Gradient norm: 47.36631658124324
Iteration: 10, Value: 83.81965938398909, Gradient norm: 16.27514127248603
Iteration: 11, Value: 83.69133665039448, Gradient norm: 5.6383183429130925
Iteration: 12, Value: 83.67553022693143, Gradient norm: 1.120123616884129
Iteration: 0, Value: 470621.1104742168, 

In [33]:
new_mean_kern, new_task_kern

(SEMagmaKernel(length_scale=-0.0567281240906174, variance=5.454397379716924),
 SEMagmaKernel(length_scale=[-0.39953694 -0.16537672 -0.37609432 -0.18648492 -0.02030922 -0.07931304
   0.11920396 -0.17766618 -0.20457625 -0.03514375  0.28585494 -0.38560367
  -0.36277171 -0.72086503 -0.40399476  0.00614847 -0.58879221  0.14658601
  -0.23616103 -0.15323145], variance=3.009735760523688))

In [34]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

1.71 s ± 861 μs per loop (mean ± std. dev. of 2 runs, 3 loops each)


---
## Conclusion

---