In [None]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.io import loadmat
from skimage.io import imread

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
rootfolder = ".."

Useful function for plot a 2D dictionary


In [None]:
def get_dictionary_img(D):
    M, N = D.shape
    p = int(round(np.sqrt(M)))
    nnn = int(np.ceil(np.sqrt(N)))
    bound = 2
    img = np.ones((nnn * p + bound * (nnn - 1), nnn * p + bound * (nnn - 1)))
    for i in range(N):
        m = np.mod(i, nnn)
        n = int((i - m) / nnn)
        m = m * p + bound * m
        n = n * p + bound * n
        atom = D[:, i].reshape((p, p))
        if atom.min() < atom.max():
            atom = (atom - atom.min()) / (atom.max() - atom.min())
        img[m : m + p, n : n + p] = atom

    return img

Define a function that implements the OMP


In [None]:
def OMP(s, D, L, tau):
    M, N = D.shape
    x = np.zeros(N)
    r = s.copy()  # residual
    omega = []  # support set

    while len(omega) < L and np.linalg.norm(r) > tau:
        # SWEEP STEP: compute residual error for each atom
        e = np.zeros(N)
        for j in range(N):
            if j not in omega:  # only consider atoms not already selected
                proj_coeff = np.dot(D[:, j], r)
                e[j] = np.linalg.norm(r - proj_coeff * D[:, j]) ** 2
            else:
                e[j] = np.inf  # exclude already selected atoms

        # find the atom that minimizes residual error
        jStar = np.argmin(e)

        # add selected atom to support set
        omega.append(jStar)

        # solve least squares problem over selected atoms
        D_omega = D[:, omega]
        x_omega = np.linalg.lstsq(D_omega, s, rcond=None)[0]

        # update coefficient vector
        x = np.zeros(N)
        x[omega] = x_omega

        # update residual
        r = s - D_omega @ x_omega

    return x

Load the image and rescale it in $[0,1]$


In [None]:
img = imread(f"{rootfolder}/data/Lena512.png") / 255

imsz = img.shape

# patch size
p = 8

# number of elements in the patch
M = p**2

Corrupt the image with white gaussian noise


In [None]:
sigma_noise = 20 / 255
noisy_img = img + np.random.normal(size=imsz) * sigma_noise

Compute the PSNR of the noisy input


In [None]:
psnr_noisy = 10 * np.log10(1 / np.mean((noisy_img - img) ** 2))
psnr_noisy

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(img, cmap="gray")
ax[0].set_title("Original image")

ax[1].imshow(noisy_img, cmap="gray")
ax[1].set_title(f"Noisy image, PSNR = {psnr_noisy:.2f}")

Load and display the dictionary learned from patches


In [None]:
D = loadmat(f"{rootfolder}/data/dict_nat_img.mat")["D"]

# display the learned basis
D_img = get_dictionary_img(D)
plt.figure(figsize=(10, 10))
plt.imshow(D_img, cmap="gray")

## Denoising


In [None]:
# initialize the estimated image
img_hat = np.zeros(imsz)

# initialize the weight matrix
weights = np.zeros(imsz)

# set the threshold
tau = 1.15 * p * sigma_noise

# define the step (=p for non overlapping paches)
STEP = 2  # STEP = 1 might be very time consuming, start with larger STEP

Operate patchwise


In [None]:
for i in range(0, imsz[0] - p + 1, STEP):
    for j in range(0, imsz[1] - p + 1, STEP):
        # extract the patch with the top left corner at pixel (i, j)
        s = noisy_img[i : i + p, j : j + p].flatten()

        # store and subtract the mean
        s_mean = s.mean()
        s -= s_mean

        # perform the sparse coding
        x = OMP(s, D, L=10, tau=tau)

        # perform the reconstruction
        s_hat = D @ x

        # add back the mean
        s_hat += s_mean

        # put the denoised patch into the estimated image using uniform weights
        img_hat[i : i + p, j : j + p] += s_hat.reshape(p, p)

        # store the weight of the current patch in the weight matrix
        weights[i : i + p, j : j + p] += 1

Normalize the estimated image with the computed weights


In [None]:
# Normalize the estimated image with the computed weights
img_hat = img_hat / weights

# Compute the PSNR of the estimated image
psnr_hat = 10 * np.log10(1 / np.mean((img_hat - img) ** 2))

Compute the PSNR of the estimated image


In [None]:
# psnr_hat =
plt.figure(figsize=(10, 10))
plt.imshow(img_hat, cmap="gray")
plt.title(f"Estimated Image,\nPSNR = {psnr_hat:.2f}")