In [None]:
import math
import cv2
import os, glob, re
import numpy as np
from tensorflow import keras
from sklearn.feature_extraction import image
from keras.layers import  Input,Conv2D,BatchNormalization,Activation,Add,Subtract,MaxPooling2D
from keras.layers.merge import concatenate
from keras.models import Model, load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, CSVLogger
from skimage.metrics import peak_signal_noise_ratio


def data_aug(img, mode=0):

    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(img)
    elif mode == 2:
        return np.rot90(img)
    elif mode == 3:
        return np.flipud(np.rot90(img))
    elif mode == 4:
        return np.rot90(img, k=2)
    elif mode == 5:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 6:
        return np.rot90(img, k=3)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))

np.random.seed(0)
train_target = np.zeros((128*1600,40,40), dtype='float32')
#train_input = np.zeros((128*1600,40,40), dtype='float32')
for i in range(400):
    name_img = './drive/MyDrive/train400/'+str('%03d'%(i+1))+'.png'
    img = cv2.imread(name_img, cv2.IMREAD_GRAYSCALE)
    patch = image.extract_patches_2d(img, (40, 40), 512)
    patch = patch.astype('float32')
    train_target[512*(i):512*(i+1),:,:] = patch
for i in range (128*1600):
    x = train_target[i]
    train_target[i] = data_aug(x, mode=np.random.randint(0,8))
    #train_input[i] = train_target[i] + np.random.normal(0, 15, train_target[i].shape)
train_target = train_target.reshape((128*1600,40,40,1))
#train_input = train_input.reshape((128*1600,40,40,1))
np.random.seed(None)

def train_datagen(epoch_num=100,batch_size=128):
    sigma = 15
    indices = list(range(128*1600))
    for _ in range(epoch_num):
        np.random.shuffle(indices)    # shuffle
        for i in range(0, len(indices), batch_size):
            batch_x = train_target[indices[i:i+batch_size]]
            noise =  np.random.normal(0, sigma, batch_x.shape)    # noise
            batch_y = batch_x + noise
            #batch_y = train_input[indices[i:i+batch_size]]
            batch_x /= 255
            batch_y /= 255
            yield batch_y, batch_x

def naive_inception_module(layer_in, f1, f2, f3):
	# 1x1 conv
	conv1 = Conv2D(f1, (1,1), padding='same', activation='elu')(layer_in)
	# 3x3 conv
	conv3 = Conv2D(f2, (3,3), padding='same', activation='elu')(layer_in)
	# 5x5 conv
	conv5 = Conv2D(f3, (5,5), padding='same', activation='elu')(layer_in)
	# 3x3 max pooling
	pool = MaxPooling2D((3,3), strides=(1,1), padding='same')(layer_in)
	# concatenate filters, assumes filters/channels last
	layer_out = concatenate([conv1, conv3, conv5, pool], axis=-1)
	return layer_out

def inception_module(layer_in, f1, f2_in, f2_out, f3_in, f3_out, f4_out):
	# 1x1 conv
	conv1 = Conv2D(f1, (1,1), padding='same', activation='elu')(layer_in)
	# 3x3 conv
	conv3 = Conv2D(f2_in, (1,1), padding='same', activation='elu')(layer_in)
	conv3 = Conv2D(f2_out, (3,3), padding='same', activation='elu')(conv3)
	# 5x5 conv
	conv5 = Conv2D(f3_in, (1,1), padding='same', activation='elu')(layer_in)
	conv5 = Conv2D(f3_out, (5,5), padding='same', activation='elu')(conv5)
	# 3x3 max pooling
	pool = MaxPooling2D((3,3), strides=(1,1), padding='same')(layer_in)
	pool = Conv2D(f4_out, (1,1), padding='same', activation='elu')(pool)
	# concatenate filters, assumes filters/channels last
	layer_out = concatenate([conv1, conv3, conv5, pool], axis=-1)
	return layer_out

def DnCNN(filters=64,image_channels=1):

  layer_count = 0
  inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))
  #layer_count += 1
  #x = Conv2D(filters=int(1*filters), kernel_size=(5,5), strides=(1,1),kernel_initializer='Orthogonal', padding='same', dilation_rate=1, use_bias = False,name = 'conv'+str(layer_count))(inpt)
  layer_count += 1
  x = naive_inception_module(inpt, 64, 128, 32)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = Conv2D(filters=int(1*filters), kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same', dilation_rate=2, use_bias = False,name = 'conv'+str(layer_count))(x)
  layer_count += 1
  x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.000, name = 'bn'+str(layer_count))(x)
  layer_count += 1
  x = Activation('elu',name = 'elu'+str(layer_count))(x)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = Conv2D(filters=int(1*filters), kernel_size=(5,5), strides=(1,1),kernel_initializer='Orthogonal', padding='same', dilation_rate=1, use_bias = False,name = 'conv'+str(layer_count))(x)
  layer_count += 1
  x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.000, name = 'bn'+str(layer_count))(x)
  layer_count += 1
  x = Activation('elu',name = 'elu'+str(layer_count))(x)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = Conv2D(filters=int(1*filters), kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same', dilation_rate=1, use_bias = False,name = 'conv'+str(layer_count))(x)
  layer_count += 1
  x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.000, name = 'bn'+str(layer_count))(x)
  layer_count += 1
  x = Activation('elu',name = 'elu'+str(layer_count))(x)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = Conv2D(filters=int(1*filters), kernel_size=(5,5), strides=(1,1),kernel_initializer='Orthogonal', padding='same', dilation_rate=1, use_bias = False,name = 'conv'+str(layer_count))(x)
  layer_count += 1
  x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.000, name = 'bn'+str(layer_count))(x)
  layer_count += 1
  x = Activation('elu',name = 'elu'+str(layer_count))(x)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = Conv2D(filters=int(1*filters), kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same', dilation_rate=2, use_bias = False,name = 'conv'+str(layer_count))(x)
  layer_count += 1
  x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.000, name = 'bn'+str(layer_count))(x)
  layer_count += 1
  x = Activation('elu',name = 'elu'+str(layer_count))(x)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = inception_module(x, 64, 96, 128, 16, 32, 32)
  layer_count += 1
  x = Add(name = 'add' + str(layer_count))([inpt, x])
  layer_count += 1
  x = Conv2D(filters=image_channels, kernel_size=(3,3), strides=(1,1), kernel_initializer='Orthogonal',padding='same', dilation_rate=1, use_bias = False,name = 'conv'+str(layer_count))(x)
  #layer_count += 1
  #x = Subtract(name = 'subtract' + str(layer_count))([inpt, x])   # input - noise
  model = Model(inputs=inpt, outputs=x)
  return model

def lr_schedule(epoch):
    initial_lr = 1e-3
    if epoch<=30:
        lr = initial_lr
    elif epoch<=60:
        lr = initial_lr/10
    elif epoch<=100:
        lr = initial_lr/20
    else:
        lr = initial_lr/40
    initial_lr = 1e-4
    t = epoch - 1
    #k = 0.1409745975 #50epoch
    k = 0.0232584353 #100epoch
    lr = initial_lr * math.exp(-k*t)
    lr = 1e-4
    return lr

save_dir = os.path.join('./drive/MyDrive/models')

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir,'model_*.hdf5'))  # get name list of all .hdf5 files
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).hdf5.*",file_)
            #print(result[0])
            epochs_exist.append(int(result[0]))
        initial_epoch=max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch


initial_epoch = findLastCheckpoint(save_dir=save_dir)
if initial_epoch > 0:
    print('resuming by loading epoch %03d'%initial_epoch)
    model = load_model(os.path.join(save_dir,'model_%03d.hdf5'%initial_epoch), compile=False)
else:
    model = DnCNN(filters=48,image_channels=1)
model.compile(optimizer=keras.optimizers.Adam(0.001), loss=keras.losses.mse)
model.summary()
checkpointer = ModelCheckpoint(os.path.join(save_dir,'model_{epoch:03d}.hdf5'), verbose=1, save_weights_only=False, save_freq="epoch")
csv_logger = CSVLogger(os.path.join(save_dir,'log.csv'), append=True, separator=',')
lr_scheduler = LearningRateScheduler(lr_schedule)
network_history = model.fit(train_datagen(epoch_num=500,batch_size=128), steps_per_epoch=1600, epochs=500, verbose=1, initial_epoch=initial_epoch,
                            callbacks=[checkpointer, csv_logger, lr_scheduler])

images = []
for i in range(1,69):
    name_img = './drive/MyDrive/Set68/test ('+str(i)+').png'
    img = cv2.imread(name_img, cv2.IMREAD_GRAYSCALE)
    j,k = img.shape
    img = img.reshape((1,j,k,1))
    images.append(img)

ite = 10
psnrs = []
for _ in range(ite):
    psnr = []
    for i in range(68):
        test_target = images[i]
        test_input = test_target + np.random.normal(0, 15, test_target.shape)
        test_input = test_input.astype('float32')
        test_input /= 255
        test_output = model.predict(test_input)
        test_output *= 255
        test_output = test_output*(test_output>=0) - test_output*(test_output>255) + 255*(test_output>255)
        #test_output = np.around(test_output)
        #test_output = test_output.astype('uint8')
        ii,j,m,n = test_target.shape
        test_target = test_target.reshape((j,m))
        test_output = test_output.reshape((j,m))
        psnr_x = peak_signal_noise_ratio(test_target, test_output, data_range=255)
        psnr.append(psnr_x)
    ave = np.mean(psnr)
    psnrs.append(ave)
    #print(ave)
print("Denoising result on testset is",np.mean(psnrs),"+-",np.std(psnrs))