In [1]:
from skimage.io import imread
import numpy as np
from scipy.io import loadmat
from matplotlib import pyplot as plt
import time
import lib

In [2]:
rootfolder = '..'

Useful function for plot the 2D DCT dictionary

In [3]:
def get_dictionary_img(D):
    M, N = D.shape
    p = int(round(np.sqrt(M)))
    nnn = int(np.ceil(np.sqrt(N)))
    bound = 2
    img = np.ones((nnn*p+bound*(nnn-1), nnn*p+bound*(nnn-1)))
    for i in range(N):
        m = np.mod(i, nnn)
        n = int((i-m)/nnn)
        m = m * p + bound * m
        n = n * p + bound * n
        atom = D[:, i].reshape((p, p))
        if atom.min() < atom.max():
            atom = (atom - atom.min()) / (atom.max() - atom.min())
        img[m: m + p, n: n + p] = atom

    return img

Define a function that implements the OMP

In [4]:
# def OMP(s, D, L, tau):
    
#     return x

Load the image and rescale it in $[0,1]$

In [5]:
# img = imread(f'{rootfolder}/data/cameraman.png') / 255
img = imread(f'{rootfolder}/data/barbara.png') / 255
# img = imread(f'{rootfolder}/data/Lena512.png') / 255

imsz = img.shape

# patch size
p = 8

# number of elements in the patch
M = p ** 2

Extract a bunch of random patches from the image

In [6]:
npatch = 10000

# S = np.zeros((M, npatch))
# # S = 
S = lib.img_to_patches(img, p)[:, np.random.choice(imsz[1], npatch)]
S.shape

Remove the mean from the patches (each column of $S$ must have zero-mean)

In [None]:
S -= S.mean(axis=0)
S.shape, np.allclose(S.mean(axis=0), 0)

((64, 10000), True)

Initialize the dictionary and the coefficient matrix

In [None]:
# number of columns of the dictionary
N = 256

# intialize the dictionary
D = np.random.randn(M, N)

# normalize each column of D (zero mean and unit norm)
D = D - np.mean(D, axis=0, keepdims=True)
D = D / np.linalg.norm(D, axis=0, keepdims=True)

# number of iteration of the KSVD
max_iter = 10

# maximum number of nonzero coefficients for the sparse coding
L = 4

# initialize the coefficient matrix
X = np.zeros((N, npatch))

Main KSVD loop


In [None]:
for iter in range(max_iter):
    time_start = time.time()
    print(f'iter {iter}')
    # perform the sparse coding via OMP of all the columns of S
    for n in range(npatch):
        X[:, n] = lib.OMP(S[:, n], D, L, 0)
    
    # iterate over the columns of D
    for j in range(N):
        # find which signals uses the j-th atom in the sparse coding
        omega = np.nonzero(X[j, :])[0]

        if len(omega) == 0:
            # if the atom is never used then ignore or substitute it with a random vector
            continue
        else:
            # compute the residual matrix E, ignoring the j-th atom
            E = S - np.dot(D, X) + np.outer(D[:, j], X[j, :])

            # restrict E to the columns indicated by omega
            Eomega = E[:, omega]

            # compute the SVD of Eomega
            U, Sigma, VT = np.linalg.svd(Eomega, full_matrices=False)

            # update the dictionary
            D[:, j] = U[:, 0]

            # update the coefficient matrix
            X[j, omega] = Sigma[0] * VT[0]
    
    time_end = time.time()
    print(f'{time_end - time_start:.0f}')

iter 0


KeyboardInterrupt: 

Show the learned dictionary

In [None]:
img_dict = get_dictionary_img(D)
plt.figure()
plt.imshow(img_dict, cmap='gray');