In [1]:
%matplotlib qt
import numpy as np
from skimage import io as skio
import scipy.sparse
import matplotlib.pyplot as plt
import time

In [2]:
def make_nabla(M,N):
    row = np.arange(0,M*N)
    dat = np.ones(M*N)
    col = np.arange(0,M*N).reshape(M,N)
    col_xp = np.block([[col[:,1:],col[:,-1:]]])
    col_yp = np.block([[col[1:,:]],[col[-1:,:]]])
    
    nabla_x = scipy.sparse.coo_matrix((dat, (row, col_xp.flatten())), shape=(M*N, M*N)) - \
              scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))

    nabla_y = scipy.sparse.coo_matrix((dat, (row, col_yp.flatten())), shape=(M*N, M*N)) - \
              scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))
    
    nabla = scipy.sparse.vstack([nabla_x, nabla_y])
    
    return(nabla)

In [3]:
# Read image

# Einstein
#g = skio.imread('einstein_gray.png')/255

# Disks
#g = skio.imread('disks.png')/255

g = skio.imread('affine.png')/255

print(g.shape)
M,N=g.shape

(335, 335)


In [4]:
# 1 ... Gaussian, 2 ... S&P, 3 ... Gamma
noise = 1

# 1 ... quad, 2 ... l1, 3 ... entropy
dataterm = 1

lamb = 5 # weight of data term
alpha1 = 1 # weight of 1st order term
alpha2 = 2 # weight of 2nd order term

maxiter = 500 # number of iterations

# generate noisy image
f = np.copy(g.flatten())

if noise == 1:
    # Gaussian noise
    sigma = 0.1
    f = f + np.random.randn(M*N)*sigma

if noise ==2:
    # Salt and Pepper noise
    n = np.random.rand(M*N)
    t = 0.2
    f[n <= t/2] = 0.0
    f[n >= (1-t/2)] = 1.0
    
if noise ==3:
    # Gamma noise
    k = 10
    f = f*np.random.gamma(k,1/k,M*N)

# clean image (primal variable)
u = np.zeros(M*N*3)
u[:M*N] = f

# dual variable
p = np.zeros(M*N*6)

# make nabla operator
nabla = make_nabla(M,N)

Z = scipy.sparse.coo_matrix((2*M*N,M*N))
I = scipy.sparse.eye(2*M*N)

K1 = scipy.sparse.hstack([nabla, -I])
K2 = scipy.sparse.hstack([Z, nabla, Z])
K3 = scipy.sparse.hstack([Z, Z, nabla])

K = scipy.sparse.vstack([K1,K2,K3])

L = np.sqrt(12)

# primal and dual step size

if dataterm == 1:
    tau = 0.01
    
if dataterm == 2:
    tau = 1/L
    
if dataterm == 3:
    tau = 0.1/L
 
sigma = 1/tau/L**2

# plot the image
plt.close("all")
plt.figure(1, figsize=(15,15))
fig = plt.gcf()
im = plt.imshow(f.reshape(M,N), cmap="gray", vmin=0, vmax=1)

for iter in range(0,maxiter):

    # primal update
    u_ = np.copy(u)
    u = u - tau*(K.T@p)
    
    # proximal maps
    if dataterm == 1:
        u[:M*N] = (u[:M*N] + tau*lamb*f)/(1+tau*lamb)
        
    if dataterm == 2:
        u[:M*N] = f + np.maximum(0.0, np.abs(u[:M*N]-f)-tau*lamb)*np.sign(u[:M*N]-f)
        
    if dataterm == 3:
        t = tau*lamb-u[:M*N]
        u[:M*N] = np.maximum(0, (np.sqrt(t**2 + 4*tau*lamb*f)-t)/2)
    
    # overrelaxation
    u_ = 2*u-u_
    
    
    # dual update
    p = (p + sigma*(K@u_))
    
    # proximal maps
    p1 = p[:2*M*N].reshape(2,M*N)
    norm_p = np.sqrt(p1[0,:]**2 + p1[1,:]**2)
    denom = np.maximum(1, norm_p/alpha1)
    p1 = p1 / denom[np.newaxis, :]
    p[:2*M*N] = p1.flatten()
    
    p2 = p[2*M*N:].reshape(4,M*N)
    norm_p = np.sqrt(p2[0,:]**2 + p2[1,:]**2 + p2[2,:]**2 + p2[3,:]**2)
    denom = np.maximum(1, norm_p/alpha2)
    p2 = p2 / denom[np.newaxis, :]
    p[2*M*N:] = p2.flatten()
        
        
    if np.mod(iter, 10) == 0:
        print("TGV2: iter = ", iter)
        im.set_data(u[:M*N].reshape(M,N))# plt.imshow(u.reshape(M,N), cmap="gray")
        fig.canvas.flush_events()
        fig.canvas.draw()

TGV2: iter =  0
TGV2: iter =  10
TGV2: iter =  20
TGV2: iter =  30
TGV2: iter =  40
TGV2: iter =  50
TGV2: iter =  60
TGV2: iter =  70
TGV2: iter =  80
TGV2: iter =  90
TGV2: iter =  100
TGV2: iter =  110
TGV2: iter =  120
TGV2: iter =  130
TGV2: iter =  140
TGV2: iter =  150
TGV2: iter =  160
TGV2: iter =  170
TGV2: iter =  180
TGV2: iter =  190


KeyboardInterrupt: 