# Benchmarks - Preprocess DB

**Main considerations when implementing Preprocess DB**

The goal of preprocess_db is to go from a pandas dataframe with columns "Task_ID", "Input", "Input_ID", "Output" and "Output_ID" to a set of tensors containing all inputs, padded inputs, padded outputs and mappings, that can be used throughout the MagmaClustPy library.

Magma and MagmaClust both work on unaligned sequences of varying sizes.
Yet, we want to be able to perform as many operations in a vectorisable and jittable way.

This mean that we need to pad the sequences to the same length.
The most straightforward way to do this is to align the sequences to the union of all distinct inputs, but this leads to a huge memory footprint and costly data movement in and out of the GPU.
A better way is to pad the sequences to the maximum length among all task sequences, and use a mapping of indices when we need to align the inputs to the union of all inputs.

As MagmaClust can handle multi-input tasks, we need multiple mappings, one for each input.
The union of all inputs from a specific dimension will have a specific length, but once again we want to be able to store them in a single tensor.
This means that even the grid of all_inputs must be padded to the maximum length across all dimensions of inputs.

---
## Setup

In [1]:
USE_X64 = False  # Set to True to use 64-bit precision, False for 32-bit precision

In [2]:
# Standard library
import os

# Config
os.environ['JAX_ENABLE_X64'] = str(USE_X64).lower()

from functools import partial

In [3]:
# Third party
import jax
from jax import vmap, jit
import jax.numpy as jnp
import jax.random as jr
from jax.lax import cond

import numpy as np
import pandas as pd

# Initialize random key
key = jax.random.PRNGKey(42)

In [4]:
# Local


In [5]:
# Set constants
M = 2  # Number of tasks
INPUTS_ID = ["x", "y", "z"]  # Each dimension of inputs
MIN_N = 3  # Minimum inputs per task
MAX_N = 5  # Maximum inputs per task
OUTPUTS_ID = ["a", "b"]  # Each dimension of outputs
GRIDS = [jnp.arange(-5., 5., 1.), jnp.arange(-1., 1., 0.5), jnp.arange(0., 2., 1)]  # Grid to pick inputs from, for each input dimension
OUTPUT_RANGES = [(-5, 5), (-10, 10)]  # Ranges for outputs, for each output dimension

---
## Data

In [6]:
def generate_dummy_db(M: int, INPUTS_ID: list[str], MIN_N: int, MAX_N: int, OUTPUTS_ID: list[str],
                      GRIDS: list[jnp.array], OUTPUT_RANGES: list[jnp.array], drop_output_rate: float = 0.,
                      key: jnp.array = jax.random.PRNGKey(42)):
	"""
	Generate a dummy database with random inputs and outputs, following the expected structure for MagmaClustPy.

	:param M: Number of tasks
	:param INPUTS_ID: List of input IDs, each representing a dimension of inputs
	:param MIN_N: Minimum number of inputs per task
	:param MAX_N: Maximum number of inputs per task
	:param OUTPUTS_ID: List of output IDs, each representing a dimension of outputs
	:param GRIDS: List of grids to pick inputs from, one for each input dimension
	:param drop_output_rate: Probability of dropping an output value. Default is 0, meaning no outputs are dropped.
	:param key: JAX random key for reproducibility

	:return: A pandas DataFrame with columns "ID", "Input", "Input_ID", "Output", "Output_ID"
	"""
	data = []

	for m in range(M):
		key, subkey1, subkey2 = jr.split(key, 3)

		n_points = jr.randint(subkey1, (), MIN_N, MAX_N).item()  # This task's number of points
		inputs = [jr.choice(subkey2, grid, (n_points,), replace=g != 0) for g, grid in
		          enumerate(GRIDS)]  # Randomly pick inputs from the grids
		# We set replace=False for the first grid, to ensure we have distinct inputs in at least one dimension.

		for n in range(n_points):
			for o, output_id in enumerate(OUTPUTS_ID):
				key, subkey1, subkey2 = jr.split(key, 3)

				if jr.uniform(subkey1) < drop_output_rate:
					# Drop output value with a certain probability
					continue

				output_val = jr.uniform(subkey2, (), jnp.float32, *OUTPUT_RANGES[o])

				for i, input_id in enumerate(INPUTS_ID):
					data.append({
						"Task_ID": m,
						"Input": inputs[i][n].item(),
						"Input_ID": input_id,
						"Output": output_val.item(),
						"Output_ID": output_id
					})

	return pd.DataFrame(data)

In [7]:
db = generate_dummy_db(M, INPUTS_ID, MIN_N, MAX_N, OUTPUTS_ID, GRIDS, 0.25, key)
db

Unnamed: 0,Task_ID,Input,Input_ID,Output,Output_ID
0,0,0.0,x,-4.660727,a
1,0,-0.5,y,-4.660727,a
2,0,1.0,z,-4.660727,a
3,0,0.0,x,2.210305,b
4,0,-0.5,y,2.210305,b
5,0,1.0,z,2.210305,b
6,0,2.0,x,-2.132901,a
7,0,-0.5,y,-2.132901,a
8,0,1.0,z,-2.132901,a
9,0,2.0,x,0.64188,b


---
## Implementation

---
## Custom implementation(s)

We transition to a multi-input multi-output setting

In [8]:
def pivot_db(df: pd.DataFrame) -> pd.DataFrame:
	"""
	Flatten a dataframe so that every line corresponds to a single observation.
	For example, if the "Input_ID" column has a multi-index with "x", "y", and "z", the resulting DataFrame will have columns like "Input_x", "Input_y", and "Input_z". If the "Output_ID" column has a multi-index with "a" and "b", the resulting DataFrame will have columns like "Output_a" and "Output_b".

	When an output dimension is missing for a given observation, the corresponding column will be filled with NaN.

	:param df: DataFrame with columns "Task_ID", "Input_ID", "Input", "Output", "Output_ID"
	:return: DataFrame "Task_ID" and "Input_*" and "Output_*" columns for each "Input_ID" and "Output_ID"
	"""
	# Ensure the dataframe has the expected columns
	if not all(col in df.columns for col in ["Task_ID", "Input_ID", "Input", "Output", "Output_ID"]):
		raise ValueError("DataFrame must contain 'Task_ID', 'Input_ID', 'Input', 'Output', and 'Output_ID' columns.")

	# Ensure the DataFrame is sorted by "Task_ID", "Input_ID", "Input", and "Output"
	df_flat = df.sort_values(by=["Task_ID", "Output_ID", "Output", "Input_ID"]).reset_index(drop=True)

	# Flatten Inputs
	df_flat = df_flat.pivot_table(
		index=['Task_ID', 'Output', 'Output_ID'],
	    columns='Input_ID',
	    values='Input',
	    aggfunc='first').reset_index()

	df_flat.columns = [f"Input_{col}" if col not in ["Task_ID", "Input_ID", "Output", "Output_ID"] else col for col in df_flat.columns]
	df_flat.columns.name = None

	# Flatten Outputs
	df_flat = df_flat.pivot_table(
	    index=['Task_ID'] + [col for col in df_flat.columns if col.startswith("Input_")],
	    columns='Output_ID',
	    values='Output',
	    aggfunc='first'
	).reset_index()

	df_flat.columns = [f'Output_{col}' if col not in ['Task_ID', 'Output_ID'] and not col.startswith("Input_") else col
	                   for col in df_flat.columns]

	df_flat.columns.name = None

	return df_flat

In [9]:
flat_db = pivot_db(db)
flat_db

Unnamed: 0,Task_ID,Input_x,Input_y,Input_z,Output_a,Output_b
0,0,-5.0,0.5,1.0,3.073529,
1,0,-4.0,-1.0,0.0,4.421373,3.004329
2,0,0.0,-0.5,1.0,-4.660727,2.210305
3,0,2.0,-0.5,1.0,-2.132901,0.64188
4,1,-5.0,0.0,0.0,,9.247931
5,1,-4.0,0.0,0.0,4.234164,7.660635
6,1,0.0,-1.0,0.0,4.972271,-7.645895


In [10]:
# Dimensions:
# T : the tasks
# N : the number of all observations across all tasks
# MAX_N_I : the maximum number of inputs per tasks
# I : dimensions of inputs
# O : dimensions of outputs

In [11]:
import jax.numpy as jnp
from jax.lax import cond, while_loop


@jit
def searchsorted_2D(vector, matrix):
	"""
	Search along axis 1 for a vector in a matrix. If found, return the index of the vector.
	If not found, return len(matrix).

	For this function to work, the vectors in the matrix must be sorted lexicographically.
	ex:
	[[1, 1, 0],
	 [1, 2, 1],
	 [1, 2, 2],
	 [2, 1, 3],
	 [2, 2, 1]]

	:param vector: the vector to search for
	:param matrix: the matrix to search in
	:return: the index of the vector in the matrix, or len(matrix) if not found
	"""
	@jit
	def compare_vectors(v1, v2):
	    """Compare two vectors lexicographically. Returns -1 if v1 < v2, 0 if equal, 1 if v1 > v2"""
	    diff = v1 - v2
	    # Find first non-zero element
	    nonzero_mask = diff != 0
	    # If all elements are zero, vectors are equal
	    first_nonzero_idx = jnp.argmax(nonzero_mask)

	    return cond(
	        jnp.any(nonzero_mask),
	        lambda: jnp.array(jnp.sign(diff[first_nonzero_idx]), dtype=jnp.int32),
	        lambda: 0
	    )

	@jit
	def search_condition(state):
	    start, end, found = state
	    return (start < end) & (~found)

	@jit
	def search_step(state):
	    start, end, found = state
	    mid = (start + end) // 2

	    comparison = compare_vectors(vector, matrix[mid])

	    # If vectors are equal, we found it
	    new_found = comparison == 0
	    new_start = cond(comparison < 0, lambda: start, lambda: mid + 1)
	    new_end = cond(comparison < 0, lambda: mid, lambda: end)

	    # If found, return the index in start position
	    final_start = cond(new_found, lambda: mid, lambda: new_start)

	    return final_start, new_end, new_found

	# Initial state: (start, end, found)
	initial_state = (0, len(matrix), False)
	final_start, final_end, found = while_loop(search_condition, search_step, initial_state)

	# Return the found index or len(matrix) if not found
	return cond(found, lambda: final_start, lambda: len(matrix))


searchsorted_2D_vectorized = jit(vmap(searchsorted_2D, in_axes=(0, None)))

In [12]:
@partial(jit, static_argnames=["max_n_i"])
def extract_task_data(_id, task_ids, input_values, output_values, all_inputs, max_n_i):
	"""
	Extract data for a given task ID from the values array and return a row of padded inputs, padded outputs and index_mappings.

	:param _id: the task ID to extract data for
	:param task_ids: the array of the task id of each observation (shape: (N, 1))
	:param input_values: the input values for all tasks  (shape: (T, N, I))
	:param output_values: the output values for all tasks (shape: (T, N, O))
	:param all_inputs: the array of all distinct inputs (shape: (P, I))
	:param max_n_i: the maximum number of inputs per task (scalar)

	:return: a tuple of (padded_input, padded_output, index_mappings)
	   - padded_input: a matrix of shape (MAX_N_I, I) with inputs for the task, padded with NaNs
	   - padded_output: a matrix of shape (MAX_N_I, O) with corresponding outputs for each input, padded with NaNs
	   - index_mappings: a matrix of shape (MAX_N_I,) with indices of the inputs in the all_inputs array. Missing inputs for the task are represented as NaNs.
	"""
	inputs_i = jnp.where(task_ids == _id, input_values, jnp.nan)
	outputs_i = jnp.where(task_ids == _id, output_values, jnp.nan)
	mappings_i = searchsorted_2D_vectorized(inputs_i, all_inputs)

	# Compute index among the whole dataset
	idx_i = jnp.where(jnp.isnan(inputs_i[:, 0]), max_n_i + 1, jnp.cumsum(~jnp.isnan(inputs_i[:, 0])) - 1)

	# Create padded inputs and outputs
	padded_input = jnp.full((max_n_i, inputs_i.shape[-1]), jnp.nan).at[idx_i].set(inputs_i)
	padded_output = jnp.full((max_n_i, outputs_i.shape[-1]), jnp.nan).at[idx_i].set(outputs_i)
	index_mappings = jnp.full((max_n_i,), all_inputs.shape[0] + 1).at[idx_i].set(mappings_i).astype(int)

	return padded_input, padded_output, index_mappings

In [13]:
def preprocess_db(db: pd.DataFrame):
	"""

	:param db:
	:return:
	"""
	# Pivot the database
	db_flat = pivot_db(db)

	# Get task IDs
	task_ids = jnp.array(db_flat["Task_ID"].values, dtype=jnp.int32)[:, None]  # Convert to column vector

	# Get inputs and outputs
	inputs = jnp.array(db_flat.filter(like="Input_").values, dtype=jnp.float32)
	outputs = jnp.array(db_flat.filter(like="Output_").values, dtype=jnp.float32)

	# Get all distinct inputs
	all_inputs = jnp.unique(db_flat.sort_values(by=[col for col in db_flat if col.startswith("Input_")]).filter(like="Input_").values, axis=0)

	# Get maximum number of inputs per task
	MAX_N = jnp.max(jnp.sum(task_ids == task_ids[0], axis=0)).item()  # Maximum number of inputs per task

	# Recover padded inputs, padded outputs and index mappings
	return vmap(extract_task_data, in_axes=(0, None, None, None, None, None))(jnp.unique(task_ids), task_ids, inputs, outputs, all_inputs, MAX_N) + (all_inputs,)

In [14]:
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)

In [15]:
np.asarray(padded_inputs)

array([[[-5. ,  0.5,  1. ],
        [-4. , -1. ,  0. ],
        [ 0. , -0.5,  1. ],
        [ 2. , -0.5,  1. ]],

       [[-5. ,  0. ,  0. ],
        [-4. ,  0. ,  0. ],
        [ 0. , -1. ,  0. ],
        [ nan,  nan,  nan]]], dtype=float32)

In [16]:
np.asarray(padded_outputs)

array([[[ 3.073529  ,         nan],
        [ 4.4213734 ,  3.0043292 ],
        [-4.6607265 ,  2.2103047 ],
        [-2.132901  ,  0.64188004]],

       [[        nan,  9.247931  ],
        [ 4.2341638 ,  7.6606345 ],
        [ 4.972271  , -7.6458955 ],
        [        nan,         nan]]], dtype=float32)

In [17]:
np.asarray(mappings)

array([[1, 2, 5, 6],
       [0, 3, 4, 8]], dtype=int32)

In [18]:
np.asarray(all_inputs)

array([[-5. ,  0. ,  0. ],
       [-5. ,  0.5,  1. ],
       [-4. , -1. ,  0. ],
       [-4. ,  0. ,  0. ],
       [ 0. , -1. ,  0. ],
       [ 0. , -0.5,  1. ],
       [ 2. , -0.5,  1. ]], dtype=float32)

In [19]:
large_db = pd.read_csv("../datasets/large_distinct_input_distinct_hp.csv")
large_db

Unnamed: 0,Task_ID,Input,Output,se_variance,se_lengthscale,noise,Input_ID,Output_ID
0,1,-300.0,-145.746528,1.24,0.76,0.08,x,a
1,1,-299.0,-151.023153,1.24,0.76,0.08,x,a
2,1,-296.5,-131.905148,1.24,0.76,0.08,x,a
3,1,-295.5,-132.007637,1.24,0.76,0.08,x,a
4,1,-293.5,-128.979291,1.24,0.76,0.08,x,a
...,...,...,...,...,...,...,...,...
299957,600,297.0,211.076208,1.40,0.93,0.14,x,a
299958,600,297.5,217.069705,1.40,0.93,0.14,x,a
299959,600,298.0,221.574248,1.40,0.93,0.14,x,a
299960,600,298.5,224.321243,1.40,0.93,0.14,x,a


In [27]:
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(large_db)

In [21]:
small_db = pd.read_csv("../datasets/small_distinct_input_distinct_hp.csv")
small_db

Unnamed: 0,Task_ID,Input,Output,se_variance,se_lengthscale,noise,Input_ID,Output_ID
0,1,-10.0,18.113684,1.70,0.11,0.16,x,a
1,1,-9.5,17.385721,1.70,0.11,0.16,x,a
2,1,-6.5,14.471432,1.70,0.11,0.16,x,a
3,1,-5.5,5.737220,1.70,0.11,0.16,x,a
4,1,-4.0,-9.900233,1.70,0.11,0.16,x,a
...,...,...,...,...,...,...,...,...
394,20,7.0,-26.722314,1.09,0.19,0.15,x,a
395,20,7.5,-31.959543,1.09,0.19,0.15,x,a
396,20,8.0,-35.560319,1.09,0.19,0.15,x,a
397,20,9.0,-40.899537,1.09,0.19,0.15,x,a


In [22]:
preprocess_db(small_db)

(Array([[[-10. ],
         [ -9.5],
         [ -6.5],
         [ -5.5],
         [ -4. ],
         [ -3.5],
         [ -3. ],
         [ -2.5],
         [ -1. ],
         [ -0.5],
         [  1. ],
         [  1.5],
         [  2. ],
         [  3. ],
         [  5. ],
         [  5.5],
         [  6. ],
         [  7. ],
         [  7.5]],
 
        [[ -9.5],
         [ -8.5],
         [ -7.5],
         [ -6. ],
         [ -5.5],
         [ -4.5],
         [ -4. ],
         [ -3.5],
         [ -2.5],
         [ -1.5],
         [ -1. ],
         [  1. ],
         [  2. ],
         [  3. ],
         [  4. ],
         [  4.5],
         [  6. ],
         [  7. ],
         [  8.5]],
 
        [[ -9. ],
         [ -8. ],
         [ -5. ],
         [ -4. ],
         [ -3. ],
         [ -2.5],
         [ -2. ],
         [ -0.5],
         [  0.5],
         [  1. ],
         [  2. ],
         [  2.5],
         [  3.5],
         [  4. ],
         [  4.5],
         [  6.5],
         [  7. ],
    

---
## Comparison

In [15]:
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape, padded_outputs.shape, masks.shape

KeyError: 'ID'

In [11]:
all_inputs_new, padded_inputs_new, padded_outputs_new, index_mappings_new = preprocess_db_new(db)
all_inputs_new.shape, padded_inputs_new.shape, padded_outputs_new.shape, index_mappings_new.shape

((21,), (2, 24), (2, 24), (2, 24))

In [12]:
print(f"{len(all_inputs)} distinct inputs, covering {len(all_inputs) / len(grid) * 100:.2f}% of the grid")

NameError: name 'grid' is not defined

In [189]:
#jnp.allclose(all_inputs, all_inputs_new), jnp.allclose(padded_inputs, padded_inputs_new, equal_nan=True), jnp.allclose(padded_outputs, padded_outputs_new, equal_nan=True), jnp.allclose(masks, masks_new)

## Performance comparison on dummy tasks

### Covariance matrix computation

In [190]:
from Kernax import RBFKernel

kern = RBFKernel(length_scale=jnp.array(0.3), variance=jnp.array(1.))

In [191]:
padded_inputs[0], padded_inputs_new[0], index_mappings_new[0]

(Array([  nan,   nan,   nan,   nan, -496.,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan, -468.,   nan,   nan,   nan,
          nan,   nan, -462.,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan, -441.,   nan,   nan,   nan,
          nan,   nan, -435.,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan, -417., -416.,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan, -397.,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan, -375.,
      

In [192]:
@jit
def map_to_full_cov(dense_cov, all_inputs, mapping):
	# return jnp.full((len(all_inputs), len(all_inputs)), jnp.nan).at[tuple(jnp.meshgrid(mapping, mapping))].set(dense_cov)
	return jnp.full((len(all_inputs), len(all_inputs)), jnp.nan).at[jnp.ix_(mapping, mapping)].set(dense_cov)

In [193]:
@jit
def map_to_full_batch(dense_covs, all_inputs, mappings):
    return vmap(map_to_full_cov, in_axes=(0, None, 0))(dense_covs, all_inputs, mappings)

In [194]:
# On single input
np.asarray(kern(padded_inputs[0]))

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan,  1., nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

In [195]:
np.asarray(map_to_full_cov(kern(padded_inputs_new[0]), all_inputs, index_mappings_new[0]))

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan,  1., nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

In [196]:
jnp.allclose(kern(padded_inputs[0]), map_to_full_cov(kern(padded_inputs_new[0]), all_inputs, index_mappings_new[0]), equal_nan=True)

Array(True, dtype=bool)

In [197]:
%timeit kern(padded_inputs[0]).block_until_ready()

276 μs ± 12.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [198]:
%timeit map_to_full_cov(kern(padded_inputs_new[0]), all_inputs, index_mappings_new[0]).block_until_ready()

361 μs ± 13 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [199]:
a = kern(padded_inputs).block_until_ready()

In [200]:
b = map_to_full_batch(kern(padded_inputs_new), all_inputs, index_mappings_new).block_until_ready()

In [201]:
jnp.allclose(a, b, equal_nan=True)

Array(True, dtype=bool)

In [202]:
%timeit kern(padded_inputs).block_until_ready()

92.9 ms ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [203]:
%timeit map_to_full_batch(kern(padded_inputs_new), all_inputs, index_mappings_new).block_until_ready()

160 ms ± 5.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Sum of invs

In [204]:
from jax.scipy.linalg import cho_factor, cho_solve

In [205]:
@jit
def full_pad_sum_of_inv(inputs, masks, all_inputs, kernel):
	"""
	compute the sum of inverses of all cross-covariance matrices for each input in inputs.
	It uses a full padding approach, where the inputs are padded to the size of all_inputs and aligned.
	masks gives indices where the inputs are valid (True) or padded (False).

	:param inputs:
	:param masks:
	:param all_inputs:
	:param kernel:
	:return:
	"""
	jitter = 1e-8  # Small value to ensure numerical stability
	covs = kernel(inputs)

	small_eye = jnp.broadcast_to(jnp.eye(covs.shape[-1]), covs.shape)

	# covs is padded with NaNs. Replace them by their corresponding identity rows/cols
	masks_2D = masks[:, :, None] & masks[:, None, :]
	covs = jnp.where(masks_2D, covs, small_eye)

	covs_U, _ = cho_factor(covs + small_eye * jitter)
	covs_inv = cho_solve((covs_U, False), small_eye)
	covs_inv -= jnp.where(masks_2D, 0, small_eye)  # Correction on the diagonal
	return covs_inv.sum(axis=0)

In [206]:
@jit
def dense_pad_sum_of_inv(inputs, mappings, all_inputs, kernel):
	"""
	compute the sum of inverses of all cross-covariance matrices for each input in inputs.
	It uses a dense padding approach, where the inputs are padded to the size of MAX_N and not aligned. Their positions in the all_inputs array are given by mappings.

	:param inputs:
	:param mappings:
	:param all_inputs:
	:param kernel:
	:return:
	"""
	jitter = 1e-8  # Small value to ensure numerical stability
	covs = kernel(inputs)

	small_eye = jnp.broadcast_to(jnp.eye(covs.shape[-1]), covs.shape)

	# Some covs may still end with a few NaNs, so we replace them by their corresponding identity rows/cols
	eyed_covs = jnp.where(jnp.isnan(covs), small_eye, covs)

	covs_U, _ = cho_factor(eyed_covs + small_eye * jitter)
	covs_inv = cho_solve((covs_U, False), small_eye)
	covs_inv -= jnp.where(jnp.isnan(covs), small_eye, 0)  # Correction on the diagonal

	# Now we need to map the covs_inv to the all_inputs array
	return jnp.nan_to_num(map_to_full_batch(covs_inv, all_inputs, mappings)).sum(axis=0)

In [207]:
a = full_pad_sum_of_inv(padded_inputs, masks, all_inputs, kern)

In [208]:
b = dense_pad_sum_of_inv(padded_inputs_new, index_mappings_new, all_inputs, kern)

In [213]:
jnp.allclose(a, b, equal_nan=True, atol=1e-6)

Array(True, dtype=bool)

In [210]:
%timeit full_pad_sum_of_inv(padded_inputs, masks, all_inputs, kern).block_until_ready()

4.28 s ± 32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [211]:
%timeit dense_pad_sum_of_inv(padded_inputs_new, index_mappings_new, all_inputs, kern).block_until_ready()

256 ms ± 5.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
