# Motivational example for damped DIF

In [None]:
%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.widgets import Slider
import ssmjax.algs as algs
import ssmjax.transforms.linearization as linearization
import ssmjax.types as types
%matplotlib widget

### Setup the problem and the algorithms

Transition and measurement functions are changeable -- it does obviously change the end result. In the end of the notebook, there is an interactive tool that lets you play with the parameters of the model to find particular parameter values and initial points that cause divergence for the different filters.

In [None]:
def transition(x, *args, **kwargs):
    dx = jnp.cos(x)*jnp.sin(x)*x**2
    return jnp.atleast_1d(jnp.vstack([dx]).squeeze())
def measurement(x, *args, **kwargs):
    y = jnp.arctan(x)
    return jnp.atleast_1d(jnp.vstack([y]).squeeze())

Q = jnp.array([[1.]])
R = jnp.array([[1.]])
x0 = jnp.array([1.])
x = jnp.array([0.])
P0 = jnp.array([[1.]])
y = jnp.array([2.])
x1 = np.linspace(-5, 5, 100)
x2 = np.linspace(-5, 5, 100)
def loss(x1, x2, y, Q, R, x0, P0, x):
    # Compare the "true" transition -- xT is the true state
    te = jnp.atleast_1d(x2 - transition(x1))
    # Compare the "true" measurement -- x is the true state
    me = jnp.atleast_1d(measurement(transition(x)) - measurement(x2))
    # pe = jnp.atleast_1d(x1 - x0)
    pe = jnp.atleast_1d(x1 - x0)
    return (
        te @ jnp.linalg.solve(Q, te[:, None])
        + me @ jnp.linalg.solve(R, me[:, None])
        + pe @ jnp.linalg.solve(P0, pe[:, None])
    ).squeeze()

mloss = jax.jit(jax.vmap(loss, (0, 0, None, None, None, None, None, None), 0))
vloss = jax.jit(
    jax.vmap(
        jax.vmap(loss, (0, None, None, None, None, None, None, None), 0),
        (None, 0, None, None, None, None, None, None),
        0,
    )
)

import jaxopt
theta = types.MVNormal(jnp.zeros(1,), cov=None)
propagate = jax.jit(jax.tree_util.Partial(algs.build_propagate(types.LinearizationMethod(linearization.first_taylor), transition), theta=theta))
update = jax.jit(jax.tree_util.Partial(algs.build_update(types.LinearizationMethod(linearization.first_taylor), measurement), theta=theta))
smooth = jax.jit(jax.tree_util.Partial(algs.build_smooth(types.LinearizationMethod(linearization.first_taylor), transition), theta=theta))

def ls_loss(
        prior,
        observation,
        observation_covariance,
        transition_covariance,
        updated_iterate,
        smoothed_iterate,
    ):
        ye = observation - measurement(updated_iterate)
        te = updated_iterate - transition(smoothed_iterate)
        xp = prior.mean - smoothed_iterate
        # Observation loss
        Vy = ye.T @ jnp.linalg.solve(observation_covariance, ye)
        # Transition loss
        Vt = te.T @ jnp.linalg.solve(transition_covariance, te)
        # Prior loss
        Vp = xp.T @ jnp.linalg.solve(prior.cov, xp)
        return Vy + Vt + Vp

def find_stepsize(previous_state, observation, observation_covariance, transition_covariance, old_updated, old_smoothed, current_updated, current_smoothed, options):
    def fun(x):
            new_updated, new_smoothed = jnp.split(x, 2)
            V = ls_loss(
                previous_state,
                observation,
                observation_covariance,
                transition_covariance,
                new_updated.squeeze(),
                new_smoothed.squeeze(),
            )
            return V
    xk = jnp.vstack([old_updated.mean, old_smoothed.mean]).squeeze()
    pi_upd = current_updated.mean - old_updated.mean
    pi_smth = current_smoothed.mean - old_smoothed.mean
    pi = jnp.vstack([pi_upd, pi_smth]).squeeze()

    ls = jaxopt.BacktrackingLineSearch(
        fun=fun,
        maxiter=20,
        condition="strong-wolfe",
        decrease_factor=options.linesearch.beta,
        c1=options.linesearch.gamma,
    )
    stepsize, _ = ls.run(init_stepsize=1.0, params=xk, descent_direction=pi)
    return stepsize

def compute_ekf_solution(y, Q, R, x0, P0):
    prior = types.MVNormal(x0, cov=P0)
    predicted, _ = propagate(prior=prior, transition_covariance=Q, linearization_point=types.MVNormal(x0, cov=P0))
    _, (updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=predicted)
    return prior, updated

def iekf_inner_loop(carry, _):
    predicted, y, R, updated = carry
    _, (new_updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=updated)
    return (predicted, y, R, new_updated), new_updated

def compute_iekf_solution(y, Q, R, x0, P0, iterations):
    prior = types.MVNormal(x0, cov=P0)
    predicted, _  = propagate(prior=prior, transition_covariance=Q, linearization_point=prior)
    _, (updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=predicted)
    
    _, seq_updated = jax.lax.scan(iekf_inner_loop, (predicted, y, R, updated), jnp.arange(iterations-1))
    seq_updated = jax.tree_map(
                lambda y, z: jnp.concatenate([y[None, ...], z], 0),
                updated,
                seq_updated,
            )
    seq_prior = jax.tree_map(lambda x: jnp.repeat(x[None, ...], seq_updated.mean.shape[0], 0), prior)
    return seq_prior, seq_updated

def diekf_inner_loop(carry, _):
    prior, y, R, Q, updated, smoothed = carry
    predicted, _ = propagate(prior=prior, transition_covariance=Q, linearization_point=smoothed)
    _, (new_updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=updated)
    new_smoothed, _ = smooth(filtered_state=prior, smoothed_state=updated, transition_covariance=Q, linearization_point=smoothed)
    return (prior, y, R, Q, new_updated, new_smoothed), (new_updated, new_smoothed)

def compute_diekf_solution(y, Q, R, x0, P0, iterations):
    prior = types.MVNormal(x0, cov=P0)
    predicted, _  = propagate(prior=prior, transition_covariance=Q, linearization_point=prior)
    _, (updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=predicted)
    smoothed, _ = smooth(filtered_state=prior, smoothed_state=updated, transition_covariance=Q, linearization_point=prior)
    
    _, (seq_updated, seq_smoothed) = jax.lax.scan(diekf_inner_loop,
                                                  (prior, y, R, Q, updated, smoothed),
                                                  jnp.arange(iterations-1))
    seq_updated = jax.tree_map(
                lambda y, z: jnp.concatenate([y[None, ...], z], 0),
                updated,
                seq_updated,
            )
    seq_smoothed = jax.tree_map(
                lambda y, z: jnp.concatenate([y[None, ...], z], 0),
                smoothed,
                seq_smoothed,
            )
    return seq_smoothed, seq_updated

def lsdiekf_inner_loop(carry, _):
    prior, y, R, Q, options, updated, smoothed = carry
    predicted, _ = propagate(prior=prior, transition_covariance=Q, linearization_point=smoothed)
    _, (new_updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=updated)
    new_smoothed, _ = smooth(filtered_state=prior, smoothed_state=updated, transition_covariance=Q, linearization_point=smoothed)
    stepsize = find_stepsize(prior, y, R, Q, updated, smoothed, new_updated, new_smoothed, options)
    new_updated = types.MVNormal((1-stepsize)*updated.mean + stepsize*new_updated.mean, cov=new_updated.cov)
    new_smoothed = types.MVNormal((1-stepsize)*smoothed.mean + stepsize*new_smoothed.mean, cov=new_smoothed.cov)
    return (prior, y, R, Q, options, new_updated, new_smoothed), (new_updated, new_smoothed)

def compute_lsdiekf_solution(y, Q, R, x0, P0, iterations, options=types.options.LineSearchIterationOptions()):
    prior = types.MVNormal(x0, cov=P0)
    predicted, _  = propagate(prior=prior, transition_covariance=Q, linearization_point=prior)
    _, (updated, _) = update(predicted=predicted, observation=y, observation_covariance=R, linearization_point=predicted)
    smoothed, _ = smooth(filtered_state=prior, smoothed_state=updated, transition_covariance=Q, linearization_point=prior)

    _ , (seq_updated, seq_smoothed) = jax.lax.scan(lsdiekf_inner_loop, (prior, y, R, Q, options, updated, smoothed), jnp.arange(iterations-1))

    seq_updated = jax.tree_map(
                lambda y, z: jnp.concatenate([y[None, ...], z], 0),
                updated,
                seq_updated,
            )
    seq_smoothed = jax.tree_map(
                lambda y, z: jnp.concatenate([y[None, ...], z], 0),
                smoothed,
                seq_smoothed,
            )
    return seq_smoothed, seq_updated

### Compute solutions for paper example
This obviously only works **if** you haven't changed the transition and measurement function definitions.

In [None]:
q = jnp.array([[0.1]])
r = jnp.array([[1.]])
p = jnp.array([[1.]])
xk = jnp.array([-3.2])
y = measurement(transition(xk))
x0n = jnp.array([-2.9])
ekf_opt = compute_ekf_solution(y, q, r, x0n, p)
iekf_opt = compute_iekf_solution(y, q, r, x0n, p, 10)
diekf_opt = compute_diekf_solution(y, q, r, x0n, p, 10)
lsdiekf_opt = compute_lsdiekf_solution(y, q, r, x0n, p, 10)

### Produce paper example plot

In [None]:
x1 = np.arange(-5., -2.01, 0.05)
x2 = np.arange(-5., 5.01, 0.05)
X1, X2 = jnp.meshgrid(x1, x2)
V = vloss(x1, x2, y, q, r, x0n, p, xk)
(row, col) = np.unravel_index(V.argmin(), V.shape)
Vx0 = loss(x0n, transition(x0n), y, q, r, x0n, p, xk)
plt.close("all")
fs = 16
plt.rc('ytick', labelsize=fs)
plt.rc('xtick', labelsize=fs)
fig = plt.figure(figsize=(8, 8))#, layout="constrained")
gs = GridSpec(nrows=4, ncols=4, figure=fig, left=0.075, right=.99,
            bottom=.075, top=.99,
                      hspace=0.0, wspace=0.0)
ax = []
joint = fig.add_subplot(gs[:-1, 1:])
marg2 = fig.add_subplot(gs[:-1, 0], sharey=joint)
marg1 = fig.add_subplot(gs[-1, 1:], sharex=joint)

joint.contourf(X1, X2, V, levels=15, cmap="Oranges")
joint.plot(X1[row, col], X2[row, col], '*', color='tab:orange', label='Optima', markersize=14)
joint.hlines(X2[row, col], xmin=x1.min(), xmax=x1.max(), color='tab:orange', lw=.5)
joint.vlines(X1[row, col], ymin=x2.min(), ymax=x2.max(), color='tab:orange', lw=.5)
joint.plot(iekf_opt[0].mean, iekf_opt[1].mean, '--', marker='.', color='tab:blue', markersize=14, label='IEKF')
joint.plot(diekf_opt[0].mean, diekf_opt[1].mean, '--', marker='.', color='tab:green', markersize=14, label='DIEKF')
joint.plot(lsdiekf_opt[0].mean, lsdiekf_opt[1].mean, '--', marker='.', color='tab:red', markersize=14, label='LSDIEKF')
joint.plot(x0n, transition(x0n), 'k', marker='*', lw=.5, label='Prior')
joint.vlines(x0n, ymin=x2.min(), ymax=x2.max(), color='k', lw=.5)
joint.hlines(transition(x0n), xmin=x1.min(), xmax=x1.max(), color='k', lw=.5)
joint.set(ylim=[x2.min(), x2.max()], 
          xlim=[x1.min(), x1.max()],)
joint.tick_params(labelbottom=False, bottom=False, left=False, labelleft=False)
leg = joint.legend(fontsize=16, loc='lower left', bbox_to_anchor=(-.35, -.35))
leg.set_in_layout(False)

marg2.plot(V[:, col], X2[:, col], 'k')
marg2.plot(V[row, col], X2[row, col], '*', color='tab:orange', markersize=14)
marg2.set(xticklabels='', xticks=[], xlim=[V[:, col].min()-0.1*V[:,col].mean(), V[:, col].max()])
marg2.set_ylabel("$X_1$", fontsize=16)
marg1.plot(X1[row, :], V[row, :], 'k')
marg1.plot(X1[row, col], V[row, col], '*', color='tab:orange', markersize=14, label='Optima')
marg1.set(yticklabels='', yticks=[], ylim=[V[row, :].min()-0.1*V[row, :].mean(), V[row, :].max()+0.1*V[row,:].mean()])
marg1.set_xlabel("$X_0$", fontsize=16)
marg1.vlines(X1[row, col], ymin=marg1.get_ylim()[0], ymax=marg1.get_ylim()[1], color='tab:orange', lw=.5)
marg2.hlines(X2[row, col], xmin=marg2.get_xlim()[0], xmax=marg2.get_xlim()[1], color='tab:orange', lw=.5)

iekf_cost = mloss(iekf_opt[0].mean, iekf_opt[1].mean, y, q, r, x0n, p, xk)
marg1.plot(iekf_opt[0].mean, iekf_cost, '--', marker='.', markersize=14, color='tab:blue', label='IEKF')
marg2.plot(iekf_cost, iekf_opt[1].mean, '--', marker='.', markersize=14, color='tab:blue')
diekf_cost = mloss(diekf_opt[0].mean, diekf_opt[1].mean, y, q, r, x0n, p, xk)
marg1.plot(diekf_opt[0].mean, diekf_cost, '--', marker='.', markersize=14, color='tab:green', label='DIEKF')
marg2.plot(diekf_cost, diekf_opt[1].mean, '--', marker='.', markersize=14, color='tab:green')
lsdiekf_cost = mloss(lsdiekf_opt[0].mean, lsdiekf_opt[1].mean, y, q, r, x0n, p, xk)
marg1.plot(lsdiekf_opt[0].mean, lsdiekf_cost, '--', marker='.', markersize=14, color='tab:red', label='LSDIEKF')
marg2.plot(lsdiekf_cost, lsdiekf_opt[1].mean, '--', marker='.', markersize=14, color='tab:red')
marg2.yaxis.set_label_coords(-.15, 0.5)
marg1.vlines(x0n, ymin=marg1.get_ylim()[0], ymax=marg1.get_ylim()[1], color='k', lw=.5)
marg2.hlines(transition(x0n), xmin=marg2.get_xlim()[0], xmax=marg2.get_xlim()[1], color='k', lw=.5)

plt.savefig("damped_dif_motivation.eps")
plt.show()

### "Divergence finder"

Interactive tool to find specific examples that diverge. Plots the loss and the IEKF/DIEKF solutions.

In [None]:
X1, X2 = jnp.meshgrid(x1, x2)
V = vloss(x1, x2, y, Q, R, x0, P0, x)
(row, col) = np.unravel_index(V.argmin(), V.shape)

plt.close("all")
fig = plt.figure(figsize=(10, 8))
gs = GridSpec(5, 6, figure=fig)

ax = []
joint = fig.add_subplot(gs[:-1, 1:-1])
marg2 = fig.add_subplot(gs[:-1, 0])
marg1 = fig.add_subplot(gs[-1, 1:-1])

marg1_loss, = marg1.plot(X1[row, :], V[row, :])
marg1_opt, = marg1.plot(X1[row, col], V[row, col], '*', color='tab:orange')
marg2_loss, = marg2.plot(V[:, col], X2[:, col])
marg2_opt, = marg2.plot(V[row, col], X2[row, col], '*', color='tab:orange')

marg1.set(xlim=[x1.min(), x1.max()], yticklabels='', ylim=[V[row, :].min(), V[row, :].max()])
marg2.set(ylim=[x2.min(), x2.max()], xticklabels='', xlim=[V[:, col].min(), V[:, col].max()])

ekf_opt = compute_ekf_solution(measurement(transition(x)), Q, R, x0, P0)
iekf_opt = compute_iekf_solution(measurement(transition(x)), Q, R, x0, P0, 3)
diekf_opt = compute_diekf_solution(measurement(transition(x)), Q, R, x0, P0, 3)
ekf_cost = mloss(ekf_opt[0].mean, ekf_opt[1].mean, y, Q, R, x0, P0, x)
iekf_cost = mloss(iekf_opt[0].mean, iekf_opt[1].mean, y, Q, R, x0, P0, x)
diekf_cost = mloss(diekf_opt[0].mean, diekf_opt[1].mean, y, Q, R, x0, P0, x)

marg1_ekf, = marg1.plot(ekf_opt[0].mean, ekf_cost, '--', marker='^', color='tab:red')
marg1_iekf, = marg1.plot(iekf_opt[0].mean, iekf_cost, '--', marker='^', color='tab:purple')
marg1_diekf, = marg1.plot(diekf_opt[0].mean, diekf_cost, '--', marker='^', color='tab:green')
marg2_ekf, = marg2.plot(ekf_cost, ekf_opt[1].mean, '--', marker='^', color='tab:red')
marg2_iekf, = marg2.plot(iekf_cost, iekf_opt[1].mean, '--', marker='^', color='tab:purple')
marg2_diekf, = marg2.plot(diekf_cost, diekf_opt[1].mean, '--', marker='^', color='tab:green')

ekfsol, = joint.plot(ekf_opt[0].mean, ekf_opt[1].mean, '--', marker='^', color='tab:red', label='EKF')
iekfsol, = joint.plot(iekf_opt[0].mean, iekf_opt[1].mean, '--', marker='^', color='tab:purple', label='IEKF')
diekfsol, = joint.plot(diekf_opt[0].mean, diekf_opt[1].mean, '--', marker='^', color='tab:green', label='DIEKF')

C = joint.contourf(X1, X2, V, levels=20)
joint.set(xlabel="$x_1$", ylabel="$x_2$")
optima, = joint.plot(X1[row, col], X2[row, col], '*', color='tab:orange', label='Optima')
joint.vlines(x0, ymin=x2.min(), ymax=x2.max(), color='k', label='$x_0$')
joint.hlines(X2[row, col], xmin=x1.min(), xmax=x1.max(), color='tab:orange', lw=.5)
joint.vlines(X1[row, col], ymin=x2.min(), ymax=x2.max(), color='tab:orange', lw=.5)
joint.hlines(transition(x), xmin=x1.min(), xmax=x1.max(), color='w', lw=.5)
joint.vlines(x, ymin=x2.min(), ymax=x2.max(), color='w', lw=.5)
xt, = joint.plot(x, transition(x), '*', color='w', label='True state')
obs, = joint.plot(measurement(transition(x)), transition(x), 'y*', label='$y$')
joint.plot(x1, transition(x1), color='w', label='Transition function')
joint.plot(measurement(x2), x2, color='y', label='Measurement function')
joint.set(title="Optimal point, $(x_1,x_2)=({:.2f},{:.2f})$\nTrue point, $(x_1,x_2)=({:.2f},{:.2f})$".format(X1[row, col], X2[row, col], x[0], transition(x)[0]), ylim=[x2.min(), x2.max()], xlim=[x1.min(), x1.max()])

joint.legend(fontsize=8, loc='upper left', bbox_to_anchor=(1, 1))
plt.suptitle("Loss contours")
joint.set(xticklabels='', yticklabels='')

fig.subplots_adjust(bottom=0.25)

x0obs = fig.add_axes([0.1, 0.1, 0.3, 0.03])
x0_slider = Slider(
    ax=x0obs,
    label='$x_0$',
    valmin=x1.min(),
    valmax=x1.max(),
    valinit=x0[0],
    valstep=.1,
)

xobs = fig.add_axes([0.1, 0.05, 0.3, 0.03])
x_slider = Slider(
    ax=xobs,
    label='$x$',
    valmin=x1.min(),
    valmax=x1.max(),
    valinit=x[0],
    valstep=.1,
)

Qobs = fig.add_axes([0.6, 0.15, 0.3, 0.03])
Q_slider = Slider(
    ax=Qobs,
    label='$Q$',
    valmin=-4,
    valmax=4,
    valinit=np.log(Q[0, 0]),
    valstep=.5,
)
Q_slider.valtext.set_text(Q[0,0])

Robs = fig.add_axes([0.6, 0.1, 0.3, 0.03])
R_slider = Slider(
    ax=Robs,
    label='$R$',
    valmin=-4,
    valmax=4,
    valinit=np.log(R[0, 0]),
    valstep=.5,
)
R_slider.valtext.set_text(R[0,0])

P0obs = fig.add_axes([0.6, 0.05, 0.3, 0.03])
P0_slider = Slider(
    ax=P0obs,
    label='$P_0$',
    valmin=-4,
    valmax=4,
    valinit=np.log(P0[0, 0]),
    valstep=.5,
)
P0_slider.valtext.set_text(P0[0,0])

def slider_update(val):
    q = 10**Q_slider.val
    Q_slider.valtext.set_text(q)
    r = 10**R_slider.val
    R_slider.valtext.set_text(r)
    p = 10**P0_slider.val
    P0_slider.valtext.set_text(p)
    
    xk = transition(x_slider.val)
    y = measurement(xk)

    V = vloss(x1, x2, y, np.atleast_2d(q), np.atleast_2d(r), x0_slider.val, np.atleast_2d(p), x_slider.val)
    (row, col) = np.unravel_index(V.argmin(), V.shape)
    
    for coll in joint.collections:
        coll.remove()
    C = joint.contourf(X1, X2, V, levels=20)
    optima.set_data((X1[row, col],), (X2[row, col],))
    obs.set_data((y,), (xk,))
    xt.set_data((x_slider.val,), (xk,))
    joint.vlines(x0_slider.val, ymin=x2.min(), ymax=x2.max(), color='k', label='$x_0$')
    joint.hlines(X2[row, col], xmin=x1.min(), xmax=x1.max(), color='tab:orange', lw=.5)
    joint.vlines(X1[row, col], ymin=x2.min(), ymax=x2.max(), color='tab:orange', lw=.5)
    joint.hlines(xk, xmin=x1.min(), xmax=x1.max(), color='w', lw=.5)
    joint.vlines(x_slider.val, ymin=x2.min(), ymax=x2.max(), color='w', lw=.5)

    marg1_loss.set_data((X1[row, :],), (V[row, :],))
    marg1_opt.set_data((X1[row, col],), (V[row, col],))
    marg2_loss.set_data((V[:, col],), (X2[:, col],))
    marg2_opt.set_data((V[row, col],), (X2[row, col],))
    marg1.set(ylim=[V[row, :].min(), V[row, :].max()])
    marg2.set(xlim=[V[:, col].min(), V[:, col].max()])

    esol = compute_ekf_solution(np.atleast_1d(y), np.atleast_2d(q), np.atleast_2d(r), np.atleast_1d(x0_slider.val), np.atleast_2d(p))
    ekf_cost = mloss(esol[0].mean, esol[1].mean, y, Q, R, x0, P0, x)
    marg1_ekf.set_data((esol[0].mean,), (ekf_cost,))
    marg2_ekf.set_data((ekf_cost,), (esol[1].mean,))
    ekfsol.set_data((esol[0].mean,), (esol[1].mean,))
    ###
    iesol = compute_iekf_solution(np.atleast_1d(y), np.atleast_2d(q), np.atleast_2d(r), np.atleast_1d(x0_slider.val), np.atleast_2d(p), 5)
    iekf_cost = mloss(iesol[0].mean, iesol[1].mean, y, Q, R, x0, P0, x)
    marg1_iekf.set_data((iesol[0].mean,), (iekf_cost,))
    marg2_iekf.set_data((iekf_cost,), (iesol[1].mean,))
    iekfsol.set_data((iesol[0].mean,), (iesol[1].mean,))
    ###
    diesol = compute_diekf_solution(np.atleast_1d(y), np.atleast_2d(q), np.atleast_2d(r), np.atleast_1d(x0_slider.val), np.atleast_2d(p), 5)
    diekf_cost = mloss(diesol[0].mean, diesol[1].mean, y, Q, R, x0, P0, x)
    diekfsol.set_data((diesol[0].mean,), (diesol[1].mean,))
    marg1_diekf.set_data((diesol[0].mean,), (diekf_cost,))
    marg2_diekf.set_data((diekf_cost,), (diesol[1].mean,))

    joint.set_title("Optimal point, $(x_1,x_2)=({:.2f},{:.2f})$\nTrue point, $(x_1,x_2)=({:.2f},{:.2f})$".format(X1[row, col], X2[row, col], x_slider.val, transition(x_slider.val)[0]))
    plt.draw()

x0_slider.on_changed(slider_update)
x_slider.on_changed(slider_update)
Q_slider.on_changed(slider_update)
R_slider.on_changed(slider_update)
P0_slider.on_changed(slider_update)

plt.show()