In [12]:
from jax.config import config

config.update("jax_enable_x64", True)
from dataclasses import dataclass

from jax import hessian
from jax.config import config
import jax.numpy as jnp
import jax.random as jr
from gpjax.base import param_field
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
)
from gpjax.typing import ScalarFloat
from matplotlib import rcParams
import matplotlib.pyplot as plt
import optax as ox
import pandas as pd
import tensorflow_probability as tfp


with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
key = jr.PRNGKey(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
colors = rcParams["axes.prop_cycle"].by_key()["color"]

# Linear Response Models

$$
\frac{dx}{dt} = B + Sf(t) - Dx(t) \quad \mid \quad x(0) = \frac{B}{D}
$$

Model $f(t)$ as a Gaussian process with kernel $k_{f}(t,t')$. As the differential equation contains only linear operations, then $x(t)$ is also a Gaussian process with a different kernel function $k_{x}(t,t')$ which we now derive.

Using the initial condition, solving the differential equation yields

$$
x(t) = \frac{B}{D} + S e^{-Dt}\int_{0}^tf(u)e^{Du}du
$$

and so 

$$
k_x(t,t') = \textrm{Cov}(x(t), x(t'))=S^2e^{-D(t+t')}\int_{0}^t\int_{0}^{t'}e^{D(u+u')}k_f(u,u')dudu'
$$

which, for an RBF kernel, can be calculated in closed form

In [14]:
@dataclass
class LinearResponseKernel(gpx.kernels.AbstractKernel):
    latent_force_kernel: gpx.kernels.RBF = gpx.kernels.RBF()
    S: ScalarFloat = param_field(jnp.array(1.0))
    B: ScalarFloat = param_field(jnp.array(1.0))
    D: ScalarFloat = param_field(jnp.array(1.0))
        
    def __post_init__(self):
        if not isinstance(self.latent_force_kernel, gpx.kernels.RBF):
            raise NotImplementedError("We only support RBF kernels ATM")
        
    def __call__(
        self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:
        
        k = self.latent_force_kernel(x,y)
        l = self.latent_force_kernel.lengthscale
        variance = self.latent_force_kernel.variance
        
        h_0 = jax.scipy.special.erf ((x-y)/l - gamma)
        h_0 += jax.scipy.special.erf (y/l + gamma)
        h_0 *= jnp.exp(-D*(y-x))
        h_1 = jax.scipy.special.erf (x/l - gamma)
        h_1 += jax.scipy.special.erf (gamma)
        h_1 *= jnp.exp(-D(x+1))
        
        h = h_0 - h_1
        h *= (jnp.exp(gamma)**2) / (2*D)
        
        K = variance * (self.S**2) * jnp.sqrt(math.pi) * l * h
        return K.squeeze()