# Dictionary Training

## Library Imports

In [1]:
import numpy as np
from scipy.optimize import minimize
from sklearn import linear_model
from scipy import signal
from scipy.signal import convolve2d
import cvxpy as cp
import time
import cv2
import os

## Sample Patch from Image

In [2]:
#hIm: high resolution image
#patch_size: patch size that we want to retrieve from image
#num_patches: number of patches we want to sample
#upscale: upscale factor from low resolution image to high resolution image
def sample_patches(hIm, patch_size, num_patches, upscale):
    #Convert RBG to Grayscale Images
    if hIm.shape[2] == 3:
        hIm = cv2.cvtColor(hIm, cv2.COLOR_RGB2GRAY)
        
    #Blur the High Resolution Image a bit
    blur_kernel = np.ones(shape = (3, 3)) / 9
    blurred_hIm = convolve2d(hIm, blur_kernel, mode = 'same')
        
    #Generate Low Resolution Image
    lIm = cv2.resize(blurred_hIm, tuple(int(x * (1/upscale)) for x in blurred_hIm.shape)[::-1], interpolation = cv2.INTER_NEAREST)
    lIm = cv2.resize(lIm, blurred_hIm.shape[::-1], interpolation = cv2.INTER_NEAREST)
    
    #Get dimensions of High Resolution Image
    nrow, ncol = hIm.shape
    
    #Get posible values of (x, y) that is top left corner of patch. The (x,y) coordinates are in the coordinate space of the High Resolution Image
    x = np.random.permutation(np.arange(0, nrow - 2 * patch_size - 1)) + patch_size
    y = np.random.permutation(np.arange(0, ncol - 2 * patch_size - 1)) + patch_size
    
    #Generated Meshgrid
    X, Y = np.meshgrid(x, y)
    
    #Flatten X and Y column wise
    xrow = X.flatten(order = 'F')
    ycol = Y.flatten(order = 'F')
    
    #If we have less patches than potential (x, y) coordinates, we have to truncate the list of (x, y) coordinates
    if num_patches < len(xrow):
        xrow = xrow[: num_patches]
        ycol = ycol[: num_patches]
    
    num_patches = len(xrow)
    
    #Store High and Low Resolution Patches
    H = np.zeros(shape = (patch_size ** 2, num_patches))
    L = np.zeros(shape = (4 * (patch_size ** 2), num_patches))
    
    #Compute first order derivatives
    hf1 = np.array([-1,0,1]).reshape((1, -1))
    vf1 = hf1.T
    
    lImG11 = signal.convolve2d(lIm, hf1[::-1, ::-1],'same') #row wise 1st order derivative
    lImG12 = signal.convolve2d(lIm, vf1[::-1, ::-1],'same') #column wise 1st order derivative
    
    #Compute second order derivatives
    hf2 = np.array([1,0,-2,0,1]).reshape((1, -1))
    vf2 = hf2.T
    
    lImG21 = signal.convolve2d(lIm, hf2[::-1, ::-1], 'same') #row wise 2nd order derivative
    lImG22 = signal.convolve2d(lIm, vf2[::-1, ::-1], 'same') #column wise 2nd order derivative
    
    #Extract Patches
    for idx in range(num_patches):
        row, col = xrow[idx], ycol[idx]
        
        #Get the patch from High Resolution Image
        Hpatch = hIm[row: row + patch_size, col: col + patch_size].flatten(order = 'F')
        H[:, idx] = Hpatch - np.mean(Hpatch) #Store High Resolution Patch
        
        #Get the patch from Low Resolution Image
        Lpatch1 = lImG11[row:row+patch_size,col:col+patch_size].flatten(order = 'F')
        Lpatch2 = lImG12[row:row+patch_size,col:col+patch_size].flatten(order = 'F')
        Lpatch3 = lImG21[row:row+patch_size,col:col+patch_size].flatten(order = 'F')
        Lpatch4 = lImG22[row:row+patch_size,col:col+patch_size].flatten(order = 'F')

        Lpatch = np.concatenate((Lpatch1, Lpatch2, Lpatch3, Lpatch4))
        L[:, idx] = Lpatch #Store Low Resolution Patch
    
    return H, L

## Randomly Sample Patches From Training Images

In [3]:
#Randomly Sample Patches
#img_path: path of image
#img_type: type of image
#patch_size: size of patch
#num_patch: number of patches we want to sample
#upscale: upscale factor from low resolution image to high resolution image
def rnd_smp_patch(img_path, img_type, patch_size, num_patch, upscale):
    #image directory with training image path
    img_dir = [file for file in os.listdir(img_path) if file.endswith(img_type)]
        
    Xh = [] #Store High Resolution Patches
    Xl = [] #Store Low Resolution Patches
    
    img_num = len(img_dir) #Total number of images
    nper_img = np.zeros(shape = (img_num, )) #number of patches per image
    print("Number of Images From Training Dataset we have Sampled: ", img_num)
    
    #Store total size of all images
    for idx in range(img_num):
        im = cv2.imread(os.path.join(img_path, img_dir[idx]))
        nper_img[idx] = im.size
    
    nper_img = np.floor(nper_img * num_patch / np.sum(nper_img)).astype(int) #number of patches per image
    
    #iterate through images
    for idx in range(img_num):
        patch_num = nper_img[idx] #number of patches from this image to select
        im = cv2.imread(os.path.join(img_path, img_dir[idx])) #Get image
        
        #Sample the Patches
        H, L = sample_patches(im, patch_size, patch_num, upscale)
                
        #Append to Xh and Xl
        Xh.append(H)
        Xl.append(L)
    
    Xh, Xl = np.concatenate(Xh, axis = 1), np.concatenate(Xl, axis = 1) #Concatenate Patches into numpy array
    return Xh, Xl

## Jointly Train Dictionaries

In [4]:
#jointly train dictionaries
#Xh: High Resolution Patches
#Xl: Low Resolution Patches
#dict_size: size of dictionary
#step size: step size for Quadratic Programming step
#lamb: lambda for linear programming step
#threshold: threshold for iterations in quadratic programming
#max_iter: maximum number of iterations
def train_coupled_dict(Xh, Xl, dict_size, step_size, lamb, threshold, max_iter):
    print("STARTING TRAINING")
    
    #Get shape of Patch Data
    N, M = Xh.shape[0], Xl.shape[0]
    
    #Get Constants
    a1 = 1 / np.sqrt(N)
    a2 = 1 / np.sqrt(M)
    
    #Initialize Xc
    Xc = np.concatenate((a1 * Xh, a2 * Xl), axis = 0)
    print(f"Xc shape: {Xc.shape}")
    
    #Initialize D as a random Gaussian Matrix
    Dc = np.random.normal(size = (N + M, dict_size))
    Dc = normalize(Dc)
    print(f"Dc shape: {Dc.shape}")
    
    #cap maximum iterations at max_iter
    for iter in range(max_iter):
        Z = lasso_optimization(Xc, Dc, lamb) #Run lasso Optimization
        Xc_pred = Dc @ Z #Get XC_pred to compute loss metric
        print(f"Iteration {iter + 1}/{max_iter} Linear Programming Stat: {np.linalg.norm(Xc - Xc_pred) / np.linalg.norm(Xc)}")
        
        Dc = quadratic_programming(Xc, Z, Dc.shape[0], Dc.shape[1], step_size, threshold, 30) #Run Quadratic Programming Step
        Xc_pred = Dc @ Z #Get XC_pred to compute loss metric
        print(f"Iteration {iter + 1}/{max_iter} Quadratic Programming Stat: {np.linalg.norm(Xc - Xc_pred) / np.linalg.norm(Xc)}")
    
    return Dc

#Lasso Optimization
#Goal: Solve for value of Z to minimize ||Xc - Dc Z||_2^2 + lamb * ||Z||_1
def lasso_optimization(Xc, Dc, lamb):
    clf = linear_model.Lasso(alpha = lamb, max_iter = 100000, fit_intercept = False)
    clf.fit(Dc, Xc)
    return clf.coef_.T
        
#prox operator given x and alpha
def prox(x, alpha):
    return np.piecewise(x, [x < -alpha, (x >= -alpha) & (x <= alpha), x >= alpha], [lambda x: x + alpha, 0, lambda x: x - alpha])

## Solve the Linear Programming Portion of Joint Dictionary Training
## Goal: Find Z that minimizes || X - DZ||_2^2 + lambda * ||Z||_1
def linear_programming(X: np.ndarray, D: np.ndarray, Zr, Zc, step_size, lamb, threshold, max_iter):
    Z = np.random.normal(size = (Zr, Zc))
    
    #Run Proximal Gradient Descent
    loss = (np.linalg.norm(X - (D @ Z)) ** 2) + (lamb * np.sum(np.abs(Z)))
    
    for iter in range(max_iter):
        grad = (-2 * (D.T @ X)) + (2 * (D.T @ D @ Z))
        
        #Update Z
        Z = Z - (step_size * grad)
        Z = prox(Z, step_size * lamb)
        
        loss = (np.linalg.norm(X - (D @ Z)) ** 2) + (lamb * np.sum(np.abs(Z)))
        if np.linalg.norm(grad) <= threshold:
            break
    
    # print(f"Loss at Iteration {iter} = {loss}, Magnitude of Gradient = {np.linalg.norm(grad)}")
    return Z

#Normalize a matrix D such that its column norms <= 1
def normalize(D: np.ndarray):
    norms = np.linalg.norm(D, axis=0)  # Calculate column norms
    mask = norms > 1  # Find columns with norms > 1
    D[:, mask] /= norms[mask]  # Normalize only columns with norms > 1
    return D

## Solve the Quadratic Programming Portion of Joint Dictionary Training
## Goal: Find D that minimizes || X - DZ||_2^2
def quadratic_programming(X: np.ndarray, Z: np.ndarray, Dr, Dc, step_size, threshold, max_iter):    
    #Run Projected Gradient Descent
    D = np.random.normal(size = (Dr, Dc))
    D = normalize(D)
    loss = (np.linalg.norm(X - (D @ Z)) ** 2) / (np.linalg.norm(X))
    
    for iter in range(max_iter):
        grad = (-2 * (X @ Z.T)) + (2 * (D @ Z @ Z.T))
        
        D = D - (step_size * grad)
        D = normalize(D)
        loss = (np.linalg.norm(X - (D @ Z)) ** 2) / (np.linalg.norm(X))
        
        if np.linalg.norm(grad) <= threshold:
            break
    
        # print(f"Loss at Iteration {iter} = {loss}, Magnitude of Gradient = {np.linalg.norm(grad)}")
    return D

## Variance-Thresholding Based Patch Pruning

In [5]:
#Prune patches whose variances are below a certain threshold
#Xh: High Resolution Patch
#Xl: Low Resolution Patch
#variance_threshold: variance threshold
def patch_pruning(Xh, Xl, variance_threshold):
    patch_variances = np.var(Xh, 0)
    idx = patch_variances > variance_threshold
    Xh = Xh[:, idx]
    Xl = Xl[:, idx]
    return Xh, Xl

## Train Dictionaries

In [6]:
training_image_path = "../Data/Training" #path that has all training images

dict_size = 512 #Dictionary Size will be 512
lamb = 0.15 #sparsity regularization
patch_size = 5 #size of patches will be 3 x 3
nSmp = 10000 #number of patches to sample
upscale = 4 #upscale factor

#randomly sample patches from training images
print("Going to Randomly Generate Patches")
Xh, Xl = rnd_smp_patch(training_image_path, '.bmp', patch_size, nSmp, upscale)

#Prune patches with small variance
print("Going to Prune Patches")
Xh, Xl = patch_pruning(Xh, Xl, variance_threshold = 1)

print(Xh.shape, Xl.shape)

#Joint Dictionary Training
start_time = time.time()
step_size = 0.0001
variance_threshold = 0.0001
max_iter = 10
print("Going to Jointly Train Dictionaries")
Dc = train_coupled_dict(Xh, Xl, dict_size, step_size, lamb, variance_threshold, max_iter)
end_time = time.time()

elapsed_time = end_time - start_time
print(f"Time taken to Jointly Train Dictionaries: {elapsed_time}")

N, M = Xh.shape[0], Xl.shape[0]
Dh, Dl = Dc[:N], Dc[N:]

#Save Dh and Dl to npy files
np.save('Dh.npy', Dh)
np.save('Dl.npy', Dl)

Going to Randomly Generate Patches
Number of Images From Training Dataset we have Sampled:  69
Going to Prune Patches
(25, 9888) (100, 9888)
Going to Jointly Train Dictionaries
STARTING TRAINING
Xc shape: (125, 9888)
Dc shape: (125, 512)
Iteration 1/10 Linear Programming Stat: 0.9788898842951262
Iteration 1/10 Quadratic Programming Stat: 0.9397385749116924
Iteration 2/10 Linear Programming Stat: 0.7132265317645815
Iteration 2/10 Quadratic Programming Stat: 0.694919479996787
Iteration 3/10 Linear Programming Stat: 0.6590612304242257
Iteration 3/10 Quadratic Programming Stat: 0.6545793492907397
Iteration 4/10 Linear Programming Stat: 0.6436152338585733
Iteration 4/10 Quadratic Programming Stat: 0.6424265986867981
Iteration 5/10 Linear Programming Stat: 0.6372965643571701
Iteration 5/10 Quadratic Programming Stat: 0.6383524212626221
Iteration 6/10 Linear Programming Stat: 0.6341268198519402
Iteration 6/10 Quadratic Programming Stat: 0.6376219764542974
Iteration 7/10 Linear Programming Sta