In [1]:
import sys
sys.path.append('..//')

import jax
from jax import jit
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt
from parsmooth.linearization import cubature, extended, gauss_hermite
from parsmooth.methods import iterated_smoothing
from bearing_data import get_data, make_parameters


In [1]:
s1 = jnp.array([-1.5, 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
r = 0.5  # Observation noise (stddev)
x0 = jnp.array([0.1, 0.2, 1, 0])  # initial true location
dt = 0.01  # discretization time step
qc = 0.01  # discretization noise
qw = 0.1  # discretization noise

Q, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)

chol_Q = jnp.linalg.cholesky(Q)
chol_R = jnp.linalg.cholesky(R)

m0 = jnp.array([-4., -1., 2., 7., 3.])
chol_P0 = jnp.eye(5)
P0 = jnp.eye(5)

init = MVNStandard(m0, P0)
chol_init = MVNSqrt(m0, chol_P0)


sqrt_transition_model = FunctionalModel(transition_function, MVNSqrt(jnp.zeros((5,)), chol_Q))
transition_model = FunctionalModel(transition_function, MVNStandard(jnp.zeros((5,)), Q))

sqrt_observation_model = FunctionalModel(observation_function, MVNSqrt(jnp.zeros((2,)), chol_R))
observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

In [3]:
# Parallel

Ts = [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000]


def func(method, Ts, runtime=15, n_iter=20, sqrt=True):
    ell_par=[]
    for i, T in enumerate(Ts):
        ell_par_res=[]
        for j in range(runtime):
            _, true_states, ys = get_data(x0, dt, r, T, s1, s2)
            if sqrt:
                initial_states_sqrt = MVNSqrt(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]),T + 1, axis=0),
                                          jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))
                args = ys, initial_states_sqrt, n_iter
            
            else:
                initial_states =  MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]),T + 1, axis=0),
                                                                 jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))
                args = ys, initial_states, n_iter
                
            
            _, ell = method(*args)

            ell_par_res.append(ell)
            
        ell_par.append(ell_par_res)

    return ell_par


# Extended

In [4]:
def IEKS_std_par(observations, initial_points, iteration):
    std_par_res, ell = iterated_smoothing(observations, init, transition_model, observation_model,
                                          extended, initial_points, True,
                                          criterion=lambda i, *_: i < iteration,
                                          return_loglikelihood = True)
    return std_par_res, ell


def IEKS_sqrt_par(observations, initial_points_sqrt, iteration):
    sqrt_par_res, ell = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
                                           extended, initial_points_sqrt, True,
                                           criterion=lambda i, *_: i < iteration,
                                           return_loglikelihood = True)
    return sqrt_par_res, ell

In [5]:
gpu_IEKS_std_par = jit(IEKS_std_par, backend="gpu")
gpu_IEKS_sqrt_par = jit(IEKS_sqrt_par, backend="gpu")

In [2]:
gpu_IEKS_std_par_ell = func(gpu_IEKS_std_par, Ts, sqrt=False)

In [3]:
gpu_IEKS_sqrt_par_ell = func(gpu_IEKS_sqrt_par, Ts, sqrt=True)

In [8]:
jnp.savez("ell_float64_extended_runtime15",
          gpu_IEKS_std_par_ell = gpu_IEKS_std_par_ell,
          gpu_IEKS_sqrt_par_ell = gpu_IEKS_sqrt_par_ell)

# Cubature

In [3]:
def ICKS_std_par(observations, initial_points, iteration):
    std_par_res, ell = iterated_smoothing(observations, init, transition_model, observation_model,
                                                 cubature, initial_points, True,
                                                 criterion=lambda i, *_: i < iteration,
                                                 return_loglikelihood = True)
    return std_par_res, ell


def ICKS_sqrt_par(observations, initial_points_sqrt, iteration):
    sqrt_par_res, ell = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
                                                      cubature, initial_points_sqrt, True,
                                                      criterion=lambda i, *_: i < iteration,
                                                      return_loglikelihood = True)
    return sqrt_par_res, ell

In [5]:
gpu_ICKS_std_par = jit(ICKS_std_par, backend="gpu")
gpu_ICKS_sqrt_par = jit(ICKS_sqrt_par, backend="gpu")

In [2]:
gpu_ICKS_std_par_ell = func(gpu_ICKS_std_par, Ts, sqrt=False)

In [3]:
gpu_ICKS_sqrt_par_ell = func(gpu_ICKS_sqrt_par, Ts, sqrt=True)

In [8]:
jnp.savez("ell_float64_cubature_runtime15",
          gpu_ICKS_std_par_ell = gpu_ICKS_std_par_ell,
          gpu_ICKS_sqrt_par_ell = gpu_ICKS_sqrt_par_ell)

In [10]:
jnp.savez("ell_float64_cubature_runtime1",
          gpu_ICKS_std_par_ell = gpu_ICKS_std_par_ell,
          gpu_ICKS_sqrt_par_ell = gpu_ICKS_sqrt_par_ell)

In [8]:
jnp.savez("ell_float32_cubature_runtime15",
          gpu_ICKS_std_par_ell = gpu_ICKS_std_par_ell,
          gpu_ICKS_sqrt_par_ell = gpu_ICKS_sqrt_par_ell)
