# Benchmarks - Likelihoods

**Main considerations when implementing Likelihoods**
- must be efficient, as they are used in the main bottleneck of the code: HP optimisation
- must work for both common and distinct HPs

---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.linalg import cho_factor, cho_solve
from jax.scipy.stats.multivariate_normal import logpdf

import pandas as pd
import numpy as np

In [3]:
# Local
from MagmaClustPy.hyperpost import hyperpost
from MagmaClustPy.kernels import SEMagmaKernel
from MagmaClustPy.utils import preprocess_db, extract_from_full_array, extract_from_full_matrix

from MagmaClustPy.legacy import _legacy_preprocess_db

In [4]:
# Config
nugget=jnp.array(1e-10)

---
## Data

In [37]:
test_db_size = "medium"
key = jax.random.PRNGKey(42)

---
## Current implementation

In [6]:
from MagmaClustPy.likelihoods import magma_neg_likelihood

---
## Custom implementation(s)

In [7]:
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.linalg import cho_factor, cho_solve
from jax.scipy.stats.multivariate_normal import logpdf


@jit
def solve_right_cholesky(A, B, nugget=jnp.array(1e-10)):
	""" Solves for X in X @ A = B """
	# For X @ A = B, we can transpose both sides: A.T @ X.T = B.T
	# As A and B are symmetric, this simplifies to A @ X.T = B
	# Then solve for X.T and transpose the result
	nugget_matrix = jnp.eye(A.shape[0]) * nugget
	return cho_solve(cho_factor(A + nugget_matrix), B).T


@jit
def magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mapping, nugget=jnp.array(1e-10)):
	nugget_matrix = jnp.eye(outputs.shape[0]) * nugget

	eyed_covar = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), covar)
	zeroed_outputs = jnp.nan_to_num(outputs)
	if mapping is not None:
		zeroed_mean = jnp.nan_to_num(extract_from_full_array(mean, outputs, mapping))
		eyed_mean_cov = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), extract_from_full_matrix(mean_process_cov, outputs, mapping))
	else:
		zeroed_mean = jnp.nan_to_num(mean)
		eyed_mean_cov = jnp.where(jnp.isnan(covar), jnp.eye(covar.shape[0]), mean_process_cov)


	# Compute log-likelihood
	multiv_neg_log_lik = -logpdf(zeroed_outputs, zeroed_mean, eyed_covar + nugget_matrix)

	# Compute correction term
	correction = 0.5 * jnp.trace(solve_right_cholesky(eyed_covar, eyed_mean_cov, nugget=nugget))

	# Compute padding corrections
	# The logpdf is computed as:
	# -0.5 * (N * log(2 * pi) + log(det(cov)) + (outputs - mean).T @ inv(cov) @ (outputs - mean))
	# det(cov) and the Mahalanobis distance are not affected by our padding
	# We only have to correct for the -0.5 * N * log(2 * pi) term, as N is bigger with padding
	nll_pad_correction = 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(jnp.isnan(outputs))

	# We also need to correct the correction term, as padding adds 1s to the diagonal and hence 1 to the trace
	corr_pad_correction = 0.5 * jnp.sum(jnp.isnan(outputs))

	return (multiv_neg_log_lik - nll_pad_correction) + (correction - corr_pad_correction)


@jit
def magma_neg_likelihood_new(kernel, inputs, outputs: jnp.array, mean: jnp.array, mean_process_cov: jnp.array, mappings: jnp.array, nugget=jnp.array(1e-10)):
	"""
	Computes the MAGMA log-likelihood.

	:param kernel: the kernel containing HPs to optimise. This kernel is used to compute the covariance (matrix `S`)
	:param inputs: inputs on which to compute the covariance matrix (shape (N, ))
	:param mask: boolean masks indicating which inputs and outputs to consider (shape (N, ))
	:param outputs: the observed values (shape (N, ))
	:param mean: the mean over the inputs (scalar or vector of shape (N, ))
	:param mean_process_cov: the hypper-posterior mean process covariance (matrix K^t)
	:param nugget: the nugget, for numerical stability

	:return: the negative log-likelihood (scalar)
	"""
	covar = kernel(inputs)

	# check if we need to vmap
	if inputs.ndim == 1:
		return magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mappings, nugget)
	elif inputs.ndim == 2:
		return vmap(magma_neg_likelihood_on_cov, in_axes=(0, 0, None, None, 0, None))(covar, outputs, mean, mean_process_cov, mappings, nugget)
	else:
		raise ValueError("inputs must be either 1D or 2D")

---
## Comparison

### Sample test

In [8]:
db = pd.DataFrame({
	'ID': [1, 1, 1, 1, 2, 2, 2, 2],
	'Input': [0.40, 4.45, 7.60, 8.30, 3.50, 5.10, 8.85, 9.35],
	'Output': [59.81620, 67.13694, 78.32495, 81.83590, 62.04943, 67.31932, 85.94063, 86.76426]
})

_legacy_all_inputs, _legacy_padded_inputs, _legacy_padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs, padded_inputs, padded_outputs, mappings = preprocess_db(db)

prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((8,), (2, 4))

In [9]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

In [10]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((8,), (8, 8))

In [11]:
_legacy_mean_llh = magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget)
_legacy_mean_llh.item()

1215.4708686599558

In [12]:
mean_llh = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh.item()

1215.4708686599558

In [13]:
jnp.allclose(_legacy_mean_llh, mean_llh)

Array(True, dtype=bool)

In [14]:
%timeit magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget).block_until_ready()

11.6 μs ± 271 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [15]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

14.2 μs ± 217 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
_legacy_task_llhs = magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)
np.asarray(_legacy_task_llhs)

array([443.00084239, 405.85107611])

In [17]:
task_llhs = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs)

array([443.00084239, 405.85107611])

In [18]:
jnp.allclose(_legacy_task_llhs, task_llhs)

Array(True, dtype=bool)

In [19]:
%timeit magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)[0].block_until_ready()

63.9 μs ± 486 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)[0].block_until_ready()

61.8 μs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### Common Input, Common HP

In [21]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_common_input_common_hp.csv")

_legacy_all_inputs, _legacy_padded_inputs, _legacy_padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs, padded_inputs, padded_outputs, mappings = preprocess_db(db)

prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((150,), (200, 150))

In [22]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

In [23]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((150,), (150, 150))

In [24]:
mean_2, cov_2 = hyperpost(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [25]:
_legacy_mean_llh = magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget)
_legacy_mean_llh.item()

912067.9333631559

In [26]:
mean_llh = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh.item()

912067.9333631559

In [27]:
jnp.allclose(_legacy_mean_llh, mean_llh)

Array(True, dtype=bool)

In [28]:
%timeit magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget).block_until_ready()

394 μs ± 13.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [29]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

423 μs ± 8.15 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [30]:
_legacy_task_llhs = magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)
np.asarray(_legacy_task_llhs)

array([ 605.23211278,  343.51185004,  918.80977743,  513.24387158,
        697.57622048,  951.96126093,  981.29273262, 1269.94340467,
        463.54045028, 1050.45871158,  961.64437667, 2003.26792708,
       1749.46618771, 1312.85185097,  445.83816778,  835.35699651,
       1463.91059083,  427.26689016,  700.89894481,  671.70252754,
        534.18847225,  926.73962846,  661.4772385 ,  308.82867287,
       1376.61098244, 2059.56683147, 2367.20428303,  581.85247446,
        400.1809817 , 1553.99525814,  563.99556873,  847.75774284,
        771.59365077, 1191.41168731,  761.09971567,  956.93556212,
        454.71559202,  424.41881608,  457.05561287, 1187.59726143,
       1127.71906547, 6487.12506158, 1503.98788529, 1000.09652318,
        396.8322629 , 2294.87962713, 3830.93031689,  416.98068956,
        593.50319673, 1251.04797794,  397.32060574,  455.53276014,
        469.18918144, 1671.10147422, 1350.63926341, 1008.98592561,
        662.63825595, 1155.21163564, 1781.42530833,  506.91403

In [31]:
task_llhs = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs)

array([ 605.23211278,  343.51185004,  918.80977743,  513.24387158,
        697.57622048,  951.96126093,  981.29273262, 1269.94340467,
        463.54045028, 1050.45871158,  961.64437667, 2003.26792708,
       1749.46618771, 1312.85185097,  445.83816778,  835.35699651,
       1463.91059083,  427.26689016,  700.89894481,  671.70252754,
        534.18847225,  926.73962846,  661.4772385 ,  308.82867287,
       1376.61098244, 2059.56683147, 2367.20428303,  581.85247446,
        400.1809817 , 1553.99525814,  563.99556873,  847.75774284,
        771.59365077, 1191.41168731,  761.09971567,  956.93556212,
        454.71559202,  424.41881608,  457.05561287, 1187.59726143,
       1127.71906547, 6487.12506158, 1503.98788529, 1000.09652318,
        396.8322629 , 2294.87962713, 3830.93031689,  416.98068956,
        593.50319673, 1251.04797794,  397.32060574,  455.53276014,
        469.18918144, 1671.10147422, 1350.63926341, 1008.98592561,
        662.63825595, 1155.21163564, 1781.42530833,  506.91403

In [32]:
jnp.allclose(_legacy_task_llhs, task_llhs)

Array(True, dtype=bool)

In [33]:
%timeit magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget).block_until_ready()

61.9 ms ± 2.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [34]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget).block_until_ready()

62.1 ms ± 2.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Common Input, Distinct HP

In [35]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_common_input_distinct_hp.csv")

_legacy_all_inputs, _legacy_padded_inputs, _legacy_padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs, padded_inputs, padded_outputs, mappings = preprocess_db(db)

prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((150,), (200, 150))

In [38]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=1.)

In [39]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((150,), (150, 150))

In [40]:
mean_2, cov_2 = hyperpost(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [41]:
_legacy_mean_llh = magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget)
_legacy_mean_llh.item()

10841.32844777948

In [42]:
mean_llh = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh.item()

10841.32844777948

In [43]:
jnp.allclose(_legacy_mean_llh, mean_llh)

Array(True, dtype=bool)

In [44]:
%timeit magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget).block_until_ready()

370 μs ± 6.36 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [45]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

417 μs ± 14.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [46]:
_legacy_task_llhs = magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)
np.asarray(_legacy_task_llhs)

array([7.42752083e+06, 7.39515317e+04, 2.38880759e+03, 2.13662388e+04,
       1.28303513e+06, 8.92513826e+02, 1.67626051e+04, 1.88331096e+03,
       1.02485976e+05, 4.18526467e+02, 3.85688142e+02, 9.88478320e+04,
       1.52705574e+06, 1.29254404e+06, 1.12086683e+03, 4.56577939e+03,
       9.66140077e+06, 1.94997496e+07, 1.06940088e+06, 1.94067773e+04,
       3.41712434e+03, 1.44115752e+06, 2.68515133e+06, 3.17008131e+04,
       2.70968708e+05, 1.46381934e+06, 2.52224037e+05, 1.69374190e+07,
       1.49539519e+04, 3.37590899e+06, 8.01090604e+03, 2.18886997e+06,
       2.96718690e+02, 5.87446719e+05, 2.47186046e+04, 9.88561293e+06,
       2.28354104e+04, 4.14281567e+02, 6.48085619e+02, 2.35004480e+04,
       2.65077600e+03, 1.59778036e+03, 1.84399764e+03, 1.72859162e+07,
       6.60580087e+03, 2.04610046e+06, 2.25396858e+03, 5.08897147e+03,
       4.13002073e+06, 1.08157645e+07, 2.37473111e+04, 3.78983455e+05,
       4.95973686e+04, 9.85670171e+06, 2.28536600e+03, 7.92248346e+07,
      

In [47]:
task_llhs = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs)

array([7.42752083e+06, 7.39515317e+04, 2.38880759e+03, 2.13662388e+04,
       1.28303513e+06, 8.92513826e+02, 1.67626051e+04, 1.88331096e+03,
       1.02485976e+05, 4.18526467e+02, 3.85688142e+02, 9.88478320e+04,
       1.52705574e+06, 1.29254404e+06, 1.12086683e+03, 4.56577939e+03,
       9.66140077e+06, 1.94997496e+07, 1.06940088e+06, 1.94067773e+04,
       3.41712434e+03, 1.44115752e+06, 2.68515133e+06, 3.17008131e+04,
       2.70968708e+05, 1.46381934e+06, 2.52224037e+05, 1.69374190e+07,
       1.49539519e+04, 3.37590899e+06, 8.01090604e+03, 2.18886997e+06,
       2.96718690e+02, 5.87446719e+05, 2.47186046e+04, 9.88561293e+06,
       2.28354104e+04, 4.14281567e+02, 6.48085619e+02, 2.35004480e+04,
       2.65077600e+03, 1.59778036e+03, 1.84399764e+03, 1.72859162e+07,
       6.60580087e+03, 2.04610046e+06, 2.25396858e+03, 5.08897147e+03,
       4.13002073e+06, 1.08157645e+07, 2.37473111e+04, 3.78983455e+05,
       4.95973686e+04, 9.85670171e+06, 2.28536600e+03, 7.92248346e+07,
      

In [48]:
jnp.allclose(_legacy_task_llhs, task_llhs)

Array(True, dtype=bool)

In [49]:
%timeit magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget).block_until_ready()

58.7 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [50]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget).block_until_ready()

61.8 ms ± 737 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Distinct Input, Common HP

In [51]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_distinct_input_common_hp.csv")

_legacy_all_inputs, _legacy_padded_inputs, _legacy_padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs, padded_inputs, padded_outputs, mappings = preprocess_db(db)

prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((401,), (200, 200))

In [52]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)
task_kern = SEMagmaKernel(length_scale=.6, variance=1.)

In [53]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((401,), (401, 401))

In [54]:
mean_2, cov_2 = hyperpost(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [55]:
_legacy_mean_llh = magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget)
_legacy_mean_llh.item()

100728.89680477855

In [56]:
mean_llh = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh.item()

100728.89680477855

In [57]:
jnp.allclose(_legacy_mean_llh, mean_llh)

Array(True, dtype=bool)

In [58]:
%timeit magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget).block_until_ready()

2.22 ms ± 24.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [59]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

2.1 ms ± 57.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [60]:
_legacy_task_llhs = magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)
np.asarray(_legacy_task_llhs)

array([1620.59576247, 1559.10140927, 1843.65037662, 2288.065937  ,
       2032.08888116, 3306.10935403, 1457.61282617, 2635.59986815,
       1822.09007064, 1727.60431391, 1281.3295019 , 2309.62716782,
       1845.11046147, 3012.58117582, 2394.85447177, 3082.20699176,
       1841.36304504, 2112.71134361, 2336.32898443, 3344.00452381,
       3775.46732848, 2988.26350803, 2314.58691239, 1766.78388941,
       3016.68748913, 1386.59116452, 1216.53230809, 1895.01858391,
       1890.30950937, 1523.81115934, 1358.36132817, 2714.32529714,
       1924.16628601, 1743.24967547, 4943.64719125, 1908.69540937,
       2329.56996731, 1872.99675956, 2305.43519083, 2008.53597867,
       1855.17807788, 2476.09159329, 2201.07184333, 3035.32641757,
       2280.16340875, 1566.30028072, 3734.98244   , 2404.75166391,
       2016.16924355, 2948.63941793, 3074.66352591, 1527.08189636,
       1808.57486048, 3315.9825258 , 1725.41041715, 2579.06820098,
       1148.48103476, 3113.60767788, 1601.05658351, 1633.44100

In [61]:
task_llhs = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs)

array([1620.59576255, 1559.1014091 , 1843.6503761 , 2288.0659359 ,
       2032.08888339, 3306.10935727, 1457.61282614, 2635.59986817,
       1822.09007057, 1727.60431388, 1281.32950187, 2309.62716824,
       1845.11046159, 3012.58118527, 2394.85447362, 3082.20698962,
       1841.36304531, 2112.71136897, 2336.32898398, 3344.0045241 ,
       3775.46734683, 2988.26351032, 2314.5869128 , 1766.78388949,
       3016.68748927, 1386.59116447, 1216.53230814, 1895.01858411,
       1890.30948867, 1523.81115805, 1358.36132818, 2714.32529729,
       1924.16628702, 1743.24967547, 4943.6471951 , 1908.6954095 ,
       2329.56996788, 1872.99675979, 2305.43519089, 2008.53597867,
       1855.1780779 , 2476.09158825, 2201.07184315, 3035.32641787,
       2280.16340632, 1566.3002807 , 3734.98244038, 2404.75166492,
       2016.1692437 , 2948.63942009, 3074.66352646, 1527.08189633,
       1808.57485966, 3315.98250197, 1725.41041715, 2579.06819962,
       1148.48103476, 3113.60767892, 1601.05658555, 1633.44100

In [62]:
jnp.allclose(_legacy_task_llhs, task_llhs)

Array(True, dtype=bool)

In [63]:
%timeit magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget).block_until_ready()

384 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [64]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget).block_until_ready()

102 ms ± 1.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Distinct Input, Distinct HP

In [65]:
db = pd.read_csv(f"../dummy_datasets/{test_db_size}_distinct_input_distinct_hp.csv")

_legacy_all_inputs, _legacy_padded_inputs, _legacy_padded_outputs, masks = _legacy_preprocess_db(db)
all_inputs, padded_inputs, padded_outputs, mappings = preprocess_db(db)

prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((401,), (200, 200))

In [66]:
mean_kern = SEMagmaKernel(length_scale=.3, variance=1.)

key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (padded_outputs.shape[0],), jnp.float64, .1, 1)
task_kern = SEMagmaKernel(length_scale=distinct_length_scales, variance=1.)

In [67]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, mappings, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)
post_mean.shape, post_cov.shape

((401,), (401, 401))

In [68]:
mean_2, cov_2 = hyperpost(padded_inputs, padded_outputs, mappings, jnp.array(0.), mean_kern, task_kern, all_inputs=all_inputs)

In [69]:
_legacy_mean_llh = magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget)
_legacy_mean_llh.item()

281229024.390255

In [70]:
mean_llh = magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget)
mean_llh.item()

281229024.390255

In [71]:
jnp.allclose(_legacy_mean_llh, mean_llh)

Array(True, dtype=bool)

In [72]:
%timeit magma_neg_likelihood(mean_kern, _legacy_all_inputs, post_mean, prior_mean, post_cov, nugget=nugget).block_until_ready()

2.3 ms ± 37.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [73]:
%timeit magma_neg_likelihood_new(mean_kern, all_inputs, post_mean, prior_mean, post_cov, None, nugget=nugget).block_until_ready()

2.11 ms ± 52.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [74]:
_legacy_task_llhs = magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)
np.asarray(_legacy_task_llhs)

array([2.15528394e+06, 1.41433905e+06, 7.90032583e+07, 4.57753932e+04,
       4.20217934e+05, 5.87935376e+07, 4.18657825e+04, 2.83264963e+05,
       2.82857465e+04, 9.96957243e+04, 4.03959868e+06, 8.02133040e+04,
       6.54473691e+07, 9.19280924e+06, 6.27233467e+04, 8.54088913e+06,
       1.06232489e+07, 1.07835547e+07, 2.18737166e+05, 4.00450799e+05,
       3.33860836e+04, 4.47344829e+05, 1.60443868e+04, 2.19676429e+06,
       1.09338126e+05, 3.48773764e+04, 1.15695060e+04, 4.69819227e+04,
       4.57103413e+05, 9.60006070e+05, 1.15119247e+04, 1.22195331e+07,
       7.28752811e+05, 6.43583180e+07, 7.22727943e+05, 9.06446405e+03,
       2.54677435e+04, 1.15018060e+07, 3.65203271e+06, 2.04149322e+05,
       1.34748586e+05, 3.83383994e+04, 4.15220087e+05, 2.86688915e+04,
       6.75913288e+03, 1.31192934e+05, 3.98962749e+07, 2.05650698e+07,
       6.03562080e+06, 8.20556616e+05, 4.01106211e+05, 7.79667751e+05,
       1.27053341e+07, 1.45374718e+05, 4.93995970e+05, 1.69143564e+07,
      

In [75]:
task_llhs = magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget)
np.asarray(task_llhs)

array([2.15528394e+06, 1.41433906e+06, 7.90032936e+07, 4.57753932e+04,
       4.20217934e+05, 5.87935374e+07, 4.18657825e+04, 2.83264963e+05,
       2.82857465e+04, 9.96957242e+04, 4.03959859e+06, 8.02133040e+04,
       6.54473757e+07, 9.19280677e+06, 6.27233467e+04, 8.54089653e+06,
       1.06232489e+07, 1.07835546e+07, 2.18737166e+05, 4.00450799e+05,
       3.33860836e+04, 4.47344829e+05, 1.60443868e+04, 2.19676429e+06,
       1.09338126e+05, 3.48773764e+04, 1.15695060e+04, 4.69819227e+04,
       4.57103413e+05, 9.60006070e+05, 1.15119247e+04, 1.22195336e+07,
       7.28752811e+05, 6.43583154e+07, 7.22727943e+05, 9.06446405e+03,
       2.54677435e+04, 1.15018064e+07, 3.65203272e+06, 2.04149322e+05,
       1.34748586e+05, 3.83383994e+04, 4.15220087e+05, 2.86688915e+04,
       6.75913288e+03, 1.31192934e+05, 3.98962713e+07, 2.05650677e+07,
       6.03562043e+06, 8.20556616e+05, 4.01106214e+05, 7.79667754e+05,
       1.27053338e+07, 1.45374718e+05, 4.93995971e+05, 1.69143562e+07,
      

In [76]:
jnp.allclose(_legacy_task_llhs, task_llhs)

Array(True, dtype=bool)

In [77]:
%timeit magma_neg_likelihood(task_kern, _legacy_padded_inputs, _legacy_padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget).block_until_ready()

378 ms ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [78]:
%timeit magma_neg_likelihood_new(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mappings, nugget=nugget).block_until_ready()

100 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


---