In [1]:
import sys
sys.path.append('..//')

import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg 
from functools import partial
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", False)
# jax.config.update("jax_enable_x64", False)

from parsmooth.linearization import extended
from parsmooth.methods import iterated_smoothing
from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt
from tests._lgssm import transition_function as lgssm_f, observation_function as lgssm_h, get_data
#from LGSSM import transition_function as lgssm_f, observation_function as lgssm_h, get_data

ell for sequential method as ground truth data

In [2]:
# Model parameters from "Temporal parallelization of Bayesian smoothers" paper.
dt = 0.1
sigma = 0.5
q = 1.
F = jnp.array ([[1, 0, dt, 0 ],
               [0, 1, 0,  dt],
               [0, 0, 1,  0 ],
               [0, 0, 0,  1 ]])
b = jnp.zeros((4,))
Q = 100 * q * np.array([[dt**3/3,    0,         dt**2/2,   0      ],
                  [0,          dt**3/3,   0,         dt**2/2],
                  [dt**2/2,    0,         dt,        0      ],
                  [0,          dt**2/2,   0,         dt     ]])
cholQ = scipy.linalg.cholesky(Q, lower=True)

H = jnp.array([[1., 0, 0, 0], [0, 1., 0, 0]])
c = jnp.zeros((2,))
R = jnp.array([[sigma**2, 0],[0, sigma**2]])
cholR = scipy.linalg.cholesky(R, lower=True)

m0 = jnp.array([0 , 0, 1., -1.])  # initial true location
chol_P0 = jnp.eye(4)
P0 = jnp.eye(4)

init = MVNStandard(m0, P0)
chol_init = MVNSqrt(m0, chol_P0)




In [3]:
# initial_states = MVNStandard(jnp.repeat(jnp.array([[0. , 0., 1., -1.]]),T + 1, axis=0),
#                                                      jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))
# sqrt_initial_states = MVNSqrt(jnp.repeat(jnp.array([[0. , 0., 1., -1.]]),T + 1, axis=0),
#                               jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))

transition_model = FunctionalModel(partial(lgssm_f, A=F), MVNStandard(b, Q))
sqrt_transition_model = FunctionalModel(partial(lgssm_f, A=F), MVNSqrt(b, cholQ))

observation_model = FunctionalModel(partial(lgssm_h, H=H), MVNStandard(c, R))
sqrt_observation_model = FunctionalModel(partial(lgssm_h, H=H), MVNSqrt(c, cholR))


In [None]:
ell_seq_fl64 = []
for T in Ts:
    _, true_states, ys = get_data(x0, dt, r, T, s1, s2)
    _, ell = iterated_smoothing(ys, init, transition_model, observation_model,
                                          extended, None, False,
                                          criterion=lambda i, *_: i < iteration,
                                          return_loglikelihood = True)
    ell_seq_fl64.append(ell)
    

In [None]:
jnp.savez("ell_float64_extended_runtime1",
          ell_seq = ell_seq_fl64)

In [5]:
# std_par_res_32_8e3, ell_par_32_8e3 = iterated_smoothing(observations, init, transition_model, observation_model,
#                                           extended, None, True,
#                                           criterion=lambda i, *_: i < 2,
#                                           return_loglikelihood = True)

In [8]:
# sqr_par_res_32, sqr_ell_par_32 = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
#                                           extended, None, True,
#                                           criterion=lambda i, *_: i < 1,
#                                           return_loglikelihood = True)


In [6]:
# Parallel

Ts = [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000]

def func(method, Ts, runtime=1, n_iter=20, sqrt=True):
    ell_par=[]
    for i, T in enumerate(Ts):
        ell_par_res=[]
        for j in range(runtime):
            _, true_states, ys = get_data(x0, dt, r, T, s1, s2)
            if sqrt:
#                 initial_states_sqrt = MVNSqrt(jnp.repeat(jnp.array([[-1., -1., 6., 4.]]),T + 1, axis=0),
#                                           jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))
                args = ys, n_iter
            
            else:
#                 initial_states =  MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4.]]),T + 1, axis=0),
#                                                                  jnp.repeat(jnp.eye(4).reshape(1, 4, 4), T + 1, axis=0))
                args = ys, n_iter
                
            
            _, ell = method(*args)

            ell_par_res.append(ell)
            
        ell_par.append(ell_par_res)

    return ell_par

In [7]:
def IEKS_std_par(observations, iteration):
    std_par_res, ell = iterated_smoothing(observations, init, transition_model, observation_model,
                                          extended, None, True,
                                          criterion=lambda i, *_: i < iteration,
                                          return_loglikelihood = True)
    return std_par_res, ell


def IEKS_sqrt_par(observations, iteration):
    sqrt_par_res, ell = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
                                           extended, None, True,
                                           criterion=lambda i, *_: i < iteration,
                                           return_loglikelihood = True)
    return sqrt_par_res, ell



In [None]:
gpu_IEKS_std_par = jit(IEKS_std_par, backend="gpu")
gpu_IEKS_sqrt_par = jit(IEKS_sqrt_par, backend="gpu")

In [None]:
gpu_IEKS_std_par_ell = func(gpu_IEKS_std_par, Ts, sqrt=False)

In [15]:
gpu_IEKS_sqrt_par_ell = func(gpu_IEKS_sqrt_par, Ts, sqrt=True)

4294967000.0

In [None]:
jnp.savez("ell_float32_extended_runtime1",
          gpu_IEKS_std_par_ell = gpu_IEKS_std_par_ell,
          gpu_IEKS_sqrt_par_ell = gpu_IEKS_sqrt_par_ell)