In [20]:
import warnings
warnings.filterwarnings('ignore')

# %%capture --no-stderr
# %%capture output
# https://ipython.readthedocs.io/en/stable/interactive/magics.html#cellmagic-capture

https://github.com/TIMOLEEGO/LSC-CNN

https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9432740

# utils

https://stackoverflow.com/questions/71182396/modulenotfounderror-no-module-named-skimage-measure-simple-metrics

In [21]:
import math
import torch
import torch.nn as nn
import numpy as np
from skimage.metrics import peak_signal_noise_ratio

In [22]:
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        # nn.init.uniform(m.weight.data, 1.0, 0.02)
        m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant(m.bias.data, 0.0)

In [23]:
def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

def batch_SSIM(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    SSIM = 0
    for i in range(Img.shape[0]):
        SSIM += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (SSIM/Img.shape[0])

def batch_MSE(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    MSE = 0
    for i in range(Img.shape[0]):
        MSE += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (MSE/Img.shape[0])

In [24]:
def data_augmentation(image, mode):
    out = np.transpose(image, (1,2,0))
    #out = image
    if mode == 0:
        # original
        out = out
    elif mode == 1:
        # flip up and down
        out = np.flipud(out)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(out)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(out)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(out, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(out, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(out, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(out, k=3)
        out = np.flipud(out)
    return np.transpose(out, (2,0,1))


# dataset

In [25]:
import os
import os.path
import numpy as np
import random
import h5py
import torch
import cv2
import glob
import torch.utils.data as udata

In [26]:
def normalize(data):
    return data/255.

In [27]:
def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])

In [28]:
def prepare_data(data_path, patch_size, stride, aug_times=1):
    # train
    print('process training data')
    scales = [1, 0.9, 0.8, 0.7]
    #files = glob.glob(os.path.join(data_path, 'grayimages', '*'))

    files = glob.glob(os.path.join(data_path, 'train','*'))
    files.sort()
    # h5f = h5py.File('train.h5', 'w')
    h5f = h5py.File('train.h5', 'r')
    train_num = 0
    for i in range(len(files)):
        img = cv2.imread(files[i])
        h, w, c = img.shape
        for k in range(len(scales)):
            Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
            Img = np.expand_dims(Img[:,:,0].copy(), 0)
            Img = np.float32(normalize(Img))
            patches = Im2Patch(Img, win=patch_size, stride=stride)
            print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))
            for n in range(patches.shape[3]):
                data = patches[:,:,:,n].copy()
                h5f.create_dataset(str(train_num), data=data)
                train_num += 1
                for m in range(aug_times-1):
                    data_aug = data_augmentation(data, np.random.randint(1,8))
                    h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
                    train_num += 1
    h5f.close()
    # val
    print('\nprocess validation data')
    #files.clear()

    files = glob.glob(os.path.join(data_path,'val','*'))
    files.sort()
    # h5f = h5py.File('val.h5', 'w')
    h5f = h5py.File('val.h5', 'r')
    val_num = 0
    for i in range(len(files)):
        print("file: %s" % files[i])
        img = cv2.imread(files[i])
        img = np.expand_dims(img[:,:,0], 0)
        img = np.float32(normalize(img))
        h5f.create_dataset(str(val_num), data=img)
        val_num += 1
    h5f.close()
    print('training set, # samples %d\n' % train_num)
    print('val set, # samples %d\n' % val_num)

In [29]:
class Dataset(udata.Dataset):
    def __init__(self, train=True):
        super(Dataset, self).__init__()
        self.train = train
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        self.keys = list(h5f.keys())
        random.shuffle(self.keys)
        h5f.close()
    def __len__(self):
        return len(self.keys)
    def __getitem__(self, index):
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        key = self.keys[index]
        data = np.array(h5f[key])
        h5f.close()
        return torch.Tensor(data)

# Model

In [30]:
from tensorflow.keras.layers import \
    Conv2D,Input,add,Activation,BatchNormalization,Multiply,Subtract,Concatenate,Conv2DTranspose
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model

import tensorflow as tf
import six

try:
    from keras import initializations
except ImportError:
    from keras import initializers as initializations
import keras.backend as K

In [31]:
def tanh_layer(x):
    return tf.tanh(x)

In [32]:
class ADresnetBuilder(object):
    def build(self, block_fn, repetitions):
        input_shape = (None,None,1)
        self._handle_dim_ordering()
        block_fn = self._get_block(block_fn)

        input = Input(shape=input_shape)
        conv1 = self._conv_bn_relu(filters=128, kernel_size=(7, 7), strides=(1, 1), padding='same')(input)#0708
        conv2 = self._conv_bn_relu(filters=64, kernel_size=(3, 3), strides=(1, 1), padding='same')(conv1)
        conv3 = self._conv_bn_relu(filters=64, kernel_size=(3, 3), strides=(2, 2), padding='same')(conv2)  #0708

        block = conv3
        filters = 64
        for i, r in enumerate(repetitions):
            block = self._residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block)
        block = self._bn_relu(block)
        block = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same')(block)
        block = self._bn_relu(block)
        block = Conv2D(filters=1,kernel_size=3,padding='same')(block)
        out = Concatenate(3)([input,block])
        out = Activation('relu')(out)
        out = Conv2D(filters=1,kernel_size=3,padding='same')(out)
        out = Multiply()([out,block])
        out2 = Subtract()([input,out])
        model = Model(inputs=input, outputs=out2)
        return model


    # @staticmethod
    def build_resnet18(self):
        return self.build(self.basic_block, [2, 2, 2, 2])

    # @staticmethod
    def build_resnet34(self):
        return self.build( self.basic_block, [3, 4, 6, 3])

    # @staticmethod
    def build_resnet50(self):
        return self.build( self.bottleneck, [3, 4, 6, 3])

    # @staticmethod
    def build_resnet101(self):
        return self.build(self.bottleneck, [3, 4, 23, 3])

    # @staticmethod
    def build_resnet152(self):
        return self.build(self.bottleneck, [3, 8, 36, 3])

    def _bn_relu(self,input):
        """Helper to build a BN -> relu block
        """
        norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
        return Activation("relu")(norm)

    def _conv_bn_relu(self,**conv_params):
        """Helper to build a conv -> BN -> relu block
        """
        filters = conv_params["filters"]
        kernel_size = conv_params["kernel_size"]
        strides = conv_params.setdefault("strides", (1, 1))
        padding = conv_params.setdefault("padding", "same")
        kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))

        def f(input):
            conv = Conv2D(filters=filters, kernel_size=kernel_size,
                          strides=strides, padding=padding,
                          kernel_regularizer=kernel_regularizer)(input)
            return self._bn_relu(conv)

        return f


    def _bn_relu_conv(self,**conv_params):
        """Helper to build a BN -> relu -> conv block.
        This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
        """
        filters = conv_params["filters"]
        kernel_size = conv_params["kernel_size"]
        strides = conv_params.setdefault("strides", (1, 1))
        padding = conv_params.setdefault("padding", "same")
        kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))

        def f(input):
            activation = self._bn_relu(input)
            return Conv2D(filters=filters, kernel_size=kernel_size,
                          strides=strides, padding=padding,
                          kernel_regularizer=kernel_regularizer)(activation)

        return f


    def _shortcut(self,input, residual):
        """Adds a shortcut between input and residual block and merges them with "sum"
        """
        # Expand channles of shortcut to match residual.
        # Stride appropriately to match residual (width, height)
        # Should be int if network architecture is correctly configured.
        input_shape = K.int_shape(input)
        residual_shape = K.int_shape(residual)
        stride_width = 1
        stride_height = 1
        equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]

        shortcut = input
        # 1 X 1 conv if shape is different. Else identity.
        if stride_width > 1 or stride_height > 1 or not equal_channels:
            shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
                              kernel_size=(1, 1),
                              strides=(stride_width, stride_height),
                              padding="valid",
                              kernel_regularizer=l2(0.0001))(input)

        return add([shortcut, residual])


    def _residual_block(self,block_function, filters, repetitions, is_first_layer=False):
        """Builds a residual block with repeating bottleneck blocks.
        """
        def f(input):
            for i in range(repetitions):
                init_strides = (1, 1)
                if i == 0 and not is_first_layer:
                    init_strides = (1, 1)
                input = block_function(filters=filters, init_strides=init_strides,
                                       is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
            return input

        return f


    def basic_block(self,filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
        """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
        Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
        """
        def f(input):

            if is_first_block_of_first_layer:
                # don't repeat bn->relu since we just did bn->relu->maxpool
                conv1 = Conv2D(filters=filters, kernel_size=(3, 3),
                               strides=init_strides,
                               padding="same",
                               #kernel_initializer="he_normal",
                               kernel_regularizer=l2(1.e-4))(input)  #1e-6
                               #use_bias=False)(input)   #+use_bias=False
            else:
                conv1 = self._bn_relu_conv(filters=filters, kernel_size=(3, 3),
                                      strides=init_strides)(input)
                                           #use_bias=False)(input)   #+use_bias=False

            residual = self._bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv1)
            return self._shortcut(input, residual)

        return f


    def bottleneck(self,filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
        """Bottleneck architecture for > 34 layer resnet.
        Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
        Returns:
            A final conv layer of filters * 4
        """
        def f(input):

            if is_first_block_of_first_layer:
                # don't repeat bn->relu since we just did bn->relu->maxpool
                conv_1_1 = Conv2D(filters=filters, kernel_size=(1, 1),
                                  strides=init_strides,
                                  padding="same",
                                  #kernel_initializer="he_normal",
                                  kernel_regularizer=l2(1.e-4))(input)#1e-4-->6-->4
                                  #use_bias=False)(input)    #+use_bias=False
            else:
                conv_1_1 = self._bn_relu_conv(filters=filters, kernel_size=(1, 1),
                                         strides=init_strides)(input)

            conv_3_3 = self._bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv_1_1)
            residual = self._bn_relu_conv(filters=filters * 4, kernel_size=(1, 1))(conv_3_3)
            return self._shortcut(input, residual)

        return f

    def _handle_dim_ordering(self):
        global ROW_AXIS
        global COL_AXIS
        global CHANNEL_AXIS
        ROW_AXIS = 1
        COL_AXIS = 2
        CHANNEL_AXIS = -1#3

    def _get_block(self,identifier):
        if isinstance(identifier, six.string_types):
            res = globals().get(identifier)
            if not res:
                raise ValueError('Invalid {}'.format(identifier))
            return res
        return identifier

# Train

In [33]:
import h5py
import numpy as np
from tensorflow.keras import callbacks
import random
import tensorflow.keras.optimizers as opt
from tensorflow.keras.callbacks import LearningRateScheduler,ReduceLROnPlateau
import datetime
import time
import cv2
import keras.backend as K
import tensorflow as tf

In [34]:
def step_decay(epoch):
   initial_lrate = 0.001
   drop = 0.1#0.5
   epochs_drop = 40 #5.0
   lr = initial_lrate * (drop ** np.floor((1 + epoch) / epochs_drop))
   return lr

lr = LearningRateScheduler(step_decay) 

def custom_loss(y_true,y_pred):
    diff=y_true-y_pred
    res=K.sum(diff*diff)/(2*128)
    return res

In [35]:
def data_generator(data_path,batch_size):
    h5f = h5py.File(data_path, 'r')
    keys = list(h5f.keys())
    flag=0
    batch_img = []
    batch_noiseimg = []
    random.shuffle(keys)
    for k in keys:
        img = np.array(h5f[k]).transpose((1,2,0))
        G_col = np.random.normal(0, 15 / 255., (1,50))
        noise = np.expand_dims(np.tile(G_col, (50, 1)),axis=2)
        noise_img = img+noise
        batch_img.append(img)
        batch_noiseimg.append(noise_img)
        flag+=1
        if flag==batch_size:
            yield (np.array(batch_noiseimg) ,np.array(batch_img))
            flag=0
            batch_img.clear()
            batch_noiseimg.clear()

In [47]:
log_dir = '../log/'+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
check_point = callbacks.ModelCheckpoint('../model.h5',
                                     monitor='loss',verbose=1,save_best_only=True, save_weights_only=False, mode='min')
vis = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_grads=True, write_images=True,
                            embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None,
                            embeddings_data=None, update_freq='epoch')

builder = ADresnetBuilder()
model = builder.build_resnet18()
model.compile(optimizer=opt.Adam(lr=0.001,beta_1=0.9,beta_2=0.999,epsilon=1.e-8), loss=custom_loss)





In [None]:
batch_size=128
train_dir = 'C:/Users/thang/OneDrive/Desktop/m1_dataChal/LSC-CNN dataset/val.h5'
h5f = h5py.File(train_dir, 'r')
keys = list(h5f.keys())
steps = int(np.floor(len(keys) / batch_size))

In [None]:
epochs=50
summary_write = tf.summary.create_file_writer(log_dir)
for epoch in range(epochs):
    reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=3, mode='min', verbose=1)
    train_data=data_generator('../train.h5',batch_size=128)
    model.fit(train_data,steps_per_epoch=steps,use_multiprocessing=False,shuffle=True,callbacks=[check_point,vis,reduce_lr,lr])
    cap = cv2.VideoCapture('../dym4.avi')
    ret = 1
    total_psnr = 0  # 0710+
    total_mse = 0  # 0710+
    total_ssim = 0
    starttime = time.clock()
    while (ret):
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            G_col = np.random.normal(0, 15 / 255., (1, 320))
            noise = np.expand_dims(np.tile(G_col, (240, 1)), axis=2)
            img = np.expand_dims(frame, axis=2).astype(np.float) / 255.
            noise_img = np.expand_dims((img + noise), axis=0)
            #cv2.imshow('1', (img + noise))
            
            # -------------------------------PSNR-MSE-----------------------------------
            result = model.predict(noise_img)
            raw_img_data = np.array(img, dtype=np.float32)  # dtype=np.float64)
            denoise_img_data = np.array(result, dtype=np.float32)  # dtype=np.float64)
            mse = (np.abs(raw_img_data - denoise_img_data) ** 2.).mean()
            psnr_denoise_raw = 20 * np.log10(1. / np.sqrt(mse))
            
            # -------------------------------SSIM-------------------------------------
            im1 = tf.image.convert_image_dtype(raw_img_data, tf.float32)
            im2 = tf.image.convert_image_dtype(denoise_img_data, tf.float32)

            ssim_tf = tf.image.ssim(im1, im2, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
            ssim_np = ssim_tf.numpy()
            ssim = ssim_np[0]
            print("PSNR: ", psnr_denoise_raw, "    SSIM: ", ssim)
            with summary_write.as_default():
                tf.summary.scalar('PSNR step=epoch', psnr_denoise_raw, step=epoch)
                tf.summary.scalar('SSIM step=epoch', ssim, step=epoch)

            total_mse += mse
            total_psnr += psnr_denoise_raw
            total_ssim += ssim
            
    endtime = time.clock()
    totaltime = endtime - starttime
    AVG_TIME = totaltime / 269
    AVG_PSNR = total_psnr / 269
    AVG_MSE = total_mse / 269
    AVG_SSIM = total_ssim / 269
    
    print("AVG_PSNR: ", AVG_PSNR)
    print("AVG_MSE: ", AVG_MSE)
    print("AVG_TIME: ", AVG_TIME)
    print("AVG_SSIM: ", AVG_SSIM)

    with summary_write.as_default():
        tf.summary.scalar('AVG_PSNR step=epoch', AVG_PSNR, step=epoch)
        tf.summary.scalar('AVG_MSE step=epoch', AVG_MSE, step=epoch)
        tf.summary.scalar('AVG_SSIM step=epoch', AVG_SSIM, step=epoch)
        tf.summary.scalar('AVG_TIME step=epoch', AVG_TIME, step=epoch)
    if AVG_PSNR>=37:
       model.save('../model_37.h5')
    if AVG_PSNR>=37.2:
       model.save('../model_37_2.h5')
    if AVG_PSNR>=37.4:
       model.save('../model_37_4.h5')

<HDF5 file "val.h5" (mode r)>

# Test

In [49]:
import cv2
import numpy as np
import time
import math
import tensorflow as tf

In [None]:
builder = ADresnetBuilder()
model = builder.build_resnet18()
input_shape = (None, None, 1)
model.build(input_shape)
model.load_weights('model_37_4.h5')
cap = cv2.VideoCapture('real.avi')  #change test file
i = 0
ret = 1
total_psnr = 0
total_ssim = 0
total_time = 0
while (ret):
    ret, frame = cap.read()
    if ret:
        frame = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
        G_col = np.random.normal(0, 15 / 255., (1, 320))
        noise = np.expand_dims(np.tile(G_col, (240, 1)), axis=2)
        img = np.expand_dims(frame,axis=2).astype(np.float)/255.
        noise_img = np.expand_dims((img + noise),axis=0)
        #cv2.imshow('1',(img + noise))
        cv2.imwrite('./test/noise' + '/' + '%d.png' % i, (img + noise))
        starttime = time.clock()
        result = model.predict(noise_img)
        single_time = time.clock() - starttime
        print(single_time)
        #-------------------------------PSNR------------------------------------
        raw_img_data = np.array(img,dtype=np.float64)
        denoise_img_data = np.array(result,dtype=np.float64)
        #----------------------------MSE----------------------------------------
        mse = (np.abs(raw_img_data-denoise_img_data)**2.).mean()
        psnr_denoise_raw = 20*np.log10(1./np.sqrt(mse))
        print(psnr_denoise_raw)
        # -------------------------SSIM------------------------------------------
        im1 = tf.image.convert_image_dtype(raw_img_data, tf.float32)
        im2 = tf.image.convert_image_dtype(denoise_img_data, tf.float32)
        ssim_tf = tf.image.ssim(im1, im2, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
        ssim_np = ssim_tf.numpy()
        ssim = ssim_np[0]
        print("ssim:",ssim)
        total_psnr += psnr_denoise_raw
        total_ssim += ssim
        #-----------------------------------------------------------------------
        result = np.clip(np.squeeze(result)*255,0,255).astype(np.uint8)
        total_time += single_time
        cv2.imwrite('./test/real' + '/' + '%d.png' % i, result)
        i += 1
        cv2.waitKey(1)
print("avg_psnr:", total_psnr/269) #real.avi->269frame
print("avg_ssim:", total_ssim/269)
print("avg_time:", total_time/269)