In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import skimage as ski
import scipy.sparse
from scipy import signal
import time

In [None]:
u = np.ones((20,20))
print(np.linalg.norm(u.ravel()))
uf = np.fft.fft2(u, norm="ortho")
print(np.linalg.norm(uf.ravel()))

In [None]:
def forward_differences(M,N):
    row = np.arange(0,M*N)
    dat = np.ones(M*N)
    col = np.arange(0,M*N).reshape(M,N)
    col_xp = np.hstack([col[:,1:], col[:,-1:]])
    col_yp = np.vstack([col[1:,:], col[-1:,:]])
    
    FD1 = scipy.sparse.coo_matrix((dat, (row, col_xp.flatten())), shape=(M*N, M*N))- \
          scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))

    FD2 = scipy.sparse.coo_matrix((dat, (row, col_yp.flatten())), shape=(M*N, M*N))- \
          scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))
    
    FD = scipy.sparse.vstack([FD1, FD2])
    
    return FD

In [None]:
def prox_conv_fft(uf, ff, kernelf, tau):    
    kernelf_conj = np.conj(kernelf)
    return (uf + tau*kernelf_conj*ff)/(1+tau*kernelf_conj*kernelf)
        
def proj_inf_l2(p, tau):
    # size must be (K,N), l2 over K, inf over N
    norm_p = np.sqrt(np.sum(p**2, axis=0, keepdims=True))
    p /= np.maximum(1, norm_p/tau)
    return p

def proj_inf(p, tau):
    return np.clip(p, -tau, tau)

In [None]:
# primal dual for deconvolution
def tv_deconv_pd(d, kernel, lamb=1.0, maxit=1000, check=100, verbose=0):
    
    M,N = d.shape
    u = d.copy()
    df = np.fft.fft2(d, norm="ortho")
    kernelf = np.fft.fft2(kernel)
    
    u = u.reshape(M*N)
    
    # dual variable
    p = np.zeros(M*N*2)

    # make nabla operator
    D = forward_differences(M,N)

    # primal and dual step size
    # tau * sigma * L^2 = 1
    L = np.sqrt(8)
    tau = 1/L
    sigma = 1/tau/L**2
    theta = 1.0
    
    t0 = time.time()
    E = []
    for it in range(0,maxit):

        # remeber old
        u_ = u.copy()

        # primal update        
        u -= tau*(D.T@p)
        
        # proximal step
        uf = np.fft.fft2(u.reshape(M,N), norm="ortho")
        uf = prox_conv_fft(uf, df, kernelf, tau)
        u = np.real(np.fft.ifft2(uf, norm="ortho")).reshape(M*N)
    
        # overrelaxation
        u_ = u + theta*(u-u_)
        
        # dual update
        p += sigma*(D@u_)

        # projection
        p = p.reshape(2,M*N)
        p = proj_inf_l2(p, lamb)
        p = p.reshape(2*M*N)
        
        TV1 = lamb*np.sum(np.sqrt(np.sum(((D@u).reshape(2,M*N))**2, axis=0)))
        energy = TV1 + 0.5*np.sum((np.real(np.fft.ifft2(uf*kernelf, norm="ortho"))-d)**2)
        E.append(energy)
        if verbose > 0:
            if it%check == check-1:
                

                print("iter = ", it,
                      ", tau = ", "{:.3f}".format(tau),
                      ", sigma = ", "{:.3f}".format(sigma),
                      ", time = ", "{:.3f}".format(time.time()-t0),
                      ", E = ", "{:.6f}".format(energy),
                      end="\r")
                
    return u.reshape(M,N), p.reshape(2,M,N), np.array(E)

In [None]:
def make_motion_blur(N_k):
    y,x = np.meshgrid(np.arange(-N_k//2+1, N_k//2+1, 1), np.arange(-N_k//2+1, N_k//2+1, 1))

    v1 = np.array([1., 1.])
    v1 /= np.linalg.norm(v1)
    sigma_1 = 1.0
    
    v2 = v1.copy()
    v2[0] = v1[1]
    v2[1] = -v1[0]
    sigma_2 = 10.0

    Sigma = np.outer(v1,v1)/sigma_1**2 + np.outer(v2,v2)/sigma_2**2

    q = x**2*Sigma[0,0]/2 + x*y*Sigma[0,1] + y**2*Sigma[1,1]/2
    kernel = np.exp(-q)
    kernel /= kernel.sum()
    
    return kernel

In [None]:
# Load image
g = ski.io.imread("watercastle.png")/255.0
M,N = g.shape

# construct blur kernel
N_k = 15
kernel = make_motion_blur(N_k)

#plt.figure()
#plt.imshow(kernel, cmap="gray")

kernel_full = np.zeros_like(g)
kernel_full[:N_k, :N_k] = kernel
kernel_full = np.roll(kernel_full, -N_k//2+1, axis=(0,1))

#plt.figure()
#plt.imshow(kernel_full, cmap="gray")

gf = np.fft.fft2(g, norm="ortho")
kernelf = np.fft.fft2(kernel_full)
f = np.real(np.fft.ifft2(gf*kernelf, norm="ortho"))

f = f + np.random.randn(M,N)*0.01

#plt.figure()
#plt.imshow(f, cmap="gray")

In [None]:
# Solve the dual ROF model using accelerated primal-dual
lamb_tv = 0.0005
u, p, energy = tv_deconv_pd(f, kernel_full, maxit=2000, check=100, lamb=lamb_tv, verbose=1)

In [None]:
plt.figure()
plt.subplot(121)
plt.imshow(f, cmap="gray")

plt.subplot(122)
plt.imshow(u, cmap="gray")