# Benchmarks - Hyper-posterior distributions

**Main considerations when implementing hyper-post**
* ...


---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from jax import lax

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel
from MagmaClustPy.utils import generate_dummy_db, preprocess_db
from MagmaClustPy.hyperpost import hyperpost

In [4]:
# Config
key = jax.random.PRNGKey(0)
test_db_size = "medium"

---
## Data

---
## Current implementation

---
## Custom implementation(s)

In [5]:
_hyperpost = hyperpost

---
## Comparison

### Common Input, Common HP

In [6]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_common_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((200,), (200, 200))

In [7]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [8]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [9]:
mean_2, cov_2 = _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [10]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [11]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [12]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

292 ms ± 10.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [13]:
%%timeit -n 5 -r 5
_hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

283 ms ± 3.97 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Common Input, Distinct HP

In [14]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_common_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((200,), (200, 200))

In [15]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [16]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [17]:
mean_2, cov_2 = _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [18]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [19]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [20]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

325 ms ± 18.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [21]:
%%timeit -n 5 -r 5
_hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

302 ms ± 21.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Distinct Input, Common HP

In [22]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [23]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [24]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [25]:
mean_2, cov_2 = _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [26]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [27]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [28]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

880 ms ± 23.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [29]:
%%timeit -n 5 -r 5
_hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

857 ms ± 18.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Distinct Input, Distinct HP

In [30]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [31]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [32]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [33]:
mean_2, cov_2 = _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [34]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [35]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [36]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

867 ms ± 22.8 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [37]:
%%timeit -n 5 -r 5
_hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

847 ms ± 17.6 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Using a custom grid

In [38]:
db = pd.read_csv(f"./dummy_datasets/{test_db_size}_distinct_input_common_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)

grid = jnp.linspace(min(all_inputs), max(all_inputs), 500)
all_inputs.shape, padded_inputs.shape, grid.shape

((401,), (200, 401), (500,))

In [39]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [40]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)

In [41]:
mean_2, cov_2 = _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)

In [42]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [43]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [44]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)[0].block_until_ready()

986 ms ± 91.4 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [45]:
%%timeit -n 5 -r 5
_hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)[0].block_until_ready()

974 ms ± 26.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


---
## Conclusion

---