# Exercise 4: Cartpole iLQR Control
In this problem we consider controlling a cart-pole system, which consists of a "cart" that can travel linearly along a one-dimensional track, and has a pendulum attached to it.
The objective is to implement a controller to solve the "swing up" problem to bring the pendulum from the downard hanging position to a stabilized upright position.

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np
import jax

from cartpole_utils import *

#### Exercise 4.1: Linearizing the Cartpole Dynamics
Implement the function `linearize` to linearize the function `f(s,u)` about the point `(s,u)` using JAX.

In [None]:
def linearize(f, s, u):
    """
    Linearize the function `f(s, u)` around `(s, u)`.

    Arguments
    ---------
    f : callable
        A nonlinear function with call signature `f(s, u)`.
    s : numpy.ndarray
        The state (1-D).
    u : numpy.ndarray
        The control input (1-D).

    Returns
    -------
    A : numpy.ndarray
        The Jacobian of `f` at `(s, u)`, with respect to `s`.
    B : numpy.ndarray
        The Jacobian of `f` at `(s, u)`, with respect to `u`.
    """
    ############################## Code starts here ##############################
    # Use JAX to compute the matrices A and B in one line.
    raise NotImplementedError()
    ############################## Code ends here ##############################
    return A, B

#### Exercise 4.3 and 4.4: Implement iLQR
Implement the function `discreteIterativeLQR.solve` below to compute the iLQR controller gain matrices and offset terms. Note here we are referring to the gain matrix as $Y$ and the offset term as $y$ to avoid confusion with respect to the time step variable $k$. Concretely, the controller we are looking to define is of the form:
\begin{equation}
    u_k = \bar u_k - y_k - Y_k(s_k - \bar s_k),
\end{equation}

Then, implement the function `discreteIterativeLQR.compute_control` to compute the control value based on either the open or closed-loop case.

In [None]:
class discreteIterativeLQR(object):
    def __init__(self, N, Q, R, QN, eps=1e-3, max_iters=1000):
        """
        Initialize the discrete-time iLQR problem with cost function and horizon.
    
        Arguments
        ---------
        N : int
            The time horizon of the LQR cost function.
        Q : numpy.ndarray
            The state cost matrix (2-D).
        R : numpy.ndarray
            The control cost matrix (2-D).
        QN : numpy.ndarray
            The terminal state cost matrix (2-D).
        eps : float, optional
            Termination threshold for iLQR.
        max_iters : int, optional
            Maximum number of iLQR iterations.
        """
        if max_iters <= 1:
            raise ValueError("Argument `max_iters` must be at least 1.")
        self.N = N
        self.Q = Q
        self.R = R
        self.QN = QN
        self.eps = eps
        self.max_iters = max_iters
        self.n = Q.shape[0]  # state dimension
        self.m = R.shape[0]  # control dimension
        
        # Initialize gains `Y` and offsets `y` for the policy
        self.Y = np.zeros((self.N, self.m, self.n))
        self.y = np.zeros((self.N, self.m))

        # Initialize the nominal trajectory `(s_bar, u_bar`),
        self.u_bar = np.zeros((self.N, self.m))
        self.s_bar = np.zeros((self.N + 1, self.n))

    def solve(self, f, s0, s_goal):
        """
        Compute the iLQR set-point tracking solution for given discrete-time dynamics and start
        goal states. Sets the s_bar, u_bar, Y, and y member variables and also returns them.
    
        Arguments
        ---------
        f : callable
            A function describing the discrete-time dynamics, such that
            `s[k+1] = f(s[k], u[k])`.
        s0 : numpy.ndarray
            The initial state (1-D).
        s_goal : numpy.ndarray
            The goal state (1-D).
    
        Returns
        -------
        s_bar : numpy.ndarray
            A 2-D array where `s_bar[k]` is the nominal state at time step `k`,
            for `k = 0, 1, ..., N-1`
        u_bar : numpy.ndarray
            A 2-D array where `u_bar[k]` is the nominal control at time step `k`,
            for `k = 0, 1, ..., N-1`
        Y : numpy.ndarray
            A 3-D array where `Y[k]` is the matrix gain term of the iLQR control
            law at time step `k`, for `k = 0, 1, ..., N-1`
        y : numpy.ndarray
            A 2-D array where `y[k]` is the offset term of the iLQR control law
            at time step `k`, for `k = 0, 1, ..., N-1`
        """
        n, m, N = self.n, self.m, self.N
        # Initialize gains `Y` and offsets `y` for the policy
        Y = np.zeros((N, m, n))
        y = np.zeros((N, m))
    
        # Initialize the nominal trajectory `(s_bar, u_bar`), and the
        # deviations `(ds, du)`
        u_bar = np.zeros((N, m))
        s_bar = np.zeros((N + 1, n))
        s_bar[0] = s0
        for k in range(N):
            s_bar[k + 1] = f(s_bar[k], u_bar[k])
        ds = np.zeros((N + 1, n))
        du = np.zeros((N, m))
    
        # iLQR loop
        converged = False
        for _ in range(self.max_iters):
            # Linearize the dynamics at each step `k` of `(s_bar, u_bar)`
            A, B = jax.vmap(linearize, in_axes=(None, 0, 0))(f, s_bar[:-1], u_bar)
            A, B = np.array(A), np.array(B)
    
            ############################## Code starts here ##############################
            # Update `Y`, `y`, `ds`, `du`, `s_bar`, and `u_bar`.
            raise NotImplementedError()
            ############################## Code ends here ##############################
    
            if np.max(np.abs(du)) < self.eps:
                converged = True
                break
        if not converged:
            raise RuntimeError("iLQR did not converge!")
        self.s_bar = s_bar
        self.u_bar = u_bar
        self.Y = Y
        self.y = y
        return s_bar, u_bar, Y, y

    def compute_control(self, k: int, s: np.ndarray, closed_loop: bool=True):
        """
        Compute the control value for the iLQR controller. If `closed_loop` is false
        just return the open loop u_bar value.
    
        Arguments
        ---------
        k: current time step
        s: current state
        closed_loop: whether to run in closed-loop or not
    
        Returns
        -------
        u: control value to apply
        """
        ############################## Code starts here ##############################
        raise NotImplementedError()
        ############################## Code ends here ##############################
        return u

Finally, run the code below to simulate the iLQR controller in both open-loop and closed-loop.

In [None]:
# Define constants for the cart-pole system
mp = 2.0  # pendulum mass
mc = 10.0  # cart mass
L = 1.0  # pendulum length
g = 9.81  # gravitational acceleration
cartpole = CartPole(mp, mc, L, g)

# Define problem set up constants
s0 = np.array([0.0, 0.0, 0.0, 0.0])  # initial state
s_goal = np.array([0.0, np.pi, 0.0, 0.0])  # goal state
T = 10.0  # simulation time
dt = 0.1  # sampling time

# Define iLQR Problem
n = 4  # state dimension
m = 1  # control dimension
t = np.arange(0.0, T, dt)
N = t.size - 1
Q = np.diag(np.array([10.0, 10.0, 2.0, 2.0]))  # state cost matrix
R = 1e-2 * np.eye(m)  # control cost matrix
QN = 1e2 * np.eye(n)  # terminal state cost matrix
ilqr = discreteIterativeLQR(N, Q, R, QN)

# Initialize continuous-time and discretized dynamics
f = jax.jit(cartpole_dynamics, static_argnums=(2,))
fd = jax.jit(lambda s, u, dt=dt: s + dt * f(s, u, cartpole))

# Compute the iLQR solution with the discretized dynamics
print("Computing iLQR solution ... ", end="", flush=True)
start = time.time()
ilqr.solve(fd, s0, s_goal)
print("done! ({:.2f} s)".format(time.time() - start), flush=True)

In [None]:
# Simulate on the true continuous-time system and plot results
closed_loop = False  # open-loop
s, u = simulate_cartpole(cartpole, t, s0, ilqr, closed_loop=closed_loop)
plot_state_and_control_history(s, u, t, ilqr.s_bar, "cartpole_swingup")
animate_cartpole(t, s[:, 0], s[:, 1])

In [None]:
# Simulate on the true continuous-time system and plot results
closed_loop = True  # closed-loop
s, u = simulate_cartpole(cartpole, t, s0, ilqr, closed_loop=closed_loop)
plot_state_and_control_history(s, u, t, ilqr.s_bar, "cartpole_swingup")
animate_cartpole(t, s[:, 0], s[:, 1])