In [1]:
from jax import config
config.update("jax_enable_x64", True)

import jax.numpy as np 
import numpy as onp
import jax
from jax import jacfwd, grad, jit, vmap, hessian
from jaxlie import SE2, SO2, manifold
import meshcat
from meshcat import transformations as tfm
from meshcat import geometry as geom 
import time
import matplotlib.pyplot as plt
import tqdm
from IPython.display import clear_output

## Dynamics Model 
We will be using a 2D (planar) drone system with two force inputs at each robot (see below).

![drone](quadrotor2d.png)


The manipulator equation of the drone are given as 
\begin{align*}
    m \ddot{p_x} &= - (u_1 + u_2) \sin(\theta) \\ 
    m \ddot{p_y} &= (u_1 + u_2) \cos(\theta) - m g \\ 
    I \ddot{\theta} &= r (u_1 - u_2)
\end{align*}
where $m$ is the mass, $I$ is the inertia, $g$ is gravity, $r$ is the distance from the center to the base of the propeller, and the state is given as $x=[p_x, p_y, \theta, \dot{p_x}, \dot{p_y}, \dot{\theta}]$.
The degrees of freedom at the $x,y$ position, drone rotation relative to the world $\theta$, and there are two force inputs $u_1, u_2$ for each rotor. 

As we are using a direct transcription approach, we need to write the dynamics as a discrete-time system, i.e., $x[k+1] = x[k] + dt * f(x[k], u[k])$


In [2]:
def euler_step(x, f, dt, *args):
    return x + f(x, *args) * dt

_dt = 0.1
_g  = 9.81 
_c1 = 0.02 
_c2 = 0.02
_r = 0.1
_I = 0.1
_m = 0.1

def f(x, u):
    """
        Input: state x=[px,py, theta, pxdt, pydt, thetadt], control u = [u1, u2]
        output: \dot{x} 
    """
    px, py, th, pxt, pyt, tht = x
    u1, u2 = u
    F = u1+u2 
    T = u1-u2
    xtt = - F * np.sin(th) / _m 
    ytt = F * np.cos(th) / _m - _g
    thtt =  _r * T/_I
    return np.array([pxt, pyt, tht, xtt, ytt, thtt])

def F(x, u):
    """
        solves x[t+dt] = x[t] + dt * f(x[t], u[t])
    """
    return euler_step(x, f, _dt, u)

F_jitted = jit(F)

## Trajectory rollout

First we will reduce the problem to a shooting based objective with soft constraints. Let's construct a function that, given an initial state condition and a control sequence returns a state and control trajectory. We will also incorporate a soft control constraint using a differentiable saturation function `umax*(tanh(x)/2 +0.5)` into the forward shooting which returns a smooth approximation of a function clip between 0 and umax=2. We will use this saturated value in the objective function to compute the cost of control.

In [3]:
def u_sat(u):
    return 2*(np.tanh(u)/2 + 0.5)

def shoot_F(x0, U):
    x = x0.copy()
    X = [x.copy()]
    U_sat = []
    for u in U:
        U_sat.append(u_sat(u))
        X.append(F(X[-1], U_sat[-1]))
    return np.array(X), np.array(U_sat)

## Define Objective

Here, we will need to redefine the cost function and terminal cost defined previously as MPC plans on smaller time-scales with a receding horizon (terminal time is then always moving away from the controller). As a result, we need to inform the controller on what the task is throughout planning (especially since MPC methods are simpler and need a lot more help to inform of solutions). We will define a quadratic cost with a terminal condition 
$$
    \sqrt{J} = \frac{1}{N} \sum (x_k - x_d)^\top Q (x_k - x_d) + u_k^\top R u_k + (x_N - x_d)^\top Q_f (x_N - x_d)
$$
where $Q, Q_f, R$ are define below, and $N$ is the discrete time horizon (used to normalize the objective value).

In addition, this function will take in a control sequence an initial state and implicitly simulate the state trajectory using the shooting function and return a scalar loss value for the state/control trajectories.

In [4]:
_xd = np.array([0.,0.,2*np.pi,0.,0.,0.])
_Q = np.diag(np.array([2., 2., 80., .01, .01, .001]))
_Qf = np.diag(np.array([4., 4., 80., .01, .01, .01]))
_R = np.diag(np.array([0.0001, 0.0001]))

def soft_objective(x0, U):
    X, U_sat = shoot_F(x0, U)
    J = 0.0
    for x,u in zip(X[:-1], U_sat):
        J += (x - _xd).T @ _Q @ (x - _xd) + u.T @ _R @ u
    return ((J + (X[-1] - _xd).T @ _Qf @ (X[-1] - _xd)) / len(X))**2

In [5]:
grad_soft_obj = grad(soft_objective, argnums=1)
jitted_grad_soft_obj = jit(grad_soft_obj)
jitted_vmap_soft_obj = jit(vmap(soft_objective, in_axes=[None, 0]))

## Define Langevin dynamics update

The Unadjusted Langevin Algorithm (ULA) is given by

$ U_{k+1} = U_k - \eta \nabla J(U) + \sqrt{2 \eta} z_k$,

where $\eta$ the step size and $z_k \sim \mathcal{N}(0, I)$ is standarad gaussian.

In [6]:
step_size = 1e-5
tH = 2
N = int(tH/_dt)

# determine how U will concentrated on the argmin J
J_scale = 100

def ULA(x0, U, step_size, random_key):
    z = jax.random.multivariate_normal(random_key, mean=np.zeros((N-1) * 2), cov=np.eye((N-1) * 2)).reshape(N-1, 2)
    return U - step_size * grad_soft_obj(x0, U) * J_scale + (2 * step_size)**0.5 * z

jitted_vmap_ULA = jit(vmap(ULA, in_axes=(None, 0, None, 0)))

In [7]:
def mpld(x0, Us, random_key, step_size):
    Us = Us.at[:, :-1].set(Us[:, 1:])
    random_keys = jax.random.split(random_key, Us.shape[0])
    for i in range(100):
        Us = jitted_vmap_ULA(x0, Us, step_size, random_keys)
        # decrease the step size to obtain unbias distribution
        step_size = step_size * 0.95
    return Us

In [8]:
x0 = np.zeros(6)

N_samples = 1000
random_key = jax.random.PRNGKey(0)
random_key, key_to_use = jax.random.split(random_key)
Us = jax.random.multivariate_normal(random_key, mean=np.zeros((N-1) * 2), cov=np.eye((N-1) * 2), shape=(N_samples,))
Us = Us.reshape(N_samples, N-1, 2)

jitted_vmap_ULA(x0, Us, step_size, jax.random.split(random_key, N_samples))

qs = [x0[:3]]
costs_list = []

for t in tqdm.tqdm(range(100)):
    random_key, key_to_use = jax.random.split(random_key)
    Us = mpld(x0, Us, random_key, step_size)
    # get the best U
    costs = jitted_vmap_soft_obj(x0, Us)
    costs_list.append(costs)
    x0 = F_jitted(x0, u_sat(Us[np.argmin(costs), 0]))
    q, qdot = np.split(x0, 2)
    qs.append(q)

100%|██████████| 150/150 [00:26<00:00,  5.62it/s]


In [9]:
viz = meshcat.Visualizer()

drone  = viz["drone"]
drone_body = drone["body"]
drone_body.set_object(
    geom.Box([0.1,0.5,0.02])
)
drone_propFL = drone["propFL"]
drone_propFL.set_transform(tfm.translation_matrix([0.,-0.25,0.05])@tfm.rotation_matrix(np.pi/2,[1,0,0]))
drone_propFL.set_object(
    geom.Cylinder(height=0.01, radius=0.2)
)

drone_propFR = drone["propFR"]
drone_propFR.set_transform(tfm.translation_matrix([0.,0.25,0.05])@tfm.rotation_matrix(np.pi/2,[1,0,0]))
drone_propFR.set_object(
    geom.Cylinder(height=0.01, radius=0.2)
)
viz.jupyter_cell()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [10]:
for q in qs:
    drone.set_transform(
        tfm.translation_matrix([0,q[0],q[1]]) @ tfm.rotation_matrix(q[2],[1,0,0])
    )
    time.sleep(_dt)