# Simulated tracking example

In [None]:
import numpy as np
import os
import jax
import jax.numpy as jnp
import pandas as pd
import yaml
import json
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
sns.set_palette('deep')
%matplotlib widget
datadir = '../data/tracking/'
dirs = [os.path.join(datadir, x) for x in next(os.walk(datadir))[1]]
dirs.sort(key=lambda x: int(x.split('/')[-1]))

### Load data from data directory

In [None]:
from ssmjax.types import MVNormal
statistics = {}
results = {}
for path in dirs:
    conf = os.path.join(os.path.join(path, '.hydra'), 'config.yaml')
    with open(conf, "r") as file:
        d = yaml.safe_load(file)
    data = dict(np.load(os.path.join(path, 'result.npz')))
    alg = d['alg']['_target_'].split('.')[-1]
    q1 = d['sim']['q1']
    T = d['sim']['T']
    s2 = d['sim']['s2']
    states = MVNormal(data.pop('state_mean'), cov=data.pop('state_cov'))
    if results.get((q1, s2, T), None) is None:
        results[(q1, s2, T)] = dict()
    results[(q1, s2, T)][alg] = dict(state_trajectory=states)
    results[(q1, s2, T)][alg].update(data)

## Trajectory plots
Plots five sample trajectories from the data. Both ground truth and measurements are visualized.

In [None]:
x = results[(0.001, 10, 1)]['ekf']['x']
y = results[(0.001, 10, 1)]['ekf']['y']

rkey = jax.random.PRNGKey(19)
inds = jax.random.randint(rkey, shape=(5,), minval=np.arange(0, 100, 20), maxval=np.arange(20, 101, 20))

sns.set_palette('deep')
with sns.axes_style("whitegrid"):
    with sns.plotting_context("paper"):
        plt.rc('ytick', labelsize=16)
        plt.rc('xtick', labelsize=16)
        plt.rc('axes', labelsize=16)
        plt.figure(figsize=(7, 6))
        plt.plot(x[inds, :, 0].T, x[inds, :, 2].T, marker='.', lw=3, markersize=10)
        # plt.plot(x[:20:4, :, 0].T, x[:20:4, :, 2].T, '.', lw=3, markersize=10)
        plt.xlabel(r'$p^x~$[m]')
        plt.xticks([-500, 0, 500, 1000])
        plt.yticks([-300, -100, 100, 300, 500])
        # plt.xticks([-800, -400, 0, 400])
        # plt.yticks([0, 400, 800, 1200])
        plt.ylabel(r'$p^y~$[m]')
        plt.plot(y[inds, :, 0].T, y[inds, :, 1].T, 'k.', markersize=5)
plt.tight_layout()
plt.savefig('trackingtrajectories.eps', bbox_inches=0)
plt.show()

In [None]:
from itertools import product
def config_mse(config_results):
    df = pd.DataFrame()
    for alg, res in config_results.items():
        se = (res['x']-res['state_trajectory'].mean)**2
        mse = np.mean(se, axis=(1)) # Mean over time
        cols = [r'$p_x$', r'$v_x$', r'$p_y$', r'$v_y$', r'$\delta$']
        index = pd.MultiIndex.from_tuples(product([alg], np.arange(mse.shape[0])), names=['Algorithm', 'Simulation'])
        df = pd.concat([df, pd.DataFrame(mse, columns=cols, index=index)], axis=0)
    return df

def calc_mse(results):
    df = pd.DataFrame()
    for config, config_result in results.items():
        stats = config_mse(config_result)
        # Convert index to dataframe
        old_idx = stats.index.to_frame()
        # Insert new level at specified location
        old_idx.insert(0, 'T', config[2])
        old_idx.insert(0, 'R', config[1])
        old_idx.insert(0, 'Q', config[0])
        # Convert back to MultiIndex
        stats.index = pd.MultiIndex.from_frame(old_idx)
        df = pd.concat([df, stats], axis=0)
    return df
df = calc_mse(results)

In [None]:
means = df.groupby(['Q', 'R', 'T', 'Algorithm']).mean()
# Restructures dataframe to baseline, iterated and relative performance
def relative_performance(mean_mses, base_alg, iter_alg, name):
    m = mean_mses.loc[:, :, :, [base_alg, iter_alg]].copy()
    m.rename(index={base_alg: 'Baseline', iter_alg: 'Iterated'}, inplace=True)
    # m2 = m.groupby(['Q', 'R', 'T']).agg(lambda x: x.loc[:,:,:,'Iterated']/x.loc[:, :, :, 'Baseline'])
    # idx = m2.index.to_frame()
    # idx.insert(3, 'Algorithm', 'Relative')
    # m2.index = pd.MultiIndex.from_frame(idx)
    # m = pd.concat([m, m2], axis=0)
    old_idx = m.index.to_frame()
    old_idx.insert(3, 'Transform', name)
    m.index = pd.MultiIndex.from_frame(old_idx)
    return m

me = relative_performance(means, 'ekf', 'diekf', 'Extended')
mc = relative_performance(means, 'ckf', 'dickf', 'Cubature')
mu = relative_performance(means, 'ukf', 'diukf', 'Unscented')
mplf = relative_performance(means, 'ukf', 'diplf', 'Posterior Linearization')
relperf = pd.concat([me, mc, mu, mplf], axis=0).sort_index()

In [None]:
from matplotlib import cm
from matplotlib.colors import ListedColormap
import matplotlib as mpl

def plot_relperf(relperf, transform, ax, index):
    Qvals = np.unique([x[0] for x in relperf.index.to_numpy()])
    Rvals = np.unique([x[1] for x in relperf.index.to_numpy()])
    
    Qvals_txt = np.flip([(r'$' + np.format_float_scientific(x, exp_digits=1).replace('.e', '0^{') + '}$').replace('+', '') for x in Qvals])
    
    # Qvals = np.flip([np.format_float_positional(x, precision=10, trim='-') for x in Qvals])
    Rvals_txt = [(r'$' + np.format_float_scientific(x, exp_digits=1).replace('.e', '0^{') + '}$').replace('+', '') for x in Rvals]
    # Rvals = [np.format_float_positional(x, precision=10, trim='-') for x in Rvals]
    
    # cmap = sns.diverging_palette(250, 14, n=15, l=50, s=90, sep=1)
    # cmap = ListedColormap(cmap.as_hex())
    # cmap = ListedColormap(sns.light_palette((130, 70, 50), input='husl', n_colors=8, reverse=True).as_hex())
    # cmap = ListedColormap(sns.color_palette('crest', n_colors=8).as_hex())
    cmap = ListedColormap(sns.cubehelix_palette(start=1.9, rot=0., gamma=1.0, hue=0.7, light=1.0, dark=0.5, reverse=True, n_colors=8).as_hex())
    # cmap = cm.get_cmap('summer', 8)
    # cmap = sns.diverging_palette(250, 14, l=50, s=90, sep=1, as_cmap=True)
    
    inds = [[0, 2], [1, 3]]
    ims = []
    for l, axi in enumerate(ax):
        dat = relperf.loc[:, :, transform].iloc[:, inds[l]].mean(axis=1)
        
        basedat = np.flip(np.sqrt(dat.loc[:,:,'Baseline'].to_numpy()).reshape(len(Qvals), len(Rvals)), axis=0)
        iterdat = np.flip(np.sqrt(dat.loc[:,:,'Iterated'].to_numpy()).reshape(len(Qvals), len(Rvals)), axis=0)
        if l == 0:
            thresh = np.sqrt(Rvals)
            iter_valid = np.round(iterdat, 0) <= thresh
            base_valid = np.round(basedat, 0) <= thresh
        reldat = iterdat/basedat
        reldat[~iter_valid] = 1 # If the iterated filter has diverged, we want the image to be white.
        
        ims.append(axi[index].imshow(reldat, cmap=cmap, vmin=0.6, vmax=1))
        axi[index].set_xlim([-.5, len(Rvals)-.5])
        axi[index].set_ylim([len(Qvals)-.5, -.5])
        axi[index].set_xticks(np.arange(len(Rvals)))
        axi[index].set_yticks(np.arange(len(Qvals)))
        axi[index].set_xticklabels(Rvals_txt)
        axi[index].set_yticklabels(Qvals_txt)
        
        axi[index].hlines(y=np.arange(len(Qvals)-1)+0.5, xmin=np.full(len(Rvals)-1, 0)-0.5, xmax=np.full(len(Rvals)-1, len(Rvals))-0.5, color="black", lw=1)
        axi[index].vlines(x=np.arange(len(Rvals)-1)+0.5, ymin=np.full(len(Qvals)-1, 0)-0.5, ymax=np.full(len(Qvals)-1, len(Qvals))-0.5, color="black", lw=1)
        
        # Loop over data dimensions and create text annotations.
        for i in range(len(Qvals)):
            for j in range(len(Rvals)):
                iter_txt = np.format_float_positional(iterdat[i, j], precision=3, trim='-') if iter_valid[i, j] else '-'
                base_txt = np.format_float_positional(basedat[i, j], precision=3, trim='-') if base_valid[i, j] else '-'
                
                if iter_valid[i, j] and base_valid[i, j]:
                    axi[index].text(j, i, r"$\frac{{{}}}{{{}}}$".format(iter_txt, base_txt),
                                   ha="center", va="center", color="k")
                else: # Adjust alignment due to matplotlib idiocy
                    axi[index].annotate(xy=(j, i), xytext=(j, i+0.03), text=r"$\frac{{{}}}{{{}}}$".format(iter_txt, base_txt),
                                   ha="center", va="center", color="k")
        axi[index].set_xlabel(r'$\sigma^2$')
        axi[index].set_title(transform)
    return ims

T = 1
plt.close('all')
sns.set_context("paper")
plt.rc('xtick', labelsize=14)
plt.rc('ytick', labelsize=14)
plt.rc('axes', labelsize=18, titlesize=18)
plt.rc('figure', titlesize=20)
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble='')
plt.rc('font', size=20)
plt.rc('ps', usedistiller='xpdf')
plt.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": "Computer Modern Serif",
})
fig = []
ax = []
with sns.axes_style('white'):
    for i in range(2):
        fig.append(plt.figure(figsize=(12, 6), constrained_layout=True))
        ax.append(fig[i].subplots(1, 3))
    
    ims = plot_relperf(relperf.loc[:,:,T], 'Extended', ax, 0)
    # plot_relperf(relperf, 'Cubature', ax[1])
    plot_relperf(relperf.loc[:,:,T], 'Unscented', ax, 1)
    plot_relperf(relperf.loc[:,:,T], 'Posterior Linearization', ax, 2)
    # fig.suptitle('T = {}'.format(T))
    for i in range(2):
        ax[i][0].set_ylabel(r'$q_1$', rotation='horizontal', y=1, horizontalalignment='left')
        ax[i][1].set_yticklabels([])
        ax[i][2].set_yticklabels([])
        cbar = fig[i].colorbar(ims[i],orientation='horizontal', ax=ax[i], location='bottom', shrink=.75, aspect=60)
        for axi in ax[i]:
            plt.setp(axi.spines.values(), linewidth=3)
        # cbar.set_label(r'Relative \textsc{RMSE} Iterated/Baseline')
    # fig[0].suptitle(r'Position \textsc{RMSE} $\left[m^2\right]$', y=1)
    # fig[1].suptitle(r'Velocity \textsc{RMSE} $\left[\frac{{{m^2}}}{{{s^2}}}\right]$', y=1)

# plt.tight_layout()
plt.show()

In [None]:
fig[0].savefig('position-rmse.eps', bbox_inches=0)
fig[1].savefig('velocity-rmse.eps', bbox_inches=0)