# Dictionary Training

## Library Imports

In [None]:
import numpy as np
from scipy.optimize import minimize
from scipy.sparse import csr_matrix
from sklearn import linear_model
from scipy import signal
from scipy.signal import convolve2d
import scipy.io as sio
import cvxpy as cp
import time
import random
import cv2
import os

## .mat to .npy Conversion

In [None]:
# Load .mat file
mat_data = sio.loadmat('D_512_0.15_5.mat')

Dh, Dl = mat_data['Dh'], mat_data['Dl']
print(Dh.shape, Dl.shape)

# Save Dh and Dl as .npy files
np.save('Dh_512_0.15_5.npy', Dh)
np.save('Dl_512_0.15_5.npy', Dl)

## Sample Patches from Images

In [None]:
def sample_patches(im, patch_size, patch_num, upscale):
    #Initialize the High Resolution Image
    hIm = im
    if im.shape[2] == 3:
        hIm = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
        
    #Blur the High Resolution Image a bit
    blur_kernel = np.ones(shape = (3, 3)) / 9
    blurred_hIm = convolve2d(hIm, blur_kernel, mode = 'same')
        
    #Generate Low Resolution Image
    lIm = cv2.resize(blurred_hIm, tuple(int(x * (1/upscale)) for x in blurred_hIm.shape)[::-1], interpolation = cv2.INTER_NEAREST)
    lIm = cv2.resize(lIm, blurred_hIm.shape[::-1], interpolation = cv2.INTER_NEAREST)
    
    #Get dimensions of hIm
    nrow, ncol = hIm.shape
    
    #Get posible values of (x, y) that is top left corner of patch
    x = np.random.permutation(np.arange(0, nrow - 2 * patch_size - 1)) + patch_size
    y = np.random.permutation(np.arange(0, ncol - 2 * patch_size - 1)) + patch_size
    
    #Generated Meshgrid
    X, Y = np.meshgrid(x, y)
    
    #Flatten X and Y column wise
    xrow = X.flatten(order = 'F')
    ycol = Y.flatten(order = 'F')
    
    if patch_num < len(xrow):
        xrow = xrow[: patch_num]
        ycol = ycol[: patch_num]
    
    patch_num = len(xrow)
    
    H = np.zeros(shape = (patch_size ** 2, patch_num))
    L = np.zeros(shape = (4 * (patch_size ** 2), patch_num))
    
    #Compute first order derivatives
    hf1 = np.array([-1,0,1]).reshape((1, -1))
    vf1 = hf1.T
    
    lImG11 = signal.convolve2d(lIm, hf1[::-1, ::-1],'same') #row wise 1st order derivative
    lImG12 = signal.convolve2d(lIm, vf1[::-1, ::-1],'same') #column wise 1st order derivative
    
    #Compute second order derivatives
    hf2 = np.array([1,0,-2,0,1]).reshape((1, -1))
    vf2 = hf2.T
    
    lImG21 = signal.convolve2d(lIm, hf2[::-1, ::-1], 'same') #row wise 2nd order derivative
    lImG22 = signal.convolve2d(lIm, vf2[::-1, ::-1], 'same') #column wise 2nd order derivative
    
    for idx in range(patch_num):
        row, col = xrow[idx], ycol[idx]
        
        #Get the patch from High Resolution Image
        Hpatch = hIm[row: row + patch_size, col: col + patch_size].flatten(order = 'F')
        H[:, idx] = Hpatch - np.mean(Hpatch)
        
        #Get the patch from Low Resolution Image
        Lpatch1 = lImG11[row:row+patch_size,col:col+patch_size].flatten(order = 'F')
        Lpatch2 = lImG12[row:row+patch_size,col:col+patch_size].flatten(order = 'F')
        Lpatch3 = lImG21[row:row+patch_size,col:col+patch_size].flatten(order = 'F')
        Lpatch4 = lImG22[row:row+patch_size,col:col+patch_size].flatten(order = 'F')

        Lpatch = np.concatenate((Lpatch1, Lpatch2, Lpatch3, Lpatch4))
        L[:, idx] = Lpatch
    
    return H, L

## Randomly Sample Patches

In [None]:
#Randomly Sample Patches
#img_path: path of image
#img_type: type of image
#patch_size: size of patch
#num_patch: number of patches we want to sample
#upscale: upscale factor
def rnd_smp_patch(img_path, img_type, patch_size, num_patch, upscale):
    img_dir = [file for file in os.listdir(img_path) if file.endswith(img_type)]
        
    Xh = [] #Store High Resolution Patches
    Xl = [] #Store Low Resolution Patches
    
    img_num = len(img_dir) #Total number of images
    nper_img = np.zeros(shape = (img_num, ))
    print("Number of Images From Training Dataset we have Sampled: ", img_num)
    
    for idx in range(img_num):
        im = cv2.imread(os.path.join(img_path, img_dir[idx]))
        nper_img[idx] = im.size
    
    nper_img = np.floor(nper_img * num_patch / np.sum(nper_img)).astype(int) #number of patches per image
    
    for idx in range(img_num):
        patch_num = nper_img[idx] #number of patches from this image to select
        im = cv2.imread(os.path.join(img_path, img_dir[idx]))
        
        #Sample the Patches
        H, L = sample_patches(im, patch_size, patch_num, upscale)
                
        #Append to Xh and Xl
        Xh.append(H)
        Xl.append(L)
    
    Xh, Xl = np.concatenate(Xh, axis = 1), np.concatenate(Xl, axis = 1)
    return Xh, Xl

## Jointly Train Dictionaries

In [None]:
def train_coupled_dict(Xh, Xl, dict_size, step_size, lamb, threshold, max_iter):
    print("STARTING TRAINING")
    N, M = Xh.shape[0], Xl.shape[0]
    
    a1 = 1 / np.sqrt(N)
    a2 = 1 / np.sqrt(M)
    
    #Initialize Xc
    Xc = np.concatenate((a1 * Xh, a2 * Xl), axis = 0)
    print(f"Xc shape: {Xc.shape}")
    
    #Initialize D as a random Gaussian Matrix
    Dc = np.random.normal(size = (N + M, dict_size))
    Dc = normalize(Dc)
    print(f"Dc shape: {Dc.shape}")
    
    #cap maximum iterations at max_iter
    for iter in range(max_iter):
        Z = lasso_optimization(Xc, Dc, lamb)
        
        Xc_pred = Dc @ Z
        print(f"Iteration {iter + 1}/{max_iter} Linear Programming Stat: {np.linalg.norm(Xc - Xc_pred) / np.linalg.norm(Xc)}")
        
        Dc = quadratic_programming(Xc, Z, Dc.shape[0], Dc.shape[1], step_size, threshold, 30)
        
        Xc_pred = Dc @ Z
        print(f"Iteration {iter + 1}/{max_iter} Quadratic Programming Stat: {np.linalg.norm(Xc - Xc_pred) / np.linalg.norm(Xc)}")
    
    return Dc

#Lasso Optimization
def lasso_optimization(Xc, Dc, lamb):
    clf = linear_model.Lasso(alpha = lamb, max_iter = 100000, fit_intercept = False)
    clf.fit(Dc, Xc)
    return clf.coef_.T
        
#prox operator
def prox(x, alpha):
    return np.piecewise(x, [x < -alpha, (x >= -alpha) & (x <= alpha), x >= alpha], [lambda x: x + alpha, 0, lambda x: x - alpha])

## Solve the Linear Programming Portion of Joint Dictionary Training
## Goal: Find Z that minimizes || X - DZ||_2^2 + lambda * ||Z||_1
def linear_programming(X: np.ndarray, D: np.ndarray, Zr, Zc, step_size, lamb, threshold, max_iter):
    Z = np.random.normal(size = (Zr, Zc))
    
    #Run Proximal Gradient Descent
    loss = (np.linalg.norm(X - (D @ Z)) ** 2) + (lamb * np.sum(np.abs(Z)))
    
    for iter in range(max_iter):
        grad = (-2 * (D.T @ X)) + (2 * (D.T @ D @ Z))
        
        #Update Z
        Z = Z - (step_size * grad)
        Z = prox(Z, step_size * lamb)
        
        loss = (np.linalg.norm(X - (D @ Z)) ** 2) + (lamb * np.sum(np.abs(Z)))
        if np.linalg.norm(grad) <= threshold:
            break
    
    # print(f"Loss at Iteration {iter} = {loss}, Magnitude of Gradient = {np.linalg.norm(grad)}")
    return Z

def normalize(D: np.ndarray):
    norms = np.linalg.norm(D, axis=0)  # Calculate column norms
    mask = norms > 1  # Find columns with norms > 1
    D[:, mask] /= norms[mask]  # Normalize only columns with norms > 1
    return D

def quadratic_programming(X: np.ndarray, Z: np.ndarray, Dr, Dc, step_size, threshold, max_iter):    
    #Run Projected Gradient Descent
    D = np.random.normal(size = (Dr, Dc))
    D = normalize(D)
    loss = (np.linalg.norm(X - (D @ Z)) ** 2) / (np.linalg.norm(X))
    
    for iter in range(max_iter):
        grad = (-2 * (X @ Z.T)) + (2 * (D @ Z @ Z.T))
        
        D = D - (step_size * grad)
        D = normalize(D)
        loss = (np.linalg.norm(X - (D @ Z)) ** 2) / (np.linalg.norm(X))
        
        if np.linalg.norm(grad) <= threshold:
            break
    
        # print(f"Loss at Iteration {iter} = {loss}, Magnitude of Gradient = {np.linalg.norm(grad)}")
    return D
    
#Quadratic Programing
def quadratic_objective_function(D_flat, X, Z):
    # Reshape the flattened D to its original shape
    D = D_flat.reshape(X.shape[0], -1)
    # Compute the objective function
    return np.linalg.norm(X - np.dot(D, Z)) ** 2

def quadratic_constraints(D_flat):
    # Reshape the flattened D to its original shape
    D = D_flat.reshape(D_flat.shape[0], -1)
    # Compute the norm squared for each column of D
    norm_squared = np.sum(D**2, axis=0)
    # Return the constraint function as a vector
    return norm_squared - 1

def quadratic_programming_scipy(X, Z):
    # Initial guess for D
    initial_guess_D = np.random.rand(X.shape[0] * Z.shape[0])

    # Define additional arguments for the objective and constraint functions
    args = (X, Z)

    # Minimize the objective function with the constraint
    result = minimize(quadratic_objective_function, initial_guess_D, args=args, method='BFGS', constraints={'type': 'ineq', 'fun': quadratic_constraints})

    # Extract the optimized D
    optimized_D = result.x.reshape(X.shape[0], Z.shape[0])
    
    return optimized_D

def quadratic_programming_CVXPY(X, Z):
    X = cp.Constant(X)  # Assuming X is an n x p numpy array
    Z = cp.Constant(Z)  # Assuming Z is an m x p numpy array
    
    # Define the optimization variables
    D = cp.Variable((X.shape[0], Z.shape[0]))  # Assuming n x m matrix D

    # Define the objective function
    objective = cp.Minimize(cp.sum_squares(X - D @ Z))

    # Define the constraints using numpy operations
    norm_squared = cp.sum(D**2, axis=0)  # Calculate the norm squared of each column of D
    constraints = [cp.norm(D, 'fro', axis=(0, 1)) <= np.sqrt(D.shape[1])]

    # Formulate the optimization problem
    problem = cp.Problem(objective, constraints)

    # Solve the problem
    problem.solve()

    # Get the optimized D
    optimized_D = D.value
    
    return optimized_D

## Train Dictionaries

In [None]:
training_image_path = "../Data/Training"

dict_size = 512 #Dictionary Size will be 512
lamb = 0.15 #sparsity regularization
patch_size = 5 #size of patches will be 3 x 3
nSmp = 10000 #number of patches to sample
upscale = 4 #upscale factor

#randomly generate patches
print("Going to Randomly Generate Patches")
Xh, Xl = rnd_smp_patch(training_image_path, '.bmp', patch_size, nSmp, upscale)

#Prune patches with small variance
print("Going to Prune Patches")
Xh, Xl = patch_pruning(Xh, Xl, threshold = 1)

print(Xh.shape, Xl.shape)

#Joint Dictionary Training
start_time = time.time()
step_size = 0.0001
threshold = 0.0001
max_iter = 10
print("Going to Jointly Train Dictionaries")
Dc = train_coupled_dict(Xh, Xl, dict_size, step_size, lamb, threshold, max_iter)
end_time = time.time()

elapsed_time = end_time - start_time
print(f"Time taken to Jointly Train Dictionaries: {elapsed_time}")

N, M = Xh.shape[0], Xl.shape[0]
Dh, Dl = Dc[:N], Dc[N:]

#Save Dh and Dl to npy files
np.save('Dh.npy', Dh)
np.save('Dl.npy', Dl)