In [1]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import jaxopt

In [2]:
@jax.jit
def f(params, k):
    """
    Generates the dual of the problem with inequality constraints on Lagrange multipliers 
    to be non-negative as a penalty such that the constrained dual becomes an unconstrained
    problem
    
    Args:
        - parameters[float]: Lagrangian multipliers
        - k(float): penalty coefficient
    returns:
        (float): penalized objective function
    """
    Q_0 = jnp.array([[1, 0], [0, 9]])
    P_1 = -jnp.array([2, 1])
    P_2 = -jnp.array([1, 3])
    P_3 = -jnp.array([1, 0])
    P_4 = -jnp.array([0, 1])
    P = params[0]*P_1 + params[1]*P_2 + params[2]*P_3 + params[3]*P_4
    c_1 = 1
    c_2 = 1
    c = params[0]*c_1 + params[1]*c_2
    x = -0.5*jnp.linalg.inv(Q_0)@P
    D = x.T@Q_0@x + P@x + c
    return (-D + k*(jnp.sum((params**2)*(params<0))))

In [3]:
og_params = jnp.array([1., 1., 1., 1.]).astype("float32")

In [28]:
t = jnp.linspace(0, 5, 101)
mesh = np.zeros((101, 101))
mesh_2 = np.zeros((101,101))
for i in range(len(t)):
    for j in range(len(t)):
        mesh[i][j] = f(jnp.array([t[i], t[j]]))[0]
        mesh_2[i][j] = f(jnp.array([t[i], t[j]]))[1]

In [47]:
jnp.where((mesh_2>-0.01)*(mesh_2<0.01)*(mesh<0.01)), f(jnp.array([t[1], t[19]]))

((Array([ 0,  1,  5,  6,  7, 11, 12, 13, 17], dtype=int32),
  Array([20, 19, 14, 13, 12,  7,  6,  5,  0], dtype=int32)),
 (Array(0.4907639, dtype=float32),
  Array([0.525    , 0.1611111], dtype=float32)))

In [4]:
optimizer = jaxopt.GradientDescent(f, maxiter = 1)
params = og_params.copy()
for i in range(15):
    for j in range(20):
        print(-f(params, 0))
        params, state = optimizer.run(params, (2**i)*0.001)

-2.6944447
0.022322953
0.1712749
0.3745766
0.4803932
0.6374689
0.74557704
0.8408406
0.9916612
1.093663
1.21962
1.315411
1.4775839
1.606946
1.6967707
1.8589412
1.9917314
2.0762782
2.2385418
2.3743918
2.4823372
2.5727987
2.7119985
2.8113
2.901536
3.041582
3.140039
3.2260828
3.3718634
3.46341
3.5913253
3.6720052
3.8248253
3.9554152
4.031134
4.1833262
4.316162
4.4128246
4.501557
4.6336155
4.7314634
4.805916
4.931739
5.0112896
5.090705
5.208804
5.2955217
5.365861
5.494175
5.5672307
5.6823254
5.745105
5.8721642
5.988129
6.0467935
6.172292
6.288806
6.3437715
6.4676943
6.5844746
6.636114
6.706106
6.763817
6.802195
6.8874283
6.9245706
6.97311
7.042256
7.094511
7.130742
7.2125206
7.247156
7.3259487
7.3499284
7.4206696
7.4977155
7.5201464
7.588852
7.664114
7.68514
7.751861
7.723434
7.6905813
7.6695895
7.6598563
7.629554
7.61031
7.6011834
7.5730705
7.5552897
7.5467057
7.5205173
7.504017
7.4959335
7.471491
7.4643774
7.4395046
7.4343486
7.408293
7.407635
7.391005
7.0855427
7.0833225
6.741599
6.81620

In [113]:
params

Array([-0.00470802,  1.015672  , -0.01504018, -0.00514522], dtype=float32)

In [5]:
Q = jnp.array([[1, 0], [0, 9]])
P = jnp.array([[2, 1],[1, 3]])
ideal_x = jnp.array([0.5, 1/6])
Q_0 = jnp.array([[1, 0], [0, 9]])
P_1 = -jnp.array([2, 1])
P_2 = -jnp.array([1, 3])
P_3 = -jnp.array([1, 0])
P_4 = -jnp.array([0, 1])
P = params[0]*P_1 + params[1]*P_2 + params[2]*P_3 + params[3]*P_4
x = -0.5*jnp.linalg.inv(Q_0)@P
x.T@Q_0@x, ideal_x.T@Q_0@ideal_x, x, ideal_x

(DeviceArray(0.501835, dtype=float32),
 DeviceArray(0.5, dtype=float32),
 DeviceArray([0.49559584, 0.16872719], dtype=float32),
 DeviceArray([0.5       , 0.16666667], dtype=float32))

In [55]:
1 - jnp.array([2, 1])@x, 1 - jnp.array([1, 3])@x

(Array(-0.02247119, dtype=float32), Array(0.08356863, dtype=float32))