**CV-DESPECKNET**

**Version**: v1.2 \\
**Date**: 2021-01-12 \\
**Author**: Mullissa A.G. \\
**Description**: This script trains a complex-valued multistream fully convolutional network for despeckling a polarimetric SAR covariance matrix as discussed in our paper A. G. Mullissa, C. Persello and J. Reiche, "Despeckling Polarimetric SAR Data Using a Multistream Complex-Valued Fully Convolutional Network," in IEEE Geoscience and Remote Sensing Letters, doi: 10.1109/LGRS.2021.3066311. Some utility functions are adopted from https://github.com/cszn/DnCNN


**SETTING UP THE ENVIORNMENT**

In [1]:
#mount drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#check GPU
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Mar 28 15:44:02 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
#add path
import sys
sys.path.append('/content/drive/My Drive')

In [4]:
#Install the right tensorflow and keras versions
%tensorflow_version 1.x
!pip uninstall keras
!pip install keras==2.2.3

TensorFlow 1.x selected.
Uninstalling Keras-2.3.1:
  Would remove:
    /tensorflow-1.15.2/python3.7/Keras-2.3.1.dist-info/*
    /tensorflow-1.15.2/python3.7/docs/*
    /tensorflow-1.15.2/python3.7/keras/*
Proceed (y/n)? y
  Successfully uninstalled Keras-2.3.1
Collecting keras==2.2.3
[?25l  Downloading https://files.pythonhosted.org/packages/06/ea/ad52366ce566f7b54d36834f98868f743ea81a416b3665459a9728287728/Keras-2.2.3-py2.py3-none-any.whl (312kB)
[K     |████████████████████████████████| 317kB 18.5MB/s 
Installing collected packages: keras
  Found existing installation: Keras 2.4.3
    Uninstalling Keras-2.4.3:
      Successfully uninstalled Keras-2.4.3
Successfully installed keras-2.2.3


**HELPERS**

In [6]:
import glob
import os
import cv2
import numpy as np
import tifffile

patch_size, stride = 40, 9
aug_times = 1
scales = [1, 0.9, 0.8, 0.7]
batch_size = 128

def gen_patches(file_name):

    # read image
    img = tifffile.imread(file_name) 
    img = np.array(img)
    h, w, d = img.shape
    patches = []
    for s in scales:
        h_scaled, w_scaled = int(h*s),int(w*s)
        img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)
        # extract patches
        for i in range(0, h_scaled-patch_size+1, stride):
            for j in range(0, w_scaled-patch_size+1, stride):
                x = img_scaled[i:i+patch_size, j:j+patch_size,:]
                patches.append(x)        

                
    return patches

def make_dataTensor(data_dir,verbose=False):
    
    file_list = glob.glob(data_dir+'/*.tif')  # get name list of all .tif files
    # initrialize
    data = []
    # generate patches
    for i in range(len(file_list)):
        patch = gen_patches(file_list[i])
        data.append(patch)
        if verbose:
            print(str(i+1)+'/'+ str(len(file_list)) + ' is done ^_^')
    data = np.array(data)
    data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],6))
    discard_n = len(data)-len(data)//batch_size*batch_size;
    data = np.delete(data,range(discard_n),axis = 0)
    print("Finished generating data from {}".format(data_dir))
    return data

def get_steps(data_dir, batch_size=128):
    if os.path.isfile(data_dir):
        noisy_files = [data_dir]
    else:
        noisy_files = glob.glob(data_dir + '/*.tif')
    num = 0
    #get number of steps per epoch to use in training
    for data_file in noisy_files:
        xs = make_dataTensor(data_dir)
        if xs is not None: 
            num += len(xs)
    print("total number of patches: {}".format(num))
    print("steps per epoch: {}".format(num//batch_size))
    print("")
    return num // batch_size

**DO THE JOB**

In [7]:

import complexnn
import argparse
import re
import os, glob, datetime
from keras.layers import  Input,Conv2D,BatchNormalization,Activation,Multiply, Add
from keras.models import Model, load_model
from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler
from keras.optimizers import Adam
import keras.backend as K


save_dir = os.path.join('models','/content/drive/My Drive/modelCPLX_despecknet_SSE') 

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

def cv-deSpeckNet(depth,filters=48,image_channels=6, use_bnorm=True):
    layer_count = 0
    inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))
    # 1st layer, CV-Conv+Crelu
    layer_count += 1
    x0 = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1), activation='relu', padding='same',name = 'conv'+str(layer_count))(inpt)
    # depth-2 layers, CV-Conv+CV-BN+Crelu
    for i in range(depth-2):
        layer_count += 1
        x0 = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1),activation='relu', padding='same',name = 'conv'+str(layer_count))(x0)
        if use_bnorm:
            layer_count += 1
        x0 = complexnn.bn.ComplexBatchNormalization(name = 'bn'+str(layer_count))(x0)
    # last layer, CV-Conv+Crelu
    layer_count += 1
    x0 = complexnn.conv.ComplexConv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding='same',name = 'speckle'+str(1))(x0)
    layer_count += 1
    
    # 1st layer, CV-Conv+Crelu
    x = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1), activation='relu', padding='same',name = 'conv'+str(layer_count))(inpt)
    # depth-2 layers, CV-Conv+CV-BN+Crelu
    for i in range(depth-2):
        layer_count += 1
        x = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1),activation='relu', padding='same',name = 'conv'+str(layer_count))(x)
        if use_bnorm:
            layer_count += 1
        x = complexnn.bn.ComplexBatchNormalization(name = 'bn'+str(layer_count))(x)
    # last layer, CV-Conv
    layer_count += 1
    x = complexnn.conv.ComplexConv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding='same',name = 'clean'+str(1))(x)
    layer_count += 1
    x_orig = Add(name = 'noisy' +  str(1))([x0,x])
    
    model = Model(inputs=inpt, outputs=[x,x_orig])
    
    return model


def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir,'model_*.hdf5'))  # get name list of all .hdf5 files
    #file_list = os.listdir(save_dir)
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).hdf5.*",file_)
            #print(result[0])
            epochs_exist.append(int(result[0]))
        initial_epoch=max(epochs_exist)   
    else:
        initial_epoch = 0
    return initial_epoch

def log(args,kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),args,kwargs)

def lr_schedule(epoch):
    initial_lr = 1e-3
    if epoch<=30:
        lr = initial_lr
    elif epoch<=60:
        lr = initial_lr/10
    elif epoch<=80:
        lr = initial_lr/20 
    else:
        lr = initial_lr/20 
    #log('current learning rate is %2.8f' %lr)
    return lr

def train_datagen(epoch_iter=2000,epoch_num=5,batch_size=64,data_dir='/content/drive/My Drive/data_cplx/New_train',label_dir='/content/drive/My Drive/data_cplx/New_label_final'):
    while(True):
        n_count = 0
        if n_count == 0:
            #print(n_count)
            xs = make_dataTensor(data_dir)
            xy = make_dataTensor(label_dir)
            assert len(xs)%batch_size ==0, \
            log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!')
            xs = xs.astype('float32')
            xy = xy.astype('float32')
            indices = list(range(xs.shape[0]))
            n_count = 1
        for _ in range(epoch_num):
            np.random.shuffle(indices)    # shuffle
            for i in range(0, len(indices), batch_size):
                batch_x = xs[indices[i:i+batch_size]]
                batch_y = xy[indices[i:i+batch_size]]
                yield batch_x, [batch_y, batch_x]
        
# sum square error loss function
def sum_squared_error(y_true, y_pred):
    return K.sum(K.square(y_pred - y_true))/2
    
if __name__ == '__main__':
    # model selection
    model = cv-deSpeckNet(depth=17,filters=48,image_channels=6,use_bnorm=True)
    model.summary()
    
    # load the last model in matconvnet style
    initial_epoch = findLastCheckpoint(save_dir=save_dir)
    if initial_epoch > 0:  
        print('resuming by loading epoch %03d'%initial_epoch)
        model = load_model(os.path.join(save_dir,'model_%03d.hdf5'%initial_epoch), custom_objects={'ComplexConv2D': complexnn.conv.ComplexConv2D, 'ComplexBatchNormalization': complexnn.bn.ComplexBatchNormalization, 'sum_squared_error': sum_squared_error})

    loss_funcs = {
        'clean1': sum_squared_error,
        'noisy1' : sum_squared_error}
    
    loss_weights = {'clean1': 100.0, 'noisy1': 1.0}
    
    # compile the model
    model.compile(optimizer=Adam(0.001), loss=loss_funcs, loss_weights=loss_weights)
    
    # use call back functions
    checkpointer = ModelCheckpoint(os.path.join(save_dir,'model_{epoch:03d}.hdf5'), 
                verbose=1, save_weights_only=False, period=1)
    csv_logger = CSVLogger(os.path.join(save_dir,'log.csv'), append=True, separator=',')
    lr_scheduler = LearningRateScheduler(lr_schedule)

    nsteps = get_steps(data_dir='/content/drive/My Drive/data_cplx/New_train', batch_size=128)
    
    history = model.fit_generator(train_datagen(batch_size=64),
                steps_per_epoch=nsteps, epochs=52, verbose=1, initial_epoch=initial_epoch,
                callbacks=[checkpointer,csv_logger,lr_scheduler])



Using TensorFlow backend.



Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input0 (InputLayer)             (None, None, None, 6 0                                            
__________________________________________________________________________________________________
conv33 (ComplexConv2D)          (None, None, None, 9 2688        input0[0][0]                     
__________________________________________________________________________________________________
conv1 (ComplexConv2D)           (None, None, None, 9 2688        input0[0][0]                     
__________________________________________________________________________________________________
conv34 (ComplexConv2D)          (None, None, None, 9 4

  self.data = h5py.File(path,)









Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Finished generating data from /content/drive/My Drive/data_cplx/New_train
Finished generating data from /content/drive/My Drive/data_cplx/New_train
total number of patches: 133376
steps per epoch: 1042

Epoch 52/52
Finished generating data from /content/drive/My Drive/data_cplx/New_train
Finished generating data from /content/drive/My Drive/data_cplx/New_label_final
 226/1042 [=====>........................] - ETA: 18:29 - loss: 435660.3356 - clean1_loss: 4335.5028 - noisy1_loss: 2110.0556

KeyboardInterrupt: ignored