## TDOA example

In [None]:
%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt
from ssmjax.types import MVNormal
from ssmjax.types import StateSpaceModel
from ssmjax.types import options as options
from ssmjax import algs
import ssmjax.examples.tdoa as tdoa
from tqdm.notebook import tqdm
from ssmjax.utility.pytrees import tree_stack
import json

In [None]:
with open('../data/tdoa_data.json', 'r') as file:
    data = json.load(file)
data = {key: np.array(val) for key, val in data.items()}
with open('../data/tdoa_calibration.json', 'r') as file:
    calibration_data = json.load(file)
calibration_data = {key: np.array(val) for key, val in calibration_data.items()}

### Construct a noise model for the microphones

The model for the calibration data is given by
$$ y_i = r + kT\cdot v + e_i, $$
where $r$ is the distance between source and microphone (which is identical for all microphones) and $T$ is the period in which the pulse is emitted.
Hence, the joint model is given by
$$ y = Hx + e $$
where $x$ can be identified using least-squares which can then be subtracted from y to form an estimate of $e$, i.e.,
$$ e = y - Hx^{LS} $$
where
$$ x^{LS} = (H^TH)^{-1}H^Ty $$

In [None]:
N = calibration_data['pulse_times'].shape[1]
H = np.ones((N, 2))
H[:, -1] = np.arange(1, N+1)
H = np.kron(H, np.ones((8, 1)))
x = np.linalg.lstsq(H, calibration_data['pulse_times'].T.flatten()*tdoa.v, rcond=None)[0]
e = (calibration_data['pulse_times'].T.flatten()*tdoa.v - H@x).reshape(8,-1)
s2 = np.var(e, axis=1)

### Plot out the scenario

In [None]:
%matplotlib widget
mic_locations = data['mic_locations'][:, :4]
pulse_times = data['pulse_times'][:4, :-1] # Last pulse is bad
N = pulse_times.shape[1]

plt.figure()
plt.plot(mic_locations[0, :], mic_locations[1, :], '^', color='tab:green')
plt.plot(data['ground_truth'][1, :], data['ground_truth'][2, :], '--k', lw=.5)
plt.plot(data['initial_state'][0], data['initial_state'][1], '*', color='tab:red')
plt.xlim([-3, 3])
plt.ylim([-3.5, 3.5])
plt.show()

### Build state-space model

In [None]:
ssm = tdoa.build_model(tdoa.T, 1e-3, 1e-4, s2[:4])

### Simple snapshot localization approach
Localizes the RC-car without dynamics -- primarily to sync audio to ground truth trajectory

In [None]:
import scipy
def loss(x, y):
    diff = y - ssm.observation_function(x, mic_locations)
    return jnp.sum(jnp.linalg.solve(ssm.observation_covariance, diff.T)*diff.T, axis=0)

opt_xhat = np.zeros((N, 2))
for k in range(N):    
    yi = np.diff(pulse_times[:, k])*tdoa.v
    sol = scipy.optimize.minimize(lambda x: loss(x, yi), x0=np.array([0, 0]))
    opt_xhat[k] = sol.x

### Sync the ground truth to the audio
To sync the ground truth to the audio, the ground truth trajectory (along $x$) is assumed to be measured as
$$x = f(t) + e$$
where $t$ is the time and $e\sim\mathcal{N}(0,\sigma^2)$.
$f$ is assumed to be a Gaussian Process with some particular kernel $k$. The kernel parameters and noise variance are found by minimizing the negative log-likelihood of the ground truth trajectory data.

The first estimate of the snapshot approach is then used as a "pseudo-measurement" to sync the trajectory by maximizing the negative log predictive density of the trained GP model. This yields the most likely initial time of the audio trajectory. The succeeding times of the audio trajectory are assumed equi-distantly sampled with a 0.5 second interval. This "true" ground truth is then found through the predictive density under the fitted GP model. An independent GP is fit to the $y$ coordinate as well in a similar manner.

In [None]:
import bayesnewton
import objax
import scipy

def train_gp(X, Y):
    lr_ad = 0.1
    kern = bayesnewton.kernels.Matern72()
    lik = bayesnewton.likelihoods.Gaussian()
    model = bayesnewton.models.NewtonGP(kernel=kern, likelihood=lik, X=X, Y=Y)
    opt_hypers = objax.optimizer.Adam(model.vars())
    energy = objax.GradValues(model.energy, model.vars())

    @objax.Function.with_vars(model.vars() + opt_hypers.vars())
    def train():
        model.inference()
        dE, E = energy()
        opt_hypers(lr_ad, dE)
        return E

    train = objax.Jit(train)

    for i in tqdm(range(30)):
        train()
    return model
x_model = train_gp(data['ground_truth'][0, :], data['ground_truth'][1, :])
y_model = train_gp(data['ground_truth'][0, :], data['ground_truth'][2, :])

In [None]:
t = np.linspace(0, data['ground_truth'][0, :].max(), 1000)
mu_x, _ = x_model.predict(X=t)
mu_y, _ = y_model.predict(X=t)
vx = (np.diff(mu_x)/np.diff(t))[::2]
vy = (np.diff(mu_y)/np.diff(t))[::2]
vx_model = train_gp(t[::2], vx)
vy_model = train_gp(t[::2], vy)

In [None]:
d = np.arctan2(np.diff(mu_y), np.diff(mu_x))
dfix = d.copy() # Fix the discontinuities in d
inds = np.where(np.abs(np.diff(d)) > 3)[0] # Find indices where the change is too large (flip between pi and -pi)
for i in inds:
    dfix[i+1:] -= np.sign(d[i+1])*np.pi*2 # Subtract the difference (the flip) -> rotates the heading correctly
vd = (np.diff(dfix)/np.diff(t[:-1]))[::2]
vd_model = train_gp(t[::2][:-1], vd)

### Build model and form ground truth
The sampling time of the model is easiest varied here and the ground truth is formed according to the chosen sampling time.

In [None]:
get_obs = lambda T: np.diff(pulse_times[:, ::int(T/.5)], axis=0).T*tdoa.v
opt = scipy.optimize.minimize(lambda x: x_model.negative_log_predictive_density(X=x, Y=opt_xhat[0, [0]]), x0=pulse_times[0, [0]])
def get_gt(T, N):
    t_audio = opt.x + np.arange(N)*T
    gt_x, _ = x_model.predict(X=t_audio)
    gt_y, _ = y_model.predict(X=t_audio)
    gt_vx, _ = vx_model.predict(X=t_audio)
    gt_vy, _ = vy_model.predict(X=t_audio)
    gt_vd, _ = vd_model.predict(X=t_audio)
    gt = np.vstack([gt_x, gt_y, gt_vx, gt_vy, gt_vd]).T
    return gt

### Run filters

Run filters with parameters sweeps over different sampling intervals and process noise levels. In this way, the differences between the three types of filters is clearly illustrated.

In [None]:
import itertools
T = np.arange(.5, 4.1, step=.5)
q1 = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]
q2 = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]
configs = list(itertools.product(q1, q2))
lsopts = options.LineSearchIterationOptions(options.LineSearchOptions(gamma=0.25, beta=0.9))

prior = tdoa.build_prior('../data/tdoa_data.json', P0=np.diag([1e-1, 1e-1, 1, 1, 1e-2]))
theta = MVNormal(mic_locations.flatten(), None)

results = dict()
for Ti in tqdm(T, desc="Sampling interval: "):
    ssm = tdoa.build_model(Ti, 1e-2, 1e-2, s2[:4])
    ekf = algs.ekf(ssm)
    iekf = algs.iekf(ssm)
    lsiekf = algs.lsiekf(ssm)
    diekf = algs.diekf(ssm)
    lsdiekf = algs.lsdiekf(ssm, linesearch_method="backtracking")
    observations = get_obs(Ti)
    for config in tqdm(configs, desc="Config: "):
        tmp_ssm = tdoa.build_model(Ti, config[0], config[1], s2[:4])
        ssm.transition_covariance.value = tmp_ssm.transition_covariance.value
        ssm.observation_covariance.value = tmp_ssm.observation_covariance.value
        _, xhat, _ = ekf(prior, theta, observations)
        _, ixhat, _ = iekf(prior, theta, observations)
        _, lsixhat, _ = lsiekf(prior, theta, observations, options=lsopts)
        _, dixhat, _ = diekf(prior, theta, observations)
        _, lsdixhat, _ = lsdiekf(prior, theta, observations, options=lsopts)
        g_config = config + (Ti,)
        results[g_config] = dict(ekf=xhat, iekf=ixhat, lsiekf=lsixhat, diekf=dixhat, lsdiekf=lsdixhat)

In [None]:
se = lambda gt, mu: (gt - mu)**2
pl = lambda ax, x, **kwargs: ax.plot(x[:, 0], x[:, 1], **kwargs)
def calc_rmse(results):
    rmse = {}
    T_gt = {}
    for config, result in results.items():
        irmse = {}
        T = config[-1]
        m = list(result.keys())[1]
        if T_gt.get(T, None) is None:
            T_gt[T] = get_gt(T, result[m].mean.shape[0]) # Calculate the ground truth for this particular sampling time
        for model, est in result.items():
            irmse[model] = np.sqrt(np.mean(np.sum(se(T_gt[T], est.mean)[:, :2], axis=1)))
        rmse[config] = irmse
    return rmse
rmse = calc_rmse(results)

### Plot RMSE per configuration

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib as mpl
sns.set()
df = pd.DataFrame(rmse).T.melt(ignore_index=False).sort_index()
df.index.names = ['$q_1$', '$q_2$', "T"]
df = df.reset_index()
df = df.rename(columns={"variable": "Alg.", "value": "RMSE"})

### Double row plot -- only suitable for evenly divisible number of plots

In [None]:
# Plot for paper on unification
def plot_rmse(df, algs, T, filename=None, figsize=(8, 5.5)):
    df_dec = df.loc[(df['$q_1$']<1) & (df["Alg."].isin(algs)) & (df["T"] == T)]
    df_dec = df_dec.assign(**{"Alg.": pd.Categorical(df_dec["Alg."], categories=algs)})
    df_dec = df_dec.sort_values(["$q_1$", "Alg."]).reset_index(drop=True)
    plt.close("all")
    with sns.axes_style("whitegrid"):
        fs = 16
        plt.rc('ytick', labelsize=fs)
        plt.rc('xtick', labelsize=fs)
        plt.rc('axes', labelsize=fs,titlesize=fs)
        plt.rc('legend', fontsize=fs)
        q1_J = df_dec['$q_1$'].unique()
        rows = int(len(q1_J)/2)
        cols = 2
        fig, ax = plt.subplots(rows, cols, figsize=figsize, layout="constrained")
        df_dec.loc[df_dec["RMSE"]>1] = np.nan
        sizes = (100*np.flip(np.arange(1, len(algs)+1))).tolist()
        for j, q in enumerate(q1_J):
            row = j // cols
            sns.scatterplot(ax=ax[row, j%cols], data=df_dec.loc[df_dec['$q_1$']==q], x="$q_2$", hue="Alg.", style="Alg.", y="RMSE", #marker=".", 
                        legend="full", hue_order=algs, style_order=algs, size="Alg.", size_order=algs, sizes=sizes, edgecolor="w", linestyle="--", linewidth=.25)
            logfmt = lambda x: f'$10^{{{np.log10(x):n}}}$'
            ax[row, j%cols].set(yscale="log", xscale="log", xticks=q2, yticks=[1e-2, 1e-1, 1e-0], ylim=[1e-2*0.5, 5], xlim=[q2[0]*0.5, q2[-1]*1.5])
            ax[row, j%cols].set_title("$q_1=${}".format(logfmt(q)), y=1.02, x=0.8, pad=-20)
            if j > 0:
                ax[row, j%cols].get_legend().remove()
            if row < rows-1:
                ax[row, j%cols].set_xlabel("")
                ax[row, j%cols].tick_params(labelbottom=False)
            if j%cols > 0:
                ax[row, j%cols].tick_params(labelleft=False)
                ax[row, j%cols].set_ylabel("")
            if row == 0 or row == 2:
                ax[row, j%cols].set_ylabel("")

            #sns.despine(ax=ax[row, j%cols], top=True, bottom=False, right=True)
        handles, labels = ax[0,0].get_legend_handles_labels()
        ax[0,0].get_legend().remove()
        fig.canvas.draw()
        # Get the bounding boxes of the axes including text decorations
        r = fig.canvas.get_renderer()
        get_bbox = lambda ax: ax.get_tightbbox(r).transformed(fig.transFigure.inverted())
        bboxes = list(map(get_bbox, ax.flat))

        # Get the minimum and maximum extent of axes
        xmin = list(map(lambda b: b.x0, bboxes))
        xmax = list(map(lambda b: b.x1, bboxes))
        fig_center = np.stack([xmax, xmin]).mean()

        lgd = fig.legend(handles, [x.upper() for x in labels], ncols=len(algs), bbox_to_anchor=(fig_center, 1.03), loc="center", frameon=False)
        if filename is not None:
            plt.savefig(filename, dpi=300, bbox_inches="tight")
    return fig
# Plotting for paper on damped
def plot_rmse_ddif(df, algs, T, filename=None, figsize=(8, 5.5)):
    df_dec = df.loc[(df['$q_1$']<1) & (df["Alg."].isin(algs)) & (df["T"] == T)]
    df_dec = df_dec.assign(**{"Alg.": pd.Categorical(df_dec["Alg."], categories=algs)})
    df_dec = df_dec.sort_values(["$q_1$", "Alg."]).reset_index(drop=True)
    df_dec = df_dec.replace({'$q_2$': dict(zip(q2, range(6)))})
    shift = dict(zip(algs, np.arange(-2, 3)*0.15))
    df_dec['$q_2$'] = df_dec.apply(lambda x: x['$q_2$'] + shift[x['Alg.']], axis=1)
    plt.close("all")
    with sns.axes_style("whitegrid"):
        fs = 16
        plt.rc('ytick', labelsize=fs)
        plt.rc('xtick', labelsize=fs)
        plt.rc('axes', labelsize=fs,titlesize=fs)
        plt.rc('legend', fontsize=fs)
        q1_J = df_dec['$q_1$'].unique()
        rows = int(len(q1_J)/3)
        cols = 3
        fig, ax = plt.subplots(rows, cols, figsize=figsize, layout="constrained")
        df_dec.loc[df_dec["RMSE"]>1] = np.nan
        sizes = (150*np.ones(len(algs))).tolist()
        for j, q in enumerate(q1_J):
            row = j // cols
            sns.scatterplot(ax=ax[row, j%cols], data=df_dec.loc[df_dec['$q_1$']==q], x="$q_2$", hue="Alg.", style="Alg.", y="RMSE", #marker=".", 
                        legend="full", hue_order=algs, style_order=algs, size="Alg.", size_order=algs, sizes=sizes, edgecolor="w", linestyle="-", linewidth=.5)
            logfmt = lambda x: f'$10^{{{np.log10(x):n}}}$'
            ax[row, j%cols].set(yscale="log", yticks=[1e-2, 1e-1, 1e-0], ylim=[1e-2*0.5, 5], xticks=range(6), xlim=[-0.5, 5.5], xticklabels=[logfmt(x) for x in q2])
            ax[row, j%cols].set_title("$q_1=${}".format(logfmt(q)), y=1.02, x=0.8, pad=-20)
            if j > 0:
                ax[row, j%cols].get_legend().remove()
            if row < rows-1:
                ax[row, j%cols].set_xlabel("")
                ax[row, j%cols].tick_params(labelbottom=False)
            if j%cols > 0:
                ax[row, j%cols].tick_params(labelleft=False)
                ax[row, j%cols].set_ylabel("")
            if row == 0 or row == 2:
                ax[row, j%cols].set_ylabel("")
            ax[row, j%cols].set_ylabel("")

            #sns.despine(ax=ax[row, j%cols], top=True, bottom=False, right=True)
        handles, labels = ax[0,0].get_legend_handles_labels()
        ax[0,0].get_legend().remove()
        fig.canvas.draw()
        # Get the bounding boxes of the axes including text decorations
        r = fig.canvas.get_renderer()
        get_bbox = lambda ax: ax.get_tightbbox(r).transformed(fig.transFigure.inverted())
        bboxes = list(map(get_bbox, ax.flat))

        # Get the minimum and maximum extent of axes
        xmin = list(map(lambda b: b.x0, bboxes))
        xmax = list(map(lambda b: b.x1, bboxes))
        fig_center = np.stack([xmax, xmin]).mean()

        ymin = list(map(lambda b: b.y0, bboxes))
        ymax = list(map(lambda b: b.y1, bboxes))
        fig.supylabel("RMSE", fontsize=fs, y=np.stack([ymax, ymin]).mean())

        lgd = fig.legend(handles, [x.upper() for x in labels], ncols=len(algs), bbox_to_anchor=(fig_center, 1.03), loc="center", frameon=False)
        if filename is not None:
            plt.savefig(filename, dpi=300, bbox_inches="tight")
    return fig

### Unification plot

In [None]:
algs = ["ekf", "iekf", "diekf"]
T = 1.5
plot_rmse(df, algs, T)
plt.show()

### Damping plot

In [None]:
algs = ["ekf", "iekf", "diekf", "lsiekf", "lsdiekf"]
T = 1.5
plot_rmse_ddif(df, algs, T, figsize=(14, 4))
plt.show()

### Trajectory plots
Plots of the trajectories for all process noise configurations for a specific sampling time.

In [None]:
from matplotlib.gridspec import GridSpec

def plot_traj(T, result, order, ax=plt.gca()):
    if order is None:
        order = list(result.keys())
    res = {key: val for key, val in result.items() if key in order}
    for m in sorted(res.keys(), key = lambda el: order.index(el)):
        pl(ax, res[m].mean, lw=1, label=m.upper())
    gt = get_gt(T, res[m].mean.shape[0])
    pl(ax, gt, color='k', marker='.', linestyle='', label="GT", zorder=0)
    ax.plot(mic_locations[0, :], mic_locations[1, :], 'k*', label="Microphones")
    axmin, axmax = mic_locations.min(axis=1), mic_locations.max(axis=1)
    ext = np.abs(axmax-axmin)
    axmin -= ext/5
    axmax += ext/5
    ax.set_xlim([axmin[0], axmax[0]])
    ax.set_ylim([axmin[1], axmax[1]])
    return ax

def plot_all_trajs(results, order=None):
    N = len(results.keys())
    fig = plt.figure(figsize=(16, int(N/4)*1.5))
    q1 = np.unique([x[0] for x in results.keys()])
    q2 = np.unique([x[1] for x in results.keys()])
    gs = GridSpec(len(q1), len(q2), figure=fig)
    for i, (config, result) in enumerate(results.items()):
        ax = fig.add_subplot(gs[int(np.floor(i/len(q2))), i%len(q2)])
        plot_traj(config[-1], result, order, ax)
        ax.set_title("Config: {}".format(config))
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncols=len(labels))
    gs.tight_layout(fig)
    return fig

In [None]:
T_plot = 4.0
plt.close('all')
plot_all_trajs({key: res for key, res in results.items() if key[-1] == T_plot}, order=["ekf", "iekf", "diekf", "lsiekf", "lsdiekf"])
plt.show()