# Benchmark notebook

This notebook evaluates the speed gain of this Python translation by performing the whole train+predict pipeline on common toy datasets.

---

## Setup

In [1]:
import os

os.environ['JAX_ENABLE_X64'] = "True"

import time

from jax import jit, vmap, lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.linalg import cho_solve, cho_factor
from jax.scipy.optimize import minimize
from jax.scipy.stats.multivariate_normal import logpdf

import pandas as pd

from MagmaClustPy.kernels import SEMagmaKernel
from MagmaClustPy.hyperpost import hyperpost
from MagmaClustPy.hp_optimisation import optimise_hyperparameters
from MagmaClustPy.utils import preprocess_db

In [2]:
MAX_ITER = 25
CONVERG_THRESHOLD = 1e-3
nugget = jnp.array(1e-6)

In [3]:
dataset = "small"
grids = {
	"small": jnp.arange(-10, 10, 0.5),
	"medium": jnp.arange(-100, 100, 0.5),
	"large": jnp.arange(-500, 500, 0.5),
	"custom": jnp.arange(-20, 20, 0.5)
}
grid = grids[dataset] if dataset in grids else grids["custom"]
common_input = True
common_hp = True

---

## Start timer

In [4]:
start = time.time()

---

## Data import

In [5]:
db = pd.read_csv(f"./dummy_datasets/{dataset}_{'common_input' if common_input else 'distinct_input'}_{'common_hp' if common_hp else 'distinct_hp'}.csv")
# db has 3 columns: ID, Input, Output

In [6]:
# First 90% of IDs are for training, last 10% for testing
train_ids = db["ID"].unique()[:int(0.9 * db["ID"].nunique())]
test_ids = db["ID"].unique()[int(0.9 * db["ID"].nunique()):]

db_train = db[db["ID"].isin(train_ids)]
db_test = db[db["ID"].isin(test_ids)]

# N.b: data is already sort by ID and Input in the toy datasets, but in a real case scenario, we would need to sort it

---

## Data preprocessing

In [7]:
# We need to convert the dataframe into jax arrays
# inputs: (M, N) timestamps
# outputs: (M, N) observed outputs
# unique_inputs: (P,) unique timestamps (if common_input, P = N)
all_inputs_train, padded_inputs_train, padded_outputs_train, masks_train = preprocess_db(db_train)
all_inputs_train.shape, padded_inputs_train.shape

((15,), (18, 15))

---

## Training

In [8]:
# Priors
prior_mean = jnp.array(0.)
mean_kernel = SEMagmaKernel(length_scale=0.9, variance=1.5)

if common_hp:
	task_kernel = SEMagmaKernel(length_scale=0.3, variance=1.)
else:
	task_kernel = SEMagmaKernel(length_scale=jnp.array([0.3] * padded_inputs_train.shape[0]), variance=jnp.array([1.] * padded_inputs_train.shape[0]))

In [9]:
prev_mean_llh = jnp.inf
prev_task_llh = jnp.inf

for i in range(MAX_ITER):
	print(f"Iteration {i:4}\tLlhs: {prev_mean_llh:12.4f}, {prev_task_llh:12.4f}\tMean: {mean_kernel}\t Task: {task_kernel}")
	# e-step: compute hyper-posterior
	post_mean, post_cov = hyperpost(padded_inputs_train, padded_outputs_train, masks_train, prior_mean, mean_kernel, task_kernel, all_inputs=all_inputs_train, nugget=nugget)

	# m-step: update hyperparameters
	mean_kernel, task_kernel, mean_llh, task_llh = optimise_hyperparameters(mean_kernel, task_kernel, padded_inputs_train, padded_outputs_train, all_inputs_train, prior_mean, post_mean, post_cov, masks_train, nugget=nugget)

	# Check convergence
	if jnp.abs(prev_mean_llh - mean_llh) < CONVERG_THRESHOLD and jnp.abs(prev_task_llh - task_llh) < CONVERG_THRESHOLD:
		print(f"Convergence reached after {i} iterations.\tLlhs: {mean_llh:12.4f}, {task_llh:12.4f}\tMean: {mean_kernel}\t Task: {task_kernel}")
		break

	if i == MAX_ITER - 1:
		print(f"WARNING: Maximum number of iterations reached. Last modif: {jnp.abs(prev_mean_llh - mean_llh).item()} & {jnp.abs(prev_task_llh - task_llh).item()}")

	prev_mean_llh = mean_llh
	prev_task_llh = task_llh

Iteration    0	Llhs:          inf,          inf	Mean: SEMagmaKernel(length_scale=0.9, variance=1.5)	 Task: SEMagmaKernel(length_scale=0.3, variance=1.0)
Iteration    1	Llhs:      53.1633, 29768226375863459117006848.0000	Mean: SEMagmaKernel(length_scale=1.083940058105724, variance=6.171792930276515)	 Task: SEMagmaKernel(length_scale=1.2999720349727222, variance=1.0074785876014987)
Iteration    2	Llhs:      54.8903,   32057.1045	Mean: SEMagmaKernel(length_scale=0.9883446383339876, variance=6.24711033663178)	 Task: SEMagmaKernel(length_scale=2.7684807525772075, variance=15.36753159080174)
Iteration    3	Llhs:      54.1757,   32036.9198	Mean: SEMagmaKernel(length_scale=0.8816390098356836, variance=5.993002970966149)	 Task: SEMagmaKernel(length_scale=2.780984725134694, variance=15.32805067880792)
Iteration    4	Llhs:      52.9538,   31910.4342	Mean: SEMagmaKernel(length_scale=0.8242526734854575, variance=5.750857118472989)	 Task: SEMagmaKernel(length_scale=2.788835660182731, variance=15.247

---

## Prediction

---

## End timer

In [10]:
end = time.time()

In [11]:
print(f"Magma finished in {end - start}s")

Magma finished in 54.86550998687744s
