# Differentiable Dynamic Programming
This section is to solve the Bellman equation using differentiable dynamic programming (DDP).
Refer to [Differentiable Optimal Control via Differential Dynamic Programming](https://arxiv.org/pdf/2209.01117.pdf) and a Python implementation [here](https://github.com/neka-nat/ddp-gym/blob/master/ddp_gym.py).

DDP is designed for solving nonlinear dynamic system optimal control, which performs a quadratic approximation of the cost based on Taylor Series.




## System Dynamics

\begin{align}
x_{t+1} &= f(x_t, u_t) \\
x_0 &= x(0)
\end{align}

## Cost Function

The `cost-to-go` function at time `t` is shown as follows:

\begin{align}
    J(x_t, U_t) &= \sum_{k=t}^{N-1}\ell(x_k, u_k) + \ell_f(x_N) \\
    U_t &= \{u_t, ..., u_{N-1}\}
\end{align}

Minimum `cost-to-go` or `value function` at time `t` is then:

\begin{align}
    V_t(x_t) &= \min_{U_t} J_t(x_t, U_t) \\
            &= \min_{u_t, ..., u_{N-1}} \sum_{k=t}^{N-1} \ell(x_k, u_k) + \ell_f(x_{N}) \\
            &= \min_{u_t} [\ell(x_t, u_t) + V_{t+1}(x_{t+1})] \\
            &= \min_{u_t} [\ell(x_t, u_t) + V_{t+1}(f(x_t, u_t)] \\
            &= \min_{u_t} Q_t(x_t, u_t)
\end{align}

## Second-order Approximation of Q-function
For linear system and quadratic cost function $\ell$, $Q$ is quadratic.
But when the system is nonlinear, $Q$ is no longer quadratic.
DDP is to approximate the action-value function $Q$ in a quadratic format, which can be done based on Taylor Series.

DDP iterates between
- (i) minimizing the quadratic approximation of the Q-function in a `backward pass`, and 
- (ii) integrating the system dynamics in a `forward pass`

Given a sequence of nominal control trajectory $U_t^r = \{u_t^r, ..., u_{N-1}^r\}$, and a sequence of nominal state trajectory $X_t^r = \{x_t^r, ..., x_N^r\}$, we can approximate $Q_t(u_t, x_t)$ around the given nominal conditions based on Taylor series. 
Let's drop the subscript `t` for brivity:

\begin{equation}
\hat Q(x_t,u_t) = Q(x_t^r, u_t^r) + Q_x \Delta x_t + Q_u \Delta u_t + \frac{1}{2}\Delta x_t^T Q_{xx} \Delta x_t + \frac{1}{2}\Delta u_t^T Q_{uu} \Delta u_t + \Delta u_t^TQ_{ux} \Delta x_t 
\end{equation}

where $\Delta u_t = u_t - u_t^r$, and $\Delta x_t = x_t - x_t^r$.

The first-order and second-order partial derivatives of $Q$ can be calculated as:

\begin{align}
Q_x &= \ell_x + f_x^TV_{x,t+1} \\
Q_u &= \ell_u + f_u^TV_{x,t+1} \\
Q_{xx} &= \ell_{xx} + f_x^TV_{xx,t+1} f_x + V_{x,t+1} f_{xx} \\
Q_{uu} &= \ell_{uu} + f_u^TV_{uu,t+1} f_u + V_{x,t+1} f_{uu} \\
Q_{ux} &= \ell_{ux} + f_u^TV_{ux,t+1} f_x + V_{x,t+1} f_{ux}
\end{align}


Then optimal control problem then becomes
\begin{equation}
    \Delta u^* = \argmin_{\Delta u} \hat Q(x, u) = -Q_{uu}^{-1}Q_u - Q_{uu}^{-1}Q_{ux}\Delta x = -k - K\Delta x
\end{equation}

Use this optimal $\Delta u$ and known nominal control inputs $u^r$, we can then calculate the optimal cost-to-go $V$ in a backward pass from $t = N$ to $t = 1$:

\begin{align}
u_t^* &= u_t^r + \Delta u_t^* \\
V(x_t) &= Q(x_t, u_t^*) = Q(x_t^r, u_t^r) - Q_uQ_{uu}^{-1}Q_u \\
V_x &= Q_x - Q_{ux}Q_{uu}^{-1}Q_u \\ 
V_{xx} &= Q_{xx} - Q_{ux}Q_{uu}^{-1}Q_{ux}
\end{align}

Then we can perform a forward pass from $t= 0$ to $t=N-1$ to calculate the new nominal state trajectories using the optimal control inputs:

\begin{align}
\hat x_0 &= x(0) \\
\Delta \hat u_t &= \hat u_t - u_t^r = -k_t - K_t(\hat x_t - x_t^r) \\
\hat x_{t+1} &= f(\hat x_t, \hat u_t)
\end{align}

The backward pass and forward pass are iterated until convergence.


**Procedure**
- generate a nominal trajectory for control inputs $\{u_0^r, u_1^r, ..., u_{N-1}^r\}$
- forward pass to calculate the nominal trajectory for system states $\{x_0^r, x_1^r, ..., x_N^r\}$
- set $V_N(x_N) = \ell_f(x_N)$, and calculate $V_x$ and $V_{xx}$ at $t = N$
- backward pass: for t = N, ..., 1:
  - calculate $Q_x$, $Q_u$, $Q_{xx}$, $Q_{uu}$ and $Q_{ux}$ at $t-1$
  - calculate $k_{t-1}, K_{t-1}$
- initiate $\hat x_0 = x(0)$
- forward pass: for t = 0, ..., N-1:
  - calculate $\Delta x_t = \hat x_t - x_t^r$
  - calculate control input changes $\Delta \hat u_t$ based on $k_t$, $K_t$ and $\Delta x_t$
  - calculate improved control input $\hat u_t = u_t^r + \Delta \hat u_t$
  - calculate improved state trajectory $\hat x_{t+1}$
- repeat until convergence
  - $V(x_0)$ is minimized and converged
  - max iteration is reached 
  - etc ... 
  
  


## Deveopment Ideas
- DDP class
- interact with openGym environment

## Example

### Problem 1

Solve a pendulum problem as described https://gymnasium.farama.org/environments/classic_control/pendulum/. This is a nonlinear system for control.


\begin{align}
\ddot \theta = \frac{3g}{2l}\sin(\theta) + \frac{3}{ml^2}\tau
\end{align}

The state-space model can be expressed as:

\begin{align}
    \begin{bmatrix}
    \dot \theta \\
    \ddot \theta
    \end{bmatrix} 
    = 
    \begin{bmatrix}
    \dot \theta \\
    \frac{3g}{2l}\sin(\theta)
    \end{bmatrix}
    +
    \begin{bmatrix}
    0  \\
    \frac{3}{ml^2}
    \end{bmatrix}
    \tau
\end{align}







The discrete-time dynamics using explicit Euler method is obtained as:

\begin{align}
    \dot \theta_{t+1} &= \dot \theta_t + \frac{3g}{2l}\sin(\theta_t) \Delta t + \frac{3}{ml^2} \tau_t\Delta t \\
    \theta_{t+1} &= \theta_t + \dot \theta_{t+1} \Delta t
\end{align}

let's define cost functions:

\begin{align}
    J((\theta_t, \dot \theta_t), \tau_t) = \sum_{k=t}^{N-1} (\theta_t^2 + 0.1*\dot \theta_t^2 + 0.001* \tau_t^2)
\end{align}





In [25]:
m = 1
g = 9.8
l = 1


### Problem 2:

Continuous Car Pole -> find the optimal force





In [1]:
import gymnasium as gym
import env
from jax import grad, jacfwd, jacrev, hessian
import jax.numpy as jnp
import numpy as np
import copy
import math

class DDP:
    def __init__(self, next_state, running_cost, final_cost,
                 umax, state_dim, pred_time=50):
        self.pred_time = pred_time
        self.umax = umax
        self.v = [0.0 for _ in range(pred_time + 1)]
        self.v_x = [np.zeros(state_dim) for _ in range(pred_time + 1)]
        self.v_xx = [np.zeros((state_dim, state_dim))
                     for _ in range(pred_time + 1)]
        self.f = next_state
        self.lf = final_cost
        self.l = running_cost
        self.lf_x = jacfwd(self.lf)
        self.lf_xx = jacfwd(self.lf_x)
        self.l_x = jacfwd(self.l, 0)
        self.l_u = jacfwd(self.l, 1)
        self.l_xx = jacfwd(self.l_x, 0)
        self.l_uu = jacfwd(self.l_u, 1)
        self.l_ux = jacfwd(self.l_u, 0)
        self.f_x = jacfwd(self.f, 0)
        self.f_u = jacfwd(self.f, 1)
        self.f_xx = jacfwd(self.f_x, 0)
        self.f_uu = jacfwd(self.f_u, 1)
        self.f_ux = jacfwd(self.f_u, 0)

    def backward(self, x_seq, u_seq):
        self.v[-1] = self.lf(x_seq[-1])
        self.v_x[-1] = self.lf_x(x_seq[-1])
        self.v_xx[-1] = self.lf_xx(x_seq[-1])
        k_seq = []
        kk_seq = []
        for t in range(self.pred_time - 1, -1, -1):
            f_x_t = self.f_x(x_seq[t], u_seq[t])
            f_u_t = self.f_u(x_seq[t], u_seq[t])
            q_x = self.l_x(x_seq[t], u_seq[t]) + \
                np.matmul(f_x_t.T, self.v_x[t + 1])
            q_u = self.l_u(x_seq[t], u_seq[t]) + \
                np.matmul(f_u_t.T, self.v_x[t + 1])
            q_xx = self.l_xx(x_seq[t], u_seq[t]) + \
                np.matmul(np.matmul(f_x_t.T, self.v_xx[t + 1]), f_x_t) + \
                np.dot(self.v_x[t + 1],
                       np.squeeze(self.f_xx(x_seq[t], u_seq[t])))
            tmp = np.matmul(f_u_t.T, self.v_xx[t + 1])
            q_uu = self.l_uu(x_seq[t], u_seq[t]) + np.matmul(tmp, f_u_t) + \
                np.dot(self.v_x[t + 1],
                       np.squeeze(self.f_uu(x_seq[t], u_seq[t])))
            q_ux = self.l_ux(x_seq[t], u_seq[t]) + np.matmul(tmp, f_x_t) + \
                np.dot(self.v_x[t + 1],
                       np.squeeze(self.f_ux(x_seq[t], u_seq[t])))
            inv_q_uu = np.linalg.inv(q_uu)
            k = -np.matmul(inv_q_uu, q_u)
            kk = -np.matmul(inv_q_uu, q_ux)
            dv = 0.5 * np.matmul(q_u, k)
            self.v[t] += dv
            self.v_x[t] = q_x - np.matmul(np.matmul(q_u, inv_q_uu), q_ux)
            self.v_xx[t] = q_xx + np.matmul(q_ux.T, kk)
            k_seq.append(k)
            kk_seq.append(kk)
        k_seq.reverse()
        kk_seq.reverse()
        return k_seq, kk_seq

    def forward(self, x_seq, u_seq, k_seq, kk_seq):
        x_seq_hat = np.array(x_seq)
        u_seq_hat = np.array(u_seq)
        for t in range(len(u_seq)):
            control = k_seq[t] + \
                np.matmul(kk_seq[t], (x_seq_hat[t] - x_seq[t]))
            u_seq_hat[t] = np.clip(u_seq[t] + control, -self.umax, self.umax)
            x_seq_hat[t + 1] = self.f(x_seq_hat[t], u_seq_hat[t])
        return x_seq_hat, u_seq_hat


In [2]:
def dynamics(env, state, action):
    x, x_dot, theta, theta_dot = state
    force = action[0]
    costheta = math.cos(theta)
    sintheta = math.sin(theta)

    # For the interested reader:
    # https://coneural.org/florian/papers/05_cart_pole.pdf
    temp = (
        force + env.polemass_length * theta_dot**2 * sintheta
    ) / env.total_mass
    thetaacc = (env.gravity * sintheta - costheta * temp) / (
        env.length * (4.0 / 3.0 - env.masspole *
                    costheta**2 / env.total_mass)
    )
    xacc = temp - env.polemass_length * thetaacc * costheta / env.total_mass

    if env.kinematics_integrator == "euler":
        x = x + env.tau * x_dot
        x_dot = x_dot + env.tau * xacc
        theta = theta + env.tau * theta_dot
        theta_dot = theta_dot + env.tau * thetaacc
    else:  # semi-implicit euler
        x_dot = x_dot + env.tau * xacc
        x = x + env.tau * x_dot
        theta_dot = theta_dot + env.tau * thetaacc
        theta = theta + env.tau * theta_dot

    return np.array([x, x_dot, theta, theta_dot])

env = gym.make('CartPoleContinuous-v0').env

obs = env.reset()
ddp = DDP(lambda x, u: dynamics(env, x, u),  # x(i+1) = f(x(i), u)
          lambda x, u: 0.5 * np.sum(np.square(u)),  # l(x, u)
          lambda x: 0.5 * \
          (np.square(1.0 - np.cos(x[2])) + \
           np.square(x[1]) + np.square(x[3])),  # lf(x)
          env.force_max,
          env.observation_space.shape[0])
u_seq = [np.zeros(1) for _ in range(ddp.pred_time)]
x_seq = [obs[0]]

for t in range(ddp.pred_time):
    x_seq.append(dynamics(env, x_seq[-1], u_seq[t]))

cnt = 0
while True:
    #env.render()
    #import pyglet
    #pyglet.image.get_buffer_manager().get_color_buffer().save('frame_%04d.png' % cnt)
    for _ in range(3):
        k_seq, kk_seq = ddp.backward(x_seq, u_seq)
        x_seq, u_seq = ddp.forward(x_seq, u_seq, k_seq, kk_seq)

    print(u_seq.T)
    obs, _, _, _ = env.step(u_seq[0])
    x_seq[0] = obs.copy()
    cnt += 1


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(-0.7144089341163635, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array(-0.71440893, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
    val = Array([0., 0., 1., 0.], dtype=float32)
    batch_dim = 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError