# Benchmarks - Preprocess DB

**Main considerations when implementing Preprocess DB**

Magma and MagmaClust both work on unaligned sequences of varying sizes.
Yet, we want to be able to perform as many operations in a vectorisable and jittable way.

This mean that we need to pad the sequences to the same length, aligning them in the process, and then mask the padded values.
A mask should be computed at the first step of the process to not be re-computed every time we need to mask the values.

This way, each operation can have custom logic to handle padded values, all while preserving vectorisation and jitting.


---
## Setup

In [1]:
# Standard library
import os

# Config
os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import vmap, jit
import jax.numpy as jnp

import numpy as np
import pandas as pd

# Initialize random key
key = jax.random.PRNGKey(0)

In [3]:
# Local


In [4]:
# Set constants
M = 500  # Number of sequences
MIN_N = 100  # Minimum sequence length
MAX_N = 1000  # Maximum sequence length
grid = jnp.arange(-2000, 2000, 1, dtype=jnp.float64)  # Grid to pick inputs from

---
## Data

In [5]:
def generate_dummy_db(M: int, MIN_N: int, MAX_N: int, grid: jnp.array, key: jnp.array):
	# We fill DB with random sequences
	data = []
	for m in range(M):
		key, subkey = jax.random.split(key)
		n_points = jax.random.randint(subkey, (), MIN_N, MAX_N)
		for n in range(n_points):
			key, subkey1, subkey2 = jax.random.split(key, 3)
			data.append({
				"ID": m,
				"Input": jax.random.choice(subkey1, grid, (1,))[0].item(),
				"Output": jax.random.uniform(subkey2, (), jnp.float64, -5, 5).item()
			})
	return pd.DataFrame(data)

In [6]:
db = generate_dummy_db(M, MIN_N, MAX_N, grid, key)
db

Unnamed: 0,ID,Input,Output
0,0,1408.0,-0.957674
1,0,-1356.0,2.742172
2,0,1989.0,-4.571886
3,0,-1029.0,0.661192
4,0,282.0,-2.854472
...,...,...,...
270084,499,1575.0,0.043677
270085,499,-1662.0,1.769157
270086,499,-1604.0,2.336939
270087,499,1080.0,1.489557


---
## Current implementation

In [7]:
# Uses .values for even faster iteration
@jit
def extract_id_data(_id, values, all_inputs):
	"""
	Extract data for a given ID from the values array and return a row of padded inputs, padded outputs and mask.

	:param _id:
	:param id_index:
	:param values:
	:param all_inputs:
	:return:
	"""
	padded_input = jnp.full((len(all_inputs),), jnp.nan)
	padded_output = jnp.full((len(all_inputs),), jnp.nan)
	mask = jnp.zeros((len(all_inputs),), dtype=bool)

	idx = jnp.searchsorted(all_inputs, jnp.where(values[:, 0] == _id, values[:, 1], jnp.nan))

	return padded_input.at[idx].set(values[:, 1]), padded_output.at[idx].set(values[:, 2]), mask.at[idx].set(True)

def preprocess_db(db: pd.DataFrame):
	"""

	:param db: the db to process, with columns "ID", "Input" and "Output", in that order
	:return: a tuple of (all_inputs, padded_inputs, padded_outputs, masks)
	   - all_inputs: a matrix of shape (P, ) with all distinct inputs
	   - padded_inputs: a matrix of shape (M, P) where M is the number of sequences and P is the number of distinct
	   inputs. Missing inputs for each sequence are represented as NaNs.
	   - padded_outputs: a matrix of shape (M, P) with corresponding output for each input and NaNs for missing inputs
	   - masks: a matrix of shape (M, P) with 1 where the input is valid and 0 where it is padded
	"""
	# Get all distinct inputs
	all_ids = jnp.array(db["ID"].unique())
	all_inputs = jnp.sort(jnp.array(db["Input"].unique()))

	# Initialise padded inputs, padded outputs and masks
	padded_inputs, padded_outputs, masks = vmap(extract_id_data, in_axes=(0, None, None))(all_ids, db[["ID", "Input", "Output"]].values, all_inputs)

	return all_inputs, padded_inputs, padded_outputs, masks

---
## Custom implementation(s)

---
## Comparison

In [9]:
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)

In [10]:
all_inputs_new, padded_inputs_new, padded_outputs_new, masks_new = preprocess_db_new(db)

In [11]:
print(f"{len(all_inputs)} distinct inputs, covering {len(all_inputs)/len(grid)*100:.2f}% of the grid")

4000 distinct inputs, covering 100.00% of the grid


In [12]:
jnp.where(db["ID"].values == 0)[0].shape[0]

248

In [13]:
jnp.allclose(all_inputs, all_inputs_new), jnp.allclose(padded_inputs, padded_inputs_new, equal_nan=True), jnp.allclose(padded_outputs, padded_outputs_new, equal_nan=True), jnp.allclose(masks, masks_new)

(Array(True, dtype=bool),
 Array(True, dtype=bool),
 Array(True, dtype=bool),
 Array(True, dtype=bool))

---
## Conclusion

itertuple version is roughly 4-5x faster. Using a dictionnary rather than searchsorted doesn't change performance significantly.

Vectorised preprocessing is in its own league, reducing the time to process from 7min to 0.1sec.

---