# Benchmarks - Hyper-posterior distributions

**Main considerations when implementing hyper-post**
* ...


---
## Setup

In [46]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [47]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from jax import lax

import numpy as np
import pandas as pd

In [48]:
# Local
from MagmaClustPy.kernels import RBFKernel
from MagmaClustPy.utils import generate_dummy_db, preprocess_db, map_to_full_matrix, map_to_full_matrix_batch, map_to_full_array_batch
from MagmaClustPy.legacy import _legacy_preprocess_db
from MagmaClustPy.hyperpost import hyperpost

In [49]:
# Config
key = jax.random.PRNGKey(0)
test_db_size = "medium"

---
## Data

---
## Current implementation

---
## Custom implementation(s)

In [50]:
from jax import jit
from jax import numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from jax.tree_util import tree_flatten

@jit
def hyperpost_shared_input_shared_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_cov, inputs_to_grid=None,
                                     nugget=jnp.array(1e-10)):
	eye = jnp.eye(task_cov.shape[-1])

	# Compute task covariance and its Cholesky factor
	task_cov_u, _ = cho_factor(task_cov + eye * nugget)
	task_cov_inv = cho_solve((task_cov_u, False), eye)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	# All tasks share same inputs and hyperparameters, so their inverse covariances are the same, and we can compute
	# one then multiply rather than compute all then sum
	post_cov_inv, _ = cho_factor(mean_cov_inv + len(outputs) * task_cov_inv, )
	post_cov = cho_solve((post_cov_inv, False), eye)

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_u, False), prior_mean)
	weighted_tasks = cho_solve((task_cov_u, False), outputs.sum(axis=0))

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov


@jit
def hyperpost_shared_input_distinct_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_covs, inputs_to_grid=None,
                                       nugget=jnp.array(1e-10)):
	eye = jnp.broadcast_to(jnp.eye(task_covs.shape[-1]), task_covs.shape)

	# Compute task covariance and its Cholesky factor
	# task_covs_L = vmap(lambda x: cho_factor(x + eye * nugget, lower=True)[0])(task_covs)
	task_covs_u, _ = cho_factor(task_covs + eye * nugget)
	# task_cov_inv = vmap(lambda L: cho_solve((L, True), eye))(task_covs_L).sum(axis=0)
	task_cov_inv = cho_solve((task_covs_u, False), eye)

	task_cov_inv = task_cov_inv.sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), eye[0])

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_u, False), prior_mean)
	# weighted_tasks = vmap(lambda L, o: cho_solve((L, True), o))(task_covs_L, outputs).sum(axis=0)
	weighted_tasks = cho_solve((task_covs_u, False), outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov


@jit
def hyperpost_distinct_input(outputs, mappings, prior_mean, mean_cov_u, mean_cov_inv, task_covs, inputs_to_grid=None,
                             nugget=jnp.array(1e-10)):
	"""
	computes the hyperpost on distinct inputs

	task_covs: (M, N, N), batch of unaligned covariances
	"""
	small_eye = jnp.broadcast_to(jnp.eye(task_covs.shape[-1]), task_covs.shape)

	# task_covs is padded with NaNs. Replace them by their corresponding identity rows/cols
	eyed_task_covs = jnp.where(jnp.isnan(task_covs), small_eye, task_covs)

	# Posterior covariance
	task_covs_U, _ = cho_factor(eyed_task_covs + small_eye * nugget)
	task_covs_inv = cho_solve((task_covs_U, False), small_eye)
	task_covs_inv -= jnp.where(jnp.isnan(task_covs), small_eye, 0)  # Correction on the diagonal
	task_covs_inv = map_to_full_matrix_batch(task_covs_inv, all_inputs, mappings)
	task_cov_inv = jnp.nan_to_num(task_covs_inv).sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), jnp.eye(mean_cov_u.shape[-1]))

	# Posterior mean
	weighted_prior_mean = cho_solve((mean_cov_u, False), prior_mean)
	mapped_outputs = jnp.nan_to_num(map_to_full_array_batch(outputs, all_inputs, mappings))
	padded_task_covs_U = map_to_full_matrix_batch(task_covs_U, all_inputs, mappings)
	eyed_task_covs_U = jnp.where(jnp.isnan(padded_task_covs_U), jnp.eye(all_inputs.shape[-1]), padded_task_covs_U)
	weighted_tasks = cho_solve((eyed_task_covs_U, False), mapped_outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov


# General function
def hyperpost_new(inputs, outputs, masks, prior_mean, mean_kernel, task_kernel, all_inputs=None, grid=None,
              nugget=jnp.array(1e-10)):
	"""
	Computes the posterior mean and covariance of a Magma GP given the inputs, outputs, masks, prior mean and kernels.

	:param inputs: the preprocessed (padded and aligned) inputs
	:param outputs: the preprocessed outputs
	:param masks: the masks indicating which inputs are valid
	:param prior_mean: the prior mean, as a scalar or a vector of shape (N, ), where N is the length of the union of all
	inputs and the grid
	:param mean_kernel: kernel of the mean process, with hyperparameters loaded as attributes
	:param task_kernel: kernel of the task process, with hyperparameters loaded as attributes
	:param all_inputs: all distinct inputs. If not provided, it will be computed from the inputs
	:param grid: the grid on which the GP is defined. If not provided, the GP is defined on all distinct inputs
	:param nugget: nugget term to ensure numerical stability. Default is 1e-10
	:return: a 2-tuple of the posterior mean and covariance
	"""
	shared_hp = all([hp.ndim == 0 for hp in tree_flatten(task_kernel)[0]])

	# Merge inputs and grid to create all_inputs
	if all_inputs is None:
		all_inputs = jnp.sort(jnp.unique(inputs.flatten()))

	shared_input = len(inputs[0]) == len(all_inputs)

	if grid is None:
		grid = all_inputs
		inputs_to_grid = None
	else:
		grid = jnp.sort(jnp.unique(jnp.concatenate([all_inputs, grid])))
		inputs_to_grid = jnp.searchsorted(grid, all_inputs)
		shared_input = False  # We need to pad the cov matrices to compute on the full grid

	if prior_mean.ndim == 0:
		prior_mean = jnp.broadcast_to(prior_mean, (len(grid),))

	# Numerical stability terms
	eye = jnp.eye(grid.shape[0])

	# Compute mean covariance and its Cholesky factor
	mean_cov = mean_kernel(grid, grid)
	mean_cov_u, _ = cho_factor(mean_cov + eye * nugget)
	mean_cov_inv = cho_solve((mean_cov_u, False), eye)

	if shared_input:
		if shared_hp:
			task_cov = task_kernel(grid)  # Shape: (N, N)
			return hyperpost_shared_input_shared_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_cov,
			                                        inputs_to_grid, nugget)

		else:  # distinct HPs, we have to compute every task covariance but no padding is required
			task_covs = task_kernel(inputs)  # Shape: (M, N, N)
			return hyperpost_shared_input_distinct_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_covs,
			                                          inputs_to_grid, nugget)

	else:  # No shared input: we have to pad and mask
		# task_covs = task_kernel(jnp.broadcast_to(all_inputs, (len(inputs), len(all_inputs))))
		task_covs = task_kernel(inputs)
		return hyperpost_distinct_input(outputs, masks, prior_mean, mean_cov_u, mean_cov_inv, task_covs, inputs_to_grid,
		                                nugget)

---
## Comparison

### shared Input, shared HP

In [51]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_shared_input_shared_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs_new, padded_inputs_new, padded_outputs_new, mappings = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150,), (200, 150))

In [53]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)
task_kern = RBFKernel(length_scale=.6, variance=1.)

In [54]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [55]:
mean_2, cov_2 = hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)

In [56]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [57]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [58]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

1.94 ms ± 1.11 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [59]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)[0].block_until_ready()

1.31 ms ± 133 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### shared Input, Distinct HP

In [60]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_shared_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs_new, padded_inputs_new, padded_outputs_new, mappings = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150,), (200, 150))

In [61]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=1.)

In [62]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [63]:
mean_2, cov_2 = hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)

In [64]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [65]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [66]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

47.5 ms ± 785 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [67]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)[0].block_until_ready()

46.7 ms ± 242 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Distinct Input, shared HP

In [83]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_distinct_input_shared_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs_new, padded_inputs_new, padded_outputs_new, mappings = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [69]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)
task_kern = RBFKernel(length_scale=.6, variance=1.)

In [70]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [71]:
mean_2, cov_2 = hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)

In [72]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [73]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(1.26776596e-08, dtype=float64), Array(7.77561785e-15, dtype=float64))

In [74]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

400 ms ± 87.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [75]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)[0].block_until_ready()

187 ms ± 47.7 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Distinct Input, Distinct HP

In [76]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_distinct_input_distinct_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs_new, padded_inputs_new, padded_outputs_new, mappings = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401,), (200, 401))

In [77]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=1.)

In [78]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [79]:
mean_2, cov_2 = hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)

In [80]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [81]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(3.58775768e-09, dtype=float64), Array(6.2736897e-15, dtype=float64))

In [82]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

KeyboardInterrupt: 

In [37]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new)[0].block_until_ready()

138 ms ± 2.96 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Using a custom grid

In [38]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_distinct_input_shared_hp.csv")
all_inputs, padded_inputs, padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs_new, padded_inputs_new, padded_outputs_new, mappings = preprocess_db(db)

grid = jnp.linspace(min(all_inputs), max(all_inputs), 500)
all_inputs.shape, padded_inputs.shape, grid.shape

((401,), (200, 401), (500,))

In [39]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)
task_kern = RBFKernel(length_scale=.6, variance=1.)

In [40]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)

In [41]:
mean_2, cov_2 = hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new, grid=grid)

In [42]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [43]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(1.2658448e-08, dtype=float64), Array(7.78045934e-15, dtype=float64))

In [44]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)[0].block_until_ready()

368 ms ± 19.6 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [45]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs_new, padded_outputs_new, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs_new, grid=grid)[0].block_until_ready()

179 ms ± 10.8 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


---
## Conclusion

---