# Benchmarks - Likelihoods

**Main considerations when implementing Likelihoods**
- must be efficient, as they are used in the main bottleneck of the code: HP optimisation
- must work for both shared and distinct HPs

---
## Setup

In [1]:
# Standard library
import os

from MagmaClustPy.likelihoods import magma_neg_likelihood_on_cov

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
from jax import jit, vmap

import pandas as pd
import numpy as np

In [3]:
# Local
from MagmaClustPy.hyperpost import hyperpost
from Kernax import SEMagmaKernel, DiagKernel, ExpKernel
from MagmaClustPy.utils import preprocess_db

In [4]:
# Config
jitter=jnp.array(1e-10)

---
## Data

In [5]:
test_db_size = "medium"
key = jax.random.PRNGKey(42)

---
## Current implementation

In [6]:
import MagmaClustPy

magma_neg_likelihood_old = MagmaClustPy.likelihoods.magma_neg_likelihood
magma_neg_likelihood_on_cov_old = MagmaClustPy.likelihoods.magma_neg_likelihood_on_cov

---
## Custom implementation(s)

In [7]:
import jax.numpy as jnp
from jax import jit, vmap

from MagmaClustPy.linalg import extract_from_full_array, extract_from_full_matrix, solve_right_cholesky, logpdf


@jit
def magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mapping):
	outputs = outputs.ravel()  # For multi-output, we want to flatten the outputs.
	mean = mean.ravel()  # As the goal of likelihood is to see if the mean is close to the outputs, we want to flatten
	# it too.

	eyed_covar = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), covar)
	zeroed_outputs = jnp.nan_to_num(outputs)
	if mapping is not None:
		zeroed_mean = jnp.nan_to_num(extract_from_full_array(mean, outputs, mapping))
		eyed_mean_cov = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]),
		                          extract_from_full_matrix(mean_process_cov, outputs, mapping))
	else:
		zeroed_mean = jnp.nan_to_num(mean)
		eyed_mean_cov = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), mean_process_cov)

	# Compute log-likelihood
	multiv_neg_log_lik = -logpdf(zeroed_outputs, zeroed_mean, eyed_covar)

	# Compute correction term
	correction = 0.5 * jnp.trace(solve_right_cholesky(eyed_covar, eyed_mean_cov))

	# Compute padding corrections
	# The logpdf is computed as:
	# -0.5 * (N * log(2 * pi) + log(det(cov)) + (outputs - mean).T @ inv(cov) @ (outputs - mean))
	# det(cov) and the Mahalanobis distance are not affected by our padding
	# We only have to correct for the -0.5 * N * log(2 * pi) term, as N is bigger with padding
	nll_pad_correction = 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(jnp.isnan(outputs))

	# We also need to correct the correction term, as padding adds 1s to the diagonal and hence 1 to the trace
	corr_pad_correction = 0.5 * jnp.sum(jnp.isnan(outputs))

	return (multiv_neg_log_lik - nll_pad_correction) + (correction - corr_pad_correction)


@jit
def magma_neg_likelihood(kernel, inputs, outputs: jnp.array, mean: jnp.array, mean_process_cov: jnp.array,
                         mappings: jnp.array):
	"""
	Computes the MAGMA log-likelihood.

	:param kernel: The kernel to optimise. This kernel is used to compute the covariance (matrix `S`).
	:param inputs: Inputs on which to compute the covariance matrix (shape (N, I)) or (T, Max_N_i, I).
	:param outputs: The observed values for each input (shape (N, O) or (T, Max_N_i, O)).
	:param mean: The mean over the inputs (scalar or vector of shape (N,)).
	:param mean_process_cov: The hyperpost mean process covariance (matrix K^t)
	:param mappings: The indices of the inputs in the all_inputs array, if we compute the likelihood on a batch of
	inputs. Shape (T, Max_N_i)

	:return: The negative log-likelihood (scalar)
	"""
	# In multi-output, we want to flatten the outputs.
	# The user should provide a specific Kernel to compute a cross-covariance with the right shape too
	outputs = outputs.reshape(outputs.shape[0], -1)

	if mean.ndim == 0:
		mean = jnp.broadcast_to(mean[None], outputs.shape)

	covar = kernel(inputs)

	# check if we need to vmap
	if inputs.ndim == 2:
		return magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mappings)
	elif inputs.ndim == 3:
		return vmap(magma_neg_likelihood_on_cov, in_axes=(0, 0, None, None, 0))(covar, outputs, mean,
		                                                                              mean_process_cov, mappings)
	else:
		raise ValueError("inputs must be either 1D or 2D")


In [8]:
magma_neg_likelihood_new = magma_neg_likelihood
magma_neg_likelihood_on_cov_new = magma_neg_likelihood_on_cov

---
## Comparison

### Sample test

In [9]:
db = pd.DataFrame({
	'Task_ID': [1, 1, 1, 1, 2, 2, 2, 2],
	'Input': [0.40, 4.45, 7.60, 8.30, 3.50, 5.10, 8.85, 9.35],
	'Input_ID': ['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x'],
	'Output': [59.81620, 67.13694, 78.32495, 81.83590, 62.04943, 67.31932, 85.94063, 86.76426],
	'Output_ID': ['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a'],
})

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((8, 1), (2, 4, 1))

In [10]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [11]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((8,), (8, 8))

In [12]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, jitter=jitter)
mean_llh_old.item()

1784.1385498046875

In [13]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None)
mean_llh_new.item()

1784.1385498046875

In [14]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [15]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, jitter=jitter).block_until_ready()

13.1 μs ± 721 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None).block_until_ready()

17.4 μs ± 2.68 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [17]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)
np.asarray(task_llhs_old)

array([ 979.3787, 1425.9596], dtype=float32)

In [18]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)
np.asarray(task_llhs_new)

array([ 979.3787, 1425.9596], dtype=float32)

In [19]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [20]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)[0].block_until_ready()

77.6 μs ± 6.05 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)[0].block_until_ready()

69.8 μs ± 1.37 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### shared Input, shared HP

In [22]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_shared_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [23]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [24]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((150,), (150, 150))

In [25]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, jitter=jitter)
mean_llh_old.item()

872183.5

In [26]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None)
mean_llh_new.item()

872183.5

In [27]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [28]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, jitter=jitter).block_until_ready()

306 μs ± 5.07 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [29]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None).block_until_ready()

324 μs ± 11.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [30]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)
np.asarray(task_llhs_old)

array([443.6553 , 454.38562, 457.5374 , 455.47693, 462.87177, 463.13123,
       455.5172 , 455.21786, 445.64392, 448.64728, 454.41046, 455.07117,
       452.42175, 458.5739 , 466.3769 , 443.49414, 446.15723, 461.97308,
       450.1917 , 459.15515, 457.86658, 447.43152, 454.0028 , 447.9323 ,
       447.3684 , 458.4273 , 452.39642, 453.87402, 458.4544 , 441.76248,
       444.64832, 446.22702, 448.77118, 445.6949 , 439.18494, 453.7881 ,
       465.7704 , 458.00897, 461.81604, 464.71808, 453.14685, 460.88922,
       454.30603, 446.512  , 458.84503, 455.2395 , 447.42224, 463.80176,
       454.82025, 450.68652, 448.46313, 465.56555, 463.21326, 445.2347 ,
       460.45862, 454.88947, 448.9272 , 455.2377 , 445.3708 , 469.69214,
       459.364  , 463.89844, 456.41785, 460.37604, 458.02527, 432.56708,
       447.81824, 458.10626, 446.43527, 436.3448 , 470.3333 , 449.6783 ,
       455.79224, 454.3545 , 460.23633, 442.2329 , 455.8617 , 456.8969 ,
       458.90454, 460.1977 , 456.53622, 452.1601 , 

In [31]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)
np.asarray(task_llhs_new)

array([443.6553 , 454.38562, 457.5374 , 455.47693, 462.87177, 463.13123,
       455.5172 , 455.21786, 445.64392, 448.64728, 454.41046, 455.07117,
       452.42175, 458.5739 , 466.3769 , 443.49414, 446.15723, 461.97308,
       450.1917 , 459.15515, 457.86658, 447.43152, 454.0028 , 447.9323 ,
       447.3684 , 458.4273 , 452.39642, 453.87402, 458.4544 , 441.76248,
       444.64832, 446.22702, 448.77118, 445.6949 , 439.18494, 453.7881 ,
       465.7704 , 458.00897, 461.81604, 464.71808, 453.14685, 460.88922,
       454.30603, 446.512  , 458.84503, 455.2395 , 447.42224, 463.80176,
       454.82025, 450.68652, 448.46313, 465.56555, 463.21326, 445.2347 ,
       460.45862, 454.88947, 448.9272 , 455.2377 , 445.3708 , 469.69214,
       459.364  , 463.89844, 456.41785, 460.37604, 458.02527, 432.56708,
       447.81824, 458.10626, 446.43527, 436.3448 , 470.3333 , 449.6783 ,
       455.79224, 454.3545 , 460.23633, 442.2329 , 455.8617 , 456.8969 ,
       458.90454, 460.1977 , 456.53622, 452.1601 , 

In [32]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [33]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)[0].block_until_ready()

33.5 ms ± 3.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [34]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)[0].block_until_ready()

31.5 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### shared Input, Distinct HP

In [35]:
db = pd.read_csv(f"../datasets/{test_db_size}_shared_input_distinct_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((150, 1), (200, 150, 1))

In [36]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

  distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)


In [37]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((150,), (150, 150))

In [38]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, jitter=jitter)
mean_llh_old.item()

10192.369140625

In [39]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None)
mean_llh_new.item()

10192.369140625

In [40]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(True, dtype=bool)

In [41]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, jitter=jitter).block_until_ready()

304 μs ± 8.56 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [42]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None).block_until_ready()

335 μs ± 19.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [43]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)
np.asarray(task_llhs_old)

array([368.96872, 350.56775, 384.8568 , 353.88297, 380.07193, 358.41577,
       357.34866, 353.21313, 353.47253, 372.61047, 361.04248, 354.94385,
       377.4471 , 354.91913, 357.47842, 357.51373, 366.86664, 358.49194,
       368.3207 , 366.55603, 350.9216 , 355.50748, 356.96408, 363.69537,
       359.62677, 356.46295, 368.06085, 363.49857, 353.8957 , 351.7763 ,
       371.24283, 355.0993 , 360.23676, 363.94568, 357.47717, 358.37836,
       363.0932 , 366.7271 , 352.73233, 361.1225 , 354.43344, 353.2872 ,
       351.8119 , 354.34683, 360.62082, 357.78052, 358.61105, 352.29794,
       363.7009 , 360.263  , 359.9521 , 358.13187, 358.04987, 361.21466,
       359.20905, 366.4635 , 351.26895, 362.3736 , 363.09778, 356.1594 ,
       355.89258, 373.53943, 358.28003, 358.01212, 357.05145, 368.40247,
       359.0025 , 354.35956, 373.04333, 351.71677, 362.16464, 374.7091 ,
       355.91803, 351.41528, 366.36914, 358.41656, 362.5278 , 359.21255,
       370.31894, 375.80215, 357.18744, 362.4392 , 

In [44]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)
np.asarray(task_llhs_new)

array([368.96872, 350.56775, 384.8568 , 353.88297, 380.07193, 358.41577,
       357.34866, 353.21313, 353.47253, 372.61047, 361.04248, 354.94385,
       377.4471 , 354.91913, 357.47842, 357.51373, 366.86664, 358.49194,
       368.3207 , 366.55603, 350.9216 , 355.50748, 356.96408, 363.69537,
       359.62677, 356.46295, 368.06085, 363.49857, 353.8957 , 351.7763 ,
       371.24283, 355.0993 , 360.23676, 363.94568, 357.47717, 358.37836,
       363.0932 , 366.7271 , 352.73233, 361.1225 , 354.43344, 353.2872 ,
       351.8119 , 354.34683, 360.62082, 357.78052, 358.61105, 352.29794,
       363.7009 , 360.263  , 359.9521 , 358.13187, 358.04987, 361.21466,
       359.20905, 366.4635 , 351.26895, 362.3736 , 363.09778, 356.1594 ,
       355.89258, 373.53943, 358.28003, 358.01212, 357.05145, 368.40247,
       359.0025 , 354.35956, 373.04333, 351.71677, 362.16464, 374.7091 ,
       355.91803, 351.41528, 366.36914, 358.41656, 362.5278 , 359.21255,
       370.31894, 375.80215, 357.18744, 362.4392 , 

In [45]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [46]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)[0].block_until_ready()

34.1 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [47]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)[0].block_until_ready()

31.1 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Distinct Input, shared HP

In [48]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_shared_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 190, 1))

In [49]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [50]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((401,), (401, 401))

In [51]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, jitter=jitter)
mean_llh_old.item()

nan

In [52]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None)
mean_llh_new.item()

inf

In [53]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(False, dtype=bool)

In [54]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, jitter=jitter).block_until_ready()

2.67 ms ± 230 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [55]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None).block_until_ready()

3.44 ms ± 216 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [56]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)
np.asarray(task_llhs_old)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [57]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)
np.asarray(task_llhs_new)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [58]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [59]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)[0].block_until_ready()

56.4 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)[0].block_until_ready()

59.5 ms ± 4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Distinct Input, Distinct HP

In [61]:
db = pd.read_csv(f"../datasets/{test_db_size}_distinct_input_distinct_hp.csv")

padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

prior_mean = jnp.array(0)
all_inputs.shape, padded_inputs.shape

((401, 1), (200, 188, 1))

In [62]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float32, .1, 1)
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.)) + DiagKernel(ExpKernel(jnp.array(2.5)))

In [63]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((401,), (401, 401))

In [64]:
mean_llh_old = magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, jnp.full(all_inputs.squeeze().shape, prior_mean), post_cov, None, jitter=jitter)
mean_llh_old.item()

nan

In [65]:
mean_llh_new = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None)
mean_llh_new.item()

inf

In [66]:
jnp.allclose(mean_llh_old, mean_llh_new)

Array(False, dtype=bool)

In [67]:
%timeit magma_neg_likelihood_old(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, jitter=jitter).block_until_ready()

2.4 ms ± 110 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [68]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None).block_until_ready()

3.21 ms ± 125 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [69]:
task_llhs_old = magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)
np.asarray(task_llhs_old)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [70]:
task_llhs_new = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)
np.asarray(task_llhs_new)

array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
       inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, in

In [71]:
jnp.allclose(task_llhs_old, task_llhs_new)

Array(True, dtype=bool)

In [72]:
%timeit magma_neg_likelihood_old(task_kern, padded_inputs, padded_outputs.squeeze(), post_mean, post_cov, mappings, jitter=jitter)[0].block_until_ready()

58.3 ms ± 4.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [73]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings)[0].block_until_ready()

61.7 ms ± 4.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


---