# Probabilistic ODE solver

## About

This notebook was created to accompany a presentation of
probabilistic ODE solvers for the seminar "Physics-informed Machine Learning" at the University of Tübingen (winter term 2024/2025).

It intends to explain the concept of probabilistic ODE solvers (ODE filters and smoothers) and will build a simple implementation of a probabilistic ODE solver.

## References

This tutorial is based on:

- Mostly, esp. the theoretical parts:
  -  Philipp Hennig, Michael A. Osborne, and Hans P. Kersting
    - Probabilistic Numerics: Computation as Machine Learning. Cambridge University Press, 2022
    - DOI [10.1017/9781316681411](https://doi.org/10.1017/9781316681411)
    - Also available for free personal use here: [https://www.probabilistic-numerics.org/textbooks/](https://www.probabilistic-numerics.org/textbooks/)
- And also in parts:
  - Filip Tronarp et al.
    - Probabilistic Solutions to Ordinary Differential Equations as Nonlinear Bayesian Filtering: A New Perspective
    - In: Statistics and Computing 29.6 (Nov. 2019), pp. 1297–1315. issn: 1573-1375.
    - DOI [10.1007/s11222-019-09900-1](https://doi.org/10.1007/s11222-019-09900-1)
- The latter parts on numerical stability and the implementation thereof is based on:
  - Nathanael Bosch, Philipp Hennig, Filip Tronarp
  - Calibrated Adaptive Probabilistic ODE Solvers
  - On ArXiv, Feb. 2021
  - DOI [10.48550/arXiv.2012.08202](https://doi.org/10.48550/arXiv.2012.08202)

Where other sources are used, a reference/link is given.

In [None]:
# Install dependencies here (uncomment first)
# cuda
#!pip install numpy scipy matplotlib 'jax[cuda]' diffrax

# no cuda
#!pip install numpy scipy matplotlib jax diffrax

JAX preallocates 75% of GPU memory.

If you run into OOM errors at the start,
you can try to circumvent this by running this process
(ie. the jupyter server or whatever you use to run the notebook)
with one of these environment variables:

- XLA_PYTHON_CLIENT_MEM_FRACTION=.50 (% of preallocated mem; or set it even lower); or:
- XLA_PYTHON_CLIENT_PREALLOCATE=false (disable preallocation entirely)

See also: [JAX GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html)

In [None]:
## Setup
import jax
import scipy
import numpy as np
import jax.numpy as jnp
from functools import partial
import matplotlib.pyplot as plt
from jax.experimental import jet
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
# Use float64 precision
jax.config.update("jax_enable_x64", True)
# higher plot resolution
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

## Background

### Ordinary Differential Equations (ODEs)

_Here_, an ODE is a function
$$ x'(t) = f(x(t)) \in \mathbb{R}^d\,, t \in [0, T]\,, $$
and we want to solve it by finding a curve
$$ x: [0, T] \rightarrow \mathbb{R}^d $$
with initial value
$$ x(0) = x_0 $$
that satisfies this equation.

This means that, as opposed to general ODEs, we require,
mainly to keep the theoretical parts simple:

- The ODE only contains the first derivative $x'(t)$, no higher-order derivatives
  - ie. we only consider _first-order_ ODEs
  - This comes without loss of generality:
    - all higher-order ODEs can be rewritten as first-order ODE
    - But: the solver _could_ be extended to also consider higher-order ODE information
- $f$ does not depend on $t$ ($\rightarrow$ autonomous ODE)
  - Also no loss of generality
    - Again, every ODE can be transformed into autonomous form
- The equation given above is in _explicit_ form
  - we won't consider _implicit_ ODEs (ie. of form $ f(x(t), x'(t)) = 0 $)
    - This is a loss of generality; the solver presented here works only for explicit ODEs
- A known initial value, so we are actually interested in an ODE _initial value problem_ (IVP)
  - I'll just refer to this as ODE anyways in this tutorial

#### Lotka-Volterra system

The Lotka-Volterra system/equations model predator/prey-populations as a 2-dimensional ODE:
\begin{align*}
    y_1'(t) &= \alpha y_1(t) - \beta y_1(t) y_2(t) \\
    y_2'(t) &= -\gamma y_2(t) + \delta y_1(t) y_2(t)
\end{align*}

Where:
- $y_1$ is the population density/size of the prey
- $y_2$ is the population density of the predator
  - $\alpha$: prey growth rate
  - $\gamma$: predator death rate
  - $\beta$: effect of predator population density on prey death rate
  - $\delta$: effect of prey population density on predator growth rate

For more details, see eg. [Lotka–Volterra equations (Wikipedia)](https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations).

Let's see an example of it, using a classical ODE solver
(code taken from [Diffrax documentation](https://docs.kidger.site/diffrax/examples/coupled_odes/)):

In [None]:
@partial(jax.jit, static_argnames=["args"])
def lotka_volterra_vf(t, y, args):
    prey, predator = y[..., 0], y[..., 1]
    alpha, beta, gamma, delta = args
    d_prey: Array = alpha * prey - beta * prey * predator
    d_predator: Array = -gamma * predator + delta * prey * predator
    return jnp.array([d_prey, d_predator])

In [None]:
t0 = 0
t1 = 140
y0 = jnp.array([10.0, 10.0])
lv_args = (0.1, 0.02, 0.4, 0.02)

In [None]:
term = ODETerm(lotka_volterra_vf)
solver = Tsit5()
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
dt0 = 0.1
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=lv_args, saveat=saveat)

In [None]:
plt.plot(sol.ts, sol.ys, label=["Prey", "Predator"])
plt.legend()
plt.show()

### Classical ODE solvers

#### How do they work?

- Work on discrete time steps:
  $$ 0 = t_0 < t_1 < \ldots < t_N = T\,, \text{ with step sizes: } h_n = t_n - t_{n - 1} $$
- Iteratively perform steps:
  - At step $n$, we have $\hat{x}(t_{n-1}) \approx x(t_{n-1})$
  - Estimate derivative $x'(t_{n-1}) = f(x(t_{n-1})) \approx f(\hat{x}(t_{n-1}))$
    - Potentially use more information
    - Eg. approximate higher-order derivatives
  - Construct approximation to $\hat{x}(t_n) = \hat{x}(t_{n-1} + h_n)$
    - Using derivative information/other information
    - Example: Eulers method
      - $\hat{x}(t_{n-1} + h_n) \approx \hat{x}(t_{n-1}) + h_n f(\hat{x}(t_{n-1}))$
      - First-order Taylor expansion (linear approximation) of $\hat{x}(t_{n-1} + h_n)$


#### Order

An ODE solver of order $q$

- essentially fits a $q$-th order Taylor expansion
- thus has _local_ error (or convergence rate) $\mathcal{O}(h^{q+1})$

### Probabilistic viewpoint: Viewing ODE solving as regression

One can view solving an ODE as regression on
$$ \mathcal{D} = \left\{x(0) = x_0, x'(t_n) = f(\hat{x}(t_n)) \mid
n \in \{0, \ldots, N\}\right\} $$
(that is, all the data available to the solver),
where the goal is to find a function $x$ that fulfills all of these equations,
and the data is revealed to the solver progressively.

## Probabilistic ODE solvers

### GP Regression

Jointly model $y(t)$ and $y'(t)$ in a state vector:
$$ \begin{pmatrix} y(t) \\ y'(t) \end{pmatrix} = x(t) \sim \mathcal{GP}(\bar{x}, k) $$
with mean function $\bar{x}(t)$ and kernel function $k(t_1, t_2)$.

How to perform inference efficiently?
- In complexity $\mathcal{O}(N)$
  - instead of $\mathcal{O}(N^3)$ for traditional GP regression
- How to take advantage of the derivative information?

$\rightarrow$ we need a sensible prior

### Model ODE solution in Continuous-Time State Space Model (SSM)

We can model the state vector $x(t) \in \mathbb{R}^{q}$ by a stochastic process:
$$ x(t) \sim X(t) $$
and introduce projections $H_0, H \in \mathbb{R}^{d \times (q+1)}$ such that

\begin{align*}
    y(t) &= H_0 x(t) \\
    y'(t) &= H x(t) \,.
\end{align*}

Since we are interested in IVPs, we also model the initial value by
$$ X(0) \sim \mathcal{N}(m_0, P_0) $$
(more details later on).

An example where $x$ models $y(t)$ and its first three derivatives (ie. $q = 3$):

\begin{align*}
    H_0 &= \begin{bmatrix} 1 & 0 & 0 & 0 \end{bmatrix} \\
    H &= \begin{bmatrix} 0 & 1 & 0 & 0 \end{bmatrix} \\
    x(t) &= \begin{pmatrix} y(t) \\ y'(t) \\ y''(t) \\ y'''(t) \\ \end{pmatrix} \sim X(t)
\end{align*}

We can build the projection matrices as follows:

In [None]:
def ssm_projection_matrices(d, q):
    # Projections: project SSM state -> y (H0) or SSM state -> y' (H)
    H0 = jnp.zeros((1, q + 1)).at[0, 0].set(1.0)
    H = jnp.zeros((1, q + 1)).at[0, 1].set(1.0)
    Id = jnp.eye(d)  # Ignore this and below for now.
    H0 = jnp.kron(H0, Id)
    H = jnp.kron(H, Id)
    return H0, H

#### Dynamic model / State

We now let $X(t)$ follow this linear, time-invariant SDE:
$$ dx(t) = Fx(t)dt + Ld\omega_t $$
where $ F = \begin{bmatrix} 0&1&0&0 \\ 0&0&1&0 \\ 0&0&0&1 \\ 0&0&0&0 \end{bmatrix} $
and $ L = \begin{pmatrix} 0 \\ 0 \\ 0 \\ \sigma \end{pmatrix} $, and $d\omega_t$
can be informally thought of as "increment of the Wiener process".

This gives rise to an interesting structure:

- all elements in $x(t)$ are derivatives of each other for all $t$
- $y'''(t)$ is modeled as Wiener process, ie. the random walk
  (as its derivative is the increment of the Wiener process)

Because:
$$
dx(t)
= \begin{bmatrix} 0&1&0&0 \\ 0&0&1&0 \\ 0&0&0&1 \\ 0&0&0&0 \end{bmatrix}
\begin{pmatrix} y(t) \\ y'(t) \\ y''(t) \\ y'''(t) \\ \end{pmatrix} dt
+ \begin{pmatrix} 0 \\ 0 \\ 0 \\ \sigma \end{pmatrix} d\omega_t
= \begin{pmatrix} y'(t)dt \\ y''(t)dt \\ y'''(t)dt \\ d\omega_t \end{pmatrix}
$$

This means that $x(t)$ now models some function $y(t)$ and its $q=3$ first
derivatives, and $y'''(t)$ is modeled as Wiener process/random walk.

Thus, this Prior is known as the $q$-times integrated Wiener process (with scale $\sigma$).

Also, as is evident above, the $q$-times IWP prior is a GP that satisfies the Markov property
(referred to as Gauss-Markov process).

#### Measurement model (Information operator)

To our dynamic model, we now add a measurement model.

We define:
$$ z(t) \sim Z(t) := g(X(t)) := HX(t) - f(H_0 X(t))\,, $$
so
$$ z(t) = g(x(t)) = Hx(t) - f(H_0 x(t)) = y'(t) - f(y(t))\,, $$
which is referred to as the _state misalignment_.

$Z$ now serves as _information operator_:
It extracts information about the the state misalignment $z(t)$ from the ODE,
quantifying the "correctness" of $y$.

#### Likelihood

Now, we have our prior (the $q$-times IWP), and a measurement model (through the information operator).

As seen above, $z(t) = y'(t) - f(y(t))$, which means that iff $y$ is a solution to the ODE (ie. $y'(t) = f(y(t))$ for all $t$; and of course also $y(0) = y_0$),
then $z(t) = 0$.

Thus, if we use the likelihood
$$ p(z(t) \mid x(t)) = \delta(g(x(t))) = \delta(Hx(t) - f(H_0 x(t))) = \delta(y'(t) - f(y(t)))\,, $$
then, if we simply observe the data $0 = z(t)$ for all $t$,
the posterior $p(x \mid z = 0)$ contains the solution (also with zero uncertainty).

### Discrete-time version of the SSM

Of course, inference in this model is intractable,
we can only work on discrete (and finitely many) points $t_n$;
for tractable inference, we need to discretize the above continuous-time SSM:

With functions of $t_n$ now denoted by subscript $n$ (eg. $x(t_n) := x_n$),
the above can be discretized into
\begin{align}
    p(x_0) &= \mathcal{N}(x_0; m_0, P_0) &\text{Initial value}\\
    p(x_{n+1} \mid x_n) &= \mathcal{N}(x_{n+1}; A(h_{n+1})x_n, Q(h_{n+1})) &\text{Dynamics}\\
    p(z_n \mid x_n) &= \delta(z_n - g(x_n)) &\text{Observations (information operator)}
\end{align}

Or, alternatively,
\begin{align}
    p(z_n \mid x_n) &= \mathcal{N}(z_n; g(x_n), R) &\text{Observations (information operator)}\,,
\end{align}
which is less general (but the variance $R$ can help in some cases),
and, if the vector field $f$ is linear, this is a linear Gaussian SSM (as $g$ then is linear too).

In the discretized SSM, we have $A(h_n), Q(h_n)$ given by
\begin{align*}
    A(h_n)_{i,j} &= \mathbb{I}(j \geq i) \frac{h_n^{j-i}}{(j-i)!} &\text{Transition matrix}\\
    Q(h_n)_{i,j} &= \sigma^2 \frac{h_n^{2q + 3 - i - j}}{(2q + 3 - i - j)(q + 1 - i)!(q + 1 - j)!} &\text{Diffusion matrix}
\end{align*}

What does this mean?

Consider again the example from above:
$$ x_n = \begin{pmatrix} y_n \\ y'_n \\ y''_n \\ y'''_n \end{pmatrix}\,, $$
then, with $h := h_{n+1}, A := A(h_{n+1})$ for less clutter:
$$ (Ax_n)_1 =
\begin{bmatrix}\frac{h^0}{0!} & \frac{h^1}{1!} & \frac{h^2}{2!} & \frac{h^3}{3!}\end{bmatrix}
\begin{pmatrix} y_n \\ y'_n \\ y''_n \\ y'''_n \end{pmatrix}
= y_n + h y'_n + \frac{h^2}{2} y''_n + \frac{h^3}{6} y'''_n\,, $$
which is a $q=3$-rd order Tayler expansion of $y(t_n + h_{n+1})$.

- Thus, we'll also get a solver of order $q$ (and local error $\mathcal{O}(h^{q+1})$)
  - This is _just for illustration_, absolutely not a proof

In [None]:
def discrete_transition_matrix(d, q, h):
    A_exponents = jnp.stack([jnp.arange(q + 1) - i for i in range(q + 1)])
    A = jnp.triu(h ** A_exponents / jax.scipy.special.factorial(A_exponents))
    return jnp.kron(A, jnp.eye(d))  # Ignore everything about d/this line for now

def discrete_diffusion_matrix(d, q, h):
    Q_j = jnp.stack((jnp.arange(q + 1) + 1,) * (q + 1))
    Q_exponent = 2 * q + 3 - Q_j.T - Q_j
    Q_divisor_j = jax.scipy.special.factorial(q + 1 - Q_j)
    Q = (h ** Q_exponent  # Note the absence of the \sigma^2 factor,
         / (Q_exponent    # we'll come back to this
             * Q_divisor_j
             * Q_divisor_j.T))
    return jnp.kron(Q, jnp.eye(d))  # Ignore everything about d/this line for now

In [None]:
# Here's an example with q=2 and h=0.1:
A_ex = discrete_transition_matrix(1, 2, 0.1)
Q_ex = discrete_diffusion_matrix(1, 2, 0.1)
print(f"A:\n{A_ex}")
print(f"Q:\n{Q_ex}")

#### Why use state misalignment z and not condition on $f(H_0 x_n)$ directly?

- This separates the observations from the implementation/algorithm, ie. we observe the data $z_n = 0$ for _any_ implementation
  - If we changed the observations to:
    - $\tilde{z}_n = Hx_n$
    - and want to observe that $\tilde{z}_n = f(H_0 x_n)$ directly
  - then now our observation/data that we condition on depend on the algorithm
    - as $x_n$, the current solution/state, depends on the algorithm in use
    - Every algorithm would operate on different observations

## Filtering and Smoothing

### Kalman (ODE) Filter

For now, let's first consider the linear Gaussian SSM case from above:
\begin{align}
    p(x_0) &= \mathcal{N}(x_0; m_0, P_0) &\text{Initial value}\\
    p(x_{n+1} \mid x_n) &= \mathcal{N}(x_{n+1}; A(h_{n+1})x_n, Q(h_{n+1})) &\text{Dynamics}\\
    p(z_n \mid x_n) &= \mathcal{N}(z_n; g(x_n), R) &\text{Observations (information operator)}\end{align}

On this, we can directly perform Bayesian inference.
As the SSM fulfills the Markov property, we can perform filtering by iterating two steps:

- Predict: $p(x_{n+1} | z_{0:n}) = \int p(x_{n+1} | x_n)p(x_n | z_{0:n}) dx_n$ (Chapman-Kolmogorov Eq.)
- Observe the next $z_{n+1}$
  - Actually, observe that $z_{n+1} = 0$
  - So we don't actually need to do anything
- Update: $p(x_{n+1} | z_{0:n+1}) = \frac{p(z_{n+1} | x_{n+1})p(x_{n+1} | z_{0:n})}{p(z_{n+1}|z_{0:n})}$

As (for now) everything is linear Gaussian, the above just boils down
to a little bit of linear algebra:

#### Prediction Step

$$ p(x_{n+1} | z_{0:n}) = \mathcal{N}(x_{n+1}; m_{n+1}^p, P_{n+1}^p) $$
where
\begin{align*}
    m_{n+1}^p &= A(h_{n+1}) m_n^f &\text{predictive mean}\\
    P_{n+1}^p &= A(h_{n+1}) P_n^f A(h_{n+1})^\top + Q(h_{n+1}) &\text{predictive cov}\,,
\end{align*}
and this is simply implemented by

In [None]:
def predict(m_f, P_f, A, Q):
    m_p = A @ m_f  # predictive mean
    P_p = A @ P_f @ A.T + Q  # predictive cov
    return m_p, P_p

#### Exact Update Step ($f$ linear)

$$ p(x_{n+1} | z_{0:n+1}) = \mathcal{N}(x_{n+1}; m_{n+1}^f, P_{n+1}^f) $$
where
\begin{align*}
    \hat{z}_{n+1} &= f(H_0 m_{n+1}^p) - Hm_{n+1}^p &\text{innovation residual} \\
    \tilde{H} &= H &\text{(for now)} \\
    S_{n+1} &= \tilde{H} P_{n+1}^p \tilde{H}^\top + R &\text{innovation cov} \\
    K_{n+1} &= P_{n+1}^f \tilde{H}^\top S_{n+1}^{-1} &\text{Kalman gain} \\
    m_{n+1}^f &= m_{n+1}^p + K_{n+1} \hat{z}_{n+1}  &\text{filtering mean} \\
    P_{n+1}^f &= (I_D - K_{n+1} \tilde{H}) P_{n+1}^p &\text{filtering cov} \,.
\end{align*}

As long as $f$ is linear, everything stays linear Gaussian.
- This is known as the _Kalman filter_
- And is exact

Before we implement this: What can we do for nonlinear $f$?

#### Extended Kalman ODE Filter (EKF0, EKF1)

We need to stay linear Gaussian, but our $f$ now is nonlinear :(

What to do now?
- We can just approximate $f$ around $H_0 m_{n+1}^p$ by a linear function, then do the above
- Thus, the update is now no longer exact, but:
  - This is very cheap
  - and often accurate enough

This is known as Extended Kalman (ODE) Filter (EKF).

We have two choices for our approximation of $f$:
- constant function (order 0 Taylor expansion)
  - EKF0
  - $f(H_0 m_{n+1}^p + y) \approx f(H_0 m_{n+1}^p)$
- linear function (order 1 Taylor expansion)
  - EKF1
  - $f(H_0 m_{n+1}^p + y) \approx f(H_0 m_{n+1}^p) + J_f(H_0 m_{n+1}^p)(y - H_0 m_{n+1}^p)$

From this, we obtain two choices for the EKF update step:

\begin{align*}
    \hat{z}_{n+1} &= f(H_0 m_{n+1}^p) - Hm_{n+1}^p &\text{innovation residual} \\
    \tilde{H} &= \begin{cases}H &\text{EKF0} \\ H - J_f(H_0 m_{n+1}^p) H_0 &\text{EKF1}\end{cases} &\text{cov. of $z_{n+1}$ after approximation}\\
    S_{n+1} &= \tilde{H} P_{n+1}^p \tilde{H}^\top + R &\text{innovation cov} \\
    K_{n+1} &= P_{n+1}^f \tilde{H}^\top S_{n+1}^{-1} &\text{Kalman gain} \\
    m_{n+1}^f &= m_{n+1}^p + K_{n+1} \hat{z}_{n+1}  &\text{filtering mean} \\
    P_{n+1}^f &= (I_D - K_{n+1} \tilde{H}) P_{n+1}^p &\text{filtering cov} \,.
\end{align*}

With that, we can directly implement the update:

Calculate:
\begin{align*}
    \hat{z}_{n+1} &= f(H_0 m_{n+1}^p) - Hm_{n+1}^p &\text{innovation residual} \\
    \tilde{H} &= \begin{cases}H &\text{EKF0} \\ H - J_f(H_0 m_{n+1}^p) H_0 &\text{EKF1}\end{cases} &\text{cov. of $z_{n+1}$ after approximation}\\
\end{align*}

In [None]:
def residual(f, t, m_p, H, H0, f_args, ekf_order):
    # returns the innovation residual/state misalignment
    m_p0 = H0 @ m_p
    f_at_m_p = f(t, m_p0, f_args)
    if ekf_order == 0:
        H_hat = H
    else:
        # ekf_order == 1 (if invalid, just treat as 1)
        # Calculate Jacobian of f with forward mode AD
        J_f = jax.jacfwd(f, argnums=1)(t, m_p0, f_args)
        H_hat = H - J_f @ H0  # regular * in 1-d case
    z_hat = f_at_m_p - H @ m_p  # residual
    return z_hat, H_hat  # H_hat: \tilde{H} from above

Calculate:
\begin{align*}
    S_{n+1} &= \tilde{H} P_{n+1}^p \tilde{H}^\top + R &\text{innovation cov (intermediate)} \\
    K_{n+1} &= P_{n+1}^f \tilde{H}^\top S_{n+1}^{-1} &\text{Kalman gain (intermediate)} \\
    m_{n+1}^f &= m_{n+1}^p + K_{n+1} \hat{z}_{n+1}  &\text{filtering mean (result)} \\
    P_{n+1}^f &= (I_D - K_{n+1} \tilde{H}) P_{n+1}^p &\text{filtering cov (result)} \,.
\end{align*}

In [None]:
def next_filtering_mean_and_cov(m_p, P_p, z_hat, H_hat, R):
    # Kalman filter statistics
    S = H_hat @ P_p @ H_hat.T + R
    S_inv = jnp.linalg.inv(S)  # Yes, inverse here. We'll improve this later.
    K = P_p @ H_hat.T @ S_inv
    # Filtering mean/cov
    m_f = m_p + K @ z_hat
    P_f = (jnp.eye(P_p.shape[1]) - K @ H_hat) @ P_p
    return m_f, P_f

#### Filter step

Thus, we can build a full filter step (predict, then update) with:

In [None]:
# jit compile this function for speedup
@partial(jax.jit, static_argnames=["f", "f_args", "d", "q", "ekf_order"])
def ode_filter_step(m_f_prev, P_f_prev, f, f_args,
                    t_prev, h, d, q,
                    H0, H, R, ekf_order):
    # calculate next time step
    t = t_prev + h
    # Get discretization for stepsize h:
    A = discrete_transition_matrix(d, q, h)
    Q = discrete_diffusion_matrix(d, q, h)
    # predict
    m_p, P_p = predict(m_f_prev, P_f_prev, A, Q)
    # update
    z_hat, H_hat = residual(f, t, m_p, H, H0, f_args, ekf_order)
    m_f, P_f = next_filtering_mean_and_cov(m_p, P_p, z_hat, H_hat, R)
    # Also extract and return the current (filtering) y and its stddev
    y = H0 @ m_f
    stddev = jnp.sqrt(jnp.diag(H0 @ P_f @ H0.T))
    return m_f, P_f, m_p, P_p, y, stddev

#### How to get the initial state vector?

Remember, the state vector $x_n$ contains $y_n$ and $y'_n$,
and, if we want, derivatives up to $y^{(q)}_n$.

If $q = 1$, then it is obvious that $x_0 = \begin{pmatrix} y_0 & f(y_0) \end{pmatrix}^\top$,
but how do we handle higher $q$?
- The elements of $x$ are, and stay, derivatives of each other, so we need to initialize
  the first mean with higher-order derivatives as follows:
  $$ m_0^f = \begin{pmatrix} y_0 & f^{\langle 1 \rangle}(y_0) & f^{\langle 2 \rangle}(y_0) & \ldots & f^{\langle q \rangle}(y_0)\end{pmatrix}^\top\,, $$
  where
  \begin{align*}
    f^{\langle 1 \rangle}(y) &= f(y)\,, \\
    f^{\langle i \rangle}(y) &= \frac{d}{dy} \left( f^{\langle i-1 \rangle}(y) \cdot f(y) \right) & \text{(for $i>1$)}\,.
  \end{align*}
  This is exponentially (in $q$) expensive with regular AD,
  but can be more efficiently calculated using Taylor-mode AD, in JAX implemented in ```jet.jet()```.


If interested, more details can be found in [Bettencourt, Johnson, Duvenaud (2019)](https://openreview.net/pdf?id=SkxEF3FNPH]).

- This is the only thing we need to worry about (apart from stability issues, maybe addressed at the end)

Here's an implementation:

In [None]:
def initial_state(f, t0, y0, q, f_args):
    # Return [y0, f^<i>(y0)] as defined above
    def f_wrapped(y):
        # make f a function of y only (so it works with jet)
        return f(t0, y, f_args)
    y0_tup = (y0,)  # jet wants this in a tuple
    # y_is will now iteratively be filled like this:
    # (f(y0), f^<1>(y0), f^<2>(y0))  ->  (f(y0),) + (f^<1>(y0), f^<2>(y0), f^<3>(y0))
    y_is = (f_wrapped(y0),)
    for _ in range(q - 1):
        y_i_new, y_is_new = jet.jet(f_wrapped, y0_tup, (y_is,))
        # above returns tuple[Array, list[Array]], put all into one tuple
        y_is = (y_i_new,) + (*y_is_new,)
    return jnp.stack((y0_tup) + y_is).flatten()

#### How to fit higher-dimensional y? (ie. an ODE system)

If $y$ is $d$-dimensional, then we can just stack everything, using the Kronecker product $\otimes$,
with
$$
    \mathbf{A} = A \otimes I_d
       = \begin{bmatrix}
          a_{1,1}I_d & \cdots & a_{1,q+1}I_d \\
          \vdots & \ddots & \vdots \\
          a_{q+1,1}I_d & \cdots & a_{q+1,q+1}I_d
         \end{bmatrix} \in \mathbb{R}^{d(q+1)\times d(q+1)}\,.
$$
For example, for $d=2, q=1$, we have
$$ \mathbf{x} = \begin{pmatrix} y_1 \\ y_2 \\ y_1' \\ y_2' \end{pmatrix}\,, $$
and $\mathbf{A} = A \otimes I_2$.

The functions given above for $A, Q, H, H_0$ already do this, so let's see what this means:

In [None]:
x1 = jnp.array([0.0, 2.0])  # the first dimension
x2 = jnp.array([1.0, 3.0])  # the second dimension
x_s = jnp.stack([x1, x2]).T.flatten()
A = discrete_transition_matrix(1, 1, 0.1)
A_s = discrete_transition_matrix(2, 1, 0.1)
print(f"x1: {x1}, x2: {x2}")
print(f"x_s: {x_s}")
print(f"A:\n{A}")
print(f"A_s:\n{A_s}")
print(f"Ax1: {A@x1}, Ax2: {A@x2}")
print(f"Ax1, Ax2 stacked: {jnp.stack([A@x1, A@x2]).T.flatten()}")
print(f"A_s x_s: {A_s@x_s}")

Essentially, this means that all of the stacked operators
just operate _per dimension_ as they would normally
in the one-dimensional case.

- This means that we don't need to care about $d$ anywhere,
  as long as we stack the state and the operators correctly,
  the math stays exactly the same.
- To my understanding, this also means that every dimension
  is modeled independently of each other
  - For example, we don't get covariances between different dimensions,
    or rather, it is zero
  - In the prior, every dimension is a $q$-times IWP independently
    of every other dimension

#### ODE Filter

Now, we can build our first probabilistic ODE solver
- with constant stepsize $h$
- no error estimation etc.
- lots of other points we can adress later
a simple version of an ODE filter:

In [None]:
def ode_filter(y0, f, f_args,
               t0, t1,
               q,  # order
               h,
               ekf_order=1):
    h, t0 = jnp.array(h), jnp.array(t0)  # Ensure these are arrays
    d = y0.shape[-1]  # dimension of y
    # Initialization
    x0 = initial_state(f, t0, y0, q, f_args)
    stddev0 = jnp.zeros_like(y0)
    P0 = jnp.diag(jnp.repeat(stddev0, q + 1))  # Assume zero initial cov.
    H0, H = ssm_projection_matrices(d, q)
    # R=0 -> dirac measure on initial value
    R = jnp.zeros((d, d))
    # Init. objects for storing results
    rejected_steps = 0
    accepted_steps = 0
    ts, ys, stddevs = [t0], [y0], [stddev0]
    filtering_means, filtering_covs, predictive_means, predictive_covs = [x0], [P0], [], []
    # Init. loop variables
    t_prev = t0
    m_f_prev, P_f_prev = x0, P0  # Filtering parameters (of previous step)

    # The ODE filter loop
    while t_prev < t1:
        m_f, P_f, m_p, P_p, y, stddev = \
            ode_filter_step(m_f_prev, P_f_prev, f, f_args,
                            t_prev, h, d, q,
                            H0, H, R, ekf_order)
        # Store results
        t = t_prev + h
        ts.append(t)
        ys.append(y)
        stddevs.append(stddev)
        filtering_means.append(m_f)
        filtering_covs.append(P_f)
        predictive_means.append(m_p)
        predictive_covs.append(P_p)
        # Update loop variables
        t_prev, m_f_prev, P_f_prev = t, m_f, P_f

    # Return results in dict
    return {
        "ts": jnp.asarray(ts).squeeze(),
        "ys": jnp.asarray(ys).squeeze(),
        "stddevs": jnp.asarray(stddevs).squeeze(),
        "predictive_means": predictive_means,
        "predictive_covs": predictive_covs,
        "filtering_means": filtering_means,
        "filtering_covs": filtering_covs,
        "H0": H0, "d": d, "q": q}

In [None]:
q = 1
h = 1e-1  # If we decrease the stepsize, this is very accurate already
#h = 1e-2
ekf_order = 1
f_sol = ode_filter(
    y0, lotka_volterra_vf, lv_args,
    t0, t1, q, h, ekf_order=ekf_order)

In [None]:
def plot_lv_with_uncertainty(prob_sol, num_stddevs=2, c_sol=None, title=None):
    predator_col = "red"
    prey_col = "green"
    if c_sol is not None:
        plt.plot(sol.ts, sol.ys, label=None, linewidth=1, color="grey")
    plt.plot(prob_sol["ts"], prob_sol["ys"][:, 0], label="Prey",
             color=prey_col, linewidth=1)
    plt.plot(prob_sol["ts"], prob_sol["ys"][:, 1], label="Predator",
             color=predator_col, linewidth=1)
    plt.fill_between(prob_sol["ts"],
                     prob_sol["ys"][:, 0] - num_stddevs*prob_sol["stddevs"][:, 0],
                     prob_sol["ys"][:, 0] + num_stddevs*prob_sol["stddevs"][:, 0],
                     label=f"Prey $\\pm {num_stddevs}\\sigma$",
                     linewidth=0, color=prey_col, alpha=0.4)
    plt.fill_between(prob_sol["ts"],
                     prob_sol["ys"][:, 1] - num_stddevs*prob_sol["stddevs"][:, 1],
                     prob_sol["ys"][:, 1] + num_stddevs*prob_sol["stddevs"][:, 1],
                     label=f"Predator $\\pm {num_stddevs}\\sigma$",
                     linewidth=0, color=predator_col, alpha=0.4)
    if title is not None:
        plt.title(title)
    plt.legend(prop={'size': 6})
    plt.show()

def plot_filter_and_smoother(f_out, s_out, num_stddevs=2, c_sol=None, title=None):
    predator_col = "red"
    prey_col = "green"
    filter_col = "grey"
    if c_sol is not None:
        plt.plot(sol.ts, sol.ys, label=None, linewidth=1, color="black")
    plt.plot(f_out["ts"], f_out["ys"][:, 0], label="Prey (Filter)",
             color=filter_col, linewidth=1, linestyle="-")
    plt.plot(f_out["ts"], f_out["ys"][:, 1], label="Predator (Filter)",
             color=filter_col, linewidth=1, linestyle="-")
    plt.plot(s_out["ts"], s_out["ys"][:, 0], label="Prey (Smoother)",
             color=prey_col, linewidth=1, linestyle="-")
    plt.plot(s_out["ts"], s_out["ys"][:, 1], label="Predator (Smoother)",
             color=predator_col, linewidth=1, linestyle="-")
    plt.fill_between(f_out["ts"],
                     f_out["ys"][:, 0] - num_stddevs*f_out["stddevs"][:, 0],
                     f_out["ys"][:, 0] + num_stddevs*f_out["stddevs"][:, 0],
                     #label=f"Prey (Filter) $\\pm {num_stddevs}\\sigma$",
                     label=None,
                     linewidth=0, color=filter_col, alpha=0.3)
    plt.fill_between(f_out["ts"],
                     f_out["ys"][:, 1] - num_stddevs*f_out["stddevs"][:, 1],
                     f_out["ys"][:, 1] + num_stddevs*f_out["stddevs"][:, 1],
                     #label=f"Predator (Filter) $\\pm {num_stddevs}\\sigma$",
                     label=None,
                     linewidth=0, color=filter_col, alpha=0.3)
    plt.fill_between(s_out["ts"],
                     s_out["ys"][:, 0] - num_stddevs*s_out["stddevs"][:, 0],
                     s_out["ys"][:, 0] + num_stddevs*s_out["stddevs"][:, 0],
                     label=f"Prey (Smoother) $\\pm {num_stddevs}\\sigma$",
                     linewidth=0, color=prey_col, alpha=0.4)
    plt.fill_between(s_out["ts"],
                     s_out["ys"][:, 1] - num_stddevs*s_out["stddevs"][:, 1],
                     s_out["ys"][:, 1] + num_stddevs*s_out["stddevs"][:, 1],
                     label=f"Predator (Smoother) $\\pm {num_stddevs}\\sigma$",
                     linewidth=0, color=predator_col, alpha=0.4)
    if title is not None:
        plt.title(title)
    plt.legend(prop={'size': 6})
    plt.show()

In [None]:
plot_lv_with_uncertainty(f_sol, c_sol=sol, title=f"ODE Filter, q={q}, h={h}")

### Smoother

Great, our filter seems to be working decently!

But, it only calculates the filtering distribution $p(x_n | z_{1:n})$.

As we can see above, this is pretty useful already.
However, at every $x_n$, it only considers information up to that time point.

We can also perform a smoothing pass, where we go backwards in time
to obtain the full smoothing posterior
$$ p(x_n | z_{1:N}) = \mathcal{N}(x_n; m_n^S, P_n^S) $$

#### Smoother step

We start with $m_N^S = m_N^f, P_N^S = P_N^f$ (ie. set the last smoothing parameters to the last filtering parameters), and then iteratively perform the step

\begin{align*}
    G_n &= P_n^f A(h_{n+1})^\top \left(P_{n+1}^p\right)^{-1} &\text{Gain} \\
    m_n^S &= m_n^f + G_n(m_{n+1}^S - m_{n+1}^p) &\text{smoothing mean} \\
    P_n^S &= P_n^f + G_n(P_{n+1}^S - P_{n+1}^p) G_n^\top &\text{smoothing cov}\,,
\end{align*}

which we can easily implement:

In [None]:
@partial(jax.jit, static_argnames=["d", "q"])
def smoother_step(
        m_f, P_f,
        m_p_next, P_p_next,
        m_s_next, P_s_next,
        H0, h, d, q):
    A = discrete_transition_matrix(d, q, h)
    G = P_f @ A.T @ jnp.linalg.inv(P_p_next)  # again, we'll improve this later.
    m_s = m_f + G @ (m_s_next - m_p_next)  # posterior smoothing mean
    P_s = P_f + G @ (P_s_next - P_p_next) @ G.T  # posterior smoothing cov
    y_s = H0 @ m_s  # Again extract current y and stddev
    stddev_s = jnp.sqrt(jnp.diag(H0 @ P_s @ H0.T))
    return m_s, P_s, y_s, stddev_s

And can already implement the our ODE smoother:

In [None]:
def ode_smoother(f_out):
    # Setup, retrieve parameters
    predictive_means = f_out["predictive_means"]
    predictive_covs = f_out["predictive_covs"]
    filtering_means = f_out["filtering_means"]
    filtering_covs = f_out["filtering_covs"]
    ts = f_out["ts"]
    N, d, q = len(ts), f_out["d"], f_out["q"]
    H0 = f_out["H0"]
    # Init. loop variables
    m_s_next = filtering_means[-1]
    P_s_next = filtering_covs[-1]
    smoothing_ts = [ts[-1]]  # Init. last smoothing parameters as last filtering parameters
    smoothing_ys = [f_out["ys"][-1]]
    smoothing_stddevs = [f_out["stddevs"][-1]]

    for n in range(N - 2, -1, -1):  # shift by 1 due to zero-based indexing
        # [N - 2, 0]
        # (N-1 is last, and that step was already initialized as the
        # last filtering parameters).
        # Predictive params hold one less element at the start,
        # so eg predictive_means[n] actually corresponds to m_p[n+1]
        m_p_next, P_p_next = predictive_means[n], predictive_covs[n]
        m_f, P_f = filtering_means[n], filtering_covs[n]
        t = ts[n]
        h = ts[n + 1] - t
        m_s, P_s, y, stddev = smoother_step(
            m_f, P_f, m_p_next, P_p_next, m_s_next, P_s_next, H0, h, d, q)
        # Store results
        smoothing_ts.insert(0, t)
        smoothing_ys.insert(0, y)
        smoothing_stddevs.insert(0, stddev)
        m_s_next, P_s_next = m_s, P_s

    return {
            "ts": jnp.asarray(smoothing_ts).squeeze(),
            "ys": jnp.asarray(smoothing_ys).squeeze(),
            "stddevs": jnp.asarray(smoothing_stddevs).squeeze()}

In [None]:
s_sol = ode_smoother(f_sol)

In [None]:
#plot_lv_with_uncertainty(s_sol, c_sol=sol, title=f"ODE Smoother")
plot_filter_and_smoother(f_sol, s_sol, c_sol=sol, title=f"ODE Filter and Smoother")

## Extending the simple implementation

### Uncertainty estimation

#### Calibrated Local Uncertainty

Remember the IWP scale $\sigma$ in [Dynamic model / State](#Dynamic-model-/-State)?

So far, we've just used $\sigma = 1$.

We can get a _local_ (quasi-ML) estimate $\hat{\sigma}_n^2$ via
$$ \hat{\sigma}_n^2 = \hat{z}_n^\top \left(H \tilde{Q}(h_n) H^\top \right)^{-1} \hat{z}_n^\top\,, $$
where $\tilde{Q}(h_n)$ is $Q(h_n)$ with $\sigma^2 = 1$
(as, remember, $Q(h_n)$ does depend on $\sigma^2$).

This is actually what the function ```discrete_diffusion_matrix(d, q, h)``` does
(and the reason why we ignored the factor $\sigma^2$).

So, for calibrated (local) uncertainty, we can just run the filter like before,
and, after we've calculated $\hat{z}_n$, we can get our uncertainty estimate $\hat{\sigma}_n^2$,
and just rescale $Q(h_n) = \hat{sigma}_n^2 \tilde{Q}(h_n)$, and have calibrated local uncertainty!

#### Local Error Estimation

Additionally, we can now use our calibrated $Q(h_n) = \hat{\sigma}_n^2 \tilde{Q}(h_n)$
to obtain an estimate for the local error with
$$ D(h_n) = \sqrt{\tilde{H} Q(h_n) \tilde{H}^\top}\,, $$
which is essentially just the covariance of $\hat{z}_n$.

If we have a multi-dimensional ODE, then the above estimates an error per dimension.

I just took the max error over the dimension as estimate
- Although I'm not sure this is the best way to do this

In [None]:
@partial(jax.jit, static_argnames=["f", "f_args", "d", "q", "ekf_order"])
def ode_filter_step(m_f_prev, P_f_prev, f, f_args,
                    t_prev, h, d, q,
                    H0, H, R, ekf_order):
    t = t_prev + h
    A = discrete_transition_matrix(d, q, h)
    Q = discrete_diffusion_matrix(d, q, h)
    # predict
    m_p, _ = predict(m_f_prev, P_f_prev, A, Q)  # Don't need P_p now
    # residual
    z_hat, H_hat = residual(f, t, m_p, H, H0, f_args, ekf_order)
    # local uncertainty est.
    sigma_sq = z_hat.T @ jnp.linalg.inv(H @ Q @ H.T)^{-1} \ z_hat
    Q = sigma_sq * Q  # update Q
    local_error = jnp.max(jnp.sqrt(H_hat @ Q @ H_hat.T))  # est. local error
    _, P_p = predict(m_f_prev, P_f_prev, A, Q)  # predict again (P_p depends on Q)

    m_f, P_f = next_filtering_mean_and_cov(m_p, P_p, z_hat, H_hat, R)

    y = H0 @ m_f
    stddev = jnp.sqrt(jnp.diag(H0 @ P_f @ H0.T))
    return m_f, P_f, m_p, P_p, y, stddev, local_error  # report the local error est.

#### Step size adaptation

### Improving numerical Stability

#### Nordsieck-like Rescaling

#### Square-Root Kalman Filter