In [None]:
import numpy as np 
import pandas as pd 
from tensorflow.python.client import device_lib
import tensorflow as tf
import os
from keras.models import *
from keras.layers import Input,Conv2D,BatchNormalization,Activation,Lambda,Subtract
import glob
import cv2
from multiprocessing import Pool
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.callbacks import LearningRateScheduler
from sklearn.metrics import mean_squared_error
print(os.listdir())

In [None]:

device_lib.list_local_devices()

# Neural Network

In [None]:
def DnCNN():
    
    inpt = Input(shape=(None,None,1))

    x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(inpt)
    x = Activation('relu')(x)

    for i in range(15):
        x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(x)
        x = BatchNormalization(axis=-1, epsilon=1e-3)(x)
        x = Activation('relu')(x)   
    
    x = Conv2D(filters=1, kernel_size=(3,3), strides=(1,1), padding='same')(x)
    x = Subtract()([inpt, x])
    model = Model(inputs=inpt, outputs=x)
    
    return model

# Dataset

In [None]:
def data_aug(img, mode=0):
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(img)
    elif mode == 2:
        return np.rot90(img)
    elif mode == 3:
        return np.flipud(np.rot90(img))
    elif mode == 4:
        return np.rot90(img, k=2)
    elif mode == 5:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 6:
        return np.rot90(img, k=3)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))
    
def gen_patches(file_name):

    img = cv2.imread(file_name, 0) 
    h, w = img.shape
    scales = [1, 0.9, 0.8, 0.7]
    patches = []

    for scale in scales:
        h_scaled, w_scaled = int(h*scale),int(w*scale)
        img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)

        for i in range(0, h_scaled-patch_size+1, stride):
            for j in range(0, w_scaled-patch_size+1, stride):
                x = img_scaled[i:i+patch_size, j:j+patch_size]
                
                for k in range(0, aug_times):
                    x_aug = data_aug(x, mode=np.random.randint(0,8))
                    patches.append(x_aug)
    
    return patches

In [None]:
patch_size, stride = 20, 20
aug_times = 1
src_dir = './data/Train400/'
save_dir = './data/npy_data/'
file_list = glob.glob(src_dir+'*.png') 
num_threads = 16    

In [None]:
res = []

for i in range(0,len(file_list),num_threads):
    
    p = Pool(num_threads)
    patch = p.map(gen_patches,file_list[i:min(i+num_threads,len(file_list))])
    
    for x in patch:
        res += x

    print('Picture '+str(i)+' to '+str(i+num_threads)+' are finished.')

res = np.array(res, dtype='uint8')
print('Shape of result: ' + str(res.shape))

# Train network

In [None]:
models='DnCNN'
name = 'dncnn' 
batch_size=1024
train_data= res
test_dir='./data/Test/'
sigma=25
epoch=10
lr=1e-3

In [None]:
def learning_rate(epoch):
    
    lr = 1e-3
    if epoch%10 == 0:
        lr = lr/2
    
    return lr


def train_generator(data, batch_size=8):
    
    indices = list(range(data.shape[0]))
    while(True):
        np.random.shuffle(indices)   
        for i in range(0, len(indices), batch_size):
            batch = data[indices[i:i+batch_size]]
            noise =  np.random.normal(0, sigma/255.0, batch.shape) 
            transformed_batch = batch + noise
            yield transformed_batch, batch

In [None]:
n = len(data)
np.random.shuffle(data)
valid_y = data[int(n*0.85):]
train_data = data[:int(n*0.85)]

valid_x = valid_y.copy()
for img in range(len(valid_x)):
    valid_x[img] = valid_x[img] + np.random.normal(0, sigma/255.0, valid_x[img].shape)
    valid_x[img] = valid_x[img].astype('float32')

In [None]:
model.compile(optimizer=Adam(), loss=['mse'])

hist = model.fit_generator(train_generator(train_data, batch_size=batch_size),
            steps_per_epoch=len(data)//batch_size,
            validation_data = (valid_x, valid_y),
            epochs=epoch, 
            verbose=1, 
            callbacks=[lr])

In [None]:
import matplotlib.pyplot as plt

loss_values = hist.history['loss']
val_loss_values = hist.history['val_loss']

epochs = hist.epochs

plt.plot(epochs, loss_values, 'bo', label = 'Training loss')
plt.plot(epochs, val_loss_values, 'b', label = 'Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
file_list = glob.glob('{}/*.png'.format(test_dir))

for file in file_list:
    img_clean = np.array(Image.open(file), dtype='uint8')
    img_clean_scaled = np.array(img_clean, dtype='float32') / 255.0
    img_test = img_clean_scaled + np.random.normal(0, sigma/255.0, img_clean.shape)
    img_test = img_test.astype('float32')

    x_test = img_test.reshape(1, img_test.shape[0], img_test.shape[1], 1) 
    y_predict = model.predict(x_test)

    img_out = y_predict.reshape(img_clean.shape)


    img_out = np.clip(img_out, 0, 1)
    img_out = Image.fromarray((img_out*255).astype('uint8')) 

    print('MSE test: ', mean_squared_error(img_out, img_clean))

In [None]:
model_json = model.to_json()

with open(name+".json", "w") as json_file:
    json_file.write(model_json)

model.save_weights(name+".h5")