# Unicycle Optimal Control

Consider the kinematic model of a unicycle:
\begin{equation}
\dot{x} = v \cos(\theta), \quad \dot{y} = v \sin(\theta), \quad \dot{\theta} = \omega.
\end{equation}
where $(x,y)$ is the position, $\theta$ is the heading angle with respect to the x-axis, and the control inputs are the speed $v$ and the angular velocity $\omega$. We define the state vector $\textbf{x} = (x, y, \theta)$.

In this problem we will compute the optimal control to drive the unicycle vehicle from a starting configuration $\textbf{x}(0) = (0, 0, \pi/2)$ to the target configuration $\textbf{x}(T) = (5, 5, \pi/2)$. We seek to find a solution that minimizes time to the target configuration while also minimizing the control effort required. Specifically, we want to minimize:
\begin{equation}
J(x,u) = \int_0^T  \alpha + v(t)^2 + \omega(t)^2  dt,
\end{equation}
where $\alpha > 0$ is a constant weighting factor and $T$ is the free final time.

#### Exercise 5.2: Define the dynamics, Hamiltonian, and necessary optimality conditions

NOTE: It can be good to use both `jax.numpy` and regular `numpy`, as `numpy` operations are typically much faster than their `jax.numpy` counterparts (at least before `jax.jit`) whenever JAX-specific magic (e.g., automatic differentiation, vectorization) is not required. However, since in this exercise we encourage you to use `jax.numpy` for practice.

In [None]:
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
import matplotlib.pyplot as plt

First, implement the function `unicycle_dynamics`:

In [None]:
def unicycle_dynamics(x: jnp.array, u: jnp.array) -> jnp.array:
    """
    Evaluate the continuous-time dynamics of a unicycle.

    Parameters
    ----------
    x : An array of shape (3,) containing the unicycle pose (x, y, θ).
    u : An array of shape (2,) containing the velocity controls (v, ω).

    Returns
    -------
    dx/dt : An array of shape (3,) containing the time derivative of the state (dx/dt, dy/dt, dθ/dt)
    """
    _, _, θ = x
    v, ω = u

    return dx

Next, derive the Hamiltonian for the optimal control problem and implement the function `hamiltonian` below to compute it:

In [None]:
def hamiltonian(x: jnp.array, p: jnp.array, u: jnp.array) -> float:
    """
    Evaluate the Hamiltonian.

    Parameters
    ----------
    x : An array of shape (3,) containing the unicycle pose (x, y, θ).
    p : An array of shape (3,) containing the co-state (px, py, pθ).
    u : An array of shape (2,) containing the velocity controls (v, ω).

    Returns
    -------
    The value of the Hamiltonian.
    """
    α = 0.25 # we will keep this value of α constant

    return H

From the Hamiltonian, we can now derive the optimal control function $u^*(t)$ and the set of ordinary differential equations and their boundary conditions that constitute the necessary optimality conditions. Implement the function $u^*(t)$ in `optimal_control` and the ordinary differential equations for the NOCs in `noc_ode`.

In [None]:
def optimal_control(x: jnp.array, p: jnp.array) -> jnp.array:
    """
    Compute an optimal control as a function of the state and co-state.

    Parameters
    ----------
    x : An array of shape (3,) containing the unicycle pose (x, y, θ).
    p : An array of shape (3,) containing the co-state (px, py, pθ).

    Returns
    -------
    An array of shape (2,) containing an optimal control u* = (v*, ω*).
    """
    _, _, θ = x
    px, py, pθ = p

    return u

In [None]:
def noc_ode(x_and_p: tuple[jnp.array, jnp.array],
            t: float) -> tuple[jnp.array, jnp.array]:
    """
    Evaluate the ODE that an optimal state and co-state must obey.

    Parameters
    ----------
    x_and_p : A tuple of arrays (x, p), where x is an array of shape
              (3,) containing the unicycle pose (x, y, θ), and p is
              an array of shape (3,) containing the co-state (px, py, pθ).
    t : The current time (required for use with `odeint`, but can be ignored here).

    Returns
    -------
    A tuple of arrays (dx/dt, dp/dt) containing the time derivatives of the
    state and co-state, respectively.
    """
    x, p = x_and_p

    # Hint: Use `jax.grad` for dp/dt
    return dx, dp

#### Exercise 5.3: Define the boundary condition residual function
Now, use the `noc_trajectories` function defined below to complete the code for `boundary_residual`, where we want to compute the residual error of the boundary conditions for the NOCs. In other words, we want to express the boundary conditions in a standard form $l(z(t_0), z(t_f)) = 0$ where $z$ represents a vector of the states and costates and have `boundary_residual` return $l(z(t_0), z(t_f))$.

In [None]:
def noc_trajectories(x0: jnp.array,
                     p0: jnp.array,
                     T: float,
                     N: int = 20) -> tuple[jnp.array, jnp.array, jnp.array,
                                           jnp.array]:
    """
    Integrate the optimal state and co-state ODE forward in time.

    Parameters
    ----------
    x0 : An array of shape (3,) containing the initial state (x(0), y(0), θ(0)).
    p0 : An array of shape (3,) containing the initial co-state (px(0), py(0), pθ(0)).
    T : The final time T.
    N : The number of nodes along the ODE solution at which to report the
        solution values.

    Returns
    -------
    A tuple of arrays (ts, xs, us, ps) where:
        ts : An array of shape (N,) containing a sequence of times
             spanning [0, T].
        xs : An array of shape (N, 3) containing the states at ts.
        us : An array of shape (N, 2) containing the optimal control
             inputs at ts.
        ps : An array of shape (N, 3) containing the co-states at ts.
    """
    ts = jnp.linspace(0, T, N)

    return ts, xs, us, ps

def boundary_residual(p0: jnp.array,
                      T: float,
                      x0: jnp.array,
                      xT: jnp.array) -> jnp.array:
    """
    Compute the residual error of the boundary conditions for the NOCs.

    Parameters
    ----------
    p0 : An array of shape (3,) containing the initial co-state (px(0), py(0), pθ(0)) 
         estimate.
    T :  The final time T estimate.
    x0 : An array of shape (3,) containing the initial state (x(0), y(0), θ(0)).
    xT : An array of shape (3,) containing the final state x(T), y(T), θ(T)).

    Returns
    -------
    The array of shape (4,) we want to drive to zero through appropriate
    selection of p0 and T.
    """
    # Hint: Use `noc_trajectories` here.
    # xs, ps = odeint(noc_ode, (x0, p0), jnp.array([0., T]))
    # xT, pT = xs[-1], ps[-1]
    
    return residual

#### Exercise 5.4: Implement the Newton-Raphson method and solve the optimal control problem
Now that we have implemented `boundary_residual` our goal is to find a way to get the boundary conditions to be satisfied, or in other words have `boundary_residual` return a zero vector. To do this, we need a method to search over a set of initial costate values and the final time $T$ to find a solution with zero residual. We will use a Newton-Raphson method to accomplish this. Implement the Newton-Raphson step in the function `newton_step` below.

In [None]:
# Uncomment `@jax.jit` for a speedier runtime per iteration (post-compilation),
# but a harder time debugging.
@jax.jit
def newton_step(p0: jnp.array,
                T: float,
                x0: jnp.array,
                xT: jnp.array) -> tuple[jnp.array, float]:
    """
    Implement a step of the Newton-Raphson method for `boundary_residual`.

    Parameters
    ----------
    p0 : An array of shape (3,) containing the current initial co-state
         (px(0), py(0), pθ(0)) estimate.
    T : The current final time T estimate.
    x0 : An array of shape (3,) containing the initial state (x(0), y(0), θ(0)).
    xT : An array of shape (3,) containing the final state (x(T), y(T), θ(T)).

    Returns
    -------
    A tuple containing the next estimate of p0 and T computed by the
    Newton-Raphson method.
    """
    # Hint: Use `jax.jacobian` and `jnp.linalg.solve`.

    return p0, T

Now, run the code below to run iterations of the Newton-Raphson method to find the initial costate parameters and the final time that will satisfy the necessary optimality conditions, and thus will give us the optimal control for the unicycle problem.

In [None]:
def single_shooting(p0: jnp.array,
                    T: float,
                    x0: jnp.array,
                    xT: jnp.array,
                    max_iters: int = 10,
                    tol: float = 1e-4) -> tuple[jnp.array, float]:
    """
    Single shooting for the unicycle using a Newton-Raphson root finding method.

    Parameters
    ----------
    p0 : An array of shape (3,) containing an initial guess for the
         initial co-state (px(0), py(0), pθ(0)).
    T : An initial guess for the final time T.
    x0: An array of shape (3,) containing the initial state (x(0), y(0), θ(0)).
    xT: An array of shape (3,) containing the desired final state (x(T), y(T), θ(T)).
    max_iters : The maximum number of Newton-Raphson steps to take.
    tol : The convergence tolerance.

    Returns
    -------
    A tuple containing the optimized initial co-state p0 and final time T.
    """
    converged = False
    for k in range(max_iters):
        p0, T = newton_step(p0, T, x0, xT)
        r = boundary_residual(p0, T, x0, xT)
        error = jnp.max(jnp.abs(r))
        print('Iteration {} (error = {})'.format(k, error))
        if error < tol:
            converged = True
            break
    if not converged:
        raise RuntimeError('Single shooting did not converge!')
    return p0, T


p0_init = jnp.array([1., 1., -1.])
T_init = 20.

x0 = jnp.array([0, 0, jnp.pi / 2])
xT = jnp.array([5, 5, jnp.pi / 2])
p0, T = single_shooting(p0_init, T_init, x0, xT)
ts, xs, us, ps = noc_trajectories(x0, p0, T)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(xs[:, 0], xs[:, 1], 'k-', linewidth=2)
ax[0].quiver(xs[:, 0], xs[:, 1], jnp.cos(xs[:, 2]), jnp.sin(xs[:, 2]))
ax[0].plot(x0[0], x0[1], 'go', markerfacecolor='green', markersize=15)
ax[0].plot(xT[0], xT[1], 'ro', markerfacecolor='red', markersize=15)
ax[0].grid(True)
ax[0].axis([-1, 6, -1, 6])
ax[0].set_xlabel(r'$x$ [m]')
ax[0].set_ylabel(r'$y$ [m]')
ax[0].set_title('Optimal trajectory')

ax[1].plot(ts, us[:, 0], linewidth=2)
ax[1].plot(ts, us[:, 1], linewidth=2)
ax[1].grid(True)
ax[1].set_xlabel(r'$t$ [s]')
ax[1].legend([r'$v^*$ [m/s]', r'$\omega^*$ [rad/s]'])
ax[1].set_title('Optimal control sequence')

plt.tight_layout()
plt.show()