In [None]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.io import loadmat
from skimage.io import imread

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
rootfolder = ".."

Useful function for plot a 2D dictionary


In [None]:
def get_dictionary_img(D):
    M, N = D.shape
    p = int(round(np.sqrt(M)))
    nnn = int(np.ceil(np.sqrt(N)))
    bound = 2
    img = np.ones((nnn * p + bound * (nnn - 1), nnn * p + bound * (nnn - 1)))
    for i in range(N):
        m = np.mod(i, nnn)
        n = int((i - m) / nnn)
        m = m * p + bound * m
        n = n * p + bound * n
        atom = D[:, i].reshape((p, p))
        if atom.min() < atom.max():
            atom = (atom - atom.min()) / (atom.max() - atom.min())
        img[m : m + p, n : n + p] = atom

    return img

Define a function that implements the OMP


In [None]:
def OMP(s, D, L, tau):
    _, N = D.shape
    r = s.copy()  # initial residual
    omega = []  # support set
    x_OMP = np.zeros(N)  # final sparse code

    while len(omega) < L and np.linalg.norm(r) > tau:
        # SWEEP STEP: compute correlations between residual and dictionary atoms
        e = np.zeros(N)
        for j in range(N):
            e[j] = np.abs(D[:, j].T @ r)

        # find the column index with maximum correlation
        jStar = np.argmax(e)

        # UPDATE support set
        if jStar not in omega:
            omega.append(jStar)

        # update coefficients using least squares
        D_omega = D[:, omega]
        a, _, _, _ = np.linalg.lstsq(D_omega, s, rcond=None)

        # update residual
        r = s - D_omega @ a

    # construct full sparse vector
    for i, idx in enumerate(omega):
        x_OMP[idx] = a[i]

    return x_OMP

Load the image and rescale it in $[0,1]$


In [None]:
img = imread(f"{rootfolder}/data/peppers256.png") / 255
# img = imread(f'{rootfolder}/data/barbara.png') / 255
# img = imread(f'{rootfolder}/data/Lena512.png') / 255

imsz = img.shape

# patch size
p = 8

# number of elements in the patch
M = p**2

Corrupt the image with white gaussian noise


In [None]:
sigma_noise = 20 / 255
noisy_img = img + np.random.normal(size=imsz) * sigma_noise

Percentage of removed pixel


In [None]:
perc_of_removed_pixels = 0.25

Arbitrarily remove pixels setting them to zero


In [None]:
# create a vector with all the indexes of the image
idx = np.arange(img.size)

# shuffle it and take the target percentage of indexes
np.random.shuffle(idx)
idx = idx[: int(perc_of_removed_pixels * img.size)]

# the mask is 0 for the chosen idx, 1 elsewhere
msk = np.ones(img.shape)
msk.flat[idx] = 0

# apply the mask: set to 0 some elements in the noisy image
noisy_img *= msk

Compute the psnr of the noisy input


In [None]:
psnr_noisy = 10 * np.log10(1**2 / np.mean((img - noisy_img) ** 2))

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(noisy_img, cmap="gray")
ax[0].set_title("Noisy image before inpainting, PSNR = {psnr_noisy:.2f}")

ax[1].imshow(msk, cmap="gray")
ax[1].set_title("Dead pixels")

Load and display the dictionary learned from patches


In [None]:
D = loadmat(f"{rootfolder}/data/dict_nat_img.mat")["D"]

# add a constant atom to D, KSVD was trained over patches with zero mean - and normalize it
# UPDATE D
dc = np.ones((M, 1)) / np.sqrt(M)
D = np.hstack([D, dc])
D = D / np.linalg.norm(D, axis=0)

# display the dictionary

D_img = get_dictionary_img(D)

plt.figure(figsize=(10, 10))
plt.imshow(D_img, cmap="gray")

## Inpainting


In [None]:
# SET stopping criteria of OMP
# orthogonal matching pursuit uses sparsity and errors as stopping criteria
L = M / 2

# initialize the estimated image
img_hat = np.zeros_like(img)

# initialize the weight matrix
weights = np.zeros_like(img)

# define the step (=p for non overlapping paches)
STEP = 4  # STEP = 1 might be very time consuming, start with larger STEP

Operate patchwise


In [None]:
for i in range(0, imsz[0] - p + 1, STEP):
    for j in range(0, imsz[1] - p + 1, STEP):
        # extract the patch with the top left corner at pixel (ii, jj)
        s = noisy_img[i : i + p, j : j + p].reshape(-1, 1)

        # patch extracted from the mask
        m = msk[i : i + p, j : j + p].reshape(-1, 1)

        # design the projection operator over the current patch
        proj = np.diag(m.ravel())

        # tau should be proportional to the number of pixels remaining in the patch
        tau = 1e-3 * np.sqrt(np.sum(m))

        # sparse coding w.r.t. PD the inpainted dictionary using L and tau as stopping criteria
        PD = proj @ D
        x = OMP(proj @ s, PD, L, tau)

        # reconstruction: synthesis w.r.t. D the dictionary yielding sparse representation
        s_hat = D @ x

        # use uniform weights for aggregation
        w = 1

        # put the denoised patch into the estimated image using uniform weights
        # UPDATE img_hat
        img_hat[i : i + p, j : j + p] += w * s_hat.reshape(p, p)

        # store the weight of the current patch in the weight matrix
        # UPDATE weights
        weights[i : i + p, j : j + p] += w

Normalize the estimated image with the computed weights


In [None]:
img_hat = img_hat / weights

Compute the psnr of the estimated image


In [None]:
psnr_hat = 10 * np.log10(1**2 / np.mean((img - img_hat) ** 2))

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(20, 10))

ax[0].imshow(img, cmap="gray")
ax[0].set_title("Original image")

ax[1].imshow(noisy_img, cmap="gray")
ax[1].set_title(f"Corrupted image, PSNR = {psnr_noisy:.2f}")

ax[2].imshow(img_hat, cmap="gray")
ax[2].set_title(f"Estimated Image,\nPSNR = {psnr_hat:.2f}")