In [7]:
from jax import jit, random, numpy as jnp
from lqg.tracking import BoundedActor
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatLogSlider

## Influence of initial uncertainty
Starting with an initial belief $p(x_0) = \mathcal{N}(\hat x_0, \Sigma_0)$, we can compute the next belief $p(x_{t+1} | y_{1:t }) = \mathcal{N}(\hat x_{t+1}, \Sigma_{t+1})$ using the Kalman filter udpate equations

$$\hat{{x}}_{t+1} = A_t \hat{{x}}_t + B_t {u}_t + K_t ({y}_t - H_t \hat {x}_t ), \\
\Sigma_{t+1} = (A_t - K_t H_t) \Sigma_t A_t^T + V_t V_t^T,
$$


Here, we are interested in the effect of changing the initial uncertainty $\Sigma_0$ on tracking behavior. We use the bounded actor model from our eLife paper (Straub & Rothkopf, 2022).

By changing the initial uncertainty about the target position, we can change the agent's behavior, especially in the beginning of the trial. The idea is to run an experiment in which we first let people train in different environments, in which there is either high or low uncertainty about the initial target position. After having them learn the distribution of initial target positions, we assume that they use this as their initial uncertainty in the Kalman filter. We then test them both in the same environment, e.g. one in which the target always starts in the middle and see whether their initial belief about the target position has an effect on their behavior.

The effect of this is probably different depending on the other parameters of the model:

- cost $c$
- motor noise $\sigma_c$
- and especially perceptual uncertainty $\sigma$

The goal in this notebook is to find out which summary statistics (tracking errors, cross-correlation etc.) of the tracking data might be affected by the initital uncertainty.

In [11]:
dt = 1. / 60.
duration = 2.5
T = int(duration / dt)

In [12]:
model = BoundedActor(c=0.5, motor_noise=0.5, sigma=60., T=T)

@jit
def simulate(init_std=1.):
    x = model.simulate(random.PRNGKey(0), n=20, Sigma0=jnp.diag(jnp.array([init_std, 0.5])**2))

    return x

In [13]:
@interact(init_std=FloatLogSlider(value=1., min=-2, max=2))
def plot(init_std):

    x = simulate(init_std=init_std)

    plt.plot(x[0,:,0])
    plt.plot(x[0,:,1])
    
    print(jnp.sum((x[..., 0] - x[..., 1])**2))


interactive(children=(FloatLogSlider(value=1.0, description='init_std', max=2.0, min=-2.0), Output()), _dom_cl…