# Extended Kalman Filter
We implement a basic [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter) based on the linearized nonlinear dynamics about the current operating point. This is referred to as the [Extended Kalman Filter](https://en.wikipedia.org/wiki/Extended_Kalman_filter) and, despite no guarantees on optimality of the estimation, it is extremely popular due to its effectiveness and ease of implementation.

We will use acrobot and cartpole as our model systems. Our goal is to use LQR to achieve the balancing task with the following challenges compared to the ideal full state feedback case:
1. We only get to measure the position variables and consider the velocities unknown.
2. Our sensor measurements are subject to observation noise.
3. Our dynamics are subject to process noise, which we take to be random white noise injected through our actuator dynamics.

In the linear setting, you will often see the general problem of optimal control plus optimal estimation as the [Linear Quadratic Gaussian](https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic%E2%80%93Gaussian_control) problem.

In [239]:
import matplotlib.pyplot as plt
import mujoco
import mujoco.viewer
import numpy as np
import os
from pathlib import Path
import time
import control

xml = """
<mujoco model='test_cartpole'>

  <size nkey="1"/>

  <option integrator="implicitfast" timestep='0.01'/>

  <default>
    <joint damping='0.05' solreflimit='.08 1'/>
  </default>

  <worldbody>
    <camera name='fixed' pos='0 -2.5 0' quat='0.707 0.707 0 0'/>
    <body name='cart' pos='0 0 0'>
      <camera name='cart' pos='0 -2.5 0' quat='0.707 0.707 0 0' />
      <joint name='slider' type='slide' limited='false' pos='0 0 0'
               axis='1 0 0' />
      <geom name='cart' type='box' pos='0 0 0'
              size='0.2 0.1 0.05' rgba='0.7 0.7 0 1' mass="1" />
      <site name='cart sensor' type='box' pos='0 0 0'
              size='0.2 0.1 0.05' rgba='0.7 0.7 0 0' />
      <body name='pole' pos='0 0 0'>
        <camera name='pole'  pos='0 -2.5 0' quat='0.707 0.707 0 0' />
        <joint name='hinge' type='hinge' pos='0 0 0' axis='0 1 0'/>
        <geom name='cpole' type='capsule' fromto='0 0 0 0 0 0.6'
                size='0.01 0.6' rgba='0 0.7 0.7 1' mass="0.0001"/>
        <body name='ball' pos='0 0 0.6'>
          <geom name='ball' type='sphere' size='0.05' rgba='0 0.7 0.7 1' mass="0.2"/>
        </body>
      </body>
    </body>

  </worldbody>

  <actuator>
    <motor name='slide' joint='slider' gear='1' ctrllimited='false' />
  </actuator>

  <sensor>
    <accelerometer name="accelerometer" site="cart sensor"/>
    <touch name="collision" site="cart sensor"/>
  </sensor>

  <keyframe>
    <key name="hanging_down" qpos="0 1.57"/>
  </keyframe>

</mujoco>
"""
cartpole = mujoco.MjModel.from_xml_string(xml)

In [240]:
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
os.environ["MUJOCO_GL"] = "egl"

model_name = f"double_pendulum"

q_ref = [0.5, 1.0]
model_path = Path("mujoco_models") / (str(model_name) + str(".xml"))
    # Load the model and data
acrobot = mujoco.MjModel.from_xml_path(str(model_path.absolute()))

### Linearized Dynamics:
We can use mujoco's built in functions to linearize our dynamics around the current state.

In [241]:
def discrete_jacobian(model, data):
    """
    Compute the Jacobian of the dynamics function numerically.
    """
    n = model.nv * 2
    m = model.nu
    A = np.zeros((n, n))  # Full state-state Jacobian
    B = np.zeros((n, m))  # Full state-control Jacobian
    mujoco.mj_forward(model, data)
    # Compute finite difference Jacobian using MuJoCo
    eps = 1e-5
    flg_centered = 1  # Use centered difference for better accuracy
    mujoco.mjd_transitionFD(model, data, eps, flg_centered, A, B, None, None)
    
    return A, B

### Standard LQR Controller:

In [242]:
def controller(model, data, x, A, B, Q, R):
    #put the controller here. This function is called inside the simulation.
    
    K, S, E = control.dlqr(A, B, Q, R)
    # x = np.hstack((data.qpos, data.qvel))
    u = - K @ x
    data.ctrl[0] = u[0]

### Kalman Filter:

Due to the recursive nature of the Kalman Filter, we only need to keep track of our state and covariance estimates. Everything else can be done within a function that is called once per step to update these estimates. The following function performs that operation.

In [243]:
def kalman_filter_update(x_prev, P_prev, z, u, A, B, C, W, V):
    """
    Perform one step of the Kalman Filter update.
    
    Parameters:
    - x_prev: np.array, previous state estimate (n,)
    - P_prev: np.array, previous covariance estimate (n,n)
    - z: np.array, current measurement (m,)
    - u: np.array, control input (l,)
    - A: np.array, state transition matrix (n,n)
    - B: np.array, control input matrix (n,l)
    - C: np.array, observation matrix (m,n)
    - W: np.array, process noise covariance (n,n)
    - V: np.array, measurement noise covariance (m,m)

    Returns:
    - x_post: np.array, updated state estimate (n,)
    - P_post: np.array, updated covariance estimate (n,n)
    """
    
    # Prediction
    x_pred = A @ x_prev + B @ u
    P_pred = A @ P_prev @ A.T + W

    # Innovation
    y = z - C @ x_pred
    S = C @ P_pred @ C.T + V

    # Kalman Gain
    L = (np.linalg.solve(S, C @ P_pred)).T


    # Update
    x_post = x_pred + L @ y
    P_post = (np.eye(P_prev.shape[0]) - L @ C) @ P_pred

    return x_post, P_post


### Simulate Model
Feel free to try both models and play around with initial conditions, noise variances, LQR gains, etc.

In [249]:
def simulate_model(model):
    data = mujoco.MjData(model)

    # Initial Joint Positions
    data.qpos[0] = 0.0
    data.qpos[1] = 0.0
    
    # Process variance
    W_var = 0.0001

    # Observation Variance
    V_var = 0.0005
    V = V_var * np.eye(2)
    
    # Initialize Filter state
    P_prev = np.eye(4)
    x_prev = [0.0, 0.0, 0.0, 0.0]

    # Measurement Model (only measure positions)
    C = np.hstack((np.eye(2), np.zeros((2,2))))

    # LQR Gains
    Q = np.diag([1,1,1,1])
    R = 1

    model.dof_damping[:] = 0.00
    with mujoco.viewer.launch_passive(model, data) as viewer:
        viewer.cam.azimuth = 90
        viewer.cam.elevation = 5
        viewer.cam.distance =  8
        viewer.cam.lookat = np.array([0.0 , 0.0 , 0.0])
        while viewer.is_running():
            step_start = time.time()
            first_time = time.time()

            # Get model linearized at current x
            A, B = discrete_jacobian(model, data)
            
            # Calculate LQR control
            controller(model, data, x_prev, A, B, Q, R)

            # Take a "measurement" (here we just add noise)
            z = data.qpos + np.random.normal(0.0, np.sqrt(V_var), 2)

            # Process noise is white noise through the actuator, so we model its covariance as follows:
            W = B @ B.T * W_var

            # Kalman step
            x_post, P_post = kalman_filter_update(x_prev, P_prev, z, data.ctrl, A, B, C, W, V)

            
            print("Estimation Error: ", x_post - np.hstack((data.qpos, data.qvel)))

            # store new estimations
            x_prev = x_post
            P_prev = P_post
            
            # Add process noise
            data.ctrl += np.random.normal(0.0, np.sqrt(W_var))

            # Step simulation
            mujoco.mj_step(model, data)
            

            # Pick up changes to the physics state, apply perturbations, update options from GUI.
            viewer.sync()

            # Rudimentary time keeping, will drift relative to wall clock.
            time_until_next_step = model.opt.timestep - (time.time() - step_start)
            if time_until_next_step > 0:
                time.sleep(time_until_next_step)

In [250]:
if __name__ == "__main__":
    # Simulate the model
    # simulate_model(cartpole)
    simulate_model(acrobot)

Estimation Error:  [ 0.02177354 -0.02584596  0.00617334 -0.01571513]
Estimation Error:  [ 0.01267785 -0.03049425 -0.27797149  0.3304579 ]
Estimation Error:  [ 0.01486348 -0.03636049  0.4590458  -1.82300108]
Estimation Error:  [ 0.026403   -0.03574127  1.1851668  -2.84371231]
Estimation Error:  [ 7.22416808e-04 -3.64894473e-03 -9.68609099e-01  2.33893983e+00]
Estimation Error:  [-0.01273336  0.01813709 -1.31346571  3.0872093 ]
Estimation Error:  [  0.04304171  -0.13610193   7.7019237  -25.03633021]
Estimation Error:  [-0.02888793  0.10968114 -9.55398042 28.00589965]
Estimation Error:  [-0.0686771   0.28626597 -8.00057789 24.73952491]
Estimation Error:  [-0.11172828  0.42996778 -6.68618695 21.7638616 ]
Estimation Error:  [-0.10851163  0.46401025 -3.00820647 11.14554937]
Estimation Error:  [-0.0946052   0.41848928 -0.76200046  3.68522389]
Estimation Error:  [-0.06509152  0.33249122  0.59572656 -0.80029857]
Estimation Error:  [-0.051109    0.23101264  1.14307822 -3.37820747]
Estimation Err

LinAlgError: Failed to find a finite solution.