In [1]:
''' import '''
import scipy.io as io
import numpy as np
import numpy.linalg as la
import scipy.sparse as sp
import matplotlib.pyplot as plt

In [2]:
''' load data '''
data = io.loadmat('./data/hw3_prob2.mat')

f_orig = data['f']
x_orig = data['x_orig']

M, N = f_orig.shape

In [3]:
''' hyper param '''
MU = 0.02
LAMBDA = 0.0002

MAXITERS = 300
CRIT = 1e-4

In [4]:
''' function '''
def soft_thres(z, t):
    return np.sign(z)*np.maximum(np.abs(z) - t, 0)

# ADMM (1d)

In [5]:
''' TV (D_h, D_v) '''
# no-pading
# Im = sp.eye(M)
# Dn = sp.diags([-1, 1], offsets=(0, 1), shape=(N-1, N))
# D_h = sp.kron(Im, Dn)

# In = sp.eye(N)
# Dm = sp.diags([-1, 1], offsets=(0, 1), shape=(M-1, M))
# D_v = sp.kron(Dm, In)

# padding
Im = sp.eye(M)
Dn = sp.diags([1, -1, 1], offsets=(-(N-1), 0, 1), shape=(N,N))
D_h = sp.kron(Im, Dn)

In = sp.eye(N)
Dm = sp.diags([1, -1, 1], offsets=(-(M-1), 0, 1), shape=(M,M))
D_v = sp.kron(Dm, In)

In [None]:
''' ADMM (1d) '''

fn_admm = []
xn_admm = []

f = f_orig.reshape(-1,1)
x = np.zeros((M*N,1))
x_prev = x

d_h = np.zeros((M*N,1))
q_h = np.zeros((M*N,1))
d_v = np.zeros((M*N,1))
q_v = np.zeros((M*N,1))

for k in range(MAXITERS):

    # f
    fn = (MU/2)*la.norm(x - f)**2 + la.norm(d_h, 1) + la.norm(d_v, 1) \
        + (LAMBDA/2)*la.norm(d_h - D_h@x - q_h)**2 + (LAMBDA/2)*la.norm(d_v - D_v@x - q_v)**2

    # update
    x = f + (LAMBDA/MU)*(D_h.T@(d_h - D_h@x - q_h) + D_v.T@(d_v - D_v@x - q_v))
    # x = MU*f/(MU + 4*LAMBDA) + LAMBDA/(MU + 4*LAMBDA)*(D_vh(x) + D_h.T@(d_h - q_h) + D_v.T@(d_v - q_v))

    d_h = soft_thres(D_h@x + q_h, 1/LAMBDA)
    d_v = soft_thres(D_v@x + q_v, 1/LAMBDA)
    q_h = q_h + (D_h@x - d_h)
    q_v = q_v + (D_v@x - d_v)

    # stop condition
    xcon = la.norm(x - x_prev)/la.norm(x)
    if (k >= 1) and (la.norm(x - x_prev)/la.norm(x) < CRIT):
        break
    
    x_prev = x

    # history
    fn_admm.append(fn)
    xn_admm.append(xcon)
    print(f'i = {k:<4d}, |x-x|/|x| = {xcon:.8f}, f = {fn:.8f}')

# optimal
x_admm = x

In [None]:
''' ADMM + reweighted L1 (1d) '''

DELTA = 3.9

fn_admm_rew = []
xn_admm_rew = []

f = f_orig.reshape(-1,1)
delta_list = np.linspace(0.2, 1, 12)

x = x_admm
x_prev = x

d_h = np.zeros((M*N,1))
q_h = np.zeros((M*N,1))
d_v = np.zeros((M*N,1))
q_v = np.zeros((M*N,1))

W_h = np.ones((M*N,1))
W_v = np.ones((M*N,1))

for k in range(MAXITERS):

    # f
    fn = (MU/2)*la.norm(x - f)**2 + la.norm(d_h, 1) + la.norm(d_v, 1) \
        + (LAMBDA/2)*la.norm(d_h - W_h*(D_h@x) - q_h)**2 + (LAMBDA/2)*la.norm(d_v - W_v*(D_v@x) - q_v)**2

    # update
    x = f + (LAMBDA/MU)*(D_h.T@(W_h*(d_h - W_h*(D_h@x) - q_h)) + D_v.T@(W_v*(d_v - W_v*(D_v@x) - q_v)))
    # x = MU*f/(MU + 4*LAMBDA) + LAMBDA/(MU + 4*LAMBDA)*(D_vh@x + D_h.T@(d_h - q_h) + D_v.T@(d_v - q_v))

    d_h = soft_thres(W_h*(D_h@x) + q_h, 1/LAMBDA)
    d_v = soft_thres(W_v*(D_v@x) + q_v, 1/LAMBDA)
    q_h = q_h + (W_h*(D_h@x) - d_h)
    q_v = q_v + (W_v*(D_v@x) - d_v)

    W_h = 1 / (np.abs(d_h) + DELTA)
    W_v = 1 / (np.abs(d_v) + DELTA)

    # stop condition
    xcon = la.norm(x - x_prev)/la.norm(x)
    if (k >= 1) and (xcon < CRIT):
        break
    
    x_prev = x

    # history
    fn_admm_rew.append(fn)
    xn_admm_rew.append(xcon)
    print(f'i = {k:<4d}, |x-x|/|x| = {xcon:.8f}, f = {fn:.8f}')

# optimal
x_admm_rew = x

In [None]:
''' plot '''
_, axs = plt.subplots(1,4, figsize=(15,5))
axs[0].imshow(x_admm.reshape(M,N), cmap='gray')
axs[0].set_title('ADMM')
axs[0].axis('off')

axs[1].imshow(x_admm_rew.reshape(M,N), cmap='gray')
axs[1].set_title('ADMM + reweighted')
axs[1].axis('off')

axs[2].imshow(f_orig, cmap='gray')
axs[2].set_title('f')
axs[2].axis('off')

axs[3].imshow(x_orig, cmap='gray')
axs[3].set_title('x_orig')
axs[3].axis('off')

plt.tight_layout()
plt.show()

In [None]:
''' find delta : reweighted L1 (ADMM)  '''

MAXITERS = 300
DELTA = 3.9

fn_hist = []
xn_hist = []
x_hist = []

f = f_orig.reshape(-1,1)
delta_list = np.linspace(0.2, 1, 12)

for delta in delta_list:

    x = x_admm
    x_prev = x

    d_h = np.zeros((M*N,1))
    q_h = np.zeros((M*N,1))
    d_v = np.zeros((M*N,1))
    q_v = np.zeros((M*N,1))

    W_h = np.ones((M*N,1))
    W_v = np.ones((M*N,1))

    for k in range(MAXITERS):

        # f
        fn = (MU/2)*la.norm(x - f)**2 + la.norm(d_h, 1) + la.norm(d_v, 1) \
            + (LAMBDA/2)*la.norm(d_h - W_h*(D_h@x) - q_h)**2 + (LAMBDA/2)*la.norm(d_v - W_v*(D_v@x) - q_v)**2

        # update
        x = f + (LAMBDA/MU)*(D_h.T@(W_h*(d_h - W_h*(D_h@x) - q_h)) + D_v.T@(W_v*(d_v - W_v*(D_v@x) - q_v)))
        # x = MU*f/(MU + 4*LAMBDA) + LAMBDA/(MU + 4*LAMBDA)*(D_vh@x + D_h.T@(d_h - q_h) + D_v.T@(d_v - q_v))

        d_h = soft_thres(W_h*(D_h@x) + q_h, 1/LAMBDA)
        d_v = soft_thres(W_v*(D_v@x) + q_v, 1/LAMBDA)
        q_h = q_h + (W_h*(D_h@x) - d_h)
        q_v = q_v + (W_v*(D_v@x) - d_v)

        W_h = 1 / (np.abs(d_h) + delta)
        W_v = 1 / (np.abs(d_v) + delta)

        # stop condition
        xcon = la.norm(x - x_prev)/la.norm(x)
        if (k >= 1) and (xcon < CRIT):
            break
        
        x_prev = x

    # history
    fn_hist.append(fn)
    xn_hist.append(xcon)
    x_hist.append(x)
    print(f'delta = {delta:.2f}, i = {k:<4d}, |x-x|/|x| = {xcon:.6f}, f = {fn:.2f}')


In [None]:
# argmin
fn_min_idx = np.argmin(fn_hist[:30])
xn_min_idx = np.argmin(xn_hist[:30])

print(f'fn_min : {delta_list[fn_min_idx]:.3f}')
print(f'xn_min : {delta_list[xn_min_idx]:.3f}')

In [None]:
# plot graph
_, axs = plt.subplots(2,1, sharex=True, figsize=(10,5))
axs[0].semilogy(fn_hist, label='f')
# axs[0].vlines(fn_min_idx, min(fn_hist), max(fn_hist), colors='red', linestyles='--', label='$\\delta=0.4$')
axs[0].set_ylabel('$f$')
axs[0].legend()
axs[0].grid()

axs[1].semilogy(xn_hist, label='$|x_k - x_{k-1}|_2/|x_k|_2$')
# axs[1].vlines(xn_min_idx, 0, max(xn_hist), colors='red', linestyles='--', label='$\\delta=0.5$')
axs[1].set_ylabel('$|x-x|/|x|$')
axs[1].set_xlabel('$\\delta$')
axs[1].legend()
axs[1].grid()

plt.show()

In [None]:
# plot images
num_img = len(x_hist)
cols = 6
rows = (num_img + cols - 1) // cols

fig, axs = plt.subplots(rows, cols, figsize=(15, 2.6 * rows))

for i, ax in enumerate(axs.flat):
    if i < num_img:
        ax.imshow(x_hist[i].reshape(M,N), cmap='gray')
        ax.set_title(f'$\\delta={delta_list[i]:.2f}$')
        ax.axis('off')
    else:
        ax.axis('off')

plt.tight_layout()
plt.show()

# ADMM (1d)
D_h + D_v = D

In [13]:
''' TV (D_h, D_v) '''
D = sp.vstack([D_h, D_v])   # [D_h, D_v] -> 하나로 처리

In [None]:
''' ADMM (1d) '''

fn_admm2 = []
xn_admm2 = []

f = f_orig.reshape(-1,1)
x = np.zeros((M*N,1))
x_prev = x

d = np.zeros((2*(M*N),1))
q = np.zeros((2*(M*N),1))

for k in range(MAXITERS):

    # f
    fn = (MU/2)*la.norm(x - f)**2 + la.norm(d, 1) + (LAMBDA/2)*la.norm(d - D@x - q)**2

    # update
    x = f + (LAMBDA/MU)*(D.T@(d - D@x - q))

    d = soft_thres(D@x + q, 1/LAMBDA)
    q = q + (D@x - d)

    # stop condition
    xcon = la.norm(x - x_prev)/la.norm(x)
    if (k >= 1) and (la.norm(x - x_prev)/la.norm(x) < CRIT):
        break
    
    x_prev = x

    # history
    fn_admm2.append(fn)
    xn_admm2.append(xcon)
    print(f'i = {k:<4d}, |x-x|/|x| = {xcon:.8f}, f = {fn:.8f}')

# optimal
x_admm2 = x

In [None]:
''' ADMM + reweighted L1 (1d) '''

DELTA = 3.9

fn_admm_rew2 = []
xn_admm_rew2 = []

f = f_orig.reshape(-1,1)
x = x_admm2
x_prev = x

d = np.zeros((2*(M*N),1))
q = np.zeros((2*(M*N),1))
W = np.ones((2*(M*N),1))

for k in range(MAXITERS):

    # f
    fn = (MU/2)*la.norm(x - f)**2 + la.norm(d, 1) + (LAMBDA/2)*la.norm(d - W*(D@x) - q)**2

    # update
    x = f + (LAMBDA/MU)*(D.T@(d - W*(D@x) - q))

    d = soft_thres(W*(D@x) + q, 1/LAMBDA)
    q = q + (W*(D@x) - d)
    W = 1 / (d + DELTA)

    # stop condition
    xcon = la.norm(x - x_prev)/la.norm(x)
    if (k >= 1) and (la.norm(x - x_prev)/la.norm(x) < CRIT):
        break
    
    x_prev = x

    # history
    fn_admm_rew2.append(fn)
    xn_admm_rew2.append(xcon)
    print(f'i = {k:<4d}, |x-x|/|x| = {xcon:.8f}, f = {fn:.8f}')

# optimal
x_admm_rew2 = x