In [None]:
import numpy as np
from keras.datasets import mnist
import scipy.ndimage
import matplotlib.pyplot as plt
from skimage import metrics
import pandas as pd
import concurrent.futures

In [None]:
# MNIST dataset

np.random.seed(100)

# Import MNIST
(a_tr, labels_train), (a_test, labels_test) = mnist.load_data()

# Smaller sets
N_test_set = 100
N_train_set = 7000
N_val = 3000
N = a_tr.shape[1]

# this is a matrix of matrices
a_test = a_test[0:N_test_set, :, :] / 255.
a_train = a_tr[0:7000, :, :] / 255.
a_val = a_tr[7000:10000, :, :] / 255.

# Add noise
noise_lev = 0.1

# These are matrices of vectors
b_test = a_test + noise_lev * np.random.randn(N_test_set, N, N)
b_train = a_train + noise_lev * np.random.randn(7000, N, N)
b_val = a_val + noise_lev * np.random.randn(N_val, N, N)

In [None]:
#Algorithm functions

def prox_g_d(x, Lambda):
    return np.clip(x, -Lambda, Lambda)

# Create the operator D (and D^T)
def image_grad(x):
    x_img = x.reshape(N,N)
    x_diff_v = np.diff(x_img,axis=0)
    x_diff_h = np.diff(x_img,axis=1)
    return [np.concatenate((x_diff_v.reshape((N-1)*N),x_diff_h.reshape((N-1)*N)))]

def image_div(x):
    x_v = x[0:(N-1)*N].reshape(N-1,N)
    x_h = x[(N-1)*N:].reshape(N,N-1)
    x_div_v = np.concatenate((np.zeros((1,N)),x_v),axis=0) - np.concatenate((x_v,np.zeros((1,N))),axis=0)
    x_div_h = np.concatenate((np.zeros((N,1)),x_h),axis=1) - np.concatenate((x_h,np.zeros((N,1))),axis=1)
    return [(x_div_v + x_div_h).reshape(N**2)]

D_Op = scipy.sparse.linalg.LinearOperator((2 * N * (N - 1), N ** 2), 
                                          matvec=image_grad, rmatvec=image_div)

# We compute here D_TV(x, f_\lambda(y)). Inputs vectors, outputs number
def breg_dist_tv(x, y):
    gradx = D_Op.matvec(x)
    grady = D_Op.matvec(y)
    normx = np.linalg.norm(gradx, ord=1)
    return normx - np.dot(np.sign(grady), gradx)


# We want to do here FISTA on the dual of ||x-y||^2+laTV(x) -> min_x

# First the inertial function

iner = 12

def beta(tk, k):
    t_next = (k + iner - 1) / iner
    return (tk - 1) / t_next, t_next

# Now the algorithm

def dual_fista(lam, y):
    y = y.reshape(N ** 2)
    gamma = 1 / 8
    u = np.zeros((2 * N * (N - 1)))
    x = np.zeros((N ** 2))
    tk = 1
    z = u.copy()
    nor1 = 1
    t = 0
    while nor1 >= 1e-8:
        beta_n, tk = beta(tk, t)
        u_prev = u.copy()
        gz = D_Op.matvec(y - D_Op.rmatvec(z))
        u = prox_g_d(z + gamma * gz, lam)
        z = u + beta_n * (u - u_prev)
        x_old = x.copy()
        x = y - D_Op.rmatvec(u)
        nor1 = np.linalg.norm(x - x_old, ord=1)
        t += 1
    return x.reshape((N, N))


In [None]:
# lambda_Lambda functions

def train_TV(lamb):
    def compute_distance(i):
        f_TV = dual_fista(lamb, b_train[i, :, :]).reshape(N ** 2)
        return breg_dist_tv(a_train[i, :, :].reshape(N ** 2), f_TV)

    with concurrent.futures.ThreadPoolExecutor() as executor:
        distances = np.array(list(executor.map(compute_distance, range(N_train_set))))

    return np.mean(distances)


def get_lambda_star(lamb):  # here lambda is a vector
    TV_err = np.array([train_TV(l) for l in lamb])
    return lamb[np.argmin(TV_err)], TV_err.min()

In [None]:
lamb = np.logspace(-4, -1, num=50)

lambda_star, L_lambda_star = get_lambda_star(lamb)

print(r'$\lambda^*$: ', lambda_star)

In [None]:
# ER functions

def val_TVn(n, lamb):
    sel = np.random.permutation(N_val)[:n]
    a_tr = a_val[sel, :, :]
    b_tr = b_val[sel, :, :]

    def compute_distance(i):
        f_TV = dual_fista(lamb, b_tr[i, :, :]).reshape(N ** 2)
        return breg_dist_tv(a_tr[i, :, :].reshape(N ** 2), f_TV)

    with concurrent.futures.ThreadPoolExecutor() as executor:
        distances = np.array(list(executor.map(compute_distance, range(n))))

    return np.mean(distances)


def cvlambda(n, lamb):  # here lambda is a vector
    TV_err = np.array([val_TVn(n, l) for l in lamb])
    return TV_err.min(), lamb[np.argmin(TV_err)]

def get_L_hat(L_lambda_star, lamb, N_vec, n_it):
    Deltan = np.zeros((len(N_vec), n_it))
    for i in range(len(N_vec)):
        for j in range(n_it):
            Deltan[i, j] = np.abs(L_lambda_star - cvlambda(N_vec[i], lamb)[0])
    return Deltan                             
                             

In [None]:
# ER plot
N_vec = np.arange(10, 160, 10)
n_it = 30  # The perfect number of iterations is 30

deltaTV = get_L_hat(L_lambda_star, lamb, N_vec, n_it)


In [None]:
DeltaTV = np.zeros(np.shape(deltaTV))
for i in range(len(deltaTV)):
    DeltaTV[i] = deltaTV[i] * np.sqrt(N_vec[i])

In [None]:
# Figure
dfTV = pd.DataFrame(DeltaTV)
meanDelTV = dfTV.mean(axis='columns')
lowerTV = np.quantile(DeltaTV, 0.05, axis=1)
upperTV = np.quantile(DeltaTV, 0.95, axis=1)

plt.close('all')
fig, ax1 = plt.subplots(figsize=(20, 5), dpi=100)
# fig.suptitle("Excess risk behaviour", fontsize=20)
ax1.plot(N_vec, meanDelTV, '-')
ax1.scatter(N_vec, meanDelTV, color='red', s=50)
ax1.fill_between(N_vec, lowerTV, upperTV, alpha=0.2)
ax1.set_ylabel(r'$\Delta(n)\sqrt{n}$', fontsize=25)
ax1.set_xlabel(r'$n$', fontsize=25)
# ax1.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig("./ER_TVdenoising.pdf", bbox_inches='tight')
plt.show(block=False)  # This should go before show

In [None]:
#We compute different reg parameters here

In [None]:
L_lambda_hat1, lambda_hat1 = cvlambda(100, lamb)

print(r'$\widehat{\lambda}$:', lambda_hat1)

In [None]:
L_lambda_hat2, lambda_hat2 = cvlambda(100, lamb)

print(r'$\widehat{\lambda}$:', lambda_hat2)

In [None]:
L_lambda_hat3, lambda_hat3 = cvlambda(100, lamb)

print(r'$\widehat{{\lambda}}$:', lambda_hat3)

In [None]:
L_lambda_hat4, lambda_hat4 = cvlambda(100, lamb)

print(r'$\widehat{{\lambda}}$:', lambda_hat4)

In [None]:
#Plot for two test example

a1 = a_test[42, :, :]
b1 = b_test[42, :, :]
a1_tv= dual_fista(lambda_hat1, b1)
a2_tv= dual_fista(lambda_hat2, b1)

a2 = a_test[55, :, :]
b2 = b_test[55, :, :]
a3_tv= dual_fista(lambda_hat3, b2)
a4_tv= dual_fista(lambda_hat4, b2)


fig, axs = plt.subplots(2, 4, figsize=(20, 8))

axs[0, 0].imshow(a1, cmap='gray', vmin=0, vmax=1)
axs[0, 0].set_title('Original', fontsize=20)
axs[0, 1].imshow(b1, cmap='gray', vmin=0, vmax=1)
axs[0, 1].set_title(r'Noisy, $D_{{\mathrm{{TV}}}}$: {:.4f}'
                    .format(breg_dist_tv(a1.reshape(N**2), b1.reshape(N**2))), fontsize=20)
axs[0, 2].imshow(a1_tv, cmap='gray', vmin=0, vmax=1)
axs[0, 2].set_title(r'$\widehat{{\lambda}}_1=0.0494$, $D_{{\mathrm{{TV}}}}$: {:.4f}'
                    .format(breg_dist_tv(a1.reshape(N**2), a1_tv.reshape(N**2))), fontsize=20)
axs[0, 3].imshow(a2_tv, cmap='gray', vmin=0, vmax=1)
axs[0, 3].set_title(r'$\widehat{{\lambda}}_2=0.0091$, $D_{{\mathrm{{TV}}}}$: {:.4f}'
                    .format(breg_dist_tv(a1.reshape(N**2), a2_tv.reshape(N**2))), fontsize=20)

axs[1, 0].imshow(a2, cmap='gray', vmin=0, vmax=1)
axs[1, 0].set_title('Original', fontsize=20)
axs[1, 1].imshow(b2, cmap='gray', vmin=0, vmax=1)
axs[1, 1].set_title(r'Noisy, $D_{{\mathrm{{TV}}}}$: {:.4f}'
                    .format(breg_dist_tv(a2.reshape(N**2), b2.reshape(N**2))), fontsize=20)
axs[1, 2].imshow(a3_tv, cmap='gray', vmin=0, vmax=1)
axs[1, 2].set_title(r'$\widehat{{\lambda}}_3=0.0025$, $D_{{\mathrm{{TV}}}}$: {:.4f}'
                    .format(breg_dist_tv(a2.reshape(N**2), a3_tv.reshape(N**2))), fontsize=20)
axs[1, 3].imshow(a4_tv, cmap='gray', vmin=0, vmax=1)
axs[1, 3].set_title(r'$\widehat{{\lambda}}_4=0.0002$, $D_{{\mathrm{{TV}}}}$: {:.4f}'
                    .format(breg_dist_tv(a2.reshape(N**2), a4_tv.reshape(N**2))), fontsize=20)

plt.tight_layout()
plt.savefig("./tvrecov_test_combined.pdf", dpi=300, bbox_inches='tight')
plt.show()