In [2]:
''' import '''
import scipy.io as io
import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt

In [3]:
''' load data '''
data = io.loadmat('./data/hw3_prob1a.mat')

A = data['A']
b = data['b']
x_orig = data['x_orig']

M, N = A.shape

In [4]:
''' hyper param '''
LAMBDA = 2
MAXITERS = 10000
CRIT = 1e-4

t = 1/la.norm(A, 2)**2 # step size

In [5]:
''' function '''
def soft_thres(z, t):
    return np.sign(z)*np.maximum(np.abs(z) - t, 0)

In [None]:
''' soft-threshold (ISTA) '''

f_hist_ista = []
xcond_hist_ista = []

x = np.zeros((N, 1))

for k in range(MAXITERS):

    # f
    f = 1/2*la.norm(A@x - b, 2)**2 + LAMBDA*la.norm(x, 1)   # f(x) = g(x) + h(x)
    grad_g = A.T@(A@x - b)

    # update x
    x = soft_thres(x - t*grad_g, LAMBDA*t)

    # stop condition
    if (k >= 1) and (la.norm(x - x_prev, 2)/la.norm(x, 2) < CRIT):
        break
    
    x_prev = x

    # history
    f_hist_ista.append(f)
    xcond_hist_ista.append(la.norm(x - x_orig, 2))

    print(f'i = {k}, f = {f}')

# optimal
x_ista = x

In [None]:
plt.title('ISTA')
plt.plot(x_ista, linestyle='', marker='x', label='$x_{recon}$')
plt.plot(x_orig, linestyle='', marker='.', label='$x_{orig}$')
plt.legend()
plt.show()

In [None]:
''' FISTA '''

f_hist_fista = []
xcond_hist_fista = []

x = np.zeros((N, 1))
y = x

for k in range(MAXITERS):

    # f
    f = 1/2*la.norm(A@x - b)**2 + LAMBDA*la.norm(x, 1)
    grad_gy = A.T@(A@y - b)

    # update x, y
    x = soft_thres(y - t*grad_gy, LAMBDA*t)
    y = x + k/(k+3)*(x - x_prev)

    # stop condition
    if (k >= 1) and (la.norm(x - x_prev, 2)/la.norm(x, 2) < CRIT):
        break
    
    x_prev = x
    
    # history
    f_hist_fista.append(f)
    xcond_hist_fista.append(la.norm(x - x_orig, 2))

    print(f'i = {k}, f = {f}')

# optimal
x_fista = x

In [None]:
plt.title('FISTA')
plt.plot(x_fista, linestyle='', marker='x', label='$x_{recon}$')
plt.plot(x_orig, linestyle='', marker='.', label='$x_{orig}$')
plt.legend()
plt.show()

In [None]:
plt.title('Cost = $||x_k - x_{orig}||$')
plt.semilogy(xcond_hist_ista, label='ISTA')
plt.semilogy(xcond_hist_fista, label='FISTA')

plt.xlabel('Iteration')
plt.ylabel('Cost')
plt.legend()
plt.show()