# Benchmarks - Hyper-parameters optimisation

**Main considerations when implementing HPs optimisation**
- we made kernels pytrees, so we should be able to compute gradient and optimise for them directly


---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
import jax.random as jrd

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel, SEMagmaKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.hyperpost import hyperpost
from MagmaClustPy.hp_optimisation import optimise_hyperparameters

In [4]:
# Config
key = jrd.PRNGKey(0)
test_db_size = "medium"

---
## Data

---
## Current implementation

---
## Custom implementation(s)

---
## Comparison

In [5]:
prior_mean = jnp.array(0)
nugget = jnp.array(1e-10)

### Common Input, Common HP

In [6]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_common_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150,), (200, 150))

In [7]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [8]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [9]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)

2025-03-12 14:09:43.880663: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %reduce-window.8 = f64[200,5,5]{2,1,0} reduce-window(f64[200,150,150]{2,1,0} %broadcast.1530, f64[] %constant.255), window={size=1x32x32 stride=1x32x32 pad=0_0x5_5x5_5}, to_apply=%region_6.576

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-03-12 14:09:43.881199: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.00559s
Constant folding an instruction is taking > 1s:

  %reduce-window.8 = f64[200,5,5]{2,1,0} reduce-window(f64[200,150,150]{2,1,0} %broadcast.1530, f64[] %constant

In [10]:
# jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

In [11]:
# jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

In [12]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

35.7 s ± 4.77 s per loop (mean ± std. dev. of 2 runs, 3 loops each)


In [13]:
# %%timeit -n 5 -r 5
# _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

### Common Input, Distinct HP

In [14]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_common_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150,), (200, 150))

In [15]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [19]:
distinct_length_scales.shape

(200,)

In [16]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [17]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)

ValueError: Parameter length_scale must be a scalar or a 1D array, got shape (10, 200).

In [10]:
# jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [11]:
# jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [12]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

292 ms ± 10.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [None]:
# %%timeit -n 5 -r 5
# _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

### Distinct Input, Common HP

In [22]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [23]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [24]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [9]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)

In [10]:
# jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [11]:
# jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [12]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

292 ms ± 10.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [None]:
# %%timeit -n 5 -r 5
# _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

### Distinct Input, Distinct HP

In [30]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [31]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [32]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [9]:
new_mean_kern, new_task_kern = optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)

In [10]:
# jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [11]:
# jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [12]:
%%timeit -n 3 -r 2
optimise_hyperparameters(mean_kern, task_kern, padded_inputs, padded_outputs, all_inputs, prior_mean, post_mean, post_cov, masks, nugget=nugget)[0].length_scale.block_until_ready()

292 ms ± 10.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [None]:
# %%timeit -n 5 -r 5
# _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

---
## Conclusion

---