In [3]:
import sys
sys.path.append('..//')

import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg 
from functools import partial
import matplotlib.pyplot as plt

# jax.config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", False)

from parsmooth.methods import iterated_smoothing
from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt
# from tests._lgssm import transition_function as lgssm_f, observation_function as lgssm_h, get_data
from LGSSM import transition_function as lgssm_f, observation_function as lgssm_h, get_data

ell for sequential method as ground truth data

In [4]:
# Model parameters from "Temporal parallelization of Bayesian smoothers" paper.
T = 2000
dt = 0.1
sigma = 0.5
q = 1
F = jnp.array ([[1, 0, dt, 0 ],
               [0, 1, 0,  dt],
               [0, 0, 1,  0 ],
               [0, 0, 0,  1 ]])
b = jnp.zeros((4,))
Q = q * np.array([[dt**3/3,    0,         dt**2/2,   0      ],
                  [0,          dt**3/3,   0,         dt**2/2],
                  [dt**2/2,    0,         dt,        0      ],
                  [0,          dt**2/2,   0,         dt     ]])
cholQ = scipy.linalg.cholesky(Q, lower=True)

H = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0]])
c = jnp.zeros((2,))
R = jnp.array([[sigma**2, 0],[0, sigma**2]])
cholR = scipy.linalg.cholesky(R, lower=True)

m0 = jnp.array([0 , 0, 1, -1])  # initial true location
chol_P0 = jnp.eye(4)
P0 = jnp.eye(4)

init = MVNStandard(m0, P0)
chol_init = MVNSqrt(m0, chol_P0)

true_states, observations = get_data(m0, F, H, R, Q, b, c, T)



In [5]:
transition_function = FunctionalModel(partial(lgssm_f, A=F), MVNSqrt(b, cholQ))
observation_function = FunctionalModel(partial(lgssm_h, H=H), MVNSqrt(c, cholR))

transition_model = FunctionalModel(partial(lgssm_f, A=F), MVNStandard(b, Q))
observation_model = FunctionalModel(partial(lgssm_h, H=H), MVNStandard(c, R))


initial_states = MVNStandard(jnp.repeat(jnp.array([[0. , 0., 1., -1.]]),T + 1, axis=0),
                                                     jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))
initial_states_sqrt = MVNSqrt(jnp.repeat(jnp.array([[0. , 0., 1., -1.]]),T + 1, axis=0),
                              jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))


sqrt_transition_model = FunctionalModel(transition_function, MVNSqrt(jnp.zeros((4,)), cholQ))
transition_model = FunctionalModel(transition_function, MVNStandard(jnp.zeros((4,)), Q))

sqrt_observation_model = FunctionalModel(observation_function, MVNSqrt(jnp.zeros((2,)), cholR))
observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

In [6]:
# Parallel

Ts = [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000]


def func(method, Ts, runtime=15, n_iter=20, sqrt=True):
    ell_par=[]
    for i, T in enumerate(Ts):
        ell_par_res=[]
        for j in range(runtime):
            _, true_states, ys = get_data(x0, dt, r, T, s1, s2)
            if sqrt:
                initial_states_sqrt = MVNSqrt(jnp.repeat(jnp.array([[-1., -1., 6., 4.]]),T + 1, axis=0),
                                          jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))
                args = ys, initial_states_sqrt, n_iter
            
            else:
                initial_states =  MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4.]]),T + 1, axis=0),
                                                                 jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))
                args = ys, initial_states, n_iter
                
            
            _, ell = method(*args)

            ell_par_res.append(ell)
            
        ell_par.append(ell_par_res)

    return ell_par

In [7]:
def IEKS_std_par(observations, initial_points, iteration):
    std_par_res, ell = iterated_smoothing(observations, init, transition_model, observation_model,
                                          extended, initial_points, False,
                                          criterion=lambda i, *_: i < iteration,
                                          return_loglikelihood = True)
    return std_par_res, ell


def IEKS_sqrt_par(observations, initial_points_sqrt, iteration):
    sqrt_par_res, ell = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
                                           extended, initial_points_sqrt, False,
                                           criterion=lambda i, *_: i < iteration,
                                           return_loglikelihood = True)
    return sqrt_par_res, ell



In [None]:
gpu_IEKS_std_par = jit(IEKS_std_par, backend="gpu")
gpu_IEKS_sqrt_par = jit(IEKS_sqrt_par, backend="gpu")

In [None]:
gpu_IEKS_std_par_ell = func(gpu_IEKS_std_par, Ts, sqrt=False)