# Benchmarks - Hyper-posterior distributions

**Main considerations when implementing hyper-post**
* ...


---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from jax import lax

import numpy as np
import pandas as pd

In [3]:
# Local
from Kernax import RBFKernel
from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.linalg import map_to_full_matrix_batch, map_to_full_array_batch
from MagmaClustPy.hyperpost import hyperpost

In [4]:
# Config
key = jax.random.PRNGKey(0)
test_db_size = "medium"

---
## Data

---
## Current implementation

In [5]:
import MagmaClustPy

hyperpost_old = MagmaClustPy.hyperpost.hyperpost

---
## Custom implementation(s)

In [6]:
from jax import jit
from jax import numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from jax.tree_util import tree_flatten

@jit
def hyperpost_shared_input_shared_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_cov, inputs_to_grid=None,
                                     nugget=jnp.array(1e-10)):
	eye = jnp.eye(task_cov.shape[-1])

	# Compute task covariance and its Cholesky factor
	task_cov_u, _ = cho_factor(task_cov + eye * nugget)
	task_cov_inv = cho_solve((task_cov_u, False), eye)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	# All tasks share same inputs and hyperparameters, so their inverse covariances are the same, and we can compute
	# one then multiply rather than compute all then sum
	post_cov_inv, _ = cho_factor(mean_cov_inv + len(outputs) * task_cov_inv, )
	post_cov = cho_solve((post_cov_inv, False), eye)

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_u, False), prior_mean)
	weighted_tasks = cho_solve((task_cov_u, False), outputs.sum(axis=0))

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov


@jit
def hyperpost_shared_input_distinct_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_covs, inputs_to_grid=None,
                                       nugget=jnp.array(1e-10)):
	eye = jnp.broadcast_to(jnp.eye(task_covs.shape[-1]), task_covs.shape)

	# Compute task covariance and its Cholesky factor
	task_covs_u, _ = cho_factor(task_covs + eye * nugget)
	task_cov_inv = cho_solve((task_covs_u, False), eye)

	task_cov_inv = task_cov_inv.sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), eye[0])

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_u, False), prior_mean)
	# weighted_tasks = vmap(lambda L, o: cho_solve((L, True), o))(task_covs_L, outputs).sum(axis=0)
	weighted_tasks = cho_solve((task_covs_u, False), outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov


@jit
def hyperpost_distinct_input(outputs, mappings, prior_mean, mean_cov_u, mean_cov_inv, task_covs, all_inputs, inputs_to_grid=None,
                             nugget=jnp.array(1e-10)):
	"""
	computes the hyperpost on distinct inputs

	task_covs: (M, N, N), batch of unaligned covariances
	"""
	small_eye = jnp.broadcast_to(jnp.eye(task_covs.shape[-1]), task_covs.shape)

	# task_covs is padded with NaNs. Replace them by their corresponding identity rows/cols
	eyed_task_covs = jnp.where(jnp.isnan(task_covs), small_eye, task_covs)

	# Posterior covariance
	task_covs_U, _ = cho_factor(eyed_task_covs + small_eye * nugget)
	task_covs_inv = cho_solve((task_covs_U, False), small_eye)
	task_covs_inv -= jnp.where(jnp.isnan(task_covs), small_eye, 0)  # Correction on the diagonal
	task_covs_inv = map_to_full_matrix_batch(task_covs_inv, all_inputs, mappings)
	task_cov_inv = jnp.nan_to_num(task_covs_inv).sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), jnp.eye(mean_cov_u.shape[-1]))

	# Posterior mean
	weighted_prior_mean = cho_solve((mean_cov_u, False), prior_mean)
	mapped_outputs = jnp.nan_to_num(map_to_full_array_batch(outputs, all_inputs, mappings))
	padded_task_covs_U = map_to_full_matrix_batch(task_covs_U, all_inputs, mappings)
	eyed_task_covs_U = jnp.where(jnp.isnan(padded_task_covs_U), jnp.eye(all_inputs.shape[-1]), padded_task_covs_U)
	weighted_tasks = cho_solve((eyed_task_covs_U, False), mapped_outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov


# General function
def hyperpost(inputs, outputs, masks, prior_mean, mean_kernel, task_kernel, all_inputs=None, grid=None,
              nugget=jnp.array(1e-10)):
	"""
	Computes the posterior mean and covariance of a Magma GP given the inputs, outputs, masks, prior mean and kernels.

	:param inputs: the preprocessed (padded and aligned) inputs
	:param outputs: the preprocessed outputs
	:param masks: the masks indicating which inputs are valid
	:param prior_mean: the prior mean, as a scalar or a vector of shape (N, ), where N is the length of the union of all
	inputs and the grid
	:param mean_kernel: kernel of the mean process, with hyperparameters loaded as attributes
	:param task_kernel: kernel of the task process, with hyperparameters loaded as attributes
	:param all_inputs: all distinct inputs. If not provided, it will be computed from the inputs
	:param grid: the grid on which the GP is defined. If not provided, the GP is defined on all distinct inputs
	:param nugget: nugget term to ensure numerical stability. Default is 1e-10

	:return: a 2-tuple of the posterior mean and covariance
	"""
	#TODO: maybe with multi-output we will have multiple post_mean to compute. We should discuss how to unify the return shape of hyperpost
	outputs = outputs.squeeze()  # For now this will do the trick

	shared_hp = all([hp.ndim == 0 for hp in tree_flatten(task_kernel)[0]])

	# Merge inputs and grid to create all_inputs
	if all_inputs is None:
		all_inputs = jnp.sort(jnp.unique(inputs.flatten()))

	shared_input = len(inputs[0]) == len(all_inputs)

	if grid is None:
		grid = all_inputs
		inputs_to_grid = None
	else:
		grid = jnp.sort(jnp.unique(jnp.concatenate([all_inputs, grid])))
		inputs_to_grid = jnp.searchsorted(grid, all_inputs)
		shared_input = False  # We need to pad the cov matrices to compute on the full grid

	if prior_mean.ndim == 0:
		prior_mean = jnp.broadcast_to(prior_mean, (len(grid),))

	# Numerical stability terms
	eye = jnp.eye(grid.shape[0])

	# Compute mean covariance and its Cholesky factor
	mean_cov = mean_kernel(grid, grid)
	mean_cov_u, _ = cho_factor(mean_cov + eye * nugget)
	mean_cov_inv = cho_solve((mean_cov_u, False), eye)

	if shared_input:
		if shared_hp:
			task_cov = task_kernel(grid)  # Shape: (N, N)
			return hyperpost_shared_input_shared_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_cov,
			                                        inputs_to_grid, nugget)

		else:  # distinct HPs, we have to compute every task covariance but no padding is required
			task_covs = task_kernel(inputs)  # Shape: (M, N, N)
			return hyperpost_shared_input_distinct_hp(outputs, prior_mean, mean_cov_u, mean_cov_inv, task_covs,
			                                          inputs_to_grid, nugget)

	else:  # No shared input: we have to pad and mask
		# task_covs = task_kernel(jnp.broadcast_to(all_inputs, (len(inputs), len(all_inputs))))
		task_covs = task_kernel(inputs)
		return hyperpost_distinct_input(outputs, masks, prior_mean, mean_cov_u, mean_cov_inv, task_covs, all_inputs,
		                                inputs_to_grid, nugget)


In [7]:
hyperpost_new = hyperpost

---
## Comparison

### shared Input, shared HP

In [8]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [9]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)
task_kern = RBFKernel(length_scale=.6, variance=1.)

In [10]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [11]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [12]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [13]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [14]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs.squeeze(), padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

The slowest run took 4.77 times longer than the fastest. This could mean that an intermediate result is being cached.
2.56 ms ± 2.13 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [15]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

1.45 ms ± 200 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### shared Input, Distinct HP

In [16]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_distinct_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [17]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=1.)

In [18]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [19]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [20]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [21]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [22]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

53.3 ms ± 1.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [23]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

52.4 ms ± 696 μs per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Distinct Input, shared HP

In [24]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 190, 1))

In [25]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)
task_kern = RBFKernel(length_scale=.6, variance=1.)

In [26]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [27]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [28]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [29]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [30]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

150 ms ± 3.26 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [31]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

153 ms ± 4.36 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Distinct Input, Distinct HP

In [32]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_distinct_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 188, 1))

In [33]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = RBFKernel(length_scale=distinct_length_scales, variance=1.)

In [34]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [35]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [36]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [37]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [38]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

174 ms ± 23.4 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [39]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

174 ms ± 21.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


### Using a custom grid

In [40]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

grid = jnp.linspace(jnp.min(all_inputs, axis=1), jnp.max(all_inputs, axis=1), 500)
all_inputs.shape, padded_inputs.shape, grid.shape

((401, 1), (200, 190, 1), (500, 401))

In [41]:
mean_kern = RBFKernel(length_scale=.3, variance=1.)
task_kern = RBFKernel(length_scale=.6, variance=1.)

In [42]:
mean_1, cov_1 = hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [43]:
mean_2, cov_2 = hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [44]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [45]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [46]:
%%timeit -n 5 -r 5
hyperpost_old(padded_inputs, padded_outputs.squeeze(), mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

178 ms ± 21.7 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [47]:
%%timeit -n 5 -r 5
hyperpost_new(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

157 ms ± 3.58 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


---
## Conclusion

---