In [1]:
import numpy as np
import scipy as sp

from cuqi.solver import CGLS, LM, FISTA, ADMM, ProximalL1, ProjectNonnegative
from scipy.optimize import lsq_linear


def test_ADMM_matrix_form():
    # Parameters
    rng = np.random.default_rng(seed = 42)
    m, n = 10, 5
    A = rng.standard_normal((m, n))
    b = rng.standard_normal(m)
    
    k = 4
    L = rng.standard_normal((k, n))

    x0 = np.zeros(n)
    sol, _ = ADMM(A, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],
                   x0, 10, maxit = 100, adaptive = True).solve()

    print(sol)
    ref_sol = np.array([-3.99513417e-03, -1.32339656e-01, -4.52822633e-02, -7.44973888e-02, -3.35005208e-11])
    # Compare
    assert np.allclose(sol, ref_sol, atol=1e-4)

test_ADMM_matrix_form()

[-3.99513417e-03 -1.32339656e-01 -4.52822633e-02 -7.44973888e-02
 -3.35005208e-11]


In [2]:


def test_ADMM_function_form():
    # Parameters
    rng = np.random.default_rng(seed = 42)
    m, n = 10, 5
    A = rng.standard_normal((m, n))
    def A_fun(x, flag):
        if flag == 1:
            return A@x
        if flag == 2:
            return A.T@x
        
    b = rng.standard_normal(m)
    
    k = 4
    L = rng.standard_normal((k, n))

    x0 = np.zeros(n)
    sol, _ = ADMM(A_fun, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],
                   x0, 10, maxit = 100, adaptive = True).solve()

    print(sol)
    ref_sol = np.array([-3.99513417e-03, -1.32339656e-01, -4.52822633e-02, -7.44973888e-02, -3.35005208e-11])
    # Compare
    assert np.allclose(sol, ref_sol, atol=1e-4)

test_ADMM_function_form()

[-3.99513417e-03 -1.32339656e-01 -4.52822633e-02 -7.44973888e-02
 -3.35005152e-11]


In [3]:

# Parameters
rng = np.random.default_rng(seed = 42)
m, n = 10, 5
A = rng.standard_normal((m, n))
b = rng.standard_normal(m)
    
k = 4
L = rng.standard_normal((k, n))

x0 = np.zeros(n)
sol, _ = ADMM(A, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],
                   x0, 10, maxit = 100, adaptive = False).solve()

print(sol)

[-3.99513417e-03 -1.32339656e-01 -4.52822633e-02 -7.44973888e-02
 -3.35005208e-11]
