### Imports

In [1]:
import jax.numpy as jnp
from parsmooth.parallel import ekf, eks
from parsmooth.sequential import ekf as seq_ekf, eks as seq_eks, ckf as seq_ckf, cks as seq_cks
from parsmooth.models.linear import get_data, make_parameters
from parsmooth.utils import MVNormalParameters

### Input parameters

In [2]:
r = 0.5
q = 0.1
x0 = jnp.array([0., 0.])  # initial true location

T = 1000  # number of observations

2024-08-09 20:57:28.400845: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


### Get parameters

In [3]:
A, H, Q, R, observation_function, transition_function = make_parameters(r, q)

In [4]:
observation_function = jnp.vectorize(observation_function, signature="(m)->(d)")
transition_function = jnp.vectorize(transition_function, signature="(m)->(m)")

### Get data

In [5]:
ts, true_states, observations = get_data(x0, A, H, R, Q, T, 42)

### We can now run the filter

Initial state guess

In [6]:
m = jnp.array([0., 0.])
P = jnp.eye(2)

initial_guess = MVNormalParameters(m, P)

### We can now run the smoother

Run the filters

In [7]:
par_ekf_filtered = ekf(initial_guess, observations, transition_function, Q, observation_function, R)
seq_ekf_ll, seq_ekf_filtered = seq_ekf(initial_guess, observations, transition_function, Q, observation_function, R)
par_ckf_ll, seq_ckf_filtered = seq_ckf(initial_guess, observations, transition_function, Q, observation_function, R)

Compare:

In [8]:
print(seq_ekf_ll, par_ckf_ll)

print(jnp.max(jnp.abs(par_ekf_filtered.mean - seq_ekf_filtered.mean)))
print(jnp.max(jnp.abs(par_ekf_filtered.mean - seq_ckf_filtered.mean)))

print(jnp.max(jnp.abs(par_ekf_filtered.cov - seq_ekf_filtered.cov)))
print(jnp.max(jnp.abs(par_ekf_filtered.cov - seq_ckf_filtered.cov)))

-1178.5336 -1178.5336
1.7851591e-05
1.7851591e-05
2.346933e-06
2.346933e-06


Run the smoothers

In [9]:
par_eks_smoothed = eks(transition_function, Q, par_ekf_filtered, par_ekf_filtered.mean)
seq_eks_smoothed = seq_eks(transition_function, Q, par_ekf_filtered)
seq_cks_smoothed = seq_cks(transition_function, Q, par_ekf_filtered)

In [10]:
print(jnp.max(jnp.abs(par_eks_smoothed.mean - seq_eks_smoothed.mean)))
print(jnp.max(jnp.abs(par_eks_smoothed.mean - seq_cks_smoothed.mean)))

print(jnp.max(jnp.abs(par_eks_smoothed.cov - seq_eks_smoothed.cov)))
print(jnp.max(jnp.abs(par_eks_smoothed.cov - seq_cks_smoothed.cov)))

6.023445e-05
6.01972e-05
2.7239323e-05
2.7239323e-05
