In [None]:
from google.colab import drive
drive.mount('/content/my_drive')

Mounted at /content/my_drive


In [None]:
!nvidia-smi

Mon May 27 13:56:58 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0              45W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
%pdb

Automatic pdb calling has been turned ON


In [None]:
'''Import Libraries'''
import argparse
from argparse import ArgumentParser
import glob
import cv2
import re
import random
import math
import os, glob, datetime
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from keras.layers import  Input,Conv2D,BatchNormalization,Activation,Subtract, Reshape, Attention
from keras.models import Model, load_model
from tensorflow.python.keras.utils import conv_utils
from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam

#import data_generator as dg
import keras.backend as K
import skimage
from skimage.metrics import structural_similarity as ssim
from skimage.io import imread, imsave


In [None]:
tf.config.experimental.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [None]:
'''Set Parameters'''
## Params
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='Deep_DeQuIP', type=str, help='choose a type of model')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')

parser.add_argument('--train_data_clean', default='/content/my_drive/MyDrive/CNRS Research/data/400_CLEAN', type=str, help='path of train data clean')
parser.add_argument('--train_data_noisy', default='/content/my_drive/MyDrive/CNRS Research/data/400_FOCUS', type=str, help='path of train data noisy')

parser.add_argument('--kernel_size', default=5, type=int, help='Hamiltonian kernel size')
parser.add_argument('--patches_size', default=50, type=int, help='patch size')

parser.add_argument('--epoch', default=50, type=int, help='number of train epoches')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
parser.add_argument('--save_every', default=1, type=int, help='save model at every x epoches')
parser.add_argument('-f', '--file', required=False)

args = parser.parse_args()
#args.save_every = args.epoch


In [None]:
'''Set Save Dir for Models'''
save_dir = os.path.join('/content/my_drive/MyDrive/CNRS Research',
                        args.model+'_NLResAttn_Memory')

if not os.path.exists(save_dir):
  print(save_dir)
  os.mkdir(save_dir)


In [None]:
'''utility functions'''

def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir,'model_*.hdf5'))  # returns names of all .hdf5 files
    if file_list:
        epochs_exist = []
        for file in file_list:
            result = re.findall(".*model_(.*).hdf5.*",file) # returns epoch number from the model checkpoint file
            epochs_exist.append(int(result[0]))
        initial_epoch=max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch

def log(*args,**kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs)

def lr_schedule(epoch):
    initial_lr = args.lr
    if epoch<=20:
        lr = initial_lr
    elif epoch<=30:
        lr = initial_lr/10
    elif epoch<=40:
        lr = initial_lr/20
    else:
        lr = initial_lr/20
    log('current learning rate is %2.8f' %lr)
    return lr


In [None]:
def train_datagen(epoch_iter=2000, epoch_num=5, batch_size=128, data_dir=args.train_data_noisy):
  #Original Batch_size = 128
  #Original iter = 2000
  n_count = 0 # AneeshFix
  while(True):
      # n_count = 0 AneeshError
      if n_count == 0:
          #print(n_count)
          #xs, ys = speckled_datagenerator(data_dir)  AneeshError?  # generate clean and noisy data
          clean_data, noisy_data = speckled_datagenerator(data_dir) #AneeshFix

          assert len(clean_data) % args.batch_size == 0, 'make sure the last iteration has a full batchsize, this is important if you use batch normalization!'

          # normalize the pixel values between 0 and 1
          clean_data = clean_data.astype('float32')/255.0
          noisy_data = noisy_data.astype('float32')/255.0

          indices = list(range(clean_data.shape[0]))
          n_count = 1

      for _ in range(epoch_num):
          np.random.shuffle(indices)
          for i in range(0, len(indices), batch_size):
              clean_batch = clean_data[indices[i:i+batch_size]]
              noisy_batch = noisy_data[indices[i:i+batch_size]]

             # noise =  np.random.normal(0, args.sigma/255.0, batch_x.shape)

              yield noisy_batch, clean_batch

In [None]:
def make_batch(data):
  data = np.array(data, dtype='uint8')
  data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],1))
  discard_n = len(data)-len(data) // batch_size*batch_size
  data = np.delete(data, range(discard_n), axis = 0)
  return data

In [None]:
def speckled_datagenerator(data_dir, verbose=False):

    file_list = glob.glob(data_dir+'/*.png')  # returns names of all .png files in B_Mode Dir
    #print('file_list', file_list)

    data = []
    data_clean = []

    # generate patches for all images in the directory
    for i in range(len(file_list)):
        clean_patch, patch = gen_speckled_image_patches(file_list[i])

        data.append(patch)
        data_clean.append(clean_patch)

        if verbose:
            print('image :',str(i+1)+'/'+ str(len(file_list)))

    # do for speckled data
    data = np.array(data, dtype='uint8')
    data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],1))
    discard_n = len(data)-len(data)//batch_size*batch_size
    data = np.delete(data,range(discard_n),axis = 0)

    # do for clean data
    data_clean = np.array(data_clean, dtype='uint8')
    data_clean = data_clean.reshape((data_clean.shape[0]*data_clean.shape[1],data_clean.shape[2],data_clean.shape[3],1))
    discard_n = len(data_clean)-len(data_clean)//batch_size*batch_size
    data_clean = np.delete(data_clean,range(discard_n),axis = 0)

    print('-----training data finished-----')
    print('noisy image shape:',data.shape)
    print('clean image shape:',data_clean.shape)

    assert data.shape == data_clean.shape


    return data_clean, data

In [None]:
import matplotlib.pyplot as plt
def gen_speckled_image_patches(file_name):

    last_name = file_name.split('_')[-1] #  Returns Name of the Image
    clean_image_file_name = os.path.join(args.train_data_clean, last_name) # clean train image directory

    img = cv2.imread(file_name, 0) # noisy image
    clean_img = cv2.imread(clean_image_file_name, 0) # clean image

    '''show(np.hstack((clean_img,img))) # display the images'''

    h, w = img.shape

    patches = []
    clean_patches = []

    for s in scales: # scaling the images
        h_scaled, w_scaled = int(h*s),int(w*s)
        img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)
        clean_img_scaled = cv2.resize(clean_img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)

        # extract patches from the images '''ANEESH: PATCH 1 GETTING GENERATED'''
        for i in range(0, h_scaled-patch_size+1, stride):
            for j in range(0, w_scaled-patch_size+1, stride):
                patch = img_scaled[i:i+patch_size, j:j+patch_size]
                clean_patch = clean_img_scaled[i:i+patch_size, j:j+patch_size]

                # data augmentation
                for k in range(0, aug_times):
                  mode_k=np.random.randint(0,8)
                  patch_aug = data_augmentation(patch, mode=mode_k)
                  clean_patch_aug = data_augmentation(clean_patch, mode=mode_k)
                  patches.append(patch_aug)
                  clean_patches.append(clean_patch_aug)




    return clean_patches, patches

In [None]:
'''Data Augmentation'''
def data_augmentation(img, mode=0):
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(img)
    elif mode == 2:
        return np.rot90(img)
    elif mode == 3:
        return np.flipud(np.rot90(img))
    elif mode == 4:
        return np.rot90(img, k=2)
    elif mode == 5:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 6:
        return np.rot90(img, k=3)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))

In [None]:

'''Show Images'''
def show(x,title=None,cbar=False,figsize=None):
    plt.figure(figsize=figsize)
    plt.imshow(x,interpolation='nearest',cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

In [None]:
'''Loss Function'''
def sum_squared_error(y_true, y_pred):
    #return K.mean(K.square(y_pred - y_true), axis=-1)
    #return K.sum(K.square(y_pred - y_true), axis=-1)/2
    return K.sum(K.square(y_pred - y_true))/2

In [None]:
import math
import tensorflow as tf
from tensorflow.keras import layers, models, activations, initializers

class NonLocalResAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, in_channels, out_channels, kernel_size, use_bias=True, bn=True, act=True, res_scale=1):
        super(NonLocalResAttentionBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.use_bias = use_bias
        self.bn = bn
        self.act = act
        self.multiheadattn = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=out_channels)
        self.res_scale = res_scale
        self.conv1 = Conv2D(out_channels, kernel_size, padding='same', use_bias=use_bias)
        self.bn1 = BatchNormalization() if bn else None
        self.act1 = Activation('relu') if act else None
        self.nl = NonLocalBlock2D(out_channels)
        self.conv2 = Conv2D(out_channels, kernel_size, padding='same', use_bias=use_bias)
        self.bn2 = BatchNormalization() if bn else None
        self.act2 = Activation('relu') if act else None

    def call(self, x):
        x1 = self.multiheadattn(x, x)
        res = self.conv1(x1)
        if self.bn1:
            res = self.bn1(res)
        if self.act1:
            res = self.act1(res)
        nl = self.nl(res)
        res = self.conv2(nl)
        if self.bn2:
            res = self.bn2(res)
        if self.act2:
            res = self.act2(res)
        return x + res * self.res_scale

    def get_config(self):
        config = {
            'in_channels': self.in_channels,
            'out_channels': self.out_channels,
            'kernel_size': self.kernel_size,
            'use_bias': self.use_bias,
            'bn': self.bn,
            'act': self.act,
            'res_scale': self.res_scale
        }
        base_config = super(NonLocalResAttentionBlock, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def conv(in_channels, out_channels, kernel_size, use_bias=True):
    return layers.Conv2D(
        filters=out_channels, kernel_size=kernel_size, padding='same', use_bias=use_bias
    )


class NonLocalBlock2D(tf.keras.layers.Layer): # long-range dependency
    def __init__(self, in_channels, inter_channels=None, use_bias=True):
        super(NonLocalBlock2D, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = inter_channels if inter_channels else in_channels // 2
        self.g = Conv2D(self.inter_channels, (1, 1), use_bias=use_bias)
        self.theta = Conv2D(self.inter_channels, (1, 1), use_bias=use_bias)
        self.phi = Conv2D(self.inter_channels, (1, 1), use_bias=use_bias)
        self.W = Conv2D(self.in_channels, (1, 1), use_bias=use_bias)

    def call(self, x):
        batch_size, h, w, c = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        g_x = tf.reshape(self.g(x), (batch_size, -1, self.inter_channels))
        theta_x = tf.reshape(self.theta(x), (batch_size, -1, self.inter_channels))
        phi_x = tf.reshape(self.phi(x), (batch_size, -1, self.inter_channels))

        theta_phi = tf.matmul(theta_x, phi_x, transpose_b=True)
        theta_phi = tf.nn.softmax(theta_phi)

        y = tf.matmul(theta_phi, g_x)
        y = tf.reshape(y, (batch_size, h, w, self.inter_channels))
        W_y = self.W(y)
        z = W_y + x
        return z


In [None]:
'''DIVA2D Model'''
def DIVA2D(depth,filters=64,image_channels=1, kernel_size=5, use_bnorm=True):
    layer_count = 0
    inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))

    # Get the initial patches /initial_patches '''ANEESH: SMALLER PATCHES GENERATED FROM BIGGER PATCHES BY GENERATE_PATCH FUNCTION, IMAGE DIMENSION REMAINS SAME'''
    initial_patches = Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=(1,1),kernel_initializer='Orthogonal', padding='same',name = 'initial_patches')(inpt)
    initial_patches = Activation('relu',name = 'initial_patch_acti')(initial_patches)
    #print('initial shape', initial_patches.get_shape())


    inter = NonLocalResAttentionBlock(conv, filters, kernel_size, use_bias=True, bn=False, act=True, res_scale=1)(initial_patches)
    inter = Activation('relu',name = 'inter_acti')(inter)

    #print('inter shape',inter.get_shape())

    # Get contributions of the original potential in the Hamiltonian kernel ANEESH: Ja from DIVA Diagram
    ori_poten_kernel = tf.keras.layers.MaxPooling2D (pool_size=(21,21), strides=(15,15), padding='same', name = 'ori_poten_ker', data_format=None )(initial_patches)
    #print('ori_poten_kernel',ori_poten_kernel.get_shape())

    # Get contributions of the interactions in the Hamiltonian kernel ANEESH: Ia from DIVA Diagram
    inter_kernel = tf.keras.layers.MaxPooling2D (pool_size=(21,21), strides=(15,15), padding='same', name = 'inter_ker', data_format=None )(inter)
    #print('inter_kernel',inter_kernel.get_shape())


    # Get projection coefficients of the initial patches on the Hamiltonian kernel
    x = Hamiltonian_Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), kernel_3 = ori_poten_kernel, kernel_4 = inter_kernel, strides=(1,1), activation='relu',
                              kernel_initializer='Orthogonal', padding='same', name = 'proj_coef')(inter)

    #print('coef',x.get_shape())


    # Do Thresholding (depth depends on the noise intensity)
    for i in range(depth):
      layer_count += 1
      x = Conv2D(filters=filters, kernel_size=(kernel_size,kernel_size), strides=(1,1),kernel_initializer='Orthogonal', padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)

      layer_count += 1
      x = BatchNormalization(axis=3, momentum=0.1,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
        #x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x)

      # Thresholding
      x = Activation('relu',name = 'Thresholding'+str(layer_count))(x)

    # Inverse projection
    x = Conv2D(filters=image_channels, kernel_size=(kernel_size,kernel_size), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'inv_trans')(x)


    # Deconvolution layer ANEESH: NEUTRALIZATION LAYER SIMILAR TO DIVA-A?
    layer_count += 1
    x = Conv2D(filters=filters, kernel_size=(args.kernel_size,args.kernel_size), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'deconv'+str(layer_count))(x)
    layer_count += 1
    x = Conv2D(filters=filters, kernel_size=(args.kernel_size,args.kernel_size), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'deconv'+str(layer_count))(x)
    layer_count += 1
    x = Conv2D(filters=image_channels, kernel_size=(args.kernel_size,args.kernel_size), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'deconv'+str(layer_count))(x)
    layer_count += 1


    x = Subtract(name = 'subtract')([inpt, x])   # input - noise '''ANEESH: Noisy Image is getting subtracted from the Denoised Image'''

    model = Model(inputs=inpt, outputs=x)

    return model

In [None]:
 '''Hamiltonian convolution layer'''
class Hamiltonian_Conv2D(Conv2D):

    def __init__(self, filters, kernel_size, kernel_3=None, kernel_4=None, activation=None, use_bias = False, **kwargs):

        self.rank = 2               # Dimension of the kernel
        self.num_filters = filters  # Number of filter in the convolution layer
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, self.rank, 'kernel_size')
        self.kernel_3 = kernel_3    # Weights from original potential
        self.kernel_4 = kernel_4    # Weights from interaction

        super(Hamiltonian_Conv2D, self).__init__(self.num_filters, self.kernel_size,
              activation=activation, use_bias=False, **kwargs)

    def build(self, input_shape):
        if K.image_data_format() == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                     'should be defined. Found `None`.')

        #don't use bias:
        self.bias = None

        #consider the layer built
        self.built = True


        # Define nabla operator
        weights_1 = tf.constant([[ 2.,-1., 0.],
                                 [-1., 4.,-1.],
                                 [ 0.,-1., 2.]])


        weights_1 = tf.reshape(weights_1 , [3,3, 1])
        weights_1 = tf.repeat(weights_1 , repeats=self.num_filters, axis=2)
        #print('kernel shape of weights_1:',weights_1.get_shape())

        # Define Weights for h^2/2m  (size should be same as the nabla operator)
        weights_2 = self.add_weight(shape=weights_1.get_shape(),
                                      initializer= 'Orthogonal',
                                      name='kernel_h^2/2m',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        #print('kernel shape of weights_2:',weights_2.get_shape())
        #reshaped_weights = tf.expand_dims(weights_1*weights_2, axis=0)


        # Define the Hamiltonian kernel
        self.kernel = weights_1*weights_2 + self.kernel_3 + self.kernel_4
        #print('self.kernel',self.kernel.get_shape())

        self.built = True
        super(Hamiltonian_Conv2D, self).build(input_shape)

    # Do the 2D convolution using the Hamiltonian kernel
    def convolution_op(self, inputs, kernel):
        if self.padding == "causal":
            tf_padding = "VALID"  # Causal padding handled in `call`.
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding


        return tf.nn.convolution(
            inputs,
            kernel,
            strides=list(self.strides),
            padding=tf_padding,
            dilations=list(self.dilation_rate),
            name=self.__class__.__name__,
        )

    def call(self, inputs):
        outputs = self.convolution_op(inputs, self.kernel)
        return outputs


In [None]:
model = DIVA2D(depth=15,filters=96,image_channels=1,use_bnorm=True)


In [None]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input0 (InputLayer)         [(None, None, None, 1)]      0         []                            
                                                                                                  
 initial_patches (Conv2D)    (None, None, None, 96)       2496      ['input0[0][0]']              
                                                                                                  
 initial_patch_acti (Activa  (None, None, None, 96)       0         ['initial_patches[0][0]']     
 tion)                                                                                            
                                                                                                  
 non_local_res_attention_bl  (None, None, None, 96)       516912    ['initial_patch_acti[0][0]

In [None]:
'''Hyperparameters'''
patch_size, stride = 25, 10
aug_times = 1
scales = [1] # [1, 0.9, 0.8, 0.7]
batch_size = 128

In [None]:
if __name__ == '__main__':
    # model selection
    model = DIVA2D(depth=15,filters=96,image_channels=1,use_bnorm=True)
    #model.summary()

    # load the last model in matconvnet style
    initial_epoch = findLastCheckpoint(save_dir=save_dir)
    if initial_epoch > 0:
        print('resuming by loading epoch %03d'%initial_epoch)
        model.load_weights(os.path.join(save_dir,'model_%03d.hdf5'%initial_epoch))

    # compile the model
    model.compile(optimizer=Adam(0.001), loss= tf.keras.losses.MeanSquaredError(), #tf.keras.losses.CosineSimilarity (axis=-1, reduction="auto", name="cosine_similarity"),
                  metrics=[tf.keras.metrics.MeanSquaredError(),
                           tf.keras.metrics.RootMeanSquaredError(),
                           tf.keras.metrics.MeanSquaredLogarithmicError(),
                           tf.keras.metrics.MeanAbsoluteError(),
                           sum_squared_error])

    # tf.keras.metrics.MeanAbsolutePercentageError(), tf.keras.metrics.CosineSimilarity(name="cosine_similarity", dtype=None, axis=-1),
    # tf.keras.metrics.LogCoshError(),

    # use call back functions
    checkpointer = ModelCheckpoint(os.path.join(save_dir,'model_{epoch:03d}.hdf5'),
                verbose=1, save_weights_only=False, period=1)
    csv_logger = CSVLogger(os.path.join(save_dir,'log.csv'), append=True, separator=',')
    lr_scheduler = LearningRateScheduler(lr_schedule)

    print('batch_size = ',args.batch_size)
    history = model.fit(train_datagen(batch_size=args.batch_size),
                steps_per_epoch=4750, epochs=80, verbose=1, initial_epoch=initial_epoch,
                callbacks=[checkpointer,csv_logger,lr_scheduler])
    #steps_per_epoch = 7000, epochs = 50

In [None]:
from google.colab import runtime
runtime.unassign()