In [None]:
import time

import numpy as np
from matplotlib import pyplot as plt
from scipy.io import loadmat
from skimage.io import imread

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
rootfolder = ".."

Useful function for plot a 2D dictionary


In [None]:
def get_dictionary_img(D):
    M, N = D.shape
    p = int(round(np.sqrt(M)))
    nnn = int(np.ceil(np.sqrt(N)))
    bound = 2
    img = np.ones((nnn * p + bound * (nnn - 1), nnn * p + bound * (nnn - 1)))
    for i in range(N):
        m = np.mod(i, nnn)
        n = int((i - m) / nnn)
        m = m * p + bound * m
        n = n * p + bound * n
        atom = D[:, i].reshape((p, p))
        if atom.min() < atom.max():
            atom = (atom - atom.min()) / (atom.max() - atom.min())
        img[m : m + p, n : n + p] = atom

    return img

Define a function that implements the OMP


In [None]:
def OMP(s, D, L, tau):
    _, N = D.shape
    r = s.copy()  # initial residual
    omega = []  # support set
    x_OMP = np.zeros(N)  # final sparse code

    while len(omega) < L and np.linalg.norm(r) > tau:
        # SWEEP STEP: compute correlations between residual and dictionary atoms
        e = np.zeros(N)
        for j in range(N):
            e[j] = D[:, j].T @ r

        # find the column index with maximum correlation
        jStar = np.argmax(np.abs(e))

        # UPDATE support set
        if jStar not in omega:
            omega.append(jStar)

        # update coefficients using least squares
        D_omega = D[:, omega]
        x_omega, _, _, _ = np.linalg.lstsq(D_omega, s, rcond=None)

        # update residual
        r = s - D_omega @ x_omega

    # construct full sparse vector
    for i, idx in enumerate(omega):
        x_OMP[idx] = x_omega[i]

    return x_OMP

## Dictionary Learning


Load the image and rescale it in $[0,1]$


In [None]:
path_image = (
    f"{rootfolder}/data/barbara.png"  #  barbara.png, cameraman.png, Lena512.png
)

img = imread(path_image) / 255

imsz = img.shape

# patch size
p = 8

# number of elements in the patch
M = p**2

Extract a bunch of random patches from the image and build the training set $S$


In [None]:
npatch = 10000

S = np.zeros((M, npatch))
for i in range(npatch):
    # Random top-left corner for patch
    row = np.random.randint(0, imsz[0] - p + 1)
    col = np.random.randint(0, imsz[1] - p + 1)
    patch = img[row : row + p, col : col + p]
    S[:, i] = patch.flatten()

Remove the mean from the patches (each column of $S$ must have zero-mean)


In [None]:
S = S - np.mean(S, axis=0, keepdims=True)

Define a function that implements the KSVD


In [None]:
def ksvd(S, M, N, max_iter, npatch, L, print_time=False):
    # initialize the dictionary
    D = np.random.randn(M, N)

    # normalize each column of D (zero mean and unit norm)
    # UPDATE D
    D = D - np.mean(D, axis=0, keepdims=True)
    D = D / np.linalg.norm(D, axis=0, keepdims=True)

    # initialize the coefficient matrix
    X = np.zeros((N, npatch))

    # Main KSVD loop
    for iter in range(max_iter):
        time_start = time.time()

        # Sparse coding step
        # perform the sparse coding via OMP of all the columns of S
        for n in range(npatch):
            X[:, n] = OMP(S[:, n], D, L, 1e-6)

        # Dictionary update step
        # iterate over the columns of D
        for j in range(N):
            # find which signals uses the j-th atom in the sparse coding
            omega = np.where(X[j, :] != 0)[0]

            if len(omega) == 0:
                # if the atom is never used then ignore or substitute it with a random vector
                D[:, j] = np.random.randn(M)
                D[:, j] = D[:, j] / np.linalg.norm(D[:, j])
            else:
                # compute the residual matrix E, ignoring the j-th atom
                E = S - D @ X + np.outer(D[:, j], X[j, :])

                # restrict E to the columns indicated by omega
                Eomega = E[:, omega]

                # Compute the best rank-1 approximation
                U, Sigma, Vt = np.linalg.svd(Eomega, full_matrices=False)

                # update the dictionary
                D[:, j] = U[:, 0]

                # update the coefficient matrix
                X[j, omega] = Sigma[0] * Vt[0, :]

        time_end = time.time()
        if print_time:
            print(f"Iteration {iter} runtime: {time_end - time_start}")

    return D

In [None]:
# number of columns of the dictionary
N = 256

# number of iteration of the KSVD
max_iter = 10

# maximum number of nonzero coefficients for the sparse coding
L = 4


# Call the KSVD implementation
D = ksvd(S, M, N, max_iter, npatch, L, print_time=True)

Show the learned dictionary


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))

ax[0].imshow(img, cmap="gray")
ax[0].set_title(f"Image {path_image.split('/')[-1]}")
ax[0].axis("off")

img_dict = get_dictionary_img(D)
ax[1].imshow(img_dict, cmap="gray")
ax[1].set_title(f"Dictionary learned from {path_image.split('/')[-1]}")
ax[1].axis("off")

plt.tight_layout()

## OMP denoising with learned dictionaries


In [None]:
img_clean = imread(f"{rootfolder}/data/barbara.png") / 255

# Corrupt the image

sigma_noise = 20 / 255
noisy_img = img_clean + np.random.normal(size=imsz) * sigma_noise

In [None]:
psnr_noisy = 10 * np.log10(1 / np.mean((noisy_img - img_clean) ** 2))

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(img_clean, cmap="gray")
ax[0].set_title("Original image (barbara.png)")
ax[0].axis("off")

ax[1].imshow(noisy_img, cmap="gray")
ax[1].set_title(f"Noisy image, PSNR = {psnr_noisy:.2f}")
ax[1].axis("off")

plt.tight_layout()


In [None]:
# patch size
p = 8

# number of elements in the patch
M = p**2

# number of columns of the dictionary
N = 256

# number of iteration of the KSVD
max_iter = 10

# maximum number of nonzero coefficients for the sparse coding
L = 4

Generic dictionary


In [None]:
D_generic = loadmat(f"{rootfolder}/data/dict_nat_img.mat")["D"]

Dictionary learned from a different image


In [None]:
img = imread(f"{rootfolder}/data/cameraman.png") / 255

# Extract random patches
npatch = 10000

S = np.zeros((M, npatch))
for i in range(npatch):
    # Random top-left corner for patch
    row = np.random.randint(0, img.shape[0] - p + 1)
    col = np.random.randint(0, img.shape[1] - p + 1)
    patch = img[row : row + p, col : col + p]
    S[:, i] = patch.flatten()

S = S - np.mean(S, axis=0, keepdims=True)

# Learn the dictionary
D_diff = ksvd(S, M, N, max_iter, npatch, L, print_time=True)

Dictionary learned from the noisy image


In [None]:
# Extract random patches
npatch = 10000

S = np.zeros((M, npatch))
for i in range(npatch):
    # Random top-left corner for patch
    row = np.random.randint(0, noisy_img.shape[0] - p + 1)
    col = np.random.randint(0, noisy_img.shape[1] - p + 1)
    patch = noisy_img[row : row + p, col : col + p]
    S[:, i] = patch.flatten()

S = S - np.mean(S, axis=0, keepdims=True)

# Learn the dictionary
D_noisy = ksvd(S, M, N, max_iter, npatch, L, print_time=True)

Dictionary learned from the clean image


In [None]:
# Extract random patches
npatch = 10000

S = np.zeros((M, npatch))
for i in range(npatch):
    # Random top-left corner for patch
    row = np.random.randint(0, img_clean.shape[0] - p + 1)
    col = np.random.randint(0, img_clean.shape[1] - p + 1)
    patch = img_clean[row : row + p, col : col + p]
    S[:, i] = patch.flatten()

S = S - np.mean(S, axis=0, keepdims=True)

# Learn the dictionary
D_clean = ksvd(S, M, N, max_iter, npatch, L, print_time=True)

OMP denoising


In [None]:
def omp_denoising(noisy_img, D, step, tau):
    # Get image dimensions and patch size
    imsz = noisy_img.shape
    M, N = D.shape
    p = int(np.sqrt(M))  # patch size (assuming square patches)

    # Initialize the estimated image and weight matrix
    img_hat = np.zeros(imsz)
    weights = np.zeros(imsz)

    # Operate patchwise
    for i in range(0, imsz[0] - p + 1, step):
        for j in range(0, imsz[1] - p + 1, step):
            # Extract the patch with the top left corner at pixel (i, j)
            s = noisy_img[i : i + p, j : j + p].flatten()

            # Store and subtract the mean
            s_mean = s.mean()
            s -= s_mean

            # Perform the sparse coding
            x = OMP(s, D, L=10, tau=tau)

            # Perform the reconstruction
            s_hat = D @ x

            # Add back the mean
            s_hat += s_mean

            # Put the denoised patch into the estimated image
            img_hat[i : i + p, j : j + p] += s_hat.reshape(p, p)

            # Store the weight of the current patch in the weight matrix
            weights[i : i + p, j : j + p] += 1

    # Normalize the estimated image with the computed weights
    img_hat = img_hat / weights

    return img_hat

Denoising using the learned dictionaries


In [None]:
# set the threshold
tau = 1.15 * p * sigma_noise

# define the step (=p for non overlapping paches)
STEP = 4  # STEP = 1 might be very time consuming, start with larger STEP

Solve the four denoising problems


In [None]:
# Denoising with dictionary D_generic
img_hat_generic = omp_denoising(noisy_img, D_generic, STEP, tau)

# Denoising with dictionary D_diff
img_hat_diff = omp_denoising(noisy_img, D_diff, STEP, tau)

# Denoising with dictionary D_noisy
img_hat_noisy = omp_denoising(noisy_img, D_noisy, STEP, tau)

# Denoising with dictionary D_clean
img_hat_clean = omp_denoising(noisy_img, D_clean, STEP, tau)

Visualize the results


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(get_dictionary_img(D_generic), cmap="gray")
ax[0].set_title("Dictionary (Generic)")
ax[0].axis("off")

psnr_hat = 10 * np.log10(1 / np.mean((img_hat_generic - img_clean) ** 2))

ax[1].imshow(img_hat_generic, cmap="gray")
ax[1].set_title(f"Denoised image, PSNR = {psnr_hat:.2f}")
ax[1].axis("off")

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(get_dictionary_img(D_diff), cmap="gray")
ax[0].set_title("Dictionary (From a different image)")
ax[0].axis("off")

psnr_hat = 10 * np.log10(1 / np.mean((img_hat_diff - img_clean) ** 2))

ax[1].imshow(img_hat_diff, cmap="gray")
ax[1].set_title(f"Denoised image, PSNR = {psnr_hat:.2f}")
ax[1].axis("off")

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(get_dictionary_img(D_noisy), cmap="gray")
ax[0].set_title("Dictionary (From the noisy image)")
ax[0].axis("off")

psnr_hat = 10 * np.log10(1 / np.mean((img_hat_noisy - img_clean) ** 2))

ax[1].imshow(img_hat_noisy, cmap="gray")
ax[1].set_title(f"Denoised image, PSNR = {psnr_hat:.2f}")
ax[1].axis("off")

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(get_dictionary_img(D_clean), cmap="gray")
ax[0].set_title("Dictionary (from the clean image)")
ax[0].axis("off")

psnr_hat = 10 * np.log10(1 / np.mean((img_hat_clean - img_clean) ** 2))

ax[1].imshow(img_hat_clean, cmap="gray")
ax[1].set_title(f"Denoised image, PSNR = {psnr_hat:.2f}")
ax[1].axis("off")

plt.tight_layout()

- Download few texture-rich images from the Brodatz dataset
- Use KSVD to learn dictionaries from these images
- Try with different patch sizes


In [None]:
# Load Brodatz texture image
brodatz_img = imread(f"{rootfolder}/data/1.1.03.tiff") / 255
plt.figure(figsize=(8, 8))
plt.imshow(brodatz_img, cmap="gray")
plt.title("Brodatz Texture Image")
plt.axis("off")
plt.show()

In [None]:
# Dictionary learning parameters
N = 256  # dictionary size
max_iter = 10
L = 4  # sparsity level
npatch = 10000

# Try different patch sizes
patch_sizes = [4, 8, 16]
dictionaries = {}

for p in patch_sizes:
    print(f"\nLearning dictionary with patch size {p}x{p}")

    M = p**2  # number of elements in patch

    # Extract random patches
    S = np.zeros((M, npatch))
    for i in range(npatch):
        row = np.random.randint(0, brodatz_img.shape[0] - p + 1)
        col = np.random.randint(0, brodatz_img.shape[1] - p + 1)
        patch = brodatz_img[row : row + p, col : col + p]
        S[:, i] = patch.flatten()

    # Remove mean from patches
    S = S - np.mean(S, axis=0, keepdims=True)

    # Learn dictionary using K-SVD
    D_brodatz = ksvd(S, M, N, max_iter, npatch, L, print_time=True)
    dictionaries[p] = D_brodatz

In [None]:
# Visualize learned dictionaries
fig, axes = plt.subplots(1, 3, figsize=(20, 6))

for i, p in enumerate(patch_sizes):
    dict_img = get_dictionary_img(dictionaries[p])
    axes[i].imshow(dict_img, cmap="gray")
    axes[i].set_title(f"Dictionary learned from Brodatz texture\nPatch size: {p}x{p}")
    axes[i].axis("off")

plt.tight_layout()
plt.show()

- Use KSVD to learn the dictionary 𝐷 from the clean image
- Use this image-specific dictionary to perform inpainting


In [None]:
# Load clean image for inpainting
img_inpaint = imread(f"{rootfolder}/data/barbara.png") / 255

# Create a mask for inpainting (simulate missing pixels)
mask = np.ones_like(img_inpaint)
np.random.seed(42)  # for reproducibility
missing_ratio = 0.3  # 30% missing pixels
missing_indices = np.random.choice(
    img_inpaint.size, int(missing_ratio * img_inpaint.size), replace=False
)
mask_flat = mask.flatten()
mask_flat[missing_indices] = 0
mask = mask_flat.reshape(img_inpaint.shape)
# Create damaged image
damaged_img = img_inpaint * mask

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(img_inpaint, cmap="gray")
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(damaged_img, cmap="gray")
axes[1].set_title("Damaged Image (30% missing pixels)")
axes[1].axis("off")

axes[2].imshow(mask, cmap="gray")
axes[2].set_title("Mask (white = known, black = missing)")
axes[2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Use KSVD to learn the dictionary 𝐷 from the clean image
D_inpaint = D_clean.copy()
# Prepare dictionary for inpainting (add DC component)
M, N = D_inpaint.shape
dc = np.ones((M, 1)) / np.sqrt(M)
D_inpaint = np.hstack([D_inpaint, dc])
D_inpaint = D_inpaint / np.linalg.norm(D_inpaint, axis=0)

In [None]:
def omp_inpainting(damaged_img, mask, D, step, sigma_noise=0.02):
    imsz = damaged_img.shape
    M, N = D.shape
    p = int(np.sqrt(M))  # patch size

    # SET stopping criteria of OMP
    L = M // 2  # maximum sparsity level

    # Initialize the estimated image and weight matrix
    img_hat = np.zeros_like(damaged_img)
    weights = np.zeros_like(damaged_img)

    # Operate patchwise
    for i in range(0, imsz[0] - p + 1, step):
        for j in range(0, imsz[1] - p + 1, step):
            # Extract the patch with the top left corner at pixel (i, j)
            s = damaged_img[i : i + p, j : j + p].ravel()

            # Patch extracted from the mask
            m = mask[i : i + p, j : j + p].ravel()

            # Design the projection operator over the current patch
            proj = np.diag(m)

            # Tau should be proportional to the number of pixels remaining in the patch
            missing_pixels = np.sum(m == 0)
            if missing_pixels < p**2:  # Only process if some pixels are missing
                delta_i = (
                    1.15 * p * sigma_noise * np.sqrt((p**2 - missing_pixels) / p**2)
                )

                # Sparse coding w.r.t. PD (projected dictionary)
                PD = proj @ D
                x = OMP(proj @ s, PD, L, delta_i)

                # Reconstruction: synthesis w.r.t. D yielding sparse representation
                s_hat = D @ x

                # Use uniform weights for aggregation
                w = 1

                # Put the reconstructed patch into the estimated image
                img_hat[i : i + p, j : j + p] += w * s_hat.reshape(p, p)

                # Store the weight of the current patch in the weight matrix
                weights[i : i + p, j : j + p] += w

    # Normalize the estimated image with the computed weights
    # Avoid division by zero
    weights[weights == 0] = 1
    img_hat = img_hat / weights

    return img_hat

In [None]:
# Perform inpainting
STEP_INPAINT = 2  # Step size for patch processing
img_inpainted = omp_inpainting(damaged_img, mask, D_inpaint, STEP_INPAINT)

# Calculate PSNR
psnr_damaged = 10 * np.log10(1 / np.mean((damaged_img - img_inpaint) ** 2))
psnr_inpainted = 10 * np.log10(1 / np.mean((img_inpainted - img_inpaint) ** 2))

# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(img_inpaint, cmap="gray")
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(damaged_img, cmap="gray")
axes[1].set_title(f"Damaged Image\nPSNR = {psnr_damaged:.2f} dB")
axes[1].axis("off")

axes[2].imshow(img_inpainted, cmap="gray")
axes[2].set_title(f"Inpainted Image\nPSNR = {psnr_inpainted:.2f} dB")
axes[2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
plt.imshow(np.abs(img_inpainted - img_clean), cmap="hot")
plt.colorbar()
plt.title("Absolute Error After Inpainting")
plt.show()