In [1]:
import os
import random 
import numpy as np

# uncomment to run everything on cpu
# os.environ["JAX_PLATFORM_NAME"] = "cpu"


# stock jax solvers
import jax 
import jax.numpy as jnp 
from jax import jit, grad, vmap  
from jaxopt import OSQP, BoxOSQP, CvxpyQP
import timeit

# kevin's interior point solver
import qpax

# CPU solvers
import scipy.sparse as sp
import osqp

# random stuff
import numpy as np

In [2]:
# form non-negative least squares problems
def form_nnls(m, n, N):
    # n = size of x 
    # m = rows in F 
    # N = batch size
    Fs = jnp.array(np.random.randn(N, m, n))
    gs = jnp.array(np.random.randn(N, m))


    Fs = jnp.array(np.random.randn(N, m, n))
    gs = jnp.array(np.random.randn(N, m))

    @jit
    def form_qp(F, g):
        # convert the least squares to qp form 
        n = F.shape[1]
        Q = F.T @ F 
        q = -F.T @ g 
        G = -jnp.eye(n)
        h = jnp.zeros(n)
        A = jnp.zeros((0, n))
        b = jnp.zeros(0)
        return Q, q, A, b, G, h

    # create the QPs in a batched fashion 
    Qs, qs, As, bs, Gs, hs = vmap(form_qp, in_axes = (0, 0))(Fs, gs)
    
    return Qs, qs, As, bs, Gs, hs

# create qpax function for solving a batch of QPs 
batch_qpax = jit(vmap(qpax.solve_qp_x, in_axes = (0, 0, 0, 0, 0, 0)))

# create OSQP function for solving a batch of QPs
jax_osqp = OSQP()
def batch_osqp(Qs, qs, As, bs, Gs, hs):
    return jax_osqp.run(params_obj=(Qs, qs), params_ineq=(Gs, hs))
batch_osqp = jit(vmap(batch_osqp, in_axes = (0, 0, 0, 0, 0, 0)))

In [3]:
# single 50x50 problem
Qs, qs, As, bs, Gs, hs = form_nnls(100, 50, 1)
x_qpax = batch_qpax(Qs, qs, As, bs, Gs, hs)
x_osqp = batch_osqp(Qs, qs, As, bs, Gs, hs).params
%timeit batch_qpax(Qs, qs, As, bs, Gs, hs)
%timeit batch_osqp(Qs, qs, As, bs, Gs, hs)

1.83 ms ± 6.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
11.5 ms ± 4.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
# 128 batch of 50x50 problems
Qs, qs, As, bs, Gs, hs = form_nnls(100, 50, 128)
x_qpax = batch_qpax(Qs, qs, As, bs, Gs, hs)
x_osqp = batch_osqp(Qs, qs, As, bs, Gs, hs).params
%timeit batch_qpax(Qs, qs, As, bs, Gs, hs)
%timeit batch_osqp(Qs, qs, As, bs, Gs, hs)

46.4 ms ± 2.31 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
15.5 ms ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
# single 500x500 problem
Qs, qs, As, bs, Gs, hs = form_nnls(100, 500, 1)
x_qpax = batch_qpax(Qs, qs, As, bs, Gs, hs)
x_osqp = batch_osqp(Qs, qs, As, bs, Gs, hs).params
%timeit batch_qpax(Qs, qs, As, bs, Gs, hs)
%timeit batch_osqp(Qs, qs, As, bs, Gs, hs)

42.7 ms ± 4.99 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.26 s ± 550 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
# some quadruped problems

In [8]:
# non batched, just jitted
# for some reason qpax fails on all these problems (returns nans)
# qpax_jit = jit(qpax.solve_qp_x)
osqp_jit = jit(jax_osqp.run)

# CPU OSQP
cvxpy = CvxpyQP(solver='OSQP')


probs = ["problems/QUADCMPC1.npz", "problems/QUADCMPC2.npz", "problems/QUADCMPC3.npz", "problems/QUADCMPC4.npz"]
# these were taken from https://github.com/qpsolvers/mpc_qpbenchmark/

for prob in probs:
    data = np.load(prob)
    P = data["P"]
    q = data["q"]
    A = data["A"]
    b = data["b"]
    G = data["G"]
    h = data["h"]
    
    x_osqp = osqp_jit(params_obj=(P, q), params_eq=(A,b), params_ineq=(G, h)).params.primal
    x_cvxpy_osqp = cvxpy.run(init_params=[], params_obj=(P, q), params_eq=(A,b), params_ineq=(G, h)).params.primal
    print(prob)
    
    if jnp.isnan(x_osqp).any():
        print('osqp failed')
    if jnp.isnan(x_cvxpy_osqp).any():
        print('cvxpy osqp failed')
    
    print('jax osqp time')
    %timeit osqp_jit(params_obj=(P, q), params_eq=(A,b), params_ineq=(G, h))
    print('CVXPY OSQP time')
    %timeit cvxpy.run(init_params=[], params_obj=(P, q), params_eq=(A,b), params_ineq=(G, h))
    print()

problems/QUADCMPC1.npz
jax osqp time
4.59 s ± 9.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
CVXPY OSQP time
55 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

problems/QUADCMPC2.npz
jax osqp time
2.7 s ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
CVXPY OSQP time
23.6 ms ± 349 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

problems/QUADCMPC3.npz
jax osqp time
1.19 s ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
CVXPY OSQP time
15.8 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

problems/QUADCMPC4.npz
jax osqp time
1.19 s ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
CVXPY OSQP time
15.8 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)



In [9]:
# atlas problems

box_osqp = BoxOSQP()
box_osqp_jit = jit(box_osqp.run)


probs = ["problems/ATLAS_2.npz", "problems/ATLAS_10.npz", "problems/ATLAS_30.npz", "problems/ATLAS_40.npz", "problems/ATLAS_50.npz"]
for prob in probs:
    data = np.load(prob)
    P = data["P"]
    q = data["q"]
    A = data["A"]
    l = data["l"]
    u = data["u"]
    
    x_osqp = box_osqp_jit(params_obj=(P, q), params_eq=A, params_ineq=(l,u)).params.primal[0]
    
    if jnp.isnan(x_osqp).any():
        print('boxosqp failed')
    
    print(prob)
    print('jax osqp time')
    %timeit box_osqp_jit(params_obj=(P, q), params_eq=A, params_ineq=(l,u))
    print('solving with CPU OSQP')
    P = sp.csc_matrix(P)
    A = sp.csc_matrix(A)
    
    cpu_osqp = osqp.OSQP()
    cpu_osqp.setup(P=P, q=q, A=A, l=l, u=u)
    cpu_osqp.solve()    
    print('\n\n')
    
    
    

problems/ATLAS_2.npz
jax osqp time
34.3 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
solving with CPU OSQP
-----------------------------------------------------------------
           OSQP v0.6.3  -  Operator Splitting QP Solver
              (c) Bartolomeo Stellato,  Goran Banjac
        University of Oxford  -  Stanford University 2021
-----------------------------------------------------------------
problem:  variables n = 58, constraints m = 58
          nnz(P) + nnz(A) = 2610
settings: linear system solver = qdldl,
          eps_abs = 1.0e-03, eps_rel = 1.0e-03,
          eps_prim_inf = 1.0e-04, eps_dual_inf = 1.0e-04,
          rho = 1.00e-01 (adaptive),
          sigma = 1.00e-06, alpha = 1.60, max_iter = 4000
          check_termination: on (interval 25),
          scaling: on, scaled_termination: off
          warm start: on, polish: off, time_limit: off

iter   objective    pri res    dua res    rho        time
   1  -1.6213e-23   3.17e+02   3.15e+01   1.0

KeyboardInterrupt: 