# Sandbox - Likelihood

This notebook is used to prototype the likelihood function, used to generate all scenarios of use of the algorithm.

---

## Imports & Config

In [1]:
! pwd

/Users/simonlejoly/Documents/Work/mimosa/tests


In [2]:
! export XLA_PYTHON_CLIENT_MEM_FRACTION=.25

In [3]:
# Jax configuration
USE_JIT = False
USE_X64 = True
DEBUG_NANS = False
VERBOSE = False

In [4]:
# Standard library imports
import os
os.environ['JAX_ENABLE_X64'] = str(USE_X64).lower()

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from typing import Tuple

In [5]:
# Third party
import jax
jax.config.update("jax_disable_jit", not USE_JIT)
jax.config.update("jax_debug_nans", DEBUG_NANS)
import jax.random as jr
import jax.numpy as jnp
import jax.scipy as jsp
from jax import vmap, jit, Array
from jax.tree_util import tree_map_with_path, GetAttrKey, tree_unflatten, tree_flatten

import matplotlib.pyplot as plt
import equinox as eqx
from equinox import filter_jit
import numpy as np

from kernax import WhiteNoiseKernel, VarianceKernel, AbstractKernel, SEKernel, AbstractMean, AffineMean

In [6]:
# Local imports
from mimosa.linalg import scatter_to_grid_1d, scatter_to_grid_2d, cho_factor, cho_solve
from mimosa.generate_data import generate_data
from mimosa.hyperpost import hyperpost

INFO:2026-02-27 20:49:15,772:jax._src.xla_bridge:834: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/mimosa/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
2026-02-27 20:49:15,772 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/mimosa/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


In [7]:
# Config
key = jr.PRNGKey(42)

T=9 ; K=3 ; F=1 ; N=5 ; I=1 ; O=2 ; gs=10 if I == 1 else 40

sth=True ; sch=True ; chit=True ; fh=False ; soh=True ; siit=True ; siif=True

mean = AffineMean(slope=0., intercept=0.)
mean_kernel = VarianceKernel(20.) * SEKernel(length_scale=10.)
task_kernel = VarianceKernel(.2) * SEKernel(length_scale=9.) + WhiteNoiseKernel(noise=.01)

mean_priors = {
	"slope": (-.2, .2),
	"intercept": (-2.5, 2.5)
}

mean_kernel_priors = {
	"variance": (5, 10.),
	"length_scale": (2.5, 10.)
}

task_kernel_priors = {
	"variance": (0.25, 1.),
	"length_scale": (2., 8.),
	"noise": (0.01, 0.1)
}

jax.devices()

[CpuDevice(id=0)]

In [8]:
inputs, outputs, maps, grid, m_p_means, m_p_covs, m_p, mix, t_m, m, m_k, t_k = generate_data(key, T, K, F, N,  I, O, gs, mean, mean_kernel, task_kernel, mean_priors, mean_kernel_priors, task_kernel_priors, sth, sch, chit, fh, soh, siit, siif)

In [9]:
mix_coeffs = jnp.eye(K)[mix]
mix_coeffs.shape

(9, 3)

In [10]:
p_m, p_c = hyperpost(inputs, outputs, maps, grid, mix_coeffs, m, m_k, t_k)
p_m.shape, p_c.shape

((3, 2, 10), (3, 1, 10, 10))

---

## Likelihood

In [11]:
@filter_jit
def mvn_nll(inputs: Array, outputs: Array, mean: Array, cov: Array, optim=False):
	"""
	Negative log-likelihood of a multivariate normal distribution, that works on padded data and multi-outputs.

	:param inputs: inputs points on which the mean and covariance were computed. Used for masking NaNs. Shape (F*N, I)
	:param outputs: outputs points corresponding to each input. Shape (F*N, O)
	:param mean: the mean of the multivariate normal distribution. Shape (O, F*N), with O=1 if shared_outputs_hps
	:param cov: the covariance of the multivariate normal distribution. Shape (O, F*N, F*N), with O=1 if shared_outputs_hps
	:param optim: when optimizing mean-function/kernel hyperparameters, we can ignore the constant term in the log-likelihood, as it does not depend on the hyperparameters. Setting `optim=True` will ignore this constant term, which can help with performance and numerical stability when optimizing.

	:return: the negative log-likelihood of the multivariate normal distribution. Shape (O,)
	"""
	nan_mask = jnp.isnan(inputs[:, 0])

	cov_u = jnp.where(nan_mask[None, :] | nan_mask[:, None], jnp.eye(cov.shape[-1])[None, :, :], cho_factor(cov))  # Shape (O, F*N, F*N), with O=1 if shared_outputs_hps
	diff = jnp.where(nan_mask, 0., outputs.T - mean)  # Shape (O, F*N)
	y = cho_solve(jnp.broadcast_to(cov_u, (diff.shape[0], cov_u.shape[1], cov_u.shape[2])), diff)  # Shape (O, F*N)

	quad = jnp.sum(diff * y, axis=-1)
	log_det = 2 * jnp.sum(jnp.log(jnp.diagonal(cov_u, axis1=-1, axis2=-2)), axis=-1)  # Shape (O,)

	if optim:
		return 0.5 * (quad + log_det)
	constant = (inputs.shape[0] - jnp.sum(nan_mask)) * jnp.log(2 * jnp.pi)
	return 0.5 * (quad + log_det + constant)

In [12]:
@jit
def trace_correction(inputs: Array, post_cov: Array, cov: Array):
	"""
	Computes the trace correction term to adapt the negative log-likelihood of a MVN to the Magma algorithm. Works on padded data and multi-outputs.

	:param inputs: inputs points on which the mean and covariance were computed. Used for masking NaNs. Shape (F*N, I)
	:param post_cov: the posterior covariance of a specific mean process. Shape (O, F*G, F*G), with O=1 if shared_outputs_hps
	:param cov: the covariance of the task or mean process. Shape (O, F*N, F*N), with O=1 if shared_outputs_hps

	:return: the trace correction term, defined as 0.5 * trace(post_cov @ inv(cov)). Shape (O,), with O=1 if shared_outputs_hps
	"""
	nan_mask_1d = jnp.isnan(inputs[:, 0])
	nan_mask_2d = nan_mask_1d[None, :] | nan_mask_1d[:, None]

	post_cov_u = jnp.where(nan_mask_2d, jnp.eye(cov.shape[-1])[None, :, :], cho_factor(post_cov))  # Shape (O, F*N, F*N), with O=1 if shared_outputs_hps
	cov_u = jnp.where(nan_mask_2d, jnp.eye(cov.shape[-1])[None, :, :], cho_factor(cov))  # Shape (O, F*N, F*N), with O=1 if shared_outputs_hps

	if cov_u.shape[0] > post_cov.shape[0]:
		post_cov_u = jnp.broadcast_to(post_cov_u, cov_u.shape)
	elif post_cov_u.shape[0] > cov_u.shape[0]:
		cov_u = jnp.broadcast_to(cov_u, post_cov_u.shape)

	v = jsp.linalg.solve_triangular(cov_u, post_cov_u, lower=False)
	return 0.5 * (jnp.sum(v**2, axis=(-2, -1)) - jnp.sum(nan_mask_1d))  # Shape (O,)

In [13]:
@filter_jit
def full_nll(inputs: Array, outputs: Array, post_mean: Array, post_cov: Array, cov: Array, optim=False):
	"""
	Full negative log-likelihood of a mean process in the Magma algorithm, including the trace correction term. Works on padded data and multi-outputs.

	:param inputs: inputs points on which the mean and covariance were computed. Used for masking NaNs. Shape (F*N, I)
	:param outputs: outputs points corresponding to each input. Shape (F*N, O)
	:param post_mean: the posterior mean of a specific mean process. Shape (O, F*G), with O=1 if shared_outputs_hps
	:param post_cov: the posterior covariance of a specific mean process. Shape (O, F*G, F*G), with O=1 if shared_outputs_hps
	:param cov: the covariance of the task or mean process. Shape (O, F*N, F*N), with O=1 if shared_outputs_hps
	:param optim: when optimizing mean-function/kernel hyperparameters, we can ignore the constant term in the log-likelihood, as it does not depend on the hyperparameters. Setting `optim=True` will ignore this constant term, which can help with performance and numerical stability when optimizing.

	:return: the full negative log-likelihood of a mean process in the Magma algorithm, including the trace correction term. Shape (O,)
	"""
	return mvn_nll(inputs, outputs, post_mean, post_cov, optim) + trace_correction(inputs, post_cov, cov)

## Une formule de llh, deux appels

### Pour les mean-process
* post_mean : (K, O, F*G)
* post_cov : (K, O, F*G, F*G)
* mean : (K*, O*, F*G)
* cov : (K*, O*, F*G, F*G)

### Pour les tâches
* outputs : (T, F, N, O)
* task_post_covs : (T, K, O, F*N, F*N)
* task_post_mean : (T, K, O, F*N)
* task_cov : (T*, K*, O*, F*N, F*N)

In [14]:
@filter_jit
def means_nlls(post_means: Array, post_covs: Array, grid: Array, cluster_means: Array, cluster_covs: Array, optim=False):
	"""
	Computes the negative log-likelihood of every cluster, for each output

	:param post_means: Shape (K, O, F*G)
	:param post_covs: Shape (K, O, F*G, F*G)
	:param grid: Shape (F*G, I)
	:param cluster_means: Shape (K, O, F*G) with K=1 if shared_cluster_hps and O=1 if shared_outputs_hps
	:param cluster_covs: Shape (K, O, F*G, F*G) with K=1 if shared_cluster_hps and O=1 if shared_outputs_hps
	:param optim: when optimizing mean-function/kernel hyperparameters, we can ignore the constant term in the log-likelihood, as it does not depend on the hyperparameters. Setting `optim=True` will ignore this constant term, which can help with performance and numerical stability when optimizing.

	:return: the negative log-likelihood of every cluster, for each output. Shape (K, O)
	"""
	cluster_means = jnp.broadcast_to(cluster_means, post_means.shape)
	cluster_covs = jnp.broadcast_to(cluster_covs, post_covs.shape)

	return vmap(full_nll, in_axes=(None, 0, 0, 0, 0, None))(grid, jnp.swapaxes(post_means, -1, -2), cluster_means, post_covs, cluster_covs, optim)

In [27]:
@filter_jit
def tasks_nlls(inputs: Array, outputs: Array, mappings: Array, post_means: Array, post_covs: Array, task_covs: Array, optim=False):
	"""
	Computes the negative log-likelihood of every task, for each cluster, for each output

	:param inputs: Shape (T, F*N, I) if not shared_inputs_in_tasks else (F*N, I)
	:param outputs: Shape (T, F*N, O)
	:param mappings: Shape (T, F*N) if not shared_inputs_in_tasks else (F*N,), with values in [0, F*G-1] and padded with F*G for missing inputs, mapping each input to a point in the grid on which the post_means and post_covs are computed
	:param post_means: Shape (K, O, F*G)
	:param post_covs: Shape (K, O, F*G, F*G)
	:param task_covs: Shape (T, K, O, F*N, F*N) with T=1 is shared_task_hps, K=1 if shared_cluster_hps and O=1 if shared_output_hps
	:param optim: when optimizing mean-function/kernel hyperparameters, we can ignore the constant term in the log-likelihood, as it does not depend on the hyperparameters. Setting `optim=True` will ignore this constant term, which can help with performance and numerical stability when optimizing.

	:return: the negative log-likelihood of every task, for each cluster, for each output. Shape (T, K, O)
	"""
	# A nice trick we can use in this function is that it can just be a vmap over `means_nlls`, providing only the right portions of post_means and post_covs to each task according to the mappings.
	outputs = jnp.broadcast_to(jnp.swapaxes(outputs[:, None, :, :], -1, -2), (outputs.shape[0], post_covs.shape[0], outputs.shape[-1], outputs.shape[-2]))
	task_covs = jnp.broadcast_to(task_covs, (outputs.shape[0],)+task_covs.shape[1:])
	if inputs.ndim == 3:
		return vmap(lambda o, i, m, tc: means_nlls(
			o,
			post_covs[:, :, m, :][:, :, :, m],
			i,
			post_means[:, :, m],
			tc,
			optim))(outputs, inputs, mappings, task_covs)
	return vmap(lambda o, tc: means_nlls(
			o,
			post_covs[:, :, mappings, :][:, :, :, mappings],
			inputs,
			post_means[:, :, mappings],
			tc,
			optim))(outputs, task_covs)

In [22]:
means_nlls(p_m, p_c, grid, m(grid), m_k(grid))

Array([[110.80523187,  92.64266713],
       [ 21.83380364,  33.60233158],
       [ 50.85005138, 157.16090781]], dtype=float64)

In [23]:
means_nlls(p_m, p_c, grid, m(grid), m_k(grid), optim=True)

Array([[101.61584654,  83.4532818 ],
       [ 12.64441831,  24.41294625],
       [ 41.66066605, 147.97152247]], dtype=float64)

In [24]:
%timeit means_nlls(p_m, p_c, grid, m(grid), m_k(grid)).block_until_ready()

31 ms ± 334 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%timeit means_nlls(p_m, p_c, grid, m(grid), m_k(grid), optim=True).block_until_ready()

31.9 ms ± 991 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
tasks_nlls(inputs, outputs, maps, p_m, p_c, t_k(inputs))

Array([[[  4.76169803,   2.95345784],
        [116.34803409,  58.93544729],
        [111.47247128, 138.33628054]],

       [[  7.06922728,   2.41098816],
        [ 96.54910642,  71.66479221],
        [115.14459553, 149.32385728]],

       [[  4.39896258,   5.98004561],
        [ 82.13425045,  87.11179884],
        [ 74.10104929, 218.39023547]],

       [[147.9221594 ,  92.45043689],
        [  4.20830073,   7.99017715],
        [ 72.09632952, 180.95551307]],

       [[113.54324952, 117.18599386],
        [  2.53657382,   5.4566458 ],
        [ 52.68736007, 133.83774039]],

       [[121.14017606,  64.93409634],
        [  5.33313043,   7.40264826],
        [ 63.30130996, 110.92191627]],

       [[137.22380593, 204.9878509 ],
        [ 62.65987869, 143.24156641],
        [  3.6536106 ,   6.36472294]],

       [[ 92.11833589, 200.66995025],
        [ 42.06756648, 132.47938923],
        [  4.68258506,   6.13966282]],

       [[162.48358465, 228.61061282],
        [ 72.23169286, 111.3022067

In [29]:
tasks_nlls(inputs, outputs, maps, p_m, p_c, t_k(inputs), optim=True)

Array([[[ 1.67005359e-01, -1.64123482e+00],
        [ 1.11753341e+02,  5.43407546e+01],
        [ 1.06877779e+02,  1.33741588e+02]],

       [[ 2.47453461e+00, -2.18370451e+00],
        [ 9.19544138e+01,  6.70700995e+01],
        [ 1.10549903e+02,  1.44729165e+02]],

       [[-1.95730083e-01,  1.38535294e+00],
        [ 7.75395578e+01,  8.25171062e+01],
        [ 6.95063566e+01,  2.13795543e+02]],

       [[ 1.43327467e+02,  8.78557442e+01],
        [-3.86391937e-01,  3.39548448e+00],
        [ 6.75016368e+01,  1.76360820e+02]],

       [[ 1.08948557e+02,  1.12591301e+02],
        [-2.05811885e+00,  8.61953132e-01],
        [ 4.80926674e+01,  1.29243048e+02]],

       [[ 1.16545483e+02,  6.03394037e+01],
        [ 7.38437763e-01,  2.80795559e+00],
        [ 5.87066173e+01,  1.06327224e+02]],

       [[ 1.32629113e+02,  2.00393158e+02],
        [ 5.80651860e+01,  1.38646874e+02],
        [-9.41082064e-01,  1.77003028e+00]],

       [[ 8.75236432e+01,  1.96075258e+02],
        [ 3.747287

In [54]:
%timeit tasks_nlls(inputs, outputs, maps, p_m, p_c, t_k(inputs)).block_until_ready()

91.8 μs ± 187 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [55]:
%timeit tasks_nlls(inputs, outputs, maps, p_m, p_c, t_k(inputs), optim=True).block_until_ready()

91.6 μs ± 199 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


---

## Sandbox

---