# Benchmarks - Likelihoods

**Main considerations when implementing Likelihoods**
- must be efficient, as they are used in the main bottleneck of the code: HP optimisation
- must work for both shared and distinct HPs

---
## Setup

In [1]:
# Standard library
import os

from MagmaClustPy.likelihoods import magma_neg_likelihood_on_cov

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.linalg import cho_factor, cho_solve
from jax.scipy.stats.multivariate_normal import logpdf

import pandas as pd
import numpy as np

In [3]:
# Local
from MagmaClustPy.hyperpost import hyperpost
from Kernax import SEMagmaKernel, DiagKernel, ExpKernel
from MagmaClustPy.utils import preprocess_db

In [4]:
# Config
nugget=jnp.array(1e-4)

---
## Data

In [5]:
test_db_size = "medium"
key = jax.random.PRNGKey(42)

---
## Current implementation

In [6]:
import MagmaClustPy

magma_neg_likelihood_old = MagmaClustPy.likelihoods.magma_neg_likelihood
magma_neg_likelihood_on_cov_old = MagmaClustPy.likelihoods.magma_neg_likelihood_on_cov

---
## Custom implementation(s)

In [7]:
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.stats.multivariate_normal import logpdf

from MagmaClustPy.linalg import extract_from_full_array, extract_from_full_matrix, solve_right_cholesky


@jit
def magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mapping, nugget=jnp.array(1e-10)):
	outputs = outputs.ravel()  # For multi-output, we want to flatten the outputs.
	mean = mean.ravel()  # As the goal of likelihood is to see if the mean is close to the outputs, we want to flatten it too.

	nugget_matrix = jnp.eye(outputs.shape[0]) * nugget

	eyed_covar = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), covar)
	zeroed_outputs = jnp.nan_to_num(outputs)
	if mapping is not None:
		zeroed_mean = jnp.nan_to_num(extract_from_full_array(mean, outputs, mapping))
		eyed_mean_cov = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), extract_from_full_matrix(mean_process_cov, outputs, mapping))
	else:
		zeroed_mean = jnp.nan_to_num(mean)
		eyed_mean_cov = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), mean_process_cov)


	# Compute log-likelihood
	multiv_neg_log_lik = -logpdf(zeroed_outputs, zeroed_mean, eyed_covar + nugget_matrix)

	# Compute correction term
	correction = 0.5 * jnp.trace(solve_right_cholesky(eyed_covar, eyed_mean_cov, nugget=nugget))

	# Compute padding corrections
	# The logpdf is computed as:
	# -0.5 * (N * log(2 * pi) + log(det(cov)) + (outputs - mean).T @ inv(cov) @ (outputs - mean))
	# det(cov) and the Mahalanobis distance are not affected by our padding
	# We only have to correct for the -0.5 * N * log(2 * pi) term, as N is bigger with padding
	nll_pad_correction = 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(jnp.isnan(outputs))

	# We also need to correct the correction term, as padding adds 1s to the diagonal and hence 1 to the trace
	corr_pad_correction = 0.5 * jnp.sum(jnp.isnan(outputs))

	return (multiv_neg_log_lik - nll_pad_correction) + (correction - corr_pad_correction)


@jit
def magma_neg_likelihood(kernel, inputs, outputs: jnp.array, mean: jnp.array, mean_process_cov: jnp.array, mappings: jnp.array, nugget=jnp.array(1e-10)):
	"""
	Computes the MAGMA log-likelihood.

	:param kernel: the kernel containing HPs to optimise. This kernel is used to compute the covariance (matrix `S`)
	:param inputs: inputs on which to compute the covariance matrix (shape (N, ))
	:param mask: boolean masks indicating which inputs and outputs to consider (shape (N, ))  #TODO: fix
	:param outputs: the observed values (shape (N, ))
	:param mean: the mean over the inputs (scalar or vector of shape (N, ))
	:param mean_process_cov: the hypper-posterior mean process covariance (matrix K^t)
	:param nugget: the nugget, for numerical stability

	:return: the negative log-likelihood (scalar)
	"""
	# In multi-output, we want to flatten the outputs.
	# The user should provide a specific Kernel to compute a cross-covariance with the right shape too
	outputs = outputs.reshape(outputs.shape[0], -1)

	if mean.ndim == 0:
		mean = jnp.broadcast_to(mean[None], outputs.shape)

	covar = kernel(inputs)

	# check if we need to vmap
	if inputs.ndim == 2:
		return magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mappings, nugget)
	elif inputs.ndim == 3:
		return vmap(magma_neg_likelihood_on_cov, in_axes=(0, 0, None, None, 0, None))(covar, outputs, mean, mean_process_cov, mappings, nugget)
	else:
		raise ValueError("inputs must be either 1D or 2D")

In [8]:
magma_neg_likelihood_new = magma_neg_likelihood
magma_neg_likelihood_on_cov_new = magma_neg_likelihood_on_cov

---
## Comparison

### Sample test

In [9]:
db = pd.DataFrame({
	'Task_ID': [1, 1, 1, 1, 2, 2, 2, 2],
	'Input': [0.40, 4.45, 7.60, 8.30, 3.50, 5.10, 8.85, 9.35],
	'Input_ID': ['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x'],
	'Output': [59.81620, 67.13694, 78.32495, 81.83590, 62.04943, 67.31932, 85.94063, 86.76426],
	'Output_ID': ['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a'],
})

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((8, 1), (2, 4, 1))

In [10]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)
task_kern = SEMagmaKernel(length_scale=.3, variance=1.) + DiagKernel(ExpKernel(2.5))

In [11]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((8,), (8, 8))

In [12]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, nugget=nugget)
mean_llh_old.item()

1784.2900390625

In [13]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh_new.item()

1784.2900390625

In [14]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [15]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

9.88 μs ± 307 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

9.69 μs ± 205 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [17]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_old)

array([ 979.3764, 1425.9928], dtype=float32)

In [18]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_new)

array([ 979.3764, 1425.9928], dtype=float32)

In [19]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [20]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

59 μs ± 2.47 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

53.5 μs ± 2.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### shared Input, shared HP

In [22]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_shared_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [23]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)
task_kern = SEMagmaKernel(length_scale=.3, variance=1.) + DiagKernel(ExpKernel(2.5))

In [24]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((150,), (150, 150))

In [25]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, nugget=nugget)
mean_llh_old.item()

872167.75

In [26]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh_new.item()

872167.75

In [27]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [28]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

249 μs ± 7.76 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [29]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

245 μs ± 6.11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [30]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_old)

array([443.60623, 454.33514, 457.4909 , 455.43054, 462.82095, 463.07523,
       455.4749 , 455.17505, 445.59637, 448.59924, 454.36047, 455.02356,
       452.37225, 458.5221 , 466.32495, 443.44504, 446.11185, 461.92776,
       450.13547, 459.11014, 457.8158 , 447.38824, 453.94556, 447.88974,
       447.32092, 458.37628, 452.34546, 453.8296 , 458.3998 , 441.71176,
       444.6034 , 446.18707, 448.72382, 445.65454, 439.13727, 453.74615,
       465.72504, 457.96024, 461.77377, 464.66547, 453.10135, 460.84247,
       454.2546 , 446.46265, 458.79422, 455.19415, 447.37918, 463.7545 ,
       454.7738 , 450.6353 , 448.41592, 465.51456, 463.16537, 445.1893 ,
       460.41043, 454.84998, 448.88474, 455.19244, 445.32806, 469.64545,
       459.3202 , 463.84857, 456.37054, 460.3284 , 457.97467, 432.5199 ,
       447.77032, 458.05536, 446.3839 , 436.29736, 470.28336, 449.6275 ,
       455.74258, 454.30447, 460.1877 , 442.18527, 455.8119 , 456.85345,
       458.8584 , 460.14972, 456.49634, 452.11597, 

In [31]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_new)

array([443.60623, 454.33514, 457.4909 , 455.43054, 462.82095, 463.07523,
       455.4749 , 455.17505, 445.59637, 448.59924, 454.36047, 455.02356,
       452.37225, 458.5221 , 466.32495, 443.44504, 446.11185, 461.92776,
       450.13547, 459.11014, 457.8158 , 447.38824, 453.94556, 447.88974,
       447.32092, 458.37628, 452.34546, 453.8296 , 458.3998 , 441.71176,
       444.6034 , 446.18707, 448.72382, 445.65454, 439.13727, 453.74615,
       465.72504, 457.96024, 461.77377, 464.66547, 453.10135, 460.84247,
       454.2546 , 446.46265, 458.79422, 455.19415, 447.37918, 463.7545 ,
       454.7738 , 450.6353 , 448.41592, 465.51456, 463.16537, 445.1893 ,
       460.41043, 454.84998, 448.88474, 455.19244, 445.32806, 469.64545,
       459.3202 , 463.84857, 456.37054, 460.3284 , 457.97467, 432.5199 ,
       447.77032, 458.05536, 446.3839 , 436.29736, 470.28336, 449.6275 ,
       455.74258, 454.30447, 460.1877 , 442.18527, 455.8119 , 456.85345,
       458.8584 , 460.14972, 456.49634, 452.11597, 

In [32]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [33]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

21.6 ms ± 672 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [34]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

21.8 ms ± 146 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### shared Input, Distinct HP

In [35]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_distinct_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [36]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=.3, variance=1.) + DiagKernel(ExpKernel(2.5))

  distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)


In [37]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((150,), (150, 150))

In [38]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, nugget=nugget)
mean_llh_old.item()

10194.2333984375

In [39]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh_new.item()

10194.2333984375

In [40]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [41]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

244 μs ± 3.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [42]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

243 μs ± 5.56 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [43]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_old)

array([368.95862, 350.5709 , 384.86948, 353.88013, 380.0682 , 358.41782,
       357.35513, 353.2122 , 353.47562, 372.60168, 361.04303, 354.94748,
       377.4451 , 354.9243 , 357.47787, 357.51254, 366.8542 , 358.49026,
       368.32343, 366.5707 , 350.92822, 355.50162, 356.96152, 363.69498,
       359.6304 , 356.46902, 368.06277, 363.49258, 353.89468, 351.7763 ,
       371.23563, 355.09842, 360.22958, 363.93686, 357.47513, 358.38312,
       363.08444, 366.7314 , 352.72983, 361.117  , 354.43442, 353.285  ,
       351.80478, 354.34006, 360.6326 , 357.77896, 358.61356, 352.2974 ,
       363.70712, 360.25967, 359.95428, 358.1223 , 358.04218, 361.2101 ,
       359.2108 , 366.46176, 351.2713 , 362.37592, 363.10132, 356.16336,
       355.89346, 373.53323, 358.27518, 358.0164 , 357.06   , 368.3975 ,
       359.00458, 354.35678, 373.04636, 351.7203 , 362.15994, 374.7101 ,
       355.917  , 351.41705, 366.3649 , 358.41385, 362.53186, 359.20724,
       370.32977, 375.80643, 357.18503, 362.45016, 

In [44]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_new)

array([368.95862, 350.5709 , 384.86948, 353.88013, 380.0682 , 358.41782,
       357.35513, 353.2122 , 353.47562, 372.60168, 361.04303, 354.94748,
       377.4451 , 354.9243 , 357.47787, 357.51254, 366.8542 , 358.49026,
       368.32343, 366.5707 , 350.92822, 355.50162, 356.96152, 363.69498,
       359.6304 , 356.46902, 368.06277, 363.49258, 353.89468, 351.7763 ,
       371.23563, 355.09842, 360.22958, 363.93686, 357.47513, 358.38312,
       363.08444, 366.7314 , 352.72983, 361.117  , 354.43442, 353.285  ,
       351.80478, 354.34006, 360.6326 , 357.77896, 358.61356, 352.2974 ,
       363.70712, 360.25967, 359.95428, 358.1223 , 358.04218, 361.2101 ,
       359.2108 , 366.46176, 351.2713 , 362.37592, 363.10132, 356.16336,
       355.89346, 373.53323, 358.27518, 358.0164 , 357.06   , 368.3975 ,
       359.00458, 354.35678, 373.04636, 351.7203 , 362.15994, 374.7101 ,
       355.917  , 351.41705, 366.3649 , 358.41385, 362.53186, 359.20724,
       370.32977, 375.80643, 357.18503, 362.45016, 

In [45]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [46]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

21.3 ms ± 281 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [47]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

21.2 ms ± 280 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Distinct Input, shared HP

In [48]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 190, 1))

In [49]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)
task_kern = SEMagmaKernel(length_scale=.3, variance=1.) + DiagKernel(ExpKernel(2.5))

In [50]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((401,), (401, 401))

In [51]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, nugget=nugget)
mean_llh_old.item()

inf

In [52]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh_new.item()

inf

In [53]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [54]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

1.04 ms ± 18.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [55]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

1.11 ms ± 92.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [56]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_old)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [57]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_new)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [58]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [59]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

34.6 ms ± 779 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

35.6 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Distinct Input, Distinct HP

In [61]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_distinct_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 188, 1))

In [62]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float32, .1, 1)
task_kern = SEMagmaKernel(length_scale=.3, variance=1.) + DiagKernel(ExpKernel(2.5))

In [63]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs, nugget=nugget)
post_mean.shape, post_cov.shape

((401,), (401, 401))

In [64]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, nugget=nugget)
mean_llh_old.item()

inf

In [65]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh_new.item()

inf

In [66]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [67]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

1.03 ms ± 8.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [68]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

1.27 ms ± 318 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [69]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_old)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [70]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs_new)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [71]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [72]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

33.5 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [73]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

36.6 ms ± 4.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


---