# Volume 4: Linear Quadratic Gaussian Control
    <Name>
    <Class>
    <Date>

In [None]:
import numpy as np
from jax import numpy as jnp
from scipy.linalg import inv
from matplotlib import pyplot as plt

from utils import Simulator, Estimator, Controller
from animate import animate2d

# Plotting Helper Function

In [None]:
def plot(
    T: float,
    N: int,
    xs: np.ndarray,
    us: np.ndarray,
    ests: np.ndarray | None = None,
):
    """Plot position, velocity, and controls in x, y, and z.

    Parameters
    ----------
        T
            length of time
        N
            number of steps 0, ..., N
        xs
            sequence of states x_k = (sx, sy, sz, vx, vy, vz) for k in
            (0, ..., N)
            shape: (N+1, 3)
        us
            sequence of controls u_k = (ux, uy, uz) for k in (0, ..., N)
            shape: (N, 3)
        ests
            sequence of state estimates \hat x_k with same coordinates and shape
            as xs
    Returns
    -------
        fig
            the pyplot figure
        axs
            the array of pyplot axs
    """

    fig, axs = plt.subplots(3, 3, figsize=(8, 6))
    s = 0.1

    tls = np.linspace(0, T, N + 1)

    # Plot position and velocity.
    [axs[0, i].plot(tls, xs[:, i], color="blue") for i in range(3)]
    [axs[1, i - 3].plot(tls, xs[:, i], color="orange") for i in range(3, 6)]

    var_list = ["$x$", "$y$", "$z$"]
    [axs[0, i].set_title(var_list[i]) for i in range(3)]

    word_list = ["Position", "Velocity", "Control"]
    [axs[i, 0].set_ylabel(word_list[i]) for i in range(3)]

    # Plot state estimates (if provided).
    if ests is not None:
        [
            axs[0, i].scatter(tls, ests[:, i], color="blue", s=s)
            for i in range(3)
        ]
        [
            axs[1, i - 3].scatter(tls, ests[:, i], color="orange", s=s)
            for i in range(3, 6)
        ]

    # Plot controls.
    [axs[2, i].scatter(tls[:-1], us[:, i], color="red", s=1) for i in range(3)]
    [axs[2, i].set_xlabel("Time") for i in range(3)]

    [ax.axhline(0, color="black", lw=1) for axrow in axs for ax in axrow]

    fig.suptitle("Return of the Shuttle")
    fig.tight_layout()
    axs[1, 0].get_yaxis().set_label_coords(-0.26, 0.5)
    axs[0, 0].get_yaxis().set_label_coords(-0.26, 0.5)
    axs[2, 0].get_yaxis().set_label_coords(-0.26, 0.5)

    return fig, axs

# Problem 1

In [None]:
class LQR:
    """
    Given a transition system and cost matrices, compute an optimal trajectory,
    an associated control, and an optimal control rule given state feedback.
    """

    def __init__(self):
        raise (NotImplementedError)

    def fit(self):
        """
        Computes and saves the gain matrices Ks
        """
        raise (NotImplementedError)

    def compute_control(self):
        """
        Returns the optimal control uk

        Parameters:
            k (int) - The index of the state vector
            x_k (ndarray) - The state at index k

        Returns:
            u_k (ndarray) - The optimal control at index k
        """
        raise (NotImplementedError)

# Problem 2

# Problem 3

# Problem 4

In [None]:
class KalmanFilter:
    """
    Implementation of the Kalman Filter

    Given observations and a model, estimates true states by blending the model
    and observations according to approximated error
    """

    def __init__(self):
        raise (NotImplementedError)

    def fit(self):
        """
        Computes and saves the gain matrices Lks
        """
        raise (NotImplementedError)

    def predict_state(self):
        """
        Returns the next predicted state

        Parameters:
            x_k (ndarray) - The last estimated state
            u_k (ndarray) - The control at the given index

        Returns:
            x_k1|k (ndarray) - The next predicted state
        """
        raise (NotImplementedError)

    def update_state(self):
        """
        Returns the next estimated state

        Parameters:
            k (int) - The index of the state
            x_k1|k (ndarray) - The next predicted state
            z_k1 (ndarray) - The observations at the given index

        Returns:
            x_k1 (ndarray) - The next estimated state
        """
        raise (NotImplementedError)

# Problem 5