In [None]:
import argparse
from argparse import ArgumentParser
import glob
import cv2
import re
import os, glob, datetime
import numpy as np
import tensorflow as tf
import numpy as np
import time
import math
from keras.layers import  Input,Conv2D,BatchNormalization,Activation,Subtract, Reshape
from keras.models import Model, load_model
from tensorflow.python.keras.utils import conv_utils
from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam
import keras.backend as K

from skimage.transform import rescale
from scipy.io import loadmat
import scipy.io as sio
from scipy.io import savemat
from matplotlib import pyplot as plt
import skimage
from skimage.metrics import structural_similarity as ssim
from skimage.io import imread, imsave

#from skimage.measure import compare_psnr, compare_ssim



#########--------------           IMPORTANT NOTE       ----------------###############

# Select the path of the 'set_dir', 'set_names', 'model_dir', 'model_name' and 'sigma' properly to give access to the models, datasets and results

# Check whether the training images are normalized or not.

# Select the path of the directory properly to give access to the models, datasets and results
directory_path =  '_________________________' #'Choose your directory




parser = argparse.ArgumentParser(description='Keras DIVA2D test')
# choose if needed
parser.add_argument('--model', default='DIVA2D', type=str, help='choose a type of model')
parser.add_argument('--kernel_size', default=5, type=int, help='kernel size')

parser.add_argument('--set_dir', default=os.path.join(directory_path ,'data/'), type=str, help='directory of test dataset')
parser.add_argument('--set_names', default=['Set12'], type=list, help='name of test dataset')

parser.add_argument('--sigma', default=25, type=int, help='noise level - Choose a sigma value from 10, 15, 25, 50, 75, 100')
parser.add_argument('--model_dir', default=os.path.join(directory_path ,'models/DIVA_models_sigma_10_to_100'), type=str, help='directory of the model')
                                                     
parser.add_argument('--model_name', default='model_sigma25.hdf5', type=str, help='the model name')
#parser.add_argument('--result_dir', default=os.path.join(directory_path,'results'), type=str, help='directory of results')
parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0')

parser.add_argument('-f', '--file', required=False)
args = parser.parse_args()

args.result_dir = os.path.join(args.model_dir,'results')



#########--------------           IMPORTANT NOTE       ----------------###############

# for using pretrained model with sigma = 10, 15, 25, 50, 75, 100
# Please modify the 'args.sigma' by the respective sigma vamue
# Also change the 'args.model_dir' and the 'args.model_name' as defined below.

#args.sigma = 15 # '_____' #'Choose a sigma value from 10, 15, 25, 50, 75, 100'
args.model_dir = directory_path +'/models/DIVA_models_sigma_10_to_100'
model = 'model_sigma'+str(args.sigma)+'.hdf5'
args.model_name = model

print(args.sigma)
print(args.model_name)

# Also set 'use_model = False / True ' (which load pretrained models)
use_model = True 


##--------------------------------------------------------------------------------------------------------
##--------------------------------------------------------------------------------------------------------

class Hamiltonian_Conv2D(Conv2D):

    def __init__(self, filters, kernel_size, kernel_3=None, kernel_4=None, activation=None, **kwargs):

        self.rank = 2               # Dimension of the kernel
        self.num_filters = filters  # Number of filter in the convolution layer
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, self.rank, 'kernel_size')
        self.kernel_3 = kernel_3    # Weights from original potential
        self.kernel_4 = kernel_4    # Weights from interaction     

        super(Hamiltonian_Conv2D, self).__init__(self.num_filters, self.kernel_size, 
              activation=activation, use_bias=False, **kwargs)
        
    def build(self, input_shape):
        if K.image_data_format() == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                     'should be defined. Found `None`.')

        #don't use bias:
        self.bias = None

        #consider the layer built
        self.built = True


        # Define nabla operator
        weights_1 = tf.constant([[ 2.,-1., 0.],
                                 [-1., 4.,-1.],
                                 [ 0.,-1., 2.]])
        

        weights_1 = tf.reshape(weights_1 , [3,3, 1])
        weights_1 = tf.repeat(weights_1 , repeats=self.num_filters, axis=2)
        #print('kernel shape of weights_1:',weights_1.get_shape())

        # Define Weights for h^2/2m  (size should be same as the nabla operator)
        weights_2 = self.add_weight(shape=weights_1.get_shape(),
                                      initializer= 'Orthogonal',
                                      name='kernel_h^2/2m',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        #print('kernel shape of weights_2:',weights_2.get_shape())

        
        # Define the Hamiltonian kernel
        self.kernel = weights_1*weights_2 + self.kernel_3 + self.kernel_4
        #print('self.kernel',self.kernel.get_shape())

        self.built = True
        super(Hamiltonian_Conv2D, self).build(input_shape)

    # Do the 2D convolution using the Hamiltonian kernel
    def convolution_op(self, inputs, kernel):
        if self.padding == "causal":
            tf_padding = "VALID"  # Causal padding handled in `call`.
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding


        return tf.nn.convolution(
            inputs,
            kernel,
            strides=list(self.strides),
            padding=tf_padding,
            dilations=list(self.dilation_rate),
            data_format=self._tf_data_format,
            name=self.__class__.__name__,
        )

    def call(self, inputs):
        outputs = self.convolution_op(inputs, self.kernel)
        return outputs


      

# -------------------------------------------------------------------------------------------------------------------

def DIVA2D(depth,filters=64,image_channels=1, kernel_size= args.kernel_size, use_bnorm=True):
    layer_count = 0
    inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))
    
    # Get the initial patches /initial_patches
    initial_patches = Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=(1,1),kernel_initializer='Orthogonal', padding='same',name = 'initial_patches')(inpt)
    initial_patches = Activation('relu',name = 'initial_patch_acti')(initial_patches)
    #print(initial_patches.get_shape())

    # interaction layer
    inter = Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=(1,1),kernel_initializer='Orthogonal', padding='same',name = 'interactions')(initial_patches)
    inter = Activation('relu',name = 'interaction_acti'+str(layer_count))(inter)
    #print(inter.get_shape())


    # Get contributions of the original potential in the Hamiltonian kernel
    ori_poten_kernel = tf.keras.layers.MaxPooling2D (pool_size=(21,21), strides=(15,15), padding='same', name = 'ori_poten_ker', data_format=None )(initial_patches)
    #print('ori_poten_kernel',ori_poten_kernel.get_shape())

    # Get contributions of the interactions in the Hamiltonian kernel
    inter_kernel = tf.keras.layers.MaxPooling2D (pool_size=(21,21), strides=(15,15), padding='same', name = 'inter_ker', data_format=None )(inter)
    #print('inter_kernel',inter_kernel.get_shape())


    # Get projection coefficients of the initial patches on the Hamiltonian kernel
    x = Hamiltonian_Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), kernel_3 = ori_poten_kernel, kernel_4 = inter_kernel, strides=(1,1), activation='relu',
                              kernel_initializer='Orthogonal', padding='same', name = 'proj_coef')(initial_patches)      
    
    #print('coef',x.get_shape())


    # Do Thresholding (depth depends on the noise intensity)
    for i in range(depth-2):
      layer_count += 1
      x = Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=(1,1),kernel_initializer='Orthogonal', padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)

      layer_count += 1
      x = BatchNormalization(axis=3, momentum=0.1,epsilon=0.0001, name = 'bn'+str(layer_count))(x) 
        #x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
      
      # Thresholding
      x = Activation('relu',name = 'Thresholding'+str(layer_count))(x)  

    # Inverse projection
    x = Conv2D(filters=image_channels, kernel_size=(kernel_size,kernel_size), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'inv_trans')(x)

    x = Subtract(name = 'subtract')([inpt, x])

    model = Model(inputs=inpt, outputs=x)
    
    return model

##----------------------------------------------------------------------------------------------------------------------



def to_tensor(img):
    if img.ndim == 2:
        return img[np.newaxis,...,np.newaxis]
    elif img.ndim == 3:
        return np.moveaxis(img,2,0)[...,np.newaxis]


def from_tensor(img):
    return np.squeeze(np.moveaxis(img[...,0],0,-1))

def log(*args,**kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs)

def save_result(result,path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt','.dlm'):
        np.savetxt(path,result,fmt='%2.4f')
    else:
        imsave(path,np.clip(result,0,1))


def show(x,title=None,cbar=False,figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x,interpolation='nearest',cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

def psnr(target, ref):
    # Assume target is RGB/BGR image
    target_data = target.astype(np.float32)
    ref_data = ref.astype(np.float32)
    
    diff = ref_data - target_data
    diff = diff.flatten('C')
    
    rmse = np.sqrt(np.mean(diff ** 2.))
    
    return 20 * np.log10(1. / rmse)


def snr(target, ref):
    # Assume target is RGB/BGR image
    target_data = target.astype(np.float32)
    ref_data = ref.astype(np.float32)
    
    diff = ref_data - target_data
    diff = diff.flatten('C')
    target_data = target_data.flatten('C')
    
    rmse_diff = np.sqrt(np.mean(diff ** 2.))
    rmse_target = np.sqrt(np.mean(target_data ** 2.))

    return 20 * np.log10(rmse_target / rmse_diff)


##----------------------------------------------------------------------------------------------------------------
## -------------------------------------------------------------------------------------------------------------
## -------------------------------------------------------------------------------------------------------------


if __name__ == '__main__':    
    
    if  use_model:
        #choose model depth and filters number
        model = DIVA2D(depth=10,filters=64,image_channels=1,use_bnorm=True)
        model.load_weights(os.path.join(args.model_dir, args.model_name))
        log('load trained model architecture')
    else:
        print('Model- ',args.model)
        model = load_model(os.path.join(args.model_dir, args.model_name),compile=False)
        log('load trained model')

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)


#-----------------------------------------------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------------------------------------------

    for set_cur in args.set_names:  
        
        if not os.path.exists(os.path.join(args.result_dir,set_cur)):
            os.mkdir(os.path.join(args.result_dir,set_cur))
        snrs_in = []
        psnrs_in = []
        ssims_in = [] 
        psnrs = []
        ssims = [] 
        
        for im in os.listdir(os.path.join(args.set_dir,set_cur)): 
            if im.endswith(".mat") or im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):

                # I = loadmat(os.path.join(args.set_dir,set_cur,im))        # For .mat files
                # I_0 = I['img']              				    # For .mat files
                # I_1 = (I_0 - np.min(I_0))			   	    # For .mat files
                # x = (I_1/ np.max(I_1))   # normalized the image           # For .mat files

		            x = np.array(imread(os.path.join(args.set_dir,set_cur,im)), dtype=np.float32) / 255.0   # For .png files

                
                # Gaussian noise case
                y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping
                y = y.astype(np.float32)

                y_  = to_tensor(y)

                start_time = time.time()
                x_ = model.predict(y_) # inference


                elapsed_time = time.time() - start_time
                print('%10s : image:%10s : time:%2.4f second'%(set_cur,im,elapsed_time))

                x_=from_tensor(x_)

                # calculate for Gaussian noise
                snr_y = snr(x, y)        # input SNR
                psnr_y = psnr(x, y)      # input PSNR
                ssim_y = ssim(x, y)      # input SSIM
                psnr_x_ = psnr(x, x_)	 # output PSNR
                ssim_x_ = ssim(x, x_)	 # output SSIM
                print('%10s : psnr = %2.4f : ssim = %1.4f'%(set_cur,psnr_x_, ssim_x_))



                if args.save_result:
                    name, ext = os.path.splitext(im)

                    # showing image  
                    show(np.hstack((y,x_,x)),figsize=(14, 4),cbar=True) # show the image
                   
                    save_result(y,path=os.path.join(args.result_dir,set_cur,name+'_noisy_'+str(args.sigma)+ext)) # save the noisy image
                    save_result(x_,path=os.path.join(args.result_dir,set_cur,name+'_denoised_DIVA_'+str(args.sigma)+ext)) # save the denoised image
                    
		                # For .mat files
		                # img_data = {'img_noi' : y,    'img_denoi' : x_,}
	 	                # savemat(os.path.join(args.result_dir,set_cur,name+'_denoised_DIVA2D_sigma'+str(args.sigma)+'.mat'), img_data) # save the 2D images
                   
                snrs_in.append(snr_y)     # input SNR
                psnrs_in.append(psnr_y)   # input PSNR
                ssims_in.append(ssim_y)   # input SSIM
                psnrs.append(psnr_x_)	    # output PSNR
                ssims.append(ssim_x_)	    # output SSIM
                

        snr_in_avg = np.mean(snrs_in)     # average input SNR
        psnr_in_avg = np.mean(psnrs_in)   # average input PSNR
        ssim_in_avg = np.mean(ssims_in)   # average input SSIM
        psnr_avg = np.mean(psnrs)	        # average output PSNR
        ssim_avg = np.mean(ssims)	        # average output SSIM

        snrs_in.append(snr_in_avg)      # input SNR
        psnrs_in.append(psnr_in_avg)    # input PSNR
        ssims_in.append(ssim_in_avg)    # input SSIM
        psnrs.append(psnr_avg)		      # output PSNR
        ssims.append(ssim_avg)		      # output SSIM
        
        if args.save_result:
            save_result(np.hstack((psnrs, ssims)),path=os.path.join(args.result_dir,set_cur,'sigma_'+str(args.sigma)+'_results.txt'))
            save_result(np.hstack((snrs_in, psnrs_in, ssims_in)),path=os.path.join(args.result_dir,set_cur,'sigma_'+str(args.sigma)+'_results_input.txt'))

        log('Datset: {0:10s} \n  SNR_in = {1:2.2f} dB, PSNR_in = {2:2.2f} dB, SSIM_in = {3:1.4f}'.format(set_cur, snr_in_avg, psnr_in_avg, ssim_in_avg))
        log('PSNR = {0:2.2f} dB, SSIM = {1:1.4f}'.format(psnr_avg, ssim_avg))

