# Benchmarks - Preprocess DB

**Main considerations when implementing Preprocess DB**

Magma and MagmaClust both work on unaligned sequences of varying sizes.
Yet, we cant to be able to perform as many operations in a vectorisable and jittable way.

This need that we need to pad the sequences to the same length, aligning them in the process, and then mask the padded values.
A mask should be computed at the first step of the process to not be re-computed every time we need to mask the values.

This way, each operation can have custom logic to handle padded values, all while preserving vectorisation and jitting.


---
## Setup

In [1]:
# Standard library
import os

# Config
os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
import jax.numpy as jnp

import numpy as np
import pandas as pd

# Initialize random key
key = jax.random.PRNGKey(0)

In [3]:
# Local

In [4]:
# Set constants
M = 50  # Number of sequences
MIN_N = 10  # Minimum sequence length
MAX_N = 100  # Maximum sequence length
grid = jnp.arange(-200, 200, 1, dtype=jnp.float64)  # Grid to pick inputs from

---
## Data

In [5]:
db = pd.DataFrame({"ID": [], "Input": [], "Output": []})

In [6]:
# We fill DB with random sequences
data = []
for m in range(M):
	key, subkey = jax.random.split(key)
	n_points = jax.random.randint(subkey, (), MIN_N, MAX_N)
	for n in range(n_points):
		key, subkey1, subkey2 = jax.random.split(key, 3)
		data.append({
			"ID": m,
			"Input": jax.random.choice(subkey1, grid, (1,))[0].item(),
			"Output": jax.random.uniform(subkey2, (), jnp.float64, -5, 5).item()
		})
db = pd.DataFrame(data)

In [7]:
db

Unnamed: 0,ID,Input,Output
0,0,8.0,-0.957674
1,0,44.0,2.742172
2,0,189.0,-4.571886
3,0,-29.0,0.661192
4,0,82.0,-2.854472
...,...,...,...
2706,49,153.0,-1.899048
2707,49,28.0,3.695970
2708,49,58.0,-4.940126
2709,49,-71.0,2.483545


---
## Current implementation

No comparison with current implementation is possible, as the current implementation simply doesn't use padding

---
## Custom implementation(s)

In [8]:
def preprocess_db(db: pd.DataFrame):
	"""

	:param db: the db to process, with columns "ID", "Input" and "Output"
	:return: a tuple of (all_inputs, padded_inputs, padded_outputs, masks)
	   - all_inputs: a matrix of shape (P, ) with all distinct inputs
	   - padded_inputs: a matrix of shape (M, P) where M is the number of sequences and P is the number of distinct
	   inputs. Missing inputs for each sequence are represented as NaNs.
	   - padded_outputs: a matrix of shape (M, P) with corresponding output for each input and NaNs for missing inputs
	   - masks: a matrix of shape (M, P) with 1 where the input is valid and 0 where it is padded
	"""
	# Get all distinct inputs
	all_ids = jnp.array(db["ID"])
	all_inputs = jnp.sort(jnp.array(db["Input"].unique()))

	# Initialise padded inputs, padded outputs and masks
	padded_inputs = jnp.full((len(all_ids), len(all_inputs)), jnp.nan)
	padded_outputs = jnp.full((len(all_ids), len(all_inputs)), jnp.nan)
	masks = jnp.zeros((len(all_ids), len(all_inputs)), dtype=bool)

	# Fill padded inputs, padded outputs and masks
	for i, _id in enumerate(db["ID"].unique()):
		sub_db = db[db["ID"] == _id]
		idx = jnp.searchsorted(all_inputs, jnp.array(sub_db["Input"]))
		padded_inputs = padded_inputs.at[i, idx].set(sub_db["Input"].values)
		padded_outputs = padded_outputs.at[i, idx].set(sub_db["Output"].values)
		masks = masks.at[i, idx].set(jnp.ones(len(sub_db), dtype=bool))

	return all_inputs, padded_inputs, padded_outputs, masks


---
## Comparison

In [9]:
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)

In [10]:
print(f"{len(all_inputs)} distinct inputs, {len(all_inputs)/len(grid)*100:.2f}% of grid")

398 distinct inputs, 99.50% of grid


In [11]:
np.asarray(padded_inputs)

array([[  nan, -199.,   nan, ...,   nan,   nan,   nan],
       [-200.,   nan,   nan, ...,   nan,   nan,   nan],
       [  nan,   nan,   nan, ...,   nan,   nan,   nan],
       ...,
       [  nan,   nan,   nan, ...,   nan,   nan,   nan],
       [  nan,   nan,   nan, ...,   nan,   nan,   nan],
       [  nan,   nan,   nan, ...,   nan,   nan,   nan]])

In [12]:
np.asarray(padded_outputs)

array([[        nan,  4.90038074,         nan, ...,         nan,
                nan,         nan],
       [-3.93707068,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       ...,
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan]])

In [13]:
np.asarray(masks)

array([[False,  True, False, ..., False, False, False],
       [ True, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])

---
## Conclusion

---