# Debug notebook

## Imports and config

In [1]:
USE_JIT = False
USE_X64 = True
DEBUG_NANS = False

In [2]:
# Standard imports
import os

if USE_X64:
    os.environ['JAX_ENABLE_X64'] = "True"

import time
from typing import NamedTuple

# JAX imports
import jax

jax.config.update("jax_disable_jit", not USE_JIT)
jax.config.update("jax_debug_nans", DEBUG_NANS)
from jax import vmap, jit
from jax.lax import cond
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.linalg import cho_solve, cho_factor
from jax.scipy.optimize import minimize
from jax.scipy.stats.multivariate_normal import logpdf
from jax.tree_util import register_pytree_node_class, tree_flatten
import chex
import optax
import optax.tree_utils as otu

import pandas as pd
import numpy as np

from MagmaClustPy.utils import preprocess_db
from MagmaClustPy.kernels import SEMagmaKernel, NoisySEMagmaKernel
from MagmaClustPy.hyperpost import hyperpost

## Likelihood

In [3]:
def solve_right_cholesky(A, B, nugget=jnp.array(1e-10)):
	""" Solves for X in X @ A = B """
	# For X @ A = B, we can transpose both sides: A.T @ X.T = B.T
	# As A and B are symmetric, this simplifies to A @ X.T = B
	# Then solve for X.T and transpose the result
	return cho_solve(cho_factor(A + nugget), B).T

In [26]:
def magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mask=None, nugget=jnp.array(1e-10)):
	nugget_matrix = jnp.eye(outputs.shape[0]) * nugget

	if mask is not None:
		# Mask the covariance matrix and outputs
		mask_2D = mask[:, None] & mask[None, :]
		covar = jnp.where(mask_2D, covar, jnp.eye(outputs.shape[0]))
		outputs = jnp.where(mask, outputs, 0)
		mean = jnp.where(mask, mean, 0)
		mean_process_cov = jnp.where(mask_2D, mean_process_cov, jnp.eye(outputs.shape[0]))


	# Compute log-likelihood
	multiv_neg_log_lik = -logpdf(outputs, mean, covar + nugget_matrix)

	# Compute correction term
	correction = 0.5 * jnp.trace(solve_right_cholesky(covar, mean_process_cov, nugget=nugget))

	if mask is not None:
		# Correct log-likelihood for padding
		# The logpdf is computed as:
		# -0.5 * (N * log(2 * pi) + log(det(cov)) + (outputs - mean).T @ inv(cov) @ (outputs - mean))
		# det(cov) and the Mahalanobis distance are not affected by our padding
		# We only have to correct for the -0.5 * N * log(2 * pi) term, as N is bigger with padding
		nll_pad_correction = 0.5 * jnp.log(2 * jnp.pi) * jnp.sum(~mask, axis=0)

		# We also need to correct the correction term, as padding adds 1s to the diagonal and hence 1 to the trace
		corr_pad_correction = 0.5 * jnp.sum(~mask, axis=0)
	else:
		nll_pad_correction = 0
		corr_pad_correction = 0

	return (multiv_neg_log_lik - nll_pad_correction) + (correction - corr_pad_correction)

In [5]:
def magma_neg_likelihood(kernel, inputs, outputs: jnp.array, mean: jnp.array, mean_process_cov: jnp.array, mask=None,
                         nugget=jnp.array(1e-10)):
	"""
	Computes the MAGMA log-likelihood.

	:param kernel: the kernel containing HPs to optimise. This kernel is used to compute the covariance (matrix `S`)
	:param inputs: inputs on which to compute the covariance matrix (shape (N, ))
	:param mask: boolean masks indicating which inputs and outputs to consider (shape (N, ))
	:param outputs: the observed values (shape (N, ))
	:param mean: the mean over the inputs (scalar or vector of shape (N, ))
	:param mean_process_cov: the hypper-posterior mean process covariance (matrix K^t)
	:param nugget: the nugget, for numerical stability

	:return: the negative log-likelihood (scalar)
	"""
	covar = kernel(inputs)

	# check if we need to vmap
	if inputs.ndim == 1:
		res = magma_neg_likelihood_on_cov(covar, outputs, mean, mean_process_cov, mask, nugget)
		return res
	elif inputs.ndim == 2:
		res = vmap(magma_neg_likelihood_on_cov, in_axes=(0, 0, None, None, 0, None))(covar, outputs, mean,
		                                                                              mean_process_cov, mask, nugget)
		return res
	else:
		raise ValueError("inputs must be either 1D or 2D")

## Data

In [6]:
dataset = "small"
grids = {
	"small": jnp.arange(-10, 10, 0.5),
	"medium": jnp.arange(-100, 100, 0.5),
	"large": jnp.arange(-500, 500, 0.5),
	"custom": jnp.arange(-20, 20, 0.5)
}
grid = grids[dataset] if dataset in grids else grids["custom"]
common_input = True
common_hp = True

In [7]:
db = pd.read_csv(f"../dummy_datasets/{dataset}_{'common_input' if common_input else 'distinct_input'}_{'common_hp' if common_hp else 'distinct_hp'}.csv")

In [8]:
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape

((15,), (20, 15))

In [9]:
mean_kernel = SEMagmaKernel(length_scale=0.9, variance=1.5)
task_kernel = NoisySEMagmaKernel(length_scale=0.3, variance=1., noise=-2.5)

In [10]:
np.asarray(padded_inputs)

array([[-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5, -3. , -2. ,  0. ,  0.5,  1. ,  4.5,
         7. ,  7.5,  8.5,  9.5],
       [-9.5, -9. , -7. , -5. , -4.5,

In [11]:
np.asarray(padded_outputs)

array([[ 3.53182625e+01,  3.34515056e+01,  1.67775979e+01,
         3.71360051e+00,  1.82028297e+00, -1.78950679e-01,
        -1.30837658e+00, -5.92804753e-01, -3.95319763e+00,
        -7.37933523e+00, -4.05817522e+01, -3.95456973e+01,
        -4.13294558e+01, -4.63642047e+01, -4.45789070e+01],
       [ 3.44942189e+01,  3.14225276e+01,  1.41650837e+01,
         5.45369044e+00,  4.43092816e+00, -1.81285319e-01,
        -2.76611562e+00,  1.44283520e+00, -1.45869619e+00,
        -3.90911690e+00, -3.60206620e+01, -3.79515950e+01,
        -4.10708799e+01, -4.84715708e+01, -4.85742134e+01],
       [ 3.76835402e+01,  3.38058916e+01,  1.20463116e+01,
         5.04069808e+00,  4.00728990e+00, -1.02734904e+00,
        -5.00072108e+00, -1.39984116e+00, -6.43193977e+00,
        -1.02935493e+01, -3.90425931e+01, -3.61386786e+01,
        -3.81880358e+01, -4.45150539e+01, -4.62967048e+01],
       [ 3.66460647e+01,  3.52008497e+01,  1.77202135e+01,
         1.22521608e+01,  1.06979053e+01,  5.89386782

In [12]:
post_mean, post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.zeros_like(all_inputs), mean_kernel, task_kernel, nugget=jnp.array(1e-10))

In [13]:
np.asarray(post_mean)

array([ 36.23375514,  33.82153928,  15.12422913,   7.74338919,
         7.28902395,   2.29461547,  -0.31348368,   1.7170999 ,
        -2.10841717,  -6.35002582, -38.16737531, -37.92542171,
       -40.56204432, -46.25863362, -45.83719182])

In [14]:
np.asarray(post_cov)

array([[ 1.35537172e-01,  1.20316552e-01,  1.38863136e-02,
        -1.33052605e-04,  9.83387017e-05, -6.52661025e-05,
         8.81779850e-06, -3.94529343e-05,  1.09896439e-05,
        -1.19928149e-07,  2.14394650e-06, -6.06641530e-07,
         5.26806478e-07, -6.96095449e-08, -1.67101880e-07],
       [ 1.20316552e-01,  1.34614318e-01,  3.17614293e-02,
        -2.44037699e-04,  2.76924311e-04, -1.64586067e-04,
         1.85384441e-05, -1.04335119e-04,  2.83539033e-05,
        -8.45312990e-07,  5.56121318e-06, -1.57894807e-06,
         1.36967981e-06, -1.80713318e-07, -4.34597234e-07],
       [ 1.38863136e-02,  3.17614293e-02,  1.34537523e-01,
         3.16641575e-02,  1.34107299e-02,  4.49815966e-04,
         2.34558052e-05,  8.18025394e-05, -1.85414777e-05,
         3.41116995e-06, -3.79560315e-06,  1.10613735e-06,
        -9.51734104e-07,  1.24145946e-07,  3.02712401e-07],
       [-1.33052605e-04, -2.44037699e-04,  3.16641575e-02,
         1.34684671e-01,  1.20299972e-01,  3.11696139

## Debugging

In [15]:
magma_neg_likelihood(task_kernel, padded_inputs, padded_outputs, post_mean, post_cov, mask=masks, nugget=jnp.array(1e-10))

Array([26.451526  , 24.82051496, 31.05315503, 32.07589838, 32.72740246,
       29.90736642, 32.8977099 , 36.70330398, 39.31643226, 34.44775989,
       31.83923734, 25.25092831, 27.86794231, 38.3862114 , 36.28863916,
       30.66437474, 33.41162565, 31.34470312, 27.30478917, 39.583976  ],      dtype=float64)

In [16]:
magma_neg_likelihood(mean_kernel, all_inputs, post_mean, jnp.zeros_like(all_inputs), post_cov, mask=None, nugget=jnp.array(1e-10))

Array(640.7648989, dtype=float64)

In [17]:
task_covars = task_kernel(padded_inputs)
np.asarray(task_covars[0])

array([[2.80036683e+00, 2.47786604e+00, 2.68459037e-01, 1.50226023e-03,
        2.58600086e-04, 4.34189435e-07, 2.42966498e-09, 8.24291618e-15,
        2.22654366e-16, 4.99744738e-18, 8.02480483e-32, 4.34864282e-44,
        8.78693960e-47, 2.05828495e-52, 2.29847661e-58],
       [2.47786604e+00, 2.80036683e+00, 6.17771617e-01, 7.25150848e-03,
        1.50226023e-03, 4.39638488e-06, 3.56311206e-08, 2.53569629e-13,
        8.24291618e-15, 2.22654366e-16, 1.30722600e-29, 1.78828412e-41,
        4.34864282e-44, 1.47532679e-49, 2.38610445e-55],
       [2.68459037e-01, 6.17771617e-01, 2.80036683e+00, 6.17771617e-01,
        2.68459037e-01, 7.25150848e-03, 2.58600086e-04, 3.56311206e-08,
        2.42966498e-09, 1.37667045e-10, 1.44438281e-21, 8.02480483e-32,
        4.09340734e-34, 6.11062994e-39, 4.34864282e-44],
       [1.50226023e-03, 7.25150848e-03, 6.17771617e-01, 2.80036683e+00,
        2.47786604e+00, 6.17771617e-01, 9.69381635e-02, 2.58600086e-04,
        3.69895326e-05, 4.39638488e-0

In [18]:
magma_neg_likelihood(task_kernel, padded_inputs[0], padded_outputs[0], post_mean, post_cov, mask=masks[0], nugget=jnp.array(1e-10))

Array(26.451526, dtype=float64)

In [19]:
r_optim_mean_kernel = SEMagmaKernel(length_scale=0.9944771, variance=6.209469)
r_optim_task_kernel = NoisySEMagmaKernel(length_scale=0.6692923, variance=1.765444, noise=-1.579317)

In [27]:
py_optim_mean_kernel = SEMagmaKernel(length_scale=0.9975587129592896, variance=6.232192516326904)
py_optim_task_kernel = NoisySEMagmaKernel(length_scale=0.6336096525192261, variance=1.7702873945236206, noise=-1.7901536226272583)

In [28]:
r_post_mean, r_post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.zeros_like(all_inputs), r_optim_mean_kernel, r_optim_task_kernel, nugget=jnp.array(1e-10))

In [29]:
py_post_mean, py_post_cov = hyperpost(padded_inputs, padded_outputs, masks, jnp.zeros_like(all_inputs), py_optim_mean_kernel, py_optim_task_kernel, nugget=jnp.array(1e-10))

In [30]:
# mean likelihood comparison
magma_neg_likelihood(r_optim_mean_kernel, all_inputs, r_post_mean, jnp.zeros_like(all_inputs), r_post_cov, mask=None, nugget=jnp.array(1e-10)), \
magma_neg_likelihood(py_optim_mean_kernel, all_inputs, py_post_mean, jnp.zeros_like(all_inputs), py_post_cov, mask=None, nugget=jnp.array(1e-10))

(Array(54.75174095, dtype=float64), Array(54.76679924, dtype=float64))

In [31]:
# task likelihood comparison
magma_neg_likelihood(r_optim_task_kernel, padded_inputs, padded_outputs, r_post_mean, r_post_cov, mask=masks, nugget=jnp.array(1e-10)), \
magma_neg_likelihood(py_optim_task_kernel, padded_inputs, padded_outputs, py_post_mean, py_post_cov, mask=masks, nugget=jnp.array(1e-10))

KeyboardInterrupt: 

In [25]:
magma_neg_likelihood(r_optim_task_kernel, padded_inputs, padded_outputs, post_mean, post_cov, mask=masks, nugget=jnp.array(1e-10)).sum(), \
magma_neg_likelihood(py_optim_task_kernel, padded_inputs, padded_outputs, post_mean, post_cov, mask=masks, nugget=jnp.array(1e-10)).sum()

(Array(581.69145319, dtype=float64), Array(584.24812619, dtype=float64))