# Spectral Restoration: Deconvolution

Images distorted by known linear blur and additive white noise, i.e., $g=f*h+n$, can be (somewhat) recovered using deconvolution. The parametric Wiener filter produces $\hat{f} = w * g$ for which $E (f-\hat{f})^2$ is minimized. The computation can be done in the Fourier space:

$$ \hat{f} = {\cal F}^{-1}\{ W\, {\cal F} \{ g \} \}$$

$$W(u,v) = \dfrac{H^*(u,v)}{|H(u,v)|^2 + \gamma\, |R(u,v)|^2}$$

where $\gamma$ is a user-defined scalar and $|R(u,v)|^2 = 1/\text{PSNR}(u,v)$. 
A simple version might use the approximation $\text{PSNR} =\sigma^2_f/\sigma^2_n$. A more sophisticated version
would use the ratio of the actual signal and noise power spectra, namely,
$\text{PSNR}(u,v) = S_\text{ff}(u,v)/S_\text{nn}(u,v)$. The spectral information must be estimated if not known. A third version might simply penalize high-frequency content by using $R(u,v) = -4\pi^2(u^2+v^2)$ which you may recognize as the Fourier transform of the Laplacian.

In [None]:
%matplotlib inline

import numpy as np

import matplotlib.image as img
import matplotlib.pyplot as plt

from skimage import io, exposure

from scipy.fft import fft2, ifft2, fftshift
from scipy.ndimage import convolve

from skimage.util import random_noise
from skimage.util import img_as_float32 as img_as_float

In [None]:
def show_images(I, titles=None):
    fig, ax = plt.subplots(1, len(I), figsize=(12,5))        
 
    for i in np.arange(0,len(I)):
        ax[i].imshow(I[i], cmap='gray')
        ax[i].set_axis_off()
        if titles != None:
            ax[i].set_title(titles[i])
        
    plt.tight_layout()

In [None]:
def show_plots(I, titles=None):
    fig, ax = plt.subplots(1, len(I), figsize=(12,1))
 
    for i in np.arange(0,len(I)):
        if titles != None:
            ax[i].set_title(titles[i])
        
        r = I[i].shape[0]//2
        ax[i].plot(I[i][r,:])
        ax[i].set_xticks([])
        ax[i].set_yticks([])
        
    plt.tight_layout()

In [None]:
def nextpow2(N):
    n = 1
    while (n<N):
        n *= 2
    return n

In [None]:
I1 = io.imread("../../images/parrot.jpg", as_gray=True)
I1 = img_as_float(I1)

In [None]:
M, N = I1.shape
M2, N2 = (2*nextpow2(M), 2*nextpow2(N))

## Parametric Wiener Filter

In [None]:
def wiener_filter(I, h, gamma, PSNR=None, clip=False):
    F = fft2(I, (M2,N2))
    H = fft2(h, (M2,N2))
    
    if PSNR is None:
        u, v = np.mgrid[-M2//2:M2//2,-N2//2:N2//2]
        R = -4*np.pi**2*(fftshift((u/M2)**2+(v/N2)**2))
        W = np.conj(H)/(np.abs(H)**2 + gamma*R**2)
    else:   
        if clip == True:
            PSNR = np.clip(PSNR, 0, 1)
        W = np.conj(H)/(np.abs(H)**2 + gamma*(1/PSNR))
    
    Ip = np.real(ifft2(W*F))
    
    KM, KN = h.shape
    M0, N0 = (M2//2-KM//2, N2//2-KN//2)
    Ip = fftshift(Ip)[M0:M0+M,N0:N0+N]
    
    Ip = np.clip(Ip, 0, 1)
    
    #print(M, KM, M2, N, KN, N2)
 
    return img_as_float(Ip)

## Example: Gaussian Blur

In [None]:
def gaussian(sigma=1.0, truncate=4.0):
    K = np.int32(np.ceil(truncate*sigma))
    
    u, v = np.mgrid[-K:K+1,-K:K+1]
    
    h = np.exp(-0.5*(u**2+v**2)/(sigma**2))
    h /= h.sum()

    return h

In [None]:
sigma_f = np.std(I1)

sigma_h = 2
sigma_n = 0.20

In [None]:
h = gaussian(sigma_h)
Ih = convolve(I1, h, mode='constant')

In [None]:
In = random_noise(np.zeros_like(I1), mode='gaussian', var=sigma_n**2, clip=False)

In [None]:
Ic = np.clip(Ih + In, 0, 1)
In = Ic - Ih

In [None]:
# Gaussian blur, no noise, PSNR=const approximation

IP = wiener_filter(Ih, h, 0.001, PSNR=1)
IR = wiener_filter(Ih, h, 0.001)

I2 = IP

show_images([I1, Ih, I2], ['Original','Blur','Wiener Filter'])
show_plots([I1, Ih, I2])

In [None]:
# Gaussian blur, white noise, PSNR=const

IP = wiener_filter(Ic, h, 1, PSNR=sigma_f**2/sigma_n**2)
IR = wiener_filter(Ic, h, 1)

I2 = IP

show_images([I1, Ic, I2], ['Original','Blur+Noise','Wiener Filter'])
show_plots([I1, Ic, I2])

In [None]:
# Gaussian blur, white noise, PSNR=known

Sff = np.abs(fft2(I1, (M2,N2)))**2
Snn = np.abs(fft2(In, (M2,N2)))**2

IP = wiener_filter(Ic, h, 1, PSNR=Sff/Snn, clip=False)
IR = wiener_filter(Ic, h, 2)

I2 = IP

show_images([I1, Ic, I2], ['Original','Blur+Noise','Wiener Filter'])
show_plots([I1, Ic, I2])