In [1]:
%matplotlib qt
import numpy as np
from skimage import io as skio
import scipy.sparse
import matplotlib.pyplot as plt
import time

In [2]:
def make_nabla(M,N):
    row = np.arange(0,M*N)
    dat = np.ones(M*N)
    col = np.arange(0,M*N).reshape(M,N)
    col_xp = np.block([[col[:,1:],col[:,-1:]]])
    col_yp = np.block([[col[1:,:]],[col[-1:,:]]])
    
    nabla_x = scipy.sparse.coo_matrix((dat, (row, col_xp.flatten())), shape=(M*N, M*N)) - \
              scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))

    nabla_y = scipy.sparse.coo_matrix((dat, (row, col_yp.flatten())), shape=(M*N, M*N)) - \
              scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))
    
    nabla = scipy.sparse.vstack([nabla_x, nabla_y])
    
    return(nabla)

In [3]:
# Read image

# Einstein
#g = skio.imread('einstein_gray.png')/255

# Disks
g = skio.imread('disks.png')/255

print(g.shape)
M,N=g.shape

(425, 445)


In [4]:
# 1 ... Gaussian, 2 ... S&P, 3 ... Gamma
noise = 1

# 1 ... quad, 2 ... l1, 3 ... entropy
dataterm = 2

lamb = 0.1 # weight of data term
eps = 0.0 # Huber epsilon
maxiter = 2000 # number of iterations

# generate noisy image
f = np.copy(g.flatten())

if noise == 1:
    # Gaussian noise
    sigma = 0.0
    f = f + np.random.randn(M*N)*sigma
    
if noise ==2:
    # Salt and Pepper noise
    n = np.random.rand(M*N)
    t = 0.2
    f[n <= t/2] = 0.0
    f[n >= (1-t/2)] = 1.0

if noise ==3:
    # Gamma noise
    k = 10
    f = f*np.random.gamma(k,1/k,M*N)

# clean image (primal variable)
u = np.copy(f)
# dual variable
p = np.zeros(M*N*2)

# make nabla operator
nabla = make_nabla(M,N)

L = np.sqrt(8)

# primal and dual step size

if dataterm == 1:
    tau = 0.01
    
if dataterm == 2:
    tau = 1/L
    
if dataterm == 3:
    tau = 0.1/L
 

sigma = 1/tau/L**2

# plot the image
plt.close("all")
plt.figure(1, figsize=(15,15))
fig = plt.gcf()
im = plt.imshow(f.reshape(M,N), cmap="gray", vmin=0, vmax=1)

for iter in range(0,maxiter):

    # primal update
    u_ = u.copy()
    u = u - tau*(nabla.T@p)
    
    # proximal maps
    if dataterm == 1:
        u = (u + tau*lamb*f)/(1+tau*lamb)
        
    if dataterm == 2:
        u = f + np.maximum(0.0, np.abs(u-f)-tau*lamb)*np.sign(u-f)
        
    if dataterm == 3:
        t = tau*lamb-u
        u = (np.sqrt(t**2 + 4*tau*lamb*f)-t)/2
    
    # overrelaxation
    u_ = 2*u-u_
    
    # dual update
    p = (p + sigma*(nabla@u_))/(1+sigma*eps)
    
    # proximal map
    p = p.reshape(2,M*N)
    norm_p = np.sqrt(p[0,:]**2 + p[1,:]**2)
    denom = np.maximum(1, norm_p)
    p = p / denom[np.newaxis, :]
    p = p.flatten()
        
    if np.mod(iter, 10) == 0:
        print("iter = ", iter)
        im.set_data(u.reshape(M,N))# plt.imshow(u.reshape(M,N), cmap="gray")
        fig.canvas.flush_events()
        fig.canvas.draw()
        

iter =  0
iter =  10
iter =  20
iter =  30
iter =  40
iter =  50
iter =  60
iter =  70
iter =  80
iter =  90
iter =  100
iter =  110
iter =  120
iter =  130
iter =  140
iter =  150
iter =  160
iter =  170
iter =  180
iter =  190
iter =  200
iter =  210
iter =  220
iter =  230
iter =  240
iter =  250
iter =  260
iter =  270
iter =  280
iter =  290
iter =  300
iter =  310
iter =  320
iter =  330
iter =  340
iter =  350
iter =  360
iter =  370
iter =  380
iter =  390
iter =  400
iter =  410
iter =  420
iter =  430
iter =  440
iter =  450
iter =  460
iter =  470
iter =  480
iter =  490
iter =  500
iter =  510
iter =  520
iter =  530
iter =  540
iter =  550
iter =  560
iter =  570
iter =  580
iter =  590
iter =  600
iter =  610
iter =  620
iter =  630
iter =  640
iter =  650
iter =  660
iter =  670
iter =  680
iter =  690
iter =  700
iter =  710
iter =  720
iter =  730
iter =  740
iter =  750
iter =  760
iter =  770
iter =  780
iter =  790
iter =  800
iter =  810
iter =  820
iter =  830
ite