# Benchmarks - Likelihoods

**Main considerations when implementing Likelihoods**
- must be efficient, as they are used in the main bottleneck of the code: HP optimisation
- must work for both common and distinct HPs

---
## Setup

In [1]:
# Standard library
import os

from numpy.f2py.auxfuncs import throw_error

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax.numpy as jnp
from jax import jit, vmap
from jax.scipy.linalg import cho_factor, cho_solve
from jax.scipy.stats.multivariate_normal import logpdf

import pandas as pd
import numpy as np

In [3]:
# Local
from MagmaClustPy.hyperpost import hyperpost
from MagmaClustPy.kernels import SEMagmaKernel
from MagmaClustPy.utils import preprocess_db

In [4]:
# Config
nugget=jnp.array(1e-10)

---
## Data

---
## Current implementation

In [5]:
@jit
def solve_right_cholesky(A, B, nugget=jnp.array(1e-10)):
	""" Solves for X in X @ A = B """
	# For X @ A = B, we can transpose both sides: A.T @ X.T = B.T
	# As A and B are symmetric, this simplifies to A @ X.T = B
	# Then solve for X.T and transpose the result
	return cho_solve(cho_factor(A + nugget), B).T


@jit
def magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mask=None, nugget=jnp.array(1e-10)):
	nugget_matrix = jnp.eye(outputs.shape[0]) * nugget

	if mask is not None:
		# Mask the covariance matrix and outputs
		mask_2D = mask[:, None] & mask[None, :]
		covar = jnp.where(mask_2D, covar, jnp.eye(outputs.shape[0]))
		outputs = jnp.where(mask, outputs, 0)
		mean = jnp.where(mask, mean, 0)
		mean_process_cov = jnp.where(mask_2D, mean_process_cov, jnp.eye(outputs.shape[0]))


	# Compute log-likelihood
	multiv_neg_log_lik = -logpdf(outputs, mean, covar + nugget_matrix)

	# Compute correction term
	correction = 0.5 * jnp.trace(solve_right_cholesky(covar, mean_process_cov, nugget=nugget))

	if mask is not None:
		# Correct log-likelihood for padding
		# The logpdf is computed as:
		# -0.5 * (N * log(2 * pi) + log(det(cov)) + (outputs - mean).T @ inv(cov) @ (outputs - mean))
		# det(cov) and the Mahalanobis distance are not affected by our padding
		# We only have to correct for the -0.5 * N * log(2 * pi) term, as N is bigger with padding
		nll_pad_correction = 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(~mask, axis=0)

		# We also need to correct the correction term, as padding adds 1s to the diagonal and hence 1 to the trace
		corr_pad_correction = 0.5 * jnp.sum(~mask, axis=0)
	else:
		nll_pad_correction = 0
		corr_pad_correction = 0

	return (multiv_neg_log_lik - nll_pad_correction) + (correction - corr_pad_correction)


@jit
def magma_neg_likelihood(kernel, inputs, outputs: jnp.array, mean: jnp.array, mean_process_cov: jnp.array, mask=None,
                         nugget=jnp.array(1e-10)):
	"""
	Computes the MAGMA log-likelihood.

	:param kernel: the kernel containing HPs to optimise. This kernel is used to compute the covariance (matrix `S`)
	:param inputs: inputs on which to compute the covariance matrix (shape (N, ))
	:param mask: boolean masks indicating which inputs and outputs to consider (shape (N, ))
	:param outputs: the observed values (shape (N, ))
	:param mean: the mean over the inputs (scalar or vector of shape (N, ))
	:param mean_process_cov: the hypper-posterior mean process covariance (matrix K^t)
	:param nugget: the nugget, for numerical stability

	:return: the negative log-likelihood (scalar)
	"""
	covar = kernel(inputs)

	# check if we need to vmap
	if inputs.ndim == 1:
		return magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mask, nugget)
	elif inputs.ndim == 2:
		return vmap(magma_neg_likelihood_on_cov, in_axes=(0, 0, None, None, 0, None))(covar, outputs, mean,
		                                                                              mean_process_cov, mask, nugget).sum()
	else:
		raise ValueError("inputs must be either 1D or 2D")


---
## Custom implementation(s)

---
## Comparison

### Sample test

In [6]:
db = pd.DataFrame({
	'ID': [1, 1, 1, 1, 2, 2, 2, 2],
	'Input': [0.40, 4.45, 7.60, 8.30, 3.50, 5.10, 8.85, 9.35],
	'Output': [59.81620, 67.13694, 78.32495, 81.83590, 62.04943, 67.31932, 85.94063, 86.76426]
})
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
prior_mean = jnp.zeros_like(all_inputs)
all_inputs.shape, padded_inputs.shape

((8,), (2, 8))

In [7]:
mean_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
task_kern = SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

In [8]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, prior_mean, mean_kern, task_kern, all_inputs=all_inputs)

In [9]:
np.asarray(post_mean)

array([30.15131174, 38.3724416 , 47.07685095, 43.32491082, 44.93087555,
       52.22658963, 53.66208663, 50.03343896])

In [10]:
np.asarray(post_cov)

array([[ 1.35880061e+00,  2.13949374e-02, -2.59934644e-03,
         2.59717184e-03,  2.78766148e-04,  1.99815157e-05,
         1.00589065e-04,  1.34217805e-04],
       [ 2.13949374e-02,  1.13375955e+00,  6.28684547e-01,
         2.58187871e-01,  1.14224881e-02,  2.10924104e-03,
        -1.23947114e-03, -1.44853546e-03],
       [-2.59934644e-03,  6.28684547e-01,  9.60931463e-01,
         7.53038338e-01,  6.20443346e-03,  1.67904577e-03,
        -3.24402810e-03, -5.33656101e-03],
       [ 2.59717184e-03,  2.58187871e-01,  7.53038338e-01,
         1.03195229e+00,  8.03054697e-02,  1.59016757e-02,
         1.68051941e-02,  1.67388740e-02],
       [ 2.78766148e-04,  1.14224881e-02,  6.20443346e-03,
         8.03054697e-02,  1.12600160e+00,  8.34191485e-01,
         5.19943111e-01,  3.24579521e-01],
       [ 1.99815157e-05,  2.10924104e-03,  1.67904577e-03,
         1.59016757e-02,  8.34191485e-01,  9.39931650e-01,
         8.10230861e-01,  6.10321928e-01],
       [ 1.00589065e-04, -1.239471

In [11]:
mean_llh = magma_neg_likelihood(mean_kern, all_inputs, post_mean, jnp.array([0]), post_cov, nugget=nugget)
mean_llh.item()

ValueError: multivariate_normal.logpdf got incompatible shapes

In [12]:
task_llhs = magma_neg_likelihood(task_kern, padded_inputs, padded_outputs, post_mean, post_cov, mask=masks, nugget=nugget)
np.asarray(task_llhs)

array(848.85191789)

In [13]:
task_llhs.sum().item()

848.8519178862825

### Unpadded

In [14]:
task_1_inputs = padded_inputs[0][masks[0]]
task_1_outputs = padded_outputs[0][masks[0]]
task_1_inputs, task_1_outputs

(Array([0.4 , 4.45, 7.6 , 8.3 ], dtype=float64, weak_type=True),
 Array([59.8162 , 67.13694, 78.32495, 81.8359 ], dtype=float64, weak_type=True))

In [15]:
task_1_post_mean = post_mean[masks[0]]
task_1_post_mean

Array([30.15131174, 47.07685095, 44.93087555, 52.22658963], dtype=float64)

In [16]:
task_1_post_cov = post_cov[masks[0]][:, masks[0]]
np.asarray(task_1_post_cov)

array([[ 1.35880061e+00, -2.59934644e-03,  2.78766148e-04,
         1.99815157e-05],
       [-2.59934644e-03,  9.60931463e-01,  6.20443346e-03,
         1.67904577e-03],
       [ 2.78766148e-04,  6.20443346e-03,  1.12600160e+00,
         8.34191485e-01],
       [ 1.99815157e-05,  1.67904577e-03,  8.34191485e-01,
         9.39931650e-01]])

In [17]:
task_1_llh = magma_neg_likelihood(task_kern, task_1_inputs, task_1_outputs, task_1_post_mean, task_1_post_cov, nugget=nugget)
task_1_llh

Array(443.00084209, dtype=float64)

### Padded

In [18]:
i = 1

In [19]:
task_0_llh = magma_neg_likelihood(task_kern, padded_inputs[i], padded_outputs[i], post_mean, post_cov, mask=masks[i], nugget=nugget)
task_0_llh

Array(405.8510758, dtype=float64)

In [20]:
outputs = jnp.where(masks, padded_outputs, 0)
outputs[i]

Array([ 0.     , 62.04943,  0.     , 67.31932,  0.     ,  0.     ,
       85.94063, 86.76426], dtype=float64, weak_type=True)

In [21]:
covar = task_kern(padded_inputs[i])
mask_2D = masks[i][:, None] & masks[i][None, :]
covar = jnp.where(mask_2D, covar, jnp.eye(padded_outputs[i].shape[0]))
np.asarray(covar)

array([[1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 2.71828183e+00, 0.00000000e+00, 1.05311525e+00,
        0.00000000e+00, 0.00000000e+00, 6.75910648e-05, 8.49251154e-06],
       [0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 1.05311525e+00, 0.00000000e+00, 2.71828183e+00,
        0.00000000e+00, 0.00000000e+00, 1.48630336e-02, 3.37785443e-03],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 6.75910648e-05, 0.00000000e+00, 1.48630336e-02,
        0.00000000e+00, 0.00000000e+00, 2.71828183e+00, 2.

In [22]:
np.asarray(jnp.linalg.inv(covar[masks[i]][:, masks[i]]+nugget))

array([[ 0.43285552, -0.1677153 ,  0.00424407, -0.00366165],
       [-0.1677153 ,  0.4329043 , -0.01107839,  0.00956115],
       [ 0.00424407, -0.01107839,  2.17623895, -1.98374986],
       [-0.00366165,  0.00956115, -1.98374986,  2.17616661]])

In [23]:
nugget_matrix = jnp.eye(outputs[i].shape[0]) * nugget

In [24]:
post_mean_i = jnp.where(masks[i], post_mean, 0)
post_mean_i

Array([ 0.        , 38.3724416 ,  0.        , 43.32491082,  0.        ,
        0.        , 53.66208663, 50.03343896], dtype=float64)

In [25]:
multiv_neg_log_lik = -logpdf(outputs[i], post_mean_i, covar + nugget_matrix)
cor = 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(~masks[i], axis=0)
multiv_neg_log_lik - cor, multiv_neg_log_lik

(Array(405.01630572, dtype=float64), Array(408.69205985, dtype=float64))

In [26]:
outputs[i]

Array([ 0.     , 62.04943,  0.     , 67.31932,  0.     ,  0.     ,
       85.94063, 86.76426], dtype=float64, weak_type=True)

In [27]:
post_mean

Array([30.15131174, 38.3724416 , 47.07685095, 43.32491082, 44.93087555,
       52.22658963, 53.66208663, 50.03343896], dtype=float64)

In [28]:
post_cov_i = jnp.where(mask_2D, post_cov, jnp.eye(padded_outputs[i].shape[0]))
np.asarray(post_cov_i)

array([[ 1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  1.13375955,  0.        ,  0.25818787,  0.        ,
         0.        , -0.00123947, -0.00144854],
       [ 0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.25818787,  0.        ,  1.03195229,  0.        ,
         0.        ,  0.01680519,  0.01673887],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ],
       [ 0.        , -0.00123947,  0.        ,  0.01680519,  0.        ,
         0.        ,  0.94959489,  0.91003982],
       [ 0.        , -0.00144854,  0.        ,  0.01673887,  0.        ,
         0.        ,  0.91003982,  1.08573632]])

In [29]:
correction = 0.5 * jnp.trace(solve_right_cholesky(covar, post_cov_i, nugget=nugget))
corr_pad_correction = 0.5 * jnp.sum(~masks[i], axis=0)
correction - corr_pad_correction, correction, corr_pad_correction

(Array(0.83477008, dtype=float64),
 Array(2.83477008, dtype=float64),
 Array(2., dtype=float64, weak_type=True))

In [30]:
0.5 * jnp.trace(jnp.linalg.inv(covar) @ post_cov)

Array(3.02760274, dtype=float64)

In [31]:
0.5 * jnp.trace(jnp.linalg.inv(covar[masks[i]][:, masks[i]]) @ post_cov[masks[i]][:, masks[i]])

Array(0.83477008, dtype=float64)

In [32]:
0.5 * jnp.trace(solve_right_cholesky(covar[masks[i]][:, masks[i]], post_cov[masks[i]][:, masks[i]], nugget=nugget))

Array(0.83477008, dtype=float64)

### Batched unpadded

In [33]:
tasks_inputs = padded_inputs[masks].reshape(padded_inputs.shape[0], -1)
tasks_outputs = padded_outputs[masks].reshape(padded_outputs.shape[0], -1)
tasks_inputs, tasks_outputs

(Array([[0.4 , 4.45, 7.6 , 8.3 ],
        [3.5 , 5.1 , 8.85, 9.35]], dtype=float64, weak_type=True),
 Array([[59.8162 , 67.13694, 78.32495, 81.8359 ],
        [62.04943, 67.31932, 85.94063, 86.76426]],      dtype=float64, weak_type=True))

In [34]:
idx = jnp.searchsorted(all_inputs, tasks_inputs)
tasks_post_means = post_mean[idx]
tasks_post_means

Array([[30.15131174, 47.07685095, 44.93087555, 52.22658963],
       [38.3724416 , 43.32491082, 53.66208663, 50.03343896]],      dtype=float64)

In [35]:
tasks_post_covs = jnp.stack([post_cov[m][:, m] for m in masks])
np.asarray(tasks_post_covs)

array([[[ 1.35880061e+00, -2.59934644e-03,  2.78766148e-04,
          1.99815157e-05],
        [-2.59934644e-03,  9.60931463e-01,  6.20443346e-03,
          1.67904577e-03],
        [ 2.78766148e-04,  6.20443346e-03,  1.12600160e+00,
          8.34191485e-01],
        [ 1.99815157e-05,  1.67904577e-03,  8.34191485e-01,
          9.39931650e-01]],

       [[ 1.13375955e+00,  2.58187871e-01, -1.23947114e-03,
         -1.44853546e-03],
        [ 2.58187871e-01,  1.03195229e+00,  1.68051941e-02,
          1.67388740e-02],
        [-1.23947114e-03,  1.68051941e-02,  9.49594895e-01,
          9.10039821e-01],
        [-1.44853546e-03,  1.67388740e-02,  9.10039821e-01,
          1.08573632e+00]]])

In [36]:
covar = task_kern(tasks_inputs)
np.asarray(covar)

array([[[2.71828183e+00, 6.24711556e-03, 1.24442362e-08, 2.48089245e-10],
        [6.24711556e-03, 2.71828183e+00, 6.88803459e-02, 1.12162794e-02],
        [1.24442362e-08, 6.88803459e-02, 2.71828183e+00, 2.26709559e+00],
        [2.48089245e-10, 1.12162794e-02, 2.26709559e+00, 2.71828183e+00]],

       [[2.71828183e+00, 1.05311525e+00, 6.75910648e-05, 8.49251154e-06],
        [1.05311525e+00, 2.71828183e+00, 1.48630336e-02, 3.37785443e-03],
        [6.75910648e-05, 1.48630336e-02, 2.71828183e+00, 2.47786604e+00],
        [8.49251154e-06, 3.37785443e-03, 2.47786604e+00, 2.71828183e+00]]])

In [37]:
tasks_llhs = magma_neg_likelihood(task_kern, tasks_inputs, tasks_outputs, tasks_post_means, tasks_post_covs, nugget=nugget)
tasks_llhs

ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(4, 4) and b=(2, 4, 4)

### Unpadded---
## Conclusion

---