# Benchmarks - Hyper-posterior distributions

**Main considerations when implementing hyper-post**
* For now, we only focus on making the previous implementation work with our new kernels.
* This implementation will not take the grid into account at first


---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax.scipy.linalg import cho_factor, cho_solve
from jax import lax

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel
from MagmaClustPy.utils import generate_dummy_db, preprocess_db
from MagmaClustPy.hyperpost import hyperpost

In [4]:
# Config
key = jax.random.PRNGKey(0)
M = 50  # Number of sequences
MIN_N = 10  # Minimum sequence length
MAX_N = 100  # Maximum sequence length
grid = jnp.arange(-200., 200, 1)  # Grid to pick inputs from

---
## Data

In [5]:
db = generate_dummy_db(M, MIN_N, MAX_N, grid, key)
db

Unnamed: 0,ID,Input,Output
0,0,8.0,-0.957674
1,0,44.0,2.742172
2,0,189.0,-4.571886
3,0,-29.0,0.661192
4,0,82.0,-2.854472
...,...,...,...
2706,49,153.0,-1.899048
2707,49,28.0,3.695970
2708,49,58.0,-4.940126
2709,49,-71.0,2.483545


In [6]:
#common_input = False
#common_hp = True
#db = pd.read_csv(f"./dummy_datasets/medium_{'common_input' if common_input else 'distinct_input'}_{'common_hp' if common_hp else 'distinct_hp'}.csv")

In [7]:
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((398,), (50, 398))

---
## Current implementation

In [8]:
# Jitted, specialized functions
@jit
def hyperpost_common_input_common_hp(outputs, prior_mean, mean_cov_L, mean_cov_inv, task_cov, inputs_to_grid=None,nugget=jnp.array(1e-10)):
	eye = jnp.eye(task_cov.shape[-1])

	# Compute task covariance and its Cholesky factor
	task_cov_L, _ = cho_factor(task_cov + eye * nugget, lower=True)
	task_cov_inv = cho_solve((task_cov_L, True), eye)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	# All tasks share same inputs and hyper-parameters, so their inverse covs are the same and we can compute one and multiply rather than compute all and sum
	post_cov_inv, _ = cho_factor(mean_cov_inv + len(outputs) * task_cov_inv, lower=True)
	post_cov = cho_solve((post_cov_inv, True), eye)

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_L, True), prior_mean)
	weighted_tasks = cho_solve((task_cov_L, True), outputs.sum(axis=0))

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, True), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov

@jit
def hyperpost_common_input_distinct_hp(outputs, prior_mean, mean_cov_L, mean_cov_inv, task_covs, inputs_to_grid=None,nugget=jnp.array(1e-10)):
	eye = jnp.eye(task_covs.shape[-1])

	# Compute task covariance and its Cholesky factor
	task_covs_L = vmap(lambda x: cho_factor(x + eye * nugget, lower=True)[0])(task_covs)
	task_cov_inv = vmap(lambda L: cho_solve((L, True), eye))(task_covs_L).sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv, lower=True)
	post_cov = cho_solve((post_cov_inv, True), eye)

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_L, True), prior_mean)
	weighted_tasks = vmap(lambda L, o: cho_solve((L, True), o))(task_covs_L, outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, True), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov

@jit
def hyperpost_distinct_input(outputs, masks, prior_mean, mean_cov_L, mean_cov_inv, task_covs, inputs_to_grid=None,nugget=jnp.array(1e-10)):
	"""
	computes the hyperpost on distinct inputs

	task_covs: (M, N, N), batch of unaligned covariances
	"""
	small_eye = jnp.eye(task_covs.shape[-1])
	big_eye = jnp.eye(mean_cov_L.shape[-1])

	# task_covs is padded with NaNs. Replace them by their corresponding identity rows/cols
	masks_2D = masks[:, :, None] & masks[:, None, :]
	task_covs = jnp.where(masks_2D, task_covs, small_eye)

	# Posterior covariance
	task_covs_L = vmap(lambda x: cho_factor(x + small_eye * nugget)[0])(task_covs)
	task_covs_inv = vmap(lambda L: cho_solve((L, False), small_eye))(task_covs_L)
	task_covs_inv -= jnp.where(masks_2D, 0, small_eye)  # Correction on the diagonal
	task_cov_inv = task_covs_inv.sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), big_eye)

	# Posterior mean
	weighted_prior_mean = cho_solve((mean_cov_L, False), prior_mean)
	outputs = jnp.where(masks, outputs, 0)
	weighted_tasks = vmap(lambda L, o: cho_solve((L, False), o))(task_covs_L, outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov

In [9]:
# General function
def hyperpost(inputs, outputs, masks, prior_mean, mean_kernel, task_kernel, all_inputs=None, grid=None, nugget=jnp.array(1e-10)):
	"""
	Compute hyper-posterior using Cholesky decomposition for numerical stability.

	Args:
		inputs: unique timestamps/grid points (M, N)
		outputs: observed outputs (M, N)
		prior_mean: mean prior (N)
		mean_kernel: kernel function for mean prior
		task_kernel: kernel function for task (single function if common_hp=True, else list of M functions)
		grid: optional grid points for predictions
		common_input: whether all tasks share same inputs
		common_hp: whether all tasks share same hyperparameters
	"""
	common_input = jnp.all(masks)
	common_hp = all([hp.ndim == 0 for hp in task_kernel.__dict__.values()])

	# Merge inputs and grid to create all_inputs
	if all_inputs is None:
		if common_input:
			all_inputs = inputs[0]
		else:
			all_inputs = jnp.sort(jnp.unique(inputs.flatten()))

	if grid is not None:
		grid = jnp.sort(jnp.unique(jnp.concatenate([all_inputs, grid])))
		inputs_to_grid = jnp.searchsorted(grid, all_inputs)
		common_input = False  # We need to pad the cov matrices to compute on the full grid
	else:
		grid = all_inputs
		inputs_to_grid = None

	if prior_mean.ndim == 0:
		prior_mean = jnp.broadcast_to(prior_mean, (len(grid),))

	# Numerical stability terms
	eye = jnp.eye(grid.shape[0])

	# Compute mean covariance and its Cholesky factor
	mean_cov = mean_kernel(grid, grid)
	mean_cov_L, _ = cho_factor(mean_cov + eye * nugget, lower=True)
	mean_cov_inv = cho_solve((mean_cov_L, True), eye)

	if common_input:
		if common_hp:
			task_cov = task_kernel(all_inputs)  # Shape: (N, N)
			return hyperpost_common_input_common_hp(outputs, prior_mean, mean_cov_L, mean_cov_inv, task_cov, inputs_to_grid, nugget)

		else:  # distinct HPs, we have to compute every task covariance but no padding is required
			task_covs = task_kernel(inputs)  # Shape: (M, N, N)
			return hyperpost_common_input_distinct_hp(outputs, prior_mean, mean_cov_L, mean_cov_inv, task_covs, inputs_to_grid, nugget)


	else:  # No common input: we have to pad and mask
		# task_covs = task_kernel(jnp.broadcast_to(all_inputs, (len(inputs), len(all_inputs))))
		task_covs = task_kernel(inputs)
		return hyperpost_distinct_input(outputs, masks, prior_mean, mean_cov_L, mean_cov_inv, task_covs, inputs_to_grid, nugget)


In [10]:
mean_kern = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = RBFKernel(length_scale=jnp.array(.6), variance=jnp.array(1.))

In [11]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs, grid=grid)

In [12]:
np.asarray(post_mean)

array([ 4.76827875e-01, -3.82572914e-01,  6.71700107e-01,  6.68929179e-01,
       -1.61778671e-01, -2.25330496e-01,  2.29367270e-01, -4.78624682e-01,
       -2.10586507e-01, -8.02063493e-01, -3.41226362e-01,  1.15179957e+00,
       -6.11421575e-01, -4.63581972e-01,  3.62309700e-01, -3.75224485e-01,
        4.06570286e-01, -6.03555201e-01,  1.00086362e+00, -1.04201837e+00,
       -6.76914216e-01, -9.07909213e-01,  2.28139532e-01,  2.05874182e+00,
        4.62433162e-01, -1.79285822e+00,  4.49769136e-01,  1.14967660e+00,
       -1.84373185e+00,  7.04613109e-01, -1.01280214e+00,  1.29808586e+00,
       -7.05488283e-01, -2.46238529e-01,  6.66241288e-02, -2.32342244e-01,
        4.89774049e-01,  8.08436488e-01, -4.02551642e-01,  3.45664909e-01,
       -1.32317118e-01,  3.25861855e-02, -1.82340811e-01,  8.45597279e-01,
        2.05515554e-01,  6.85896626e-01,  5.18542893e-01, -1.27527027e+00,
       -3.00876015e-01, -2.70756101e-01, -4.32127557e-02, -3.68096901e-01,
        4.75224489e-01,  

In [13]:
np.asarray(post_cov)

array([[ 0.14169616,  0.00470982,  0.00029938, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.00470982,  0.12329948,  0.00548576, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.00029938,  0.00548576,  0.16508814, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.12332094,
         0.00605498, -0.00127593],
       [ 0.        ,  0.        ,  0.        , ...,  0.00605498,
         0.08965122,  0.00500166],
       [ 0.        ,  0.        ,  0.        , ..., -0.00127593,
         0.00500166,  0.19752131]])

---
## Custom implementation(s)

In [14]:
# Jitted, specialized functions
@jit
def _hyperpost_common_input_common_hp(outputs, prior_mean, mean_cov_U, mean_cov_inv, task_cov, inputs_to_grid=None,nugget=jnp.array(1e-10)):
	eye = jnp.eye(task_cov.shape[-1])

	# Compute task covariance and its Cholesky factor
	task_cov_U, _ = cho_factor(task_cov + eye * nugget)
	task_cov_inv = cho_solve((task_cov_U, False), eye)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	# All tasks share same inputs and hyper-parameters, so their inverse covs are the same and we can compute one and multiply rather than compute all and sum
	post_cov_inv, _ = cho_factor(mean_cov_inv + len(outputs) * task_cov_inv,)
	post_cov = cho_solve((post_cov_inv, False), eye)

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_U, False), prior_mean)
	weighted_tasks = cho_solve((task_cov_U, False), outputs.sum(axis=0))

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov

@jit
def _hyperpost_common_input_distinct_hp(outputs, prior_mean, mean_cov_U, mean_cov_inv, task_covs, inputs_to_grid=None,nugget=jnp.array(1e-10)):
	eye = jnp.broadcast_to(jnp.eye(task_covs.shape[-1]), task_covs.shape)

	# Compute task covariance and its Cholesky factor
	#task_covs_L = vmap(lambda x: cho_factor(x + eye * nugget, lower=True)[0])(task_covs)
	task_covs_U, _ = cho_factor(task_covs + eye * nugget)
	#task_cov_inv = vmap(lambda L: cho_solve((L, True), eye))(task_covs_L).sum(axis=0)
	task_cov_inv = cho_solve((task_covs_U, False), eye)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), eye)

	# Compute posterior mean
	weighted_prior_mean = cho_solve((mean_cov_U, False), prior_mean)
	#weighted_tasks = vmap(lambda L, o: cho_solve((L, True), o))(task_covs_L, outputs).sum(axis=0)
	weighted_tasks = cho_solve((task_covs_U, False), outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov

@jit
def _hyperpost_distinct_input(outputs, masks, prior_mean, mean_cov_U, mean_cov_inv, task_covs, inputs_to_grid=None,nugget=jnp.array(1e-10)):
	"""
	computes the hyperpost on distinct inputs

	task_covs: (M, N, N), batch of unaligned covariances
	"""
	small_eye = jnp.broadcast_to(jnp.eye(task_covs.shape[-1]), task_covs.shape)
	big_eye = jnp.eye(mean_cov_U.shape[-1])

	# task_covs is padded with NaNs. Replace them by their corresponding identity rows/cols
	masks_2D = masks[:, :, None] & masks[:, None, :]
	task_covs = jnp.where(masks_2D, task_covs, small_eye)

	# Posterior covariance
	#task_covs_L = vmap(lambda x: cho_factor(x + small_eye * nugget)[0])(task_covs)
	task_covs_U, _ = cho_factor(task_covs + small_eye * nugget)
	#task_covs_inv = vmap(lambda L: cho_solve((L, False), small_eye))(task_covs_L)
	task_covs_inv = cho_solve((task_covs_U, False), small_eye)
	task_covs_inv -= jnp.where(masks_2D, 0, small_eye)  # Correction on the diagonal
	task_cov_inv = task_covs_inv.sum(axis=0)

	if inputs_to_grid is not None:
		task_cov_inv = jnp.zeros_like(mean_cov_inv).at[jnp.ix_(inputs_to_grid, inputs_to_grid)].set(task_cov_inv)

	post_cov_inv, _ = cho_factor(mean_cov_inv + task_cov_inv)
	post_cov = cho_solve((post_cov_inv, False), big_eye)

	# Posterior mean
	weighted_prior_mean = cho_solve((mean_cov_U, False), prior_mean)
	outputs = jnp.where(masks, outputs, 0)
	#weighted_tasks = vmap(lambda L, o: cho_solve((L, False), o))(task_covs_L, outputs).sum(axis=0)
	weighted_tasks = cho_solve((task_covs_U, False), outputs).sum(axis=0)

	if inputs_to_grid is not None:
		weighted_tasks = jnp.zeros_like(prior_mean).at[inputs_to_grid].set(weighted_tasks)

	post_mean = cho_solve((post_cov_inv, False), weighted_prior_mean + weighted_tasks)

	return post_mean, post_cov

In [15]:
# General function
def _hyperpost(inputs, outputs, masks, prior_mean, mean_kernel, task_kernel, all_inputs=None, grid=None, nugget=jnp.array(1e-10)):
	"""
	Compute hyper-posterior using Cholesky decomposition for numerical stability.

	Args:
		inputs: unique timestamps/grid points (M, N)
		outputs: observed outputs (M, N)
		prior_mean: mean prior (N)
		mean_kernel: kernel function for mean prior
		task_kernel: kernel function for task (single function if common_hp=True, else list of M functions)
		grid: optional grid points for predictions
		common_input: whether all tasks share same inputs
		common_hp: whether all tasks share same hyperparameters
	"""
	common_input = jnp.all(masks)
	common_hp = all([hp.ndim == 0 for hp in task_kernel.__dict__.values()])

	# Merge inputs and grid to create all_inputs
	if all_inputs is None:
		if common_input:
			all_inputs = inputs[0]
		else:
			all_inputs = jnp.sort(jnp.unique(inputs.flatten()))

	if grid is not None:
		grid = jnp.sort(jnp.unique(jnp.concatenate([all_inputs, grid])))
		inputs_to_grid = jnp.searchsorted(grid, all_inputs)
		common_input = False  # We need to pad the cov matrices to compute on the full grid
	else:
		grid = all_inputs
		inputs_to_grid = None

	if prior_mean.ndim == 0:
		prior_mean = jnp.broadcast_to(prior_mean, (len(grid),))

	# Numerical stability terms
	eye = jnp.eye(grid.shape[0])

	# Compute mean covariance and its Cholesky factor
	mean_cov = mean_kernel(grid, grid)
	mean_cov_U, _ = cho_factor(mean_cov + eye * nugget)
	mean_cov_inv = cho_solve((mean_cov_U, False), eye)

	if common_input:
		if common_hp:
			task_cov = task_kernel(all_inputs)  # Shape: (N, N)
			return _hyperpost_common_input_common_hp(outputs, prior_mean, mean_cov_U, mean_cov_inv, task_cov, inputs_to_grid, nugget)

		else:  # distinct HPs, we have to compute every task covariance but no padding is required
			task_covs = task_kernel(inputs)  # Shape: (M, N, N)
			return _hyperpost_common_input_distinct_hp(outputs, prior_mean, mean_cov_U, mean_cov_inv, task_covs, inputs_to_grid, nugget)


	else:  # No common input: we have to pad and mask
		# task_covs = task_kernel(jnp.broadcast_to(all_inputs, (len(inputs), len(all_inputs))))
		task_covs = task_kernel(inputs)
		return _hyperpost_distinct_input(outputs, masks, prior_mean, mean_cov_U, mean_cov_inv, task_covs, inputs_to_grid, nugget)


---
## Comparison

In [16]:
mean_1, cov_1 = hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [17]:
mean_2, cov_2 = _hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [18]:
mean_1

Array([ 4.76827875e-01, -3.82572914e-01,  6.71700107e-01,  6.68929179e-01,
       -1.61778671e-01, -2.25330496e-01,  2.29367270e-01, -4.78624682e-01,
       -2.10586507e-01, -8.02063493e-01, -3.41226362e-01,  1.15179957e+00,
       -6.11421575e-01, -4.63581972e-01,  3.62309700e-01, -3.75224485e-01,
        4.06570286e-01, -6.03555201e-01,  1.00086362e+00, -1.04201837e+00,
       -6.76914216e-01, -9.07909213e-01,  2.28139532e-01,  2.05874182e+00,
        4.62433162e-01, -1.79285822e+00,  4.49769136e-01,  1.14967660e+00,
       -1.84373185e+00,  7.04613109e-01, -1.01280214e+00,  1.29808586e+00,
       -7.05488283e-01, -2.46238529e-01,  6.66241288e-02, -2.32342244e-01,
        4.89774049e-01,  8.08436488e-01, -4.02551642e-01,  3.45664909e-01,
       -1.32317118e-01,  3.25861855e-02, -1.82340811e-01,  8.45597279e-01,
        2.05515554e-01,  6.85896626e-01,  5.18542893e-01, -1.27527027e+00,
       -3.00876015e-01, -2.70756101e-01, -4.32127557e-02, -3.68096901e-01,
        4.75224489e-01,  

In [19]:
mean_2

Array([ 4.76827875e-01, -3.82572914e-01,  6.71700107e-01,  6.68929179e-01,
       -1.61778671e-01, -2.25330496e-01,  2.29367270e-01, -4.78624682e-01,
       -2.10586507e-01, -8.02063493e-01, -3.41226362e-01,  1.15179957e+00,
       -6.11421575e-01, -4.63581972e-01,  3.62309700e-01, -3.75224485e-01,
        4.06570286e-01, -6.03555201e-01,  1.00086362e+00, -1.04201837e+00,
       -6.76914216e-01, -9.07909213e-01,  2.28139532e-01,  2.05874182e+00,
        4.62433162e-01, -1.79285822e+00,  4.49769136e-01,  1.14967660e+00,
       -1.84373185e+00,  7.04613109e-01, -1.01280214e+00,  1.29808586e+00,
       -7.05488283e-01, -2.46238529e-01,  6.66241288e-02, -2.32342244e-01,
        4.89774049e-01,  8.08436488e-01, -4.02551642e-01,  3.45664909e-01,
       -1.32317118e-01,  3.25861855e-02, -1.82340811e-01,  8.45597279e-01,
        2.05515554e-01,  6.85896626e-01,  5.18542893e-01, -1.27527027e+00,
       -3.00876015e-01, -2.70756101e-01, -4.32127557e-02, -3.68096901e-01,
        4.75224489e-01,  

In [20]:
jnp.allclose(mean_1, mean_2), jnp.allclose(cov_1, cov_2)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [21]:
jnp.mean(jnp.abs(mean_1 - mean_2)), jnp.mean(jnp.abs(cov_1 - cov_2))

(Array(0., dtype=float64), Array(0., dtype=float64))

In [22]:
%%timeit -n 5 -r 5
hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

227 ms ± 22.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [23]:
%%timeit -n 5 -r 5
_hyperpost(padded_inputs, padded_outputs, masks, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)[0].block_until_ready()

223 ms ± 10.3 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


---
## Conclusion

---