In [None]:
import numpy as np
from scipy.linalg import svd
from omp import orthogonal_matching_pursuit_cholensky, orthogonal_matching_pursuit, orthogonal_matching_pursuit_cholensky_batches
import cv2
from utils import generate_data, extract_blocks, compute_metrics, normalize_image
from omp import unsparse, orthogonal_matching_pursuit_cholensky
import random
from sklearn.feature_extraction.image import extract_patches_2d, reconstruct_from_patches_2d
from tqdm import tqdm
import scipy
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.linear_model import OrthogonalMatchingPursuit

In [None]:
def dct2(a):
    return scipy.fftpack.dct(scipy.fftpack.dct( a, axis=0, norm='ortho' ), axis=1, norm='ortho' )

def idct2(a):
    return scipy.fftpack.idct( scipy.fftpack.idct( a, axis=0 , norm='ortho'), axis=1 , norm='ortho')

def learn_overcomplete_dictionary(image, block_size=(8, 8), num_atoms=128):
    # Extract blocks from image
    blocks = extract_blocks(image, block_size)
    
    # Initialize D_overcomplete matrix
    D_overcomplete = np.zeros((block_size[0] * block_size[1], num_atoms))
    
    # Iterate over blocks to fill D_overcomplete with DCT atoms
    for i in range(num_atoms):
        # Choose a random block index
        random_index = np.random.randint(len(blocks))
        block = blocks[random_index]
        # Compute DCT of the block
        # dct_block = dct2(block.reshape(block_size))
        
        # # Reshape and store the DCT coefficients as a column in D_overcomplete
        # D_overcomplete[:, i] = dct_block.flatten()
        D_overcomplete[:, i] = block.flatten()

    return D_overcomplete

im = Image.open("Figures/noisy_image0_noise_lvl25.jpg")
# im = im.resize((256, 256))
im_gray = np.array(im.convert('L'))

# start with n = 64 (8x8 patches) and K = 6000 (atoms in the dictionary)


In [None]:
plt.imshow(im_gray, cmap='gray')  # cmap='gray' for grayscale image
plt.axis('off')  # Turn off axis
plt.show()

In [None]:
# training data y
from utils import generate_data

# training patches
sparsity_level = 8
K = 4000
block_size = (8,8) # patch size
training_data = generate_data(im_gray, K, block_size) 
dc_offset = np.mean(training_data)
# training_data -= dc_offset
training_data.shape

In [None]:
# sparsity:
T0 = 10
# sparse coding stage
# use any pursuit algorithm to compute represntation vectors xi for each example y1
from sklearn.linear_model import OrthogonalMatchingPursuit

# initialize D
D = learn_overcomplete_dictionary(im_gray,  block_size=(8, 8), num_atoms=128)

training_data.shape

In [None]:
def ksvd(Y, sparsity, initial_D,
    maxiter=10, etol=1e-10):
    """
                
        Y:       rows hold training data for dictionary fitting
        sparsity:   max sparsity of signals. Reduces to K-means
                    when sparsity=1
        initial_D:  if given, an initial dictionary. Otherwise, random
                    rows of data are chosen for initial dictionary
        maxiter:    maximum number of iterations
        err_thresh: stopping criteria; minimum residual
       
        
        Returns:
            D:               learned dictionary
            X:               sparse coding of input data
            error_norms:     array of training errors for each iteration
        Task: find best dictionary D to represent Data Y;
              minimize squared norm of Y - DX, constraining
              X to sparse codings.
    """

    D = initial_D
    D = initial_D / np.linalg.norm(D, axis=0)

    # repeat until convergence or stopping criteria    
    iterator = tqdm(range(1,maxiter+1))
    for iteration in iterator:
        # sparse coding stage: estimate columns of X
        # omp = OrthogonalMatchingPursuit(n_nonzero_coefs=sparsity)
        # Fit the model and obtain the sparse code X
        # print(f"{D.shape=}; {Y.shape=}")
        # X = omp.fit(D, Y).coef_.T
        # print(f"{X.shape=}")
        print(f"{D.shape=}; {Y.shape=}")
        x, idx = orthogonal_matching_pursuit_cholensky_batches(D, Y, K=sparsity)
        # apply the unsparsing along an axes to get an (D.shape[1], x.shape[1] matrix)

        X = np.apply_along_axis(unsparse, 0, x, idx, D.shape[1])

        # print(f"{X.shape=}; {D.shape=}; {Y.shape=}")
        # codebook update stage
        for j in range(D.shape[1]):
            # index set of nonzero components
            index_set = np.nonzero(X[j,:])[0]
            if len(index_set) == 0:
                # for now, replace with some white noise
                D[:,j] = np.random.randn(*D[:,j].shape)
                D[:,j] = D[:,j] / np.linalg.norm(D[:,j])
                continue
            
            E = Y[:,index_set] - D.dot(X[:,index_set])
            D[:,j] = E.dot(X[j,index_set])     # update D
           
            D[:,j] /= np.linalg.norm(D[:,j])
            X[j,index_set] = (E.T).dot(D[:,j]) # update X
            
        # stopping condition: check error        
        err = np.linalg.norm(Y-D.dot(X),'fro')

        if err < etol:
            break
        
    return D,X

In [None]:
D = learn_overcomplete_dictionary(im_gray,  block_size=(8, 8), num_atoms=128) # best (8,8), num_atoms=256
D_learned,X = ksvd(training_data.T, initial_D=D, sparsity=16) # best: sparsity=8

print(D_learned.shape)


In [None]:
print("Extracting reference patches...")
patch_size = (8, 8)
data = extract_patches_2d(im_gray, patch_size)
data = data.astype(np.float64)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)

In [None]:
denoised_patches = []

for i, reference_patch in enumerate(data):
    if i % 1000 == 0:
        print(f"Processing reference patch {i + 1}/{len(data)}...")
    
    # Apply the OMP algorithm to obtain the sparse representation
   
    omp = OrthogonalMatchingPursuit(n_nonzero_coefs=4)
    # print(f"{D_overcomplete.shape=}; {reference_patch.shape=}")
    omp.fit(D_learned, reference_patch)
    gamma = omp.coef_

    # Reconstruct the denoised patch using the sparse representation and the overcomplete dictionary
    denoised_patch = D_learned @ gamma.flatten()
    
    # Add the intercept (mean) to the denoised patch
    denoised_patch += np.mean(reference_patch)
    
    # Append the denoised patch to the list of denoised patches
    denoised_patches.append(denoised_patch.reshape(patch_size))

    

In [None]:
denoised_patches = np.array(denoised_patches) 
# Assemble the denoised patches into the final denoised image
denoised_image_omp = reconstruct_from_patches_2d(denoised_patches, im_gray.shape)
# Display the denoised image
plt.imshow(denoised_image_omp, cmap='gray')
plt.title('Denoised Image')
plt.axis('off')
plt.show()

In [None]:
def normalize_image(image):
    """
    Normalize an image array to have values between 0 and 1.

    Parameters:
        image (numpy.ndarray): Input image array.

    Returns:
        numpy.ndarray: Normalized image array.
    """
    # Normalize the values between 0 and 1
    normalized_img = (image - image.min()) / (image.max() - image.min()) * 255
    
    return normalized_img

In [None]:
def show_with_diff(image, reference):
    """Helper function to display denoising"""
    plt.figure(figsize=(5, 3.3))
    plt.subplot(1, 2, 1)
    plt.title("Image")
    plt.imshow(image,cmap='gray')
    plt.subplot(1, 2, 2)
    difference = image-reference
    norm = np.sqrt(np.sum(difference**2))/image.shape[0] / image.shape[1]
    plt.title(f"Difference (norm/pixel): {np.round(norm,3)})")
    plt.imshow(difference,cmap='gray')
    # plt.suptitle(title, size=16)
    
im_clean = Image.open("Figures/image2.jpg")
im_clean = np.array(im_clean.convert('L'))
denoised_image = normalize_image(denoised_image_omp)

im_clean = normalize_image(im_clean)
show_with_diff(denoised_image, im_gray)
rmse, psnr = compute_metrics(im_clean, denoised_image)

print(f"RMSE: {rmse:.4f}")
print(f"PSNR: {psnr:.4f}")

In [None]:
denoised_patches = []
from omp import orthogonal_matching_pursuit_cholensky, unsparse
for i, reference_patch in enumerate(data):
    if i % 1000 == 0:
        print(f"Processing reference patch {i + 1}/{len(data)}...")
    
    # Apply the OMP algorithm to obtain the sparse representation
   
    # omp = OrthogonalMatchingPursuit(n_nonzero_coefs=2)
    # print(f"{D_overcomplete.shape=}; {reference_patch.shape=}")
    # omp.fit(D_learned, reference_patch)
    # gamma = omp.coef_
    x, idx = orthogonal_matching_pursuit_cholensky(D_learned, reference_patch, K=4)
    gamma = unsparse(x, idx, D_learned.shape[1])
    # Reconstruct the denoised patch using the sparse representation and the overcomplete dictionary
    denoised_patch = D_learned @ gamma.flatten()
    
    # Add the intercept (mean) to the denoised patch
    denoised_patch += np.mean(reference_patch)
    
    # Append the denoised patch to the list of denoised patches
    denoised_patches.append(denoised_patch.reshape(patch_size))

In [None]:
denoised_patches = np.array(denoised_patches) 
# Assemble the denoised patches into the final denoised image
# data = extract_patches_2d(im_gray, patch_size=(8,8))
# denoised_image = reconstruct_from_patches_2d(denoised_patches * np.std(data, axis=0), im_gray.shape)
denoised_image_chol = reconstruct_from_patches_2d(denoised_patches, im_gray.shape)
# Display the denoised image
plt.imshow(denoised_image_chol, cmap='gray')
plt.title('Denoised Image')
plt.axis('off')
plt.show()

In [None]:
from utils import show_with_diff

denoised_image_chol = normalize_image(denoised_image_chol)
im_gray = normalize_image(im_gray)
show_with_diff(denoised_image_chol, im_gray)

In [None]:
ground_truth = Image.open("Data/example_image_original.jpg")
ground_truth = ground_truth.convert('L') # make it gray
ground_truth = np.array(ground_truth)

im_gray = normalize_image(im_gray)
denoised_image_chol = normalize_image(denoised_image_chol)

rmse, psnr = compute_metrics(ground_truth, denoised_image_chol)

print(f"RMSE: {rmse:.4f}")
print(f"PSNR: {psnr:.4f}")

In [None]:
# save denoise_image_chol in file
im = Image.fromarray(denoised_image_chol)
im = im.convert("L")
im.save("denoised_image_chol.jpg")

In [None]:
# reconstruct the image using the learned dictionary
lamb = 0.00 # parameter for the regularization term

reconstructed = (lamb * im_gray + (1-lamb) * denoised_image) / (1 - lamb)
# Display the denoised image
plt.imshow(denoised_image, cmap='gray')
plt.title('Denoised Image')
plt.axis('off')
plt.show()
