# Benchmark notebook

This notebook evaluates the speed gain of this Python translation by performing the whole train+predict pipeline on shared toy datasets.

---

## Setup

In [1]:
USE_JIT = True
USE_X64 = False
DEBUG_NANS = False
VERBOSE = False

In [2]:
import os

os.environ['JAX_ENABLE_X64'] = "True"

import time

import jax
jax.config.update("jax_disable_jit", not USE_JIT)
jax.config.update("jax_debug_nans", DEBUG_NANS)
from jax import jit, vmap, lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.linalg import cho_solve, cho_factor
from jax.scipy.optimize import minimize
from jax.scipy.stats.multivariate_normal import logpdf

import pandas as pd

from MagmaClustPy.kernels import SEMagmaKernel, NoisySEMagmaKernel
from MagmaClustPy.hyperpost import hyperpost
from MagmaClustPy.hp_optimisation import optimise_hyperparameters
from MagmaClustPy.utils import preprocess_db

In [3]:
MAX_ITER = 25
CONVERG_THRESHOLD = 1e-3
nugget = jnp.array(1e-6)

In [4]:
dataset = "medium"
grids = {
	"small": jnp.arange(-10, 10, 0.5),
	"medium": jnp.arange(-100, 100, 0.5),
	"large": jnp.arange(-500, 500, 0.5),
	"custom": jnp.arange(-20, 20, 0.5)
}
grid = grids[dataset] if dataset in grids else grids["custom"]
shared_input = False
shared_hp = True

---

## Start timer

In [5]:
start = time.time()

---

## Data import

In [5]:
db = pd.read_csv(f"../dummy_datasets/{dataset}_{'shared_input' if shared_input else 'distinct_input'}_{'shared_hp' if shared_hp else 'distinct_hp'}.csv")
# db has 3 columns: ID, Input, Output

In [6]:
# First 90% of IDs are for training, last 10% for testing
train_ids = db["ID"].unique()[:int(0.9 * db["ID"].nunique())]
test_ids = db["ID"].unique()[int(0.9 * db["ID"].nunique()):]

db_train = db[db["ID"].isin(train_ids)]
db_test = db[db["ID"].isin(test_ids)]

# N.b: data is already sort by ID and Input in the toy datasets, but in a real case scenario, we would need to sort it

In [7]:
len(train_ids), len(test_ids)

(180, 20)

---

## Data preprocessing

In [8]:
# We need to convert the dataframe into jax arrays
# inputs: (M, N) timestamps
# outputs: (M, N) observed outputs
# unique_inputs: (P,) unique timestamps (if shared_input, P = N)
all_inputs_train, padded_inputs_train, padded_outputs_train, mappings_train = preprocess_db(db_train)
all_inputs_train.shape, padded_inputs_train.shape

((401,), (180, 200))

---

## Training

In [9]:
# Priors
prior_mean = jnp.zeros_like(all_inputs_train)
mean_kernel = SEMagmaKernel(length_scale=0.9, variance=1.5)

if shared_hp:
	task_kernel = NoisySEMagmaKernel(length_scale=0.3, variance=1., noise=-2.5)
else:
	task_kernel = NoisySEMagmaKernel(length_scale=jnp.array([0.3] * padded_inputs_train.shape[0]), variance=jnp.array([1.] * padded_inputs_train.shape[0]), noise=jnp.array([-2.5] * padded_inputs_train.shape[0]))

In [10]:
prev_mean_llh = jnp.inf
prev_task_llh = jnp.inf
conv_ratio = jnp.inf

for i in range(MAX_ITER):
	print(f"Iteration {i:4}\tLlhs: {prev_mean_llh:12.4f}, {prev_task_llh:12.4f}\tConv. Ratio: {conv_ratio:.5f}\t\n\tMean: {mean_kernel}\t\n\tTask: {task_kernel}")
	# e-step: compute hyper-posterior
	post_mean, post_cov = hyperpost(padded_inputs_train, padded_outputs_train, mappings_train, prior_mean, mean_kernel, task_kernel, all_inputs=all_inputs_train, nugget=nugget)

	# m-step: update hyperparameters
	mean_kernel, task_kernel, mean_llh, task_llh = optimise_hyperparameters(mean_kernel, task_kernel, padded_inputs_train, padded_outputs_train, all_inputs_train, prior_mean, post_mean, post_cov, mappings_train, nugget=nugget, verbose=VERBOSE)

	# Check convergence
	if i > 0:
		conv_ratio = jnp.abs((prev_mean_llh + prev_task_llh) - (mean_llh + task_llh)) / jnp.abs(prev_mean_llh + prev_task_llh)
		if conv_ratio < CONVERG_THRESHOLD:
			print(f"Convergence reached after {i+1} iterations.\tLlhs: {mean_llh:12.4f}, {task_llh:12.4f}\n\tMean: {mean_kernel}\n\tTask: {task_kernel}")
			break

	if i == MAX_ITER - 1:
		print(f"WARNING: Maximum number of iterations reached.\nLast modif: {jnp.abs(prev_mean_llh - mean_llh).item()} & {jnp.abs(prev_task_llh - task_llh).item()}")

	prev_mean_llh = mean_llh
	prev_task_llh = task_llh

Iteration    0	Llhs:          inf,          inf	Conv. Ratio: inf	
	Mean: SEMagmaKernel(length_scale=0.9, variance=1.5)	
	Task: NoisySEMagmaKernel(length_scale=0.3, variance=1.0, noise=-2.5)
Iteration    1	Llhs:    -198.0491,   56116.0753	Conv. Ratio: inf	
	Mean: SEMagmaKernel(length_scale=1.209892414332191, variance=8.319056111936368)	
	Task: NoisySEMagmaKernel(length_scale=0.07597577774571919, variance=1.955996724726523, noise=-21.455860338161056)
Iteration    2	Llhs:     -72.5236,   54086.0882	Conv. Ratio: 0.03406	
	Mean: SEMagmaKernel(length_scale=1.072244269932549, variance=8.301175932754619)	
	Task: NoisySEMagmaKernel(length_scale=0.12571368838942806, variance=1.933763250272137, noise=-21.45586048541071)
Iteration    3	Llhs:      18.9627,   53320.9271	Conv. Ratio: 0.01247	
	Mean: SEMagmaKernel(length_scale=0.9623432597675167, variance=8.163951244069654)	
	Task: NoisySEMagmaKernel(length_scale=0.15328625390476597, variance=1.943326782011153, noise=-21.455860607625105)
Iteration    

---

## Prediction

---

## End timer

In [11]:
end = time.time()

In [12]:
print(f"Magma finished in {end - start}s")

Magma finished in 32.67691493034363s
