# Benchmarks - Preprocess DB

**Main considerations when implementing Preprocess DB**

The goal of preprocess_db is to go from a pandas dataframe with columns "Task_ID", "Input", "Input_ID", "Output" and "Output_ID" to a set of tensors containing all inputs, padded inputs, padded outputs and mappings, that can be used throughout the MagmaClustPy library.

Magma and MagmaClust both work on unaligned sequences of varying sizes.
Yet, we want to be able to perform as many operations in a vectorisable and jittable way.

This mean that we need to pad the sequences to the same length.
The most straightforward way to do this is to align the sequences to the union of all distinct inputs, but this leads to a huge memory footprint and costly data movement in and out of the GPU.
A better way is to pad the sequences to the maximum length among all task sequences, and use a mapping of indices when we need to align the inputs to the union of all inputs.

As MagmaClust can handle multi-input tasks, we need multiple mappings, one for each input.
The union of all inputs from a specific dimension will have a specific length, but once again we want to be able to store them in a single tensor.
This means that even the grid of all_inputs must be padded to the maximum length across all dimensions of inputs.

---
## Setup

In [1]:
USE_X64 = False  # Set to True to use 64-bit precision, False for 32-bit precision

In [2]:
# Standard library
import os

# Config
os.environ['JAX_ENABLE_X64'] = str(USE_X64).lower()

In [3]:
# Third party
import jax
from jax import vmap, jit
import jax.numpy as jnp
import jax.random as jr

import numpy as np
import pandas as pd

# Initialize random key
key = jax.random.PRNGKey(42)

In [4]:
# Local


In [5]:
# Set constants
M = 2  # Number of tasks
INPUTS_ID = ["x", "y", "z"]  # Each dimension of inputs
MIN_N = 3  # Minimum inputs per task
MAX_N = 5  # Maximum inputs per task
OUTPUTS_ID = ["a", "b"]  # Each dimension of outputs
GRIDS = [jnp.arange(-5., 5., 1.), jnp.arange(-1., 1., 0.5), jnp.arange(0., 2., 1)]  # Grid to pick inputs from, for each input dimension
OUTPUT_RANGES = [(-5, 5), (-10, 10)]  # Ranges for outputs, for each output dimension

---
## Data

In [6]:
def generate_dummy_db(M: int, INPUTS_ID: [str], MIN_N: int, MAX_N: int, OUTPUTS_ID: [str], GRIDS: [jnp.array], drop_output_rate: float = 0., key: jnp.array = jax.random.PRNGKey(41)):
	"""
	Generate a dummy database with random inputs and outputs, following the expected structure for MagmaClustPy.

	:param M: Number of tasks
	:param INPUTS_ID: List of input IDs, each representing a dimension of inputs
	:param MIN_N: Minimum number of inputs per task
	:param MAX_N: Maximum number of inputs per task
	:param OUTPUTS_ID: List of output IDs, each representing a dimension of outputs
	:param GRIDS: List of grids to pick inputs from, one for each input dimension
	:param drop_output_rate: Probability of dropping an output value. Default is 0, meaning no outputs are dropped.
	:param key: JAX random key for reproducibility

	:return: A pandas DataFrame with columns "ID", "Input", "Input_ID", "Output", "Output_ID"
	"""
	data = []

	for m in range(M):
		key, subkey1, subkey2 = jr.split(key, 3)

		n_points = jr.randint(subkey1, (), MIN_N, MAX_N)  # This task's number of points
		inputs = [jr.choice(subkey2, grid, (n_points,), replace=g != 0) for g, grid in enumerate(GRIDS)]  # Randomly pick inputs from the grids
		# We set replace=False for the first grid, to ensure we have distinct inputs in at least one dimension.

		for n in range(n_points):
			for o, output_id in enumerate(OUTPUTS_ID):
				key, subkey1, subkey2 = jr.split(key, 3)

				if jr.uniform(subkey1) < drop_output_rate:
					# Drop output value with a certain probability
					continue

				output_val = jr.uniform(subkey2, (), jnp.float32, *OUTPUT_RANGES[o])

				for i, input_id in enumerate(INPUTS_ID):
						data.append({
							"Task_ID": m,
							"Input": inputs[i][n].item(),
							"Input_ID": input_id,
							"Output": output_val.item(),
							"Output_ID": output_id
						})

	return pd.DataFrame(data)

In [7]:
db = generate_dummy_db(M, INPUTS_ID, MIN_N, MAX_N, OUTPUTS_ID, GRIDS, 0.15, key)
db

Unnamed: 0,Task_ID,Input,Input_ID,Output,Output_ID
0,0,0.0,x,-4.660727,a
1,0,-0.5,y,-4.660727,a
2,0,1.0,z,-4.660727,a
3,0,0.0,x,2.210305,b
4,0,-0.5,y,2.210305,b
5,0,1.0,z,2.210305,b
6,0,2.0,x,-2.132901,a
7,0,-0.5,y,-2.132901,a
8,0,1.0,z,-2.132901,a
9,0,2.0,x,0.64188,b


---
## Implementation

In [8]:
def extract_all_inputs(db: pd.DataFrame):
	"""
	Extract all distinct inputs from the database.

	:param db: The database to process, with columns "Task_ID", "Input", "Input_ID", "Output", "Output_ID"
	:return: A sorted array of all distinct inputs
	"""
	db_sorted = db.sort_values(by=["Task_ID", "Output_ID", "Output", "Input_ID"]).reset_index(drop=True)

	INPUT_DIM = len(db_sorted["Input_ID"].unique())

	triplets = jnp.stack([jnp.array(db_sorted["Input"][i::INPUT_DIM]) for i in range(INPUT_DIM)])

	return jnp.unique(triplets, axis=1).T

In [16]:
db = pd.read_csv("../datasets/large_shared_input_shared_hp.csv")

In [17]:
%timeit extract_all_inputs(db).block_until_ready()

122 ms ± 5.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
from MagmaClustPy.utils import preprocess_db

In [19]:
%timeit preprocess_db(db)[0].block_until_ready()

5.69 s ± 631 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
@jit
def extract_task_data(_id, values, all_inputs, to_fill):
	"""
	Extract data for a given task ID from the values array and return a row of padded inputs, padded outputs and index_mappings.

	:param _id:
	:param id_index:
	:param values:
	:param all_inputs:
	:return:
	"""
	inputs_i = jnp.where(values[:,0] == _id, values[:,1], jnp.nan)
	outputs_i = jnp.where(values[:,0] == _id, values[:,2], jnp.nan)
	mappings_i = jnp.searchsorted(all_inputs, inputs_i)

	# Compute index among the whole dataset
	idx_i = jnp.where(jnp.isnan(inputs_i), to_fill.shape[0] + 1, jnp.cumsum(~jnp.isnan(inputs_i)) - 1)

	# Create padded inputs and outputs
	padded_input = jnp.full(to_fill.shape[0], jnp.nan).at[idx_i].set(inputs_i)
	padded_output = jnp.full(to_fill.shape[0], jnp.nan).at[idx_i].set(outputs_i)
	index_mappings = jnp.full(to_fill.shape[0], all_inputs.shape[0] + 1).at[idx_i].set(mappings_i).astype(int)

	return padded_input, padded_output, index_mappings


def preprocess_db(db: pd.DataFrame):
	"""

	:param db: the db to process, with columns "ID", "Input" and "Output", in that order
	:return: a tuple of (all_inputs, padded_inputs, padded_outputs, masks)
	   - all_inputs: a matrix of shape (P, ) with all distinct inputs
	   - padded_inputs: a matrix of shape (M, MAX_N) where M is the number of sequences and MAX_N is the max number of points among all sequences. Missing inputs for each sequence are represented as NaNs.
	   - padded_outputs: a matrix of shape (M, MAX_N) with corresponding output for each input and NaNs for missing inputs
	   - index_mappings: a matrix of shape (M, MAX_N) with indices of the inputs in the all_inputs array. Missing inputs for each sequence are represented as -1.
	"""
	# Get all distinct inputs
	db_sorted = db.sort_values(['ID', 'Input'])
	all_ids = jnp.array(db_sorted["ID"].unique())
	all_inputs = jnp.sort(jnp.array(db_sorted["Input"].unique()))
	MAX_N = db_sorted.groupby("ID")["Input"].count().max()  # Maximum number of points in a sequence
	to_fill = jnp.full((MAX_N), jnp.nan)  # Placeholder for padded inputs and outputs

	# Initialise padded inputs, padded outputs and masks
	padded_inputs, padded_outputs, index_mappings = vmap(extract_id_data, in_axes=(0, None, None, None))(all_ids,
																											 db_sorted[
																												 ["ID",
																												  "Input",
																												  "Output"]].values,
																											 all_inputs,
																											 to_fill)

	return all_inputs, padded_inputs, padded_outputs, index_mappings


---
## Custom implementation(s)

Defaults we wish to fix:

For M individuals, each having between MIN_N and MAX_N observations over a grid of size G, we currently pad the inputs to the size of the union of all distinct inputs, named P. With a big M, P is likely to approach G, even though MAX_N is much smaller than G.

Later in the process, this leads to inversion of huge padded matrices. As those matrices are padded with identity vectors, the inversion in itself is not made terribly worse. However, the memory footprint is much larger than it needs to be, and the movement of data in and out of the GPU is much more costly.

Rather than "aligning" the inputs to the union of distinct inputs, we can just compute the indices of the inputs in the grid, and use those indices to map individual inputs/covariance matrices/precision matrices to the grid when needed. By filling indices whith nans for smaller sequences, we ensure they can still be processed in a vectorised way, while limiting the memory footprint to it's theoretical minimum.

In this approach, we don't carry a mask arround, but only a mapping of indices to inputs in the distinct inputs grid. This mapping is used to compute the covariance matrices.

In [14]:
@jit
def extract_id_data_new(_id, values, all_inputs, to_fill):
	"""
	Extract data for a given ID from the values array and return a row of padded inputs, padded outputs and index_mappings.

	:param _id:
	:param id_index:
	:param values:
	:param all_inputs:
	:return:
	"""
	inputs_i = jnp.where(values[:,0] == _id, values[:,1], jnp.nan)
	outputs_i = jnp.where(values[:,0] == _id, values[:,2], jnp.nan)
	mappings_i = jnp.searchsorted(all_inputs, inputs_i)

	# Compute index among the whole dataset
	idx_i = jnp.where(jnp.isnan(inputs_i), to_fill.shape[0] + 1, jnp.cumsum(~jnp.isnan(inputs_i)) - 1)

	# Create padded inputs and outputs
	padded_input = jnp.full(to_fill.shape[0], jnp.nan).at[idx_i].set(inputs_i)
	padded_output = jnp.full(to_fill.shape[0], jnp.nan).at[idx_i].set(outputs_i)
	index_mappings = jnp.full(to_fill.shape[0], all_inputs.shape[0] + 1).at[idx_i].set(mappings_i).astype(int)

	return padded_input, padded_output, index_mappings


def preprocess_db_new(db: pd.DataFrame):
	"""

	:param db: the db to process, with columns "ID", "Input" and "Output", in that order
	:return: a tuple of (all_inputs, padded_inputs, padded_outputs, masks)
	   - all_inputs: a matrix of shape (P, ) with all distinct inputs
	   - padded_inputs: a matrix of shape (M, MAX_N) where M is the number of sequences and MAX_N is the max number of points among all sequences. Missing inputs for each sequence are represented as NaNs.
	   - padded_outputs: a matrix of shape (M, MAX_N) with corresponding output for each input and NaNs for missing inputs
	   - index_mappings: a matrix of shape (M, MAX_N) with indices of the inputs in the all_inputs array. Missing inputs for each sequence are represented as -1.
	"""
	# Get all distinct inputs
	db_sorted = db.sort_values(['ID', 'Input'])
	all_ids = jnp.array(db_sorted["ID"].unique())
	all_inputs = jnp.sort(jnp.array(db_sorted["Input"].unique()))
	MAX_N = db_sorted.groupby("ID")["Input"].count().max()  # Maximum number of points in a sequence
	to_fill = jnp.full((MAX_N), jnp.nan)  # Placeholder for padded inputs and outputs

	# Initialise padded inputs, padded outputs and masks
	padded_inputs, padded_outputs, index_mappings = vmap(extract_id_data_new, in_axes=(0, None, None, None))(all_ids,
																											 db_sorted[
																												 ["ID",
																												  "Input",
																												  "Output"]].values,
																											 all_inputs,
																											 to_fill)

	return all_inputs, padded_inputs, padded_outputs, index_mappings

---
## Comparison

In [15]:
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape, padded_outputs.shape, masks.shape

KeyError: 'ID'

In [11]:
all_inputs_new, padded_inputs_new, padded_outputs_new, index_mappings_new = preprocess_db_new(db)
all_inputs_new.shape, padded_inputs_new.shape, padded_outputs_new.shape, index_mappings_new.shape

((21,), (2, 24), (2, 24), (2, 24))

In [12]:
print(f"{len(all_inputs)} distinct inputs, covering {len(all_inputs) / len(grid) * 100:.2f}% of the grid")

NameError: name 'grid' is not defined

In [189]:
#jnp.allclose(all_inputs, all_inputs_new), jnp.allclose(padded_inputs, padded_inputs_new, equal_nan=True), jnp.allclose(padded_outputs, padded_outputs_new, equal_nan=True), jnp.allclose(masks, masks_new)

## Performance comparison on dummy tasks

### Covariance matrix computation

In [190]:
from Kernax import RBFKernel

kern = RBFKernel(length_scale=jnp.array(0.3), variance=jnp.array(1.))

In [191]:
padded_inputs[0], padded_inputs_new[0], index_mappings_new[0]

(Array([  nan,   nan,   nan,   nan, -496.,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan, -468.,   nan,   nan,   nan,
          nan,   nan, -462.,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan, -441.,   nan,   nan,   nan,
          nan,   nan, -435.,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan, -417., -416.,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan, -397.,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
          nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan, -375.,
      

In [192]:
@jit
def map_to_full_cov(dense_cov, all_inputs, mapping):
	# return jnp.full((len(all_inputs), len(all_inputs)), jnp.nan).at[tuple(jnp.meshgrid(mapping, mapping))].set(dense_cov)
	return jnp.full((len(all_inputs), len(all_inputs)), jnp.nan).at[jnp.ix_(mapping, mapping)].set(dense_cov)

In [193]:
@jit
def map_to_full_batch(dense_covs, all_inputs, mappings):
    return vmap(map_to_full_cov, in_axes=(0, None, 0))(dense_covs, all_inputs, mappings)

In [194]:
# On single input
np.asarray(kern(padded_inputs[0]))

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan,  1., nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

In [195]:
np.asarray(map_to_full_cov(kern(padded_inputs_new[0]), all_inputs, index_mappings_new[0]))

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan,  1., nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

In [196]:
jnp.allclose(kern(padded_inputs[0]), map_to_full_cov(kern(padded_inputs_new[0]), all_inputs, index_mappings_new[0]), equal_nan=True)

Array(True, dtype=bool)

In [197]:
%timeit kern(padded_inputs[0]).block_until_ready()

276 μs ± 12.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [198]:
%timeit map_to_full_cov(kern(padded_inputs_new[0]), all_inputs, index_mappings_new[0]).block_until_ready()

361 μs ± 13 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [199]:
a = kern(padded_inputs).block_until_ready()

In [200]:
b = map_to_full_batch(kern(padded_inputs_new), all_inputs, index_mappings_new).block_until_ready()

In [201]:
jnp.allclose(a, b, equal_nan=True)

Array(True, dtype=bool)

In [202]:
%timeit kern(padded_inputs).block_until_ready()

92.9 ms ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [203]:
%timeit map_to_full_batch(kern(padded_inputs_new), all_inputs, index_mappings_new).block_until_ready()

160 ms ± 5.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Sum of invs

In [204]:
from jax.scipy.linalg import cho_factor, cho_solve

In [205]:
@jit
def full_pad_sum_of_inv(inputs, masks, all_inputs, kernel):
	"""
	compute the sum of inverses of all cross-covariance matrices for each input in inputs.
	It uses a full padding approach, where the inputs are padded to the size of all_inputs and aligned.
	masks gives indices where the inputs are valid (True) or padded (False).

	:param inputs:
	:param masks:
	:param all_inputs:
	:param kernel:
	:return:
	"""
	nugget = 1e-8  # Small value to ensure numerical stability
	covs = kernel(inputs)

	small_eye = jnp.broadcast_to(jnp.eye(covs.shape[-1]), covs.shape)

	# covs is padded with NaNs. Replace them by their corresponding identity rows/cols
	masks_2D = masks[:, :, None] & masks[:, None, :]
	covs = jnp.where(masks_2D, covs, small_eye)

	covs_U, _ = cho_factor(covs + small_eye * nugget)
	covs_inv = cho_solve((covs_U, False), small_eye)
	covs_inv -= jnp.where(masks_2D, 0, small_eye)  # Correction on the diagonal
	return covs_inv.sum(axis=0)

In [206]:
@jit
def dense_pad_sum_of_inv(inputs, mappings, all_inputs, kernel):
	"""
	compute the sum of inverses of all cross-covariance matrices for each input in inputs.
	It uses a dense padding approach, where the inputs are padded to the size of MAX_N and not aligned. Their positions in the all_inputs array are given by mappings.

	:param inputs:
	:param mappings:
	:param all_inputs:
	:param kernel:
	:return:
	"""
	nugget = 1e-8  # Small value to ensure numerical stability
	covs = kernel(inputs)

	small_eye = jnp.broadcast_to(jnp.eye(covs.shape[-1]), covs.shape)

	# Some covs may still end with a few NaNs, so we replace them by their corresponding identity rows/cols
	eyed_covs = jnp.where(jnp.isnan(covs), small_eye, covs)

	covs_U, _ = cho_factor(eyed_covs + small_eye * nugget)
	covs_inv = cho_solve((covs_U, False), small_eye)
	covs_inv -= jnp.where(jnp.isnan(covs), small_eye, 0)  # Correction on the diagonal

	# Now we need to map the covs_inv to the all_inputs array
	return jnp.nan_to_num(map_to_full_batch(covs_inv, all_inputs, mappings)).sum(axis=0)

In [207]:
a = full_pad_sum_of_inv(padded_inputs, masks, all_inputs, kern)

In [208]:
b = dense_pad_sum_of_inv(padded_inputs_new, index_mappings_new, all_inputs, kern)

In [213]:
jnp.allclose(a, b, equal_nan=True, atol=1e-6)

Array(True, dtype=bool)

In [210]:
%timeit full_pad_sum_of_inv(padded_inputs, masks, all_inputs, kern).block_until_ready()

4.28 s ± 32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [211]:
%timeit dense_pad_sum_of_inv(padded_inputs_new, index_mappings_new, all_inputs, kern).block_until_ready()

256 ms ± 5.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
