### 1、Import the corresponding library and module of deep learning

In [None]:
import tensorflow as tf
from tensorflow import keras
from utils import *
import time
import os
from IPython import display
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt

### 2、Set some global variables: Learning Rate, Batch Size, Image Size , etc..

In [None]:
EPOCHS = 300
LEARNING_RATE = 0.0002
BATCH_SIZE = 6
SHAPE = [256,256,3]

### 3、Data preprocessing

#### 3.1、Set the file location of training set and verification set

In [None]:
train_input_img_dir='#########'
train_output_img_dir='#########'

validation_input_img_dir = '#########'
validation_output_img_dir = '#########'

In [None]:
train_input_img_dir='#########'
train_output_img_dir='#########'

validation_input_img_dir = '#########'
validation_output_img_dir = '#########'

In [None]:
train_input_img_path = [os.path.join(train_input_img_dir,img) for img in os.listdir(train_input_img_dir)]
train_output_img_path = [os.path.join(train_output_img_dir,img) for img in os.listdir(train_output_img_dir)]

validation_input_img_path = [os.path.join(validation_input_img_dir,img) for img in os.listdir(validation_input_img_dir)]
validation_output_img_path = [os.path.join(validation_output_img_dir,img) for img in os.listdir(validation_output_img_dir)]

#### 3.2、Show the number of images in the training set and verification set

In [None]:
print(len(train_input_img_path))
print(len(train_output_img_path))
print(len(validation_input_img_path))
print(len(validation_output_img_path))

In [None]:
train_input_img_path

In [None]:
train_output_img_path

In [None]:
validation_input_img_path

In [None]:
validation_output_img_path

#### 3.3、Check the number of channels for the input image 

In [None]:
train_input_img_path[0]

In [None]:
train_output_img_path[0]

In [None]:
print(load(train_input_img_path[0],train_output_img_path[0]))
for i in range(15):
    input_image, output_image = load(train_input_img_path[i],train_output_img_path[i])
    print(input_image.shape,output_image.shape)

In [None]:
print(load(validation_input_img_path[0],validation_output_img_path[0]))
for i in range(3):
    input_image, output_image = load(validation_input_img_path[i],validation_output_img_path[i])
    print(input_image.shape,output_image.shape)

#### 3.4、Image Enhancement

In [None]:
def load(input_image_file,output_image_file):
    
    input_image = tf.io.read_file(input_image_file) 
    input_image = tf.image.decode_jpeg(input_image) 
    input_image = tf.image.resize(input_image,size=(SHAPE[0],SHAPE[1])) 
    
    
    output_image = tf.io.read_file(output_image_file)
    output_image = tf.image.decode_jpeg(output_image)
    output_image = tf.image.resize(output_image,size=(SHAPE[0],SHAPE[1]))    
    
    
    return input_image, output_image

In [None]:
def normalize(input_image,output_image):
#     input_image = tf.image.per_image_standardization(input_image)
#     output_image = tf.image.per_image_standardization(output_image)
    input_image = input_image/255.0
    output_image = output_image/255.0
    
    return input_image, output_image

In [None]:
def load_image_train(input_image_path,output_image_path):
    input_image,output_image = load(input_image_path,output_image_path)
    input_image,output_image = normalize(input_image,output_image)
    
    return input_image,output_image

def load_image_validation(input_image_path,output_image_path):
    input_image,output_image = load(input_image_path,output_image_path)
    input_image,output_image = normalize(input_image,output_image)
    
    return input_image,output_image

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_input_img_path,train_output_img_path))
train_dataset = train_dataset.shuffle(buffer_size=16)
train_dataset = train_dataset.map(load_image_train,num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(BATCH_SIZE)

validation_dataset = tf.data.Dataset.from_tensor_slices((validation_input_img_path,validation_output_img_path))
# validation_dataset = validation_dataset.shuffle(buffer_size=16)
validation_dataset = validation_dataset.map(load_image_validation,num_parallel_calls=tf.data.experimental.AUTOTUNE)
validation_dataset = validation_dataset.batch(BATCH_SIZE)

### 4、Building Network Structures

#### 4.1、The Generator

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[256,256,3])
    output = tf.keras.layers.Conv2D(64,3,padding='same',activation='relu')(inputs)
    
    for layers in range(2,17,1):
        output2 = tf.keras.layers.Conv2D(64,3,padding='same',use_bias=False)(output)
        output2 = tf.keras.layers.BatchNormalization()(output2)
        output2 = tf.keras.layers.Activation('relu')(output2)
        output2 = tf.keras.layers.Conv2D(64,3,padding='same',use_bias=False)(output2)
        output2 = tf.keras.layers.BatchNormalization()(output2)
        output = output + output2
        output = tf.keras.layers.Activation('relu')(output)
        
    output = tf.keras.layers.Conv2D(3,3,padding='same')(output)
    
    return tf.keras.Model(inputs=inputs, outputs=output)

#### 4.2、The Discriminator

In [None]:
def lrelu(x):
    leak=0.2
    f1 = 0.5 * (1 + leak)
    f2 = 0.5 * (1 - leak)
    return f1 * x + f2 * abs(x)

In [None]:
def Discriminator():
    inputs = tf.keras.layers.Input(shape=[256,256,3])
    output1 = tf.keras.layers.Conv2D(64,3,padding='same',strides=1,activation=lrelu)(inputs)
    output2 = tf.keras.layers.Conv2D(64,3,padding='same',strides=2,activation=lrelu)(output1)
    output3 = tf.keras.layers.Conv2D(128,3,padding='same',strides=1,activation=lrelu)(output2)
    output4 = tf.keras.layers.Conv2D(128,3,padding='same',strides=2,activation=lrelu)(output3)
    output5 = tf.keras.layers.Conv2D(256,3,padding='same',strides=1,activation=lrelu)(output4)
    output6 = tf.keras.layers.Conv2D(256,3,padding='same',strides=2,activation=lrelu)(output5)
    output7 = tf.keras.layers.Conv2D(1,3,padding='same',strides=1,activation='sigmoid')(output6)
    
    return tf.keras.Model(inputs=inputs, outputs=output7)

In [None]:
generator = Generator()
discriminator = Discriminator()

### 5、Loss functions of the generator and the discriminator

#### 5.1、Loss functions of the generator

In [None]:
ADVERSARIAL_LOSS_FACTOR = 0.5
PSNR_LOSS_FACTOR = -1.0
SSIM_LOSS_FACTOR = -0.1

In [None]:
def PSNR(y_true, y_pred):
    max_pixel = 255.0
    return  10.0 * tf.math.log((max_pixel ** 2) / (tf.reduce_mean(tf.square(y_pred - y_true))))/tf.math.log(10.0)

In [None]:
def _tf_fspecial_gauss(size, sigma=1.5):
    """Function to mimic the 'fspecial' gaussian MATLAB function"""
    x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]

    x_data = np.expand_dims(x_data, axis=-1)
    x_data = np.expand_dims(x_data, axis=-1)

    y_data = np.expand_dims(y_data, axis=-1)
    y_data = np.expand_dims(y_data, axis=-1)

    x = tf.constant(x_data, dtype=tf.float32)
    y = tf.constant(y_data, dtype=tf.float32)

    g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    return g / tf.reduce_sum(g)

In [None]:
def SSIM_one(img1, img2, k1=0.01, k2=0.02, L=1, window_size=11):
    """
    The function is to calculate the ssim score
    """
    img1 = tf.expand_dims(img1, -1)
    img2 = tf.expand_dims(img2, -1)

    window = _tf_fspecial_gauss(window_size)

    mu1 = tf.nn.conv2d(img1, window, strides = [1, 1, 1, 1], padding = 'VALID')
    mu2 = tf.nn.conv2d(img2, window, strides = [1, 1, 1, 1], padding = 'VALID')

    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = tf.nn.conv2d(img1*img1, window, strides = [1 ,1, 1, 1], padding = 'VALID') - mu1_sq
    sigma2_sq = tf.nn.conv2d(img2*img2, window, strides = [1, 1, 1, 1], padding = 'VALID') - mu2_sq
    sigma1_2 = tf.nn.conv2d(img1*img2, window, strides = [1, 1, 1, 1], padding = 'VALID') - mu1_mu2

    c1 = (k1*L)**2
    c2 = (k2*L)**2

    ssim_map = ((2*mu1_mu2 + c1)*(2*sigma1_2 + c2)) / ((mu1_sq + mu2_sq + c1)*(sigma1_sq + sigma2_sq + c2))
    
#     ssim = tf.reduce_mean(ssim_map)
#     tf.print("ssim:",ssim)
    
#     return ssim
    
    return tf.reduce_mean(ssim_map)

In [None]:
def SSIM_three(img1, img2):
    rgb1 = tf.unstack(img1, axis=3)
    r1 = rgb1[0]
    g1 = rgb1[1]
    b1 = rgb1[2]

    rgb2 = tf.unstack(img2, axis=3)
    r2 = rgb2[0]
    g2 = rgb2[1]
    b2 = rgb2[2]

    ssim_r = SSIM_one(r1, r2)
    ssim_g = SSIM_one(g1, g2)
    ssim_b = SSIM_one(b1, b2)

    ssim = tf.reduce_mean(ssim_r + ssim_g + ssim_b) / 3
    tf.print("ssim:",ssim)
    
    return ssim

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
    return ADVERSARIAL_LOSS_FACTOR * tf.reduce_mean(tf.scalar_mul(-1, disc_generated_output)) + \
    PSNR_LOSS_FACTOR * PSNR(target, gen_output) + SSIM_LOSS_FACTOR * SSIM_three(target, gen_output)

#### 5.2、Loss functions of the discriminator

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    return tf.reduce_mean(disc_real_output) - tf.reduce_mean(disc_generated_output)

### 6、Optimizer for the generator and the discriminator

In [None]:
generator_optimizer = tf.keras.optimizers.RMSprop(LEARNING_RATE)
discriminator_optimizer = tf.keras.optimizers.RMSprop(LEARNING_RATE)

### 7、Test and visualize our test set

In [None]:
def generate_images(model, test_input, target):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], target[0], prediction[0],test_input[1], target[1], prediction[1],test_input[2], target[2], prediction[2]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image','Input Image', 'Ground Truth', 'Predicted Image','Input Image', 'Ground Truth', 'Predicted Image']

    localtime = time.strftime('%Y_%m_%d_%H_%M_%S_')
    for i in range(9):
        plt.subplot(3, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.savefig('./#########/'+localtime+'image.jpg')
    plt.show()

### 8、Training....

In [None]:
@tf.function(experimental_relax_shapes=True)
# @tf.function
def train_step(input_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
#         disc_real_output = discriminator(input_image, training=True)
        disc_real_output = discriminator(target, training=True)
        disc_generated_output = discriminator(gen_output, training=True)
        
        gen_loss = generator_loss(disc_generated_output,gen_output,target)
        disc_loss = discriminator_loss(disc_real_output,disc_generated_output)
        
        
    tf.print('g_loss:',gen_loss,'d_loss:',disc_loss)
    
#     generator_gradients = gen_tape.gradient(gen_loss,generator.variables)
#     discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.variables)
    generator_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    
#     generator_optimizer.apply_gradients(zip(generator_gradients,generator.variables))
#     discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.variables))
    generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        for example_input, example_target in test_ds.take(1):  
            generate_images(generator,example_input,example_target)
            localtime1 = time.strftime('%Y_%m_%d_%H_%M_%S')
            generator.save('./#########/generator_detectpos_{}.h5'.format(localtime1))
        print("Epoch:",epoch)
        
        for n,(input_image,target) in train_ds.enumerate():
            print('.',end='')
            if (n+1)%100 == 0:
                print()
            train_step(input_image,target,epoch)

#### training...

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
fit(train_dataset,EPOCHS,validation_dataset)