In [None]:
def wnnm(img, patchRadius, delta, c, K, sigma_n,N_threshold):
    # This function applies weighted nuclear norm minimization based denoising to the imput image img
    # Specify the search window
    searchWindowRadius = patchRadius*3
    
    # Specify the number of iterations for estimating \hat{X}_j
    N_iter = 3
    
    # Specify the width of padding and pad the noisy image
    pad = searchWindowRadius + patchRadius
    imgPad = np.pad(img, pad_width = pad)
    imgPad = imgPad[..., pad:-pad]
    
    # Initialize variables to be iterated over
    xhat_iter = img
    
    for n in range(K):
        # Pad the image for the iteration
        xhat_iter = np.pad(xhat_iter, pad_width = pad)
        xhat_iter = xhat_iter[..., pad:-pad] ## remove
        
        # Regularize the image that is denoised during the iteration
        y_iter = xhat_iter + delta*(imgPad - xhat_iter)
        
        # Initialize the matrix to keep track of how many times each pixel has been updated
        pixel_contribution_matrix = np.ones_like(imgPad)
        
        # Identify similar patches and produce the matrix of similar patches
        for j in range(img.shape[0]):
            for i in range(img.shape[1]):
                # Select the central patch
                centerPatch = y_iter[j+searchWindowRadius:j+searchWindowRadius+2*patchRadius,
                                 i+searchWindowRadius:i+searchWindowRadius+2*patchRadius,
                                 :]
                
                # Initialize the vector of distances between patches 
                dists= np.ones(((2*searchWindowRadius+1)**2))
                # Initialize the matrix of patches
                patches = np.zeros(((2*searchWindowRadius+1)**2,(2*patchRadius)**2))
                # Compute distances between patches
                # This is partially vectorized by using indexing to take out patches in a sliding window fashing 
                # out of a vertical slice through the search window
                for k in range(2*searchWindowRadius+1):
                    # Take a vertical slice in the search window
                    otherPatch = y_iter[j:j+2*pad,
                                    i+k:i+k+2*patchRadius,
                                    :]
                    
                    # Determine indices corresponding to patches in a window sliding down the search window
                    indexer = np.arange((2*patchRadius)**2)[None, :] + (2*patchRadius)*np.arange(otherPatch.shape[0]-2*patchRadius+1)[:, None]
                    
                    # Set columns to be patches
                    otherPatch = otherPatch.flatten()
                    otherPatch = np.reshape(otherPatch[indexer],(otherPatch[indexer].shape[0],(2*patchRadius)**2))
                    
                    # Compute distance and store the corresponding patches
                    dists[k*(2*searchWindowRadius+1):(k+1)*(2*searchWindowRadius+1)] = (np.sum((centerPatch.reshape(((2*patchRadius)**2))-otherPatch)**2,axis=1)/(2*patchRadius)**2).flatten()
                    patches[k*(2*searchWindowRadius+1):(k+1)*(2*searchWindowRadius+1),:] = otherPatch
                    
                # Select to N_threshold nearest patches and creat a patch matrix
                indcs = np.argsort(dists)
                Yj = (patches[indcs[:N_threshold],:]).transpose()
                
                # Center the columns
                Yj_means = np.sum(Yj,axis=0)
                Yj_center = Yj - Yj_means
                
                # First iteration need to estimate singular values of Xj
                U,S,V_T = svd(Yj_center, full_matrices=False)
                sing_val = np.sqrt(np.maximum(S**2-N_threshold*sigma_n**2,0))
                
                # Calculate the weights and sinfular values of \hat{X}_j iteratively
                for m in range(N_iter):
                    w = c*np.sqrt(N_threshold)/(sing_val+10**(-6))
                    sing_val = np.diag(np.maximum(S-w,0))
                
                # Compute \hat{X}_j
                Xj_hat_center = U@np.diag(np.maximum(S-w,0))@V_T
                Xj_hat = Xj_hat_center + Yj_means
                
                # Add the estimate of denoised central patch (first column of \hat{X}_j) to the esmated denoised image clipping it to between 0 and 1
                xhat_iter[j+searchWindowRadius:j+searchWindowRadius+2*patchRadius,
                                 i+searchWindowRadius:i+searchWindowRadius+2*patchRadius,
                                 :] = xhat_iter[j+searchWindowRadius:j+searchWindowRadius+2*patchRadius,
                                 i+searchWindowRadius:i+searchWindowRadius+2*patchRadius,
                                 :] + np.clip(Xj_hat[:,0].reshape((2*patchRadius,2*patchRadius,1)),0,1)
                
                # Keep track of how many times each pixel has been added to
                pixel_contribution_matrix[j+searchWindowRadius:j+searchWindowRadius+2*patchRadius,
                                 i+searchWindowRadius:i+searchWindowRadius+2*patchRadius,
                                 :] = pixel_contribution_matrix[j+searchWindowRadius:j+searchWindowRadius+2*patchRadius,
                                 i+searchWindowRadius:i+searchWindowRadius+2*patchRadius,
                                 :] + np.ones_like(pixel_contribution_matrix[j+searchWindowRadius:j+searchWindowRadius+2*patchRadius,
                                 i+searchWindowRadius:i+searchWindowRadius+2*patchRadius,:])
        
        # Remove the padding and average out contributions to pixels from different patches
        xhat_iter = xhat_iter[pad:-pad,
                    pad:-pad,
                   :]/pixel_contribution_matrix[pad:-pad,
                    pad:-pad,
                   :]
        
    # Produce the final output
    out = xhat_iter
    return out
                    
        

In [None]:
import numpy as np
def WNNM(img, patchSize, delta, c, K, sigma_n, nThreshold):
    windowSize = 5 * patchSize
    pad = windowSize
    paddedImg = np.pad(img, pad_width=pad)
    X_hat = img
    for n in range(K):
        X_hat = np.pad(X_hat, pad_width=pad)
        Y_hat = X_hat + delta * (X_hat - paddedImg)
        counting_update_numbers = np.ones_like(paddedImg)
        for i in range(windowSize, paddedImg.shape[0] - windowSize):
            for j in range(windowSize, paddedImg.shape[1] - windowSize):
                window = Y_hat[i - windowSize:i + windowSize + 1, j - windowSize:j + windowSize + 1]
                
                # Vectorized main patch selection
                mainPatch = Y_hat[i - patchSize:i + patchSize + 1, j - patchSize:j + patchSize + 1]
                
                # Vectorized window slicing and reshaping
                window_vectorized = window.reshape(-1, window.shape[-1])
                
                # Calculate the squared difference between mainPatch and each patch in the window
                patch_diff = (mainPatch - window_vectorized[:, None, None]) / 255
                patch_diff_sq = np.sum(patch_diff ** 2, axis=(1, 2))
                
                # Reshape the squared differences to match the window shape
                patch_diff_sq_reshaped = patch_diff_sq.reshape(window.shape[:-1])
                
                # Find the indices of the smallest distances
                sorted_indices = np.argsort(patch_diff_sq_reshaped.flatten())
                threshold_indices = sorted_indices[:nThreshold]
                print(threshold_indices)

                # print(threshold_indices)
                vectorized_patches = []
                for idx in threshold_indices:
                    patch = window[idx - patchSize:idx + patchSize + 1, patchSize:2 * patchSize + 1]
                    print(patch)
                    vectorized_patch = patch.flatten()
                    print(vectorized_patch)
                    vectorized_patches.append(vectorized_patch)
                patches_matrix = np.column_stack(vectorized_patches)

                # Step 1: Singular Value Decomposition
                U, Sigma, V = np.linalg.svd(patches_matrix, full_matrices=False)
                # print(f'{U.shape},{Sigma.shape}, {V.shape}')
                # Step 2: Calculate weight matrix
                num_cols = patches_matrix.shape[1]
                weights = c * np.sqrt(num_cols) / (Sigma + 1e-16)
                weights = np.sort(weights)[::-1]  # Sort weights in non-descending order

                # Step 3: Update counting matrix
                counting_update_numbers[i - patchSize:i + patchSize + 1, j - patchSize:j + patchSize + 1] += 1

                # Step 4: Reconstruct main patch
                Sigma_hat = np.diag(np.maximum(Sigma - weights, 0))

                # Adjust dimensions of U, Sigma_hat, and V.T if necessary
                if Sigma_hat.shape[0] < U.shape[1]:
                    U = U[:, :Sigma_hat.shape[0]]
                if Sigma_hat.shape[0] < V.shape[0]:
                    V = V.T[:Sigma_hat.shape[0], :]

                reconstructed_patch = np.matmul(U, np.matmul(Sigma_hat, V))

                # Step 5: Update X_hat
                X_hat[i - patchSize:i + patchSize + 1, j - patchSize:j + patchSize + 1] = reconstructed_patch

        # Step 6: Normalize X_hat
        X_hat = X_hat[pad:-pad, pad:-pad]
        X_hat = X_hat / counting_update_numbers

    return X_hat

In [None]:
def WNNM(img,patchSize,c,k,delta,sigma,nThreshold):
    windowSize=3*patchSize
    paddedImg=np.pad(img,pad_width=windowSize,mode='constant')
    numberOfIt=3
    x_hat=paddedImg
    for n in range(K):
        y_hat=x_hat+delta*(paddedImg - x_hat)
        pixel_contribution_matrix = np.ones_like(paddedImg)
        for i in range(windowSize, img.shape[0]-windowSize):
            for j in range(windowSize,img.shape[1]-windowSize):
                window=y_hat[i - windowSize:i+windowSize + 1,j-windowSize : j+windowSize+1]
                mainPatch=y_hat[i -patchSize:i+patchSize + 1,j-patchSize : j+patchSize+1]
                distances=[]
                for k in range(patchSize,windowSize-patchSize):
                    distance=[]
                    for l in range(patchSize,windowSize-patchSize):
                        similarPatch=window[k-patchSize:k+patchSize+1,l-patchSize:l+patchSize+1]
                        dis=np.sum((mainPatch-similarPatch)**2)
                        distance.append(dis)
                    distances.append(distance)