# Benchmarks - Hyper-posterior distributions

**Main considerations when implementing hyper-post**
* ...


---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
#from jax.scipy.linalg import cho_factor, cho_solve
from jax import lax

import numpy as np
import pandas as pd

In [3]:
# Local
from Kernax import RBFKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.linalg import map_to_full_matrix_batch, map_to_full_array_batch
from MagmaClustPy.hyperpost import hyperpost

In [4]:
# Config
key = jax.random.PRNGKey(0)
test_db_size = "medium"

---
## Data

---
## Current implementation

In [5]:
import MagmaClustPy

hyperpost_old = MagmaClustPy.hyperpost.hyperpost

---
## Custom implementation(s)

*Start by copy-pasting the original function from the MagmaClustPy module, then bring modifications*

In [None]:
hyperpost_new = hyperpost

---
## Comparison

### shared Input, shared HP

In [8]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [9]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [None]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

In [None]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

In [None]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

In [None]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

### shared Input, Distinct HP

In [16]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_distinct_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [17]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [None]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

In [None]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

In [None]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

In [None]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

### Distinct Input, shared HP

In [24]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 190, 1))

In [25]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [None]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

In [None]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

In [None]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

In [None]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

### Distinct Input, Distinct HP

In [32]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_distinct_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 188, 1))

In [33]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=jnp.array(1.))

In [None]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

In [None]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

In [None]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

In [None]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

### Using a custom grid

In [40]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

grid = jnp.linspace(jnp.min(all_inputs, axis=1), jnp.max(all_inputs, axis=1), 500)
all_inputs.shape, padded_inputs.shape, grid.shape

((401, 1), (200, 190, 1), (500, 401))

In [41]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [None]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)

In [None]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

In [None]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

In [None]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

In [None]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, all_inputs, jnp.array(0.), mean_kern, task_kern)[0].block_until_ready()

---
## Conclusion

---