In [None]:
import numpy as np
from matplotlib import pyplot as plt
from skimage.io import imread
from skimage.util import view_as_windows

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
rootfolder = ".."

Useful function for plotting a 2D dictionary


In [None]:
def get_dictionary_img(D):
    M = D.shape[0]
    p = int(round(np.sqrt(M)))
    bound = 2
    img = np.ones((p * p + bound * (p - 1), p * p + bound * (p - 1)))
    for i in range(M):
        m = np.mod(i, p)
        n = int((i - m) / p)
        m = m * p + bound * m
        n = n * p + bound * n
        atom = D[:, i].reshape((p, p))
        if atom.min() < atom.max():
            atom = (atom - atom.min()) / (atom.max() - atom.min())
        img[m : m + p, n : n + p] = atom

    return img

Load the image and rescale it in $[0,1]$


In [None]:
# img = imread(f'{rootfolder}/data/cameraman.png') / 255
img = imread(f"{rootfolder}/data/barbara.png") / 255
# img = imread(f'{rootfolder}/data/Lena512.png') / 255

imsz = img.shape

# patch size
p = 8

# number of elements in the patch
M = p**2

Corrupt the image with white gaussian noise


In [None]:
sigma_noise = 20 / 255
noisy_img = img + np.random.normal(size=imsz) * sigma_noise

Compute the psnr of the noisy input


In [None]:
psnr_noisy = 10 * np.log10(1 / np.mean((noisy_img - img) ** 2))
psnr_noisy

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(img, cmap="gray")
ax[0].set_title("Original image")

ax[1].imshow(noisy_img, cmap="gray")
ax[1].set_title(f"Noisy image, PSNR = {psnr_noisy:.2f}")

## Generate the Global PCA basis for this image


Set the parameters for denoising


In [None]:
# set the threshold for the Hard Thresholding
tau = 3 * sigma_noise  # Donoho says: sigma * sqrt(2*log(p^2))

Stack all the image patches in a large matrix $S$. Each patch goes in a column of $S$


In [None]:
patches = view_as_windows(img, (p, p))
patches = patches.reshape(-1, p, p)

# Vectorize each patch (flatten to a column)
S = patches.reshape(patches.shape[0], M).T  # shape: (M, number_of_patches)

Compute $\tilde S$ by zero centering $S$


In [None]:
avg_patch = np.mean(S, axis=1, keepdims=True)
Stilde = S - avg_patch

Compute the PCA transformation via SVD


In [None]:
U, Sigma, V = np.linalg.svd(Stilde, full_matrices=False)

Show the learned PCA basis


In [None]:
U_img = get_dictionary_img(U)
plt.figure(figsize=(10, 10))
plt.imshow(U_img, cmap="gray")

## Patch-based denoising

Initialize the variables


In [None]:
STEP = 1

# initialize the estimated image
img_hat = np.zeros_like(img)

# initialize the weight matrix
weights = np.zeros_like(img)

In [None]:
for i in range(0, imsz[0] - p + 1, STEP):
    for j in range(0, imsz[1] - p + 1, STEP):
        # extract the patch with the top left corner at pixel (i, j)
        s = noisy_img[i : i + p, j : j + p].flatten()

        # Preprocessing: remember to subtract the avg_patch (preprocessing used for PCA)
        s = s - avg_patch.flatten()

        # compute the representation w.r.t. the PCA basis
        x_hat = U.T @ s

        # perform the hard thresholding
        x_hat = x_hat * (np.abs(x_hat) > tau)

        # synthesis: perform the reconstruction
        y_hat = U @ x_hat

        # add the avg patch back
        y_hat = y_hat + avg_patch.flatten()

        # compute the weight for the reconstructed patch
        w = 1

        # put the denoised patch into the denoised image using the computed weight
        img_hat[i : i + p, j : j + p] += w * y_hat.reshape(p, p)

        # store the weight of the current patch in the weight matrix
        weights[i : i + p, j : j + p] += w

# normalize the estimated image with the computed weights
img_hat = img_hat / weights

Compute the psnr of the estimated image


In [None]:
psnr_hat = 10 * np.log10(1 / np.mean((img_hat - img) ** 2))
plt.figure(figsize=(10, 10))
plt.imshow(img_hat, cmap="gray")
plt.title(f"Estimated Image,\nPSNR = {psnr_hat:.2f}")