### Imports

In [1]:
import jax.numpy as jnp
from jax import jit
from pekf.parallel import ekf, eks
from pekf.sequential import ekf as seq_ekf, eks as seq_eks, ckf as seq_ckf, cks as seq_cks
from pekf.models.linear import get_data, make_parameters
from pekf.utils import MVNormalParameters

### Input parameters

In [2]:
r = 0.5
q = 0.1
x0 = jnp.array([0., 0.])  # initial true location

T = 1000  # number of observations



### Get parameters

In [3]:
A, H, Q, R, observation_function, transition_function = make_parameters(r, q)

In [4]:
observation_function = jnp.vectorize(observation_function, signature="(m)->(d)")
transition_function = jnp.vectorize(transition_function, signature="(m)->(m)")

### Get data

In [5]:
ts, true_states, observations = get_data(x0, A, H, R, Q, T, 42)

### We can now run the filter

Initial state guess

In [6]:
m = jnp.array([0., 0.])
P = jnp.eye(2)

initial_guess = MVNormalParameters(m, P)

### We can now run the smoother

Run the filters

In [7]:
par_ekf_filtered = ekf(initial_guess, observations, transition_function, Q, observation_function, R)
seq_ekf_filtered = seq_ekf(initial_guess, observations, transition_function, Q, observation_function, R)
seq_ckf_filtered = seq_ckf(initial_guess, observations, transition_function, Q, observation_function, R)

Compare:

In [8]:
print(jnp.max(jnp.abs(par_ekf_filtered.mean - seq_ekf_filtered.mean)))
print(jnp.max(jnp.abs(par_ekf_filtered.mean - seq_ckf_filtered.mean)))

print(jnp.max(jnp.abs(par_ekf_filtered.cov - seq_ekf_filtered.cov)))
print(jnp.max(jnp.abs(par_ekf_filtered.cov - seq_ckf_filtered.cov)))

8.940697e-08
8.940697e-08
1.4901161e-08
1.4901161e-08


Run the smoothers

In [9]:
par_eks_smoothed = eks(transition_function, Q, par_ekf_filtered, par_ekf_filtered.mean)
seq_eks_smoothed = seq_eks(transition_function, Q, par_ekf_filtered)
seq_cks_smoothed = seq_cks(transition_function, Q, par_ekf_filtered)

In [10]:
print(jnp.max(jnp.abs(par_eks_smoothed.mean - seq_eks_smoothed.mean)))
print(jnp.max(jnp.abs(par_eks_smoothed.mean - seq_cks_smoothed.mean)))

print(jnp.max(jnp.abs(par_eks_smoothed.cov - seq_eks_smoothed.cov)))
print(jnp.max(jnp.abs(par_eks_smoothed.cov - seq_cks_smoothed.cov)))

5.9604645e-08
5.9604645e-08
2.9802322e-08
2.9802322e-08
