In [2]:
import sys
sys.path.append('..//')

import jax
import numpy as np
import jax.numpy as jnp
import scipy
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import pandas as pd

jax.config.update("jax_enable_x64", True)

from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt
from parsmooth.linearization import extended, cubature, unscented, gauss_hermite
from parsmooth.methods import iterated_smoothing
from bearing_data_pe import get_data, make_parameters, inverse_bearings


In [3]:
s1 = jnp.array([-1., 0.5])  # First sensor location
s2 = jnp.array([1., 1.])  # Second sensor location
x0 = jnp.array([0.1, 0.2, 1, 0])  # initial true location
dt = 0.01  # discretization time step
qc = 0.1  # discretization noise
qw = 0.1  # discretization noise

r_true = 0.05
T = 50
_, true_states, ys = get_data(x0, dt, r_true, T, s1, s2)
Q, _, observation_function, transition_function = make_parameters(qc, qw, r_true, dt, s1, s2)

chol_Q = jnp.linalg.cholesky(Q)

m0 = jnp.array([2., 0, 0, 0, 0])
P0 = jnp.diag(jnp.array([0.5**2, 0.5**2, 0.5**2, 0.5**2, 1.]))
chol_P0 = scipy.linalg.cholesky(P0)

init = MVNStandard(m0, P0)
chol_init = MVNSqrt(m0, chol_P0)


positions = inverse_bearings(ys, s1, s2)
states = jnp.concatenate([jnp.concatenate([jnp.zeros((1,2)), positions], axis = 0), 
                          jnp.zeros((T+1,3))], axis = 1)

initial_states =  MVNStandard(states,jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))
initial_states_sqrt = MVNSqrt(states,jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))


sqrt_transition_model = FunctionalModel(transition_function, MVNSqrt(jnp.zeros((5,)), chol_Q))
transition_model = FunctionalModel(transition_function, MVNStandard(jnp.zeros((5,)), Q))


In [5]:
#standard
@jax.jit
def get_ell_std_extended(prec_r):
    r = 1 / prec_r[0]
    R = jnp.diag(jnp.array([r ** 2, 0.1 ** 2]))
    observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))  
    
    _, ell = iterated_smoothing(ys, init, transition_model, observation_model,
                                               extended, initial_states, True,
                                               criterion=lambda i, *_: i < 50,
                                               return_loglikelihood = True)
    return -ell

def get_ell_std_cubature(prec_r):
    r = 1 / prec_r[0]
    R = jnp.diag(jnp.array([r ** 2, 0.1 ** 2]))
    observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

    _, ell = iterated_smoothing(ys, init, transition_model, observation_model,
                                               cubature, initial_states, True,
                                               criterion=lambda i, *_: i < 50,
                                               return_loglikelihood = True)
    return -ell

def get_ell_std_unscented(prec_r):
    r = 1 / prec_r[0]
    R = jnp.diag(jnp.array([r ** 2, 0.1 ** 2]))
    observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

    _, ell = iterated_smoothing(ys, init, transition_model, observation_model,
                                               unscented, initial_states, True,
                                               criterion=lambda i, *_: i < 50,
                                               return_loglikelihood = True)
    return -ell

def get_ell_std_gh(prec_r):
    r = 1 / prec_r[0]
    R = jnp.diag(jnp.array([r ** 2, 0.1 ** 2]))
    observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

    _, ell = iterated_smoothing(ys, init, transition_model, observation_model,
                                               gauss_hermite, initial_states, True,
                                               criterion=lambda i, *_: i < 50,
                                               return_loglikelihood = True)
    return -ell

grad_ell_std_extended = jax.jit(jax.value_and_grad(get_ell_std_extended))
grad_ell_std_cubature= jax.jit(jax.value_and_grad(get_ell_std_cubature))
grad_ell_std_unscented = jax.jit(jax.value_and_grad(get_ell_std_unscented))
grad_ell_std_gh = jax.jit(jax.value_and_grad(get_ell_std_gh))

In [13]:
gpu_grad_ell_std_extended = jax.jit(grad_ell_std_extended, backend="gpu")
gpu_grad_ell_std_cubature = jax.jit(grad_ell_std_cubature, backend="gpu")
gpu_grad_ell_std_unscented = jax.jit(grad_ell_std_unscented, backend="gpu")
gpu_grad_ell_std_gh = jax.jit(grad_ell_std_gh, backend="gpu")


NameError: name 'grad_ell_std_extended' is not defined

In [7]:
def wrap_func_std_extended(r):
    loss, grad_val = gpu_grad_ell_std_extended(r)
    return np.array(loss, dtype=np.float64), np.array(grad_val, dtype=np.float64)

def wrap_func_std_cubature(r):
    loss, grad_val = gpu_grad_ell_std_cubature(r)
    return np.array(loss, dtype=np.float64), np.array(grad_val, dtype=np.float64)

def wrap_func_std_unscented(r):
    loss, grad_val = gpu_grad_ell_std_unscented(r)
    return np.array(loss, dtype=np.float64), np.array(grad_val, dtype=np.float64)

def wrap_func_std_gh(r):
    loss, grad_val = gpu_grad_ell_std_gh(r)
    return np.array(loss, dtype=np.float64), np.array(grad_val, dtype=np.float64)


      fun: -112.99485133901851
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-0.03703517])
  message: 'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 5
      nit: 4
     njev: 5
   status: 0
  success: True
        x: array([20.])

In [None]:
r0 = np.array([10.])
r_extended = 1 / minimize(wrap_func_std_extended, r0, jac=True, method="L-BFGS-B", bounds=[(0, 100)]).x
r_cubature = 1 / minimize(wrap_func_std_cubature, r0, jac=True, method="L-BFGS-B", bounds=[(0, 100)]).x
r_unscented =  1 / minimize(wrap_func_std_unscented, r0, jac=True, method="L-BFGS-B", bounds=[(0, 100)]).x
r_gh = 1 / minimize(wrap_func_std_gh, r0, jac=True, method="L-BFGS-B", bounds=[(0, 100)]).x

In [63]:
# cubature
def wrap_func_std(r):
    loss, grad_val = gpu_grad_ell_std(r)
    return np.array(loss, dtype=np.float64), np.array(grad_val, dtype=np.float64)
r0 = np.array([10.])
minimize(wrap_func_std, r0, jac=True, method="L-BFGS-B", bounds=[(0,20)])

      fun: -98.3942657017294
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-0.00301808])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 17
      nit: 7
     njev: 17
   status: 0
  success: True
        x: array([19.26860337])

In [8]:
#square-root
@jax.jit
def get_ell_sqrt(prec_r):
    r = 1 / prec_r[0]
    chol_R = jnp.diag(jnp.array([r, 0.1]))
    observation_model_sqrt = FunctionalModel(observation_function, MVNSqrt(jnp.zeros((2,)), chol_R)) 
    
    _, ell = iterated_smoothing(ys, chol_init, sqrt_transition_model, observation_model_sqrt,
                                               linearization_method, initial_states_sqrt, True,
                                               criterion=lambda i, *_: i < 50,
                                               return_loglikelihood = True)
    return -ell

grad_ell_sqrt = jax.jit(jax.value_and_grad(get_ell_sqrt))


In [9]:
gpu_grad_ell_sqrt = jax.jit(grad_ell_sqrt, backend="gpu")

In [12]:
def wrap_func_sqrt(r):
    loss, grad_val = gpu_grad_ell_sqrt(r)
    return np.array(loss, dtype=np.float64), np.array(grad_val, dtype=np.float64)
r0 = np.array([15.])
minimize(wrap_func_sqrt, r0, jac=True, method="L-BFGS-B", bounds=[(0,20)])


      fun: -112.99485133901852
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
      jac: array([-0.03703516])
  message: 'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 4
      nit: 3
     njev: 4
   status: 0
  success: True
        x: array([20.])

In [9]:
with np.load("pe_ell_std_extended_T50_par.npz") as loaded:
    ell_std_extended_T50_par = loaded["ell_std_extended_T50_par"]
    grad_ell_std_extended_T50_par = loaded["grad_ell_std_extended_T50_par"]
    theta_par_std = loaded["theta_par_std"]
    
with np.load("pe_ell_sqrt_extended_T50_par.npz") as loaded:
    ell_sqrt_extended_T50_par = loaded["ell_sqrt_extended_T50_par"]
    grad_ell_sqrt_extended_T50_par = loaded["grad_ell_sqrt_extended_T50_par"]
    theta_par_sqrt = loaded["theta_par_sqrt"]



FileNotFoundError: [Errno 2] No such file or directory: 'pe_ell_std_extended_T50_par.npz'

In [None]:
plt.figure(figsize=(10,8))
plt.plot(np.flip(1/np.linspace(10, 50)),np.flip(ell_std_extended_T50_par))
plt.plot(np.flip(1/np.linspace(10, 50)),np.flip(ell_sqrt_extended_T50_par))
plt.grid()

print(theta_par_std)
print(theta_par_sqrt)



In [None]:
plt.figure(figsize=(10,8))
plt.plot(np.flip(1/np.linspace(10, 50)),np.flip(grad_ell_std_extended_T50_par))
plt.plot(np.flip(1/np.linspace(10, 50)),np.flip(grad_ell_sqrt_extended_T50_par))
plt.axhline(y = 0, color = 'r', linestyle = '--')
plt.grid()

In [None]:
data = np.stack([np.flip(1/np.linspace(10, 50)),
                 np.flip(ell_std_extended_T50_par),
                 np.flip(ell_sqrt_extended_T50_par),
                 np.flip(grad_ell_std_extended_T50_par[:,0]),
                 np.flip(grad_ell_sqrt_extended_T50_par)[:,0]],
                 axis=1)

columns = ["ell",
           "ell_std_extended_par",
           "ell_sqrt_extended_par",
           "grad_ell_std_extended_par",
           "grad_ell_sqrt_extended_par"]

df = pd.DataFrame(data=data, columns=columns)
df.to_csv("outputs/pe_extended_ell_par.csv")