In [3]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
# if using a jupyter notebook
%matplotlib inline

In [13]:
import time
from jaxopt import OSQP
import jax.numpy as jnp

# Overview of Quadratic Programming (QP)
# https://www.youtube.com/watch?v=GZb9647X8sg
#
# cost = 0.4x1^2 -5x1 + x2^2 -6x2 + 50
start = time.time()

# Define problem data
P = jnp.array([[0.4, 0.], [0., 1.]]) * 2
q = jnp.array([-5., -6.])
A = jnp.array([[1., -1.], [-0.3, -1.], [1., 0.], [0., 1.]])
l = jnp.array([-np.inf, -np.inf, 0, 0])
u = jnp.array([-2., -8., 10., 10.])

# Create an OSQP object
prob = OSQP()
sol = prob.run(params_obj=(P, q), params_ineq=(A, u)).params
print(sol.primal)
# # Setup workspace and change alpha parameter
# prob.init_params(params_obj=(P, q), params_ineq=(A, u))



# # Solve problem
# res = prob.run()
# print(res.x)

# Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
# c = jnp.array([1.0, 1.0])
# A = jnp.array([[1.0, 1.0]])
# b = jnp.array([1.0])
# G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
# h = jnp.array([0.0, 0.0])

# qp = OSQP()
# sol = qp.run(params_obj=(Q, c), params_ineq=(G, h)).params

# print(sol.primal)
# print(sol.dual_eq)
# print(sol.dual_ineq)

end = time.time()
print(end - start)

TypeError: init_params() missing 2 required positional arguments: 'init_x' and 'params_eq'

In [14]:
import time
from jaxopt import OSQP
import jax.numpy as jnp
import numpy as np
import scipy as sp
from scipy import sparse

# Discrete time model of a quadcopter
Ad = sparse.csc_matrix([
    [1.,      0.,     0., 0., 0., 0., 0.1,     0.,     0.,  0.,     0.,     0.],
    [0.,      1.,     0., 0., 0., 0., 0.,      0.1,    0.,  0.,     0.,     0.],
    [0.,      0.,     1., 0., 0., 0., 0.,      0.,     0.1, 0.,     0.,     0.],
    [0.0488,  0.,     0., 1., 0., 0., 0.0016,  0.,     0.,  0.0992, 0.,     0.],
    [0.,     -0.0488, 0., 0., 1., 0., 0.,     -0.0016, 0.,  0.,     0.0992, 0.],
    [0.,      0.,     0., 0., 0., 1., 0.,      0.,     0.,  0.,     0.,     0.0992],
    [0.,      0.,     0., 0., 0., 0., 1.,      0.,     0.,  0.,     0.,     0.],
    [0.,      0.,     0., 0., 0., 0., 0.,      1.,     0.,  0.,     0.,     0.],
    [0.,      0.,     0., 0., 0., 0., 0.,      0.,     1.,  0.,     0.,     0.],
    [0.9734,  0.,     0., 0., 0., 0., 0.0488,  0.,     0.,  0.9846, 0.,     0.],
    [0.,     -0.9734, 0., 0., 0., 0., 0.,     -0.0488, 0.,  0.,     0.9846, 0.],
    [0.,      0.,     0., 0., 0., 0., 0.,      0.,     0.,  0.,     0.,     0.9846]
])
Bd = sparse.csc_matrix([
    [0.,      -0.0726,  0.,     0.0726],
    [-0.0726,  0.,      0.0726, 0.],
    [-0.0152,  0.0152, -0.0152, 0.0152],
    [-0.,     -0.0006, -0.,     0.0006],
    [0.0006,   0.,     -0.0006, 0.0000],
    [0.0106,   0.0106,  0.0106, 0.0106],
    [0,       -1.4512,  0.,     1.4512],
    [-1.4512,  0.,      1.4512, 0.],
    [-0.3049,  0.3049, -0.3049, 0.3049],
    [-0.,     -0.0236,  0.,     0.0236],
    [0.0236,   0.,     -0.0236, 0.],
    [0.2107,   0.2107,  0.2107, 0.2107]])
[nx, nu] = Bd.shape

# Constraints
u0 = 10.5916
umin = np.array([9.6, 9.6, 9.6, 9.6]) - u0
umax = np.array([13., 13., 13., 13.]) - u0
xmin = np.array([-np.pi/6, -np.pi/6, -np.inf, -np.inf, -np.inf, -1.,
                 -np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf])
xmax = np.array([np.pi/6, np.pi/6, np.inf, np.inf, np.inf, np.inf,
                 np.inf, np.inf, np.inf, np.inf, np.inf, np.inf])

# Objective function
Q = sparse.diags([0., 0., 10., 10., 10., 10., 0., 0., 0., 5., 5., 5.])
QN = Q
R = 0.1*sparse.eye(4)

# Initial and reference states
x0 = np.zeros(12)
xr = np.array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

# Prediction horizon
N = 10

# Cast MPC problem to a QP: x = (x(0),x(1),...,x(N),u(0),...,u(N-1))
# - quadratic objective
P = jnp.block([sparse.kron(sparse.eye(N), Q), QN,
                       sparse.kron(sparse.eye(N), R)], format='csc')
# - linear objective
q = jnp.hstack([np.kron(np.ones(N), -Q.dot(xr)), -QN.dot(xr),
               np.zeros(N*nu)])
# - linear dynamics
Ax = sparse.kron(sparse.eye(N+1), -sparse.eye(nx)) + \
    sparse.kron(sparse.eye(N+1, k=-1), Ad)
Bu = sparse.kron(sparse.vstack([sparse.csc_matrix((1, N)), sparse.eye(N)]), Bd)
Aeq = jnp.hstack([Ax, Bu])
leq = jnp.hstack([-x0, np.zeros(N*nx)])
ueq = leq
# - input and state constraints
Aineq = jnp.eye((N+1)*nx + N*nu)
lineq = jnp.hstack([np.kron(np.ones(N+1), xmin), np.kron(np.ones(N), umin)])
uineq = jnp.hstack([np.kron(np.ones(N+1), xmax), np.kron(np.ones(N), umax)])
# - OSQP constraints
A = jnp.vstack([Aeq, Aineq], format='csc')
l = jnp.hstack([leq, lineq])
u = jnp.hstack([ueq, uineq])

# Create an OSQP object
prob = OSQP()

# Setup workspace
# prob.setup(P, q, A, l, u, warm_start=True)

# Simulate in closed loop
nsim = 15
last_ctrl = []
x0_last = np.zeros(12)
for i in range(nsim):
    # Solve
    res = prob.run(params_obj=(P, q), params_eq=(Aeq, ueq), params_ineq=(Aineq, uineq)).params

    # Check solver status
    # if res.info.status != 'solved':
    #     raise ValueError('OSQP did not solve the problem!')

    # Apply first control input to the plant
    ctrl = res.primal[-N*nu:-(N-1)*nu]
    x0 = Ad.dot(x0) + Bd.dot(ctrl)
    if i == nsim - 1:
        x0_last = Ad.dot(x0_last) + Bd.dot(ctrl)
        plt.plot(xr)
        plt.plot(x0_last)
        plt.show()

    # Update initial state
    l[:nx] = -x0
    u[:nx] = -x0
    prob.update(l=l, u=u)
    # plt.plot(xr)
    # # plt.plot(res.x)
    # plt.plot(x0)
    # plt.show()


TypeError: Argument 'csc' of type <class 'str'> is not a valid JAX type