### Tutorial 3: Fast Ergodic Search with Kernel Functions

References:

[1] *Sun, M., Gaggar, A., Trautman, P. and Murphey, T.*, 2024. **Fast Ergodic Search with Kernel Functions**. arXiv preprint arXiv:2403.01536. [[Link](https://arxiv.org/abs/2403.01536)]

*(This tutorial will use JAX for auto-differentiation and accelerated computation.)*

In [1]:
import numpy as np 
np.set_printoptions(precision=4)
from tqdm import tqdm

import jax.numpy as jnp 
from jax import jit, vmap, grad
from jax.scipy.stats import multivariate_normal as mvn

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.linewidth'] = 3
mpl.rcParams['axes.titlesize'] = 20
mpl.rcParams['axes.labelsize'] = 20
mpl.rcParams['axes.titlepad'] = 8.0
mpl.rcParams['xtick.major.size'] = 6
mpl.rcParams['xtick.major.width'] = 3
mpl.rcParams['xtick.labelsize'] = 20
mpl.rcParams['ytick.major.size'] = 6
mpl.rcParams['ytick.major.width'] = 3
mpl.rcParams['ytick.labelsize'] = 20
mpl.rcParams['lines.markersize'] = 5
mpl.rcParams['lines.linewidth'] = 5
mpl.rcParams['legend.fontsize'] = 15

#### Kernel ergodic metric

Recall that the original ergodic metric is defined as:

$$
\mathcal{E}(p(x), s(t)) = \sum_{k} \lambda_k \cdot (c_k - \phi_k)^2
$$

where

$$
\begin{aligned}
    \phi_k = \int_{\mathcal{X}} f_k(x) p(x) dx, \quad c_k = \frac{1}{T} \int_{0}^{T} f_k(s(t)) dt
\end{aligned}
$$

Here we introduce another formula for the ergodic metric, named the *kernel ergodic metric*:

$$
\mathcal{E}_g(p(x), s(t)) = \frac{1}{T} \int_{0}^{T} p(s(t)) dt - \frac{2}{T^2} \int_{0}^{T} \int_{0}^{T} g(s(t), s(\tau); \theta) dt d\tau
$$

where $g(s, s^\prime)$ is a kernel function that can asymptotically converge to a Dirac delta function. In practice, it is often modeled as a squared exponential (Gaussian) kernel:

$$
\Lambda(s, s^\prime; \theta) = \theta_1 \cdot \exp\left( -\theta_2 \cdot \vert s {-} s^\prime \vert^2 \right)
$$

For each spatial probability distribution $p(x)$, there exists a parameter $\theta$ such that the kernel function $g(\cdot, \cdot ; \theta)$ such that optimizing the original ergodic metric and optimizing the kernel ergodic metric converge to the same trajectory. Furthermore, such a parameter $\theta$ can be solved through the following optimization problem:

$$
\begin{align}
    \theta^* & = {\arg\min}_{\theta} \left\vert \frac{d}{d\theta} \left( \frac{1}{N}\sum_{i=1}^{N} p(x_i) - \frac{2}{N^2} \sum_{i=0}^{N} \sum_{j=0}^{N} g(x_i, x_j; \theta) \right) \right\vert^2, \quad \{x_i\}_N \sim p(x)
\end{align}
$$

where each $x_i$ is a sample from the spatial probability distribution $p(x)$. 

In [2]:
# Define a Gaussian-mixture spatial distribution
mean1 = jnp.array([0.35, 0.38])
cov1 = jnp.array([
    [0.01, 0.004],
    [0.004, 0.01]
])
w1 = 0.5

mean2 = jnp.array([0.68, 0.25])
cov2 = jnp.array([
    [0.005, -0.003],
    [-0.003, 0.005]
])
w2 = 0.2

mean3 = jnp.array([0.56, 0.64])
cov3 = jnp.array([
    [0.008, 0.0],
    [0.0, 0.004]
])
w3 = 0.3


def pdf(x):
    return w1 * mvn.pdf(x, mean1, cov1) + \
           w2 * mvn.pdf(x, mean2, cov2) + \
           w3 * mvn.pdf(x, mean3, cov3)


# Sample from the spatial distribution
rng = 

In [3]:
# Define the kernel function
def kernel(x1, x2, theta):
    return theta[0] * jnp.exp(-1.0 * jnp.sum(jnp.square(x1-x2)))