In [None]:
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.optimizers import AdamW

In [None]:
from enhanced_model import enhance_net
from utils.training_process import training_process

# Load Data

In [None]:
epochs = 100
batch_size = 8
# 各模型極限batch size(3080Ti-12GB)
model_name = ['DCE','CSP_DCE', "MSP_DCE",'DCE++']
model_name = model_name[1]
add_noise = True
lr = 0.0001
# input_shape = (512,512,3)
model_path = './model/'

In [None]:
def process(image):
    image = tf.cast(image/255. ,tf.float32)
    return image

def process2(image, image2):
    image = tf.cast(image/255. ,tf.float32)
    image2 = tf.cast(image2/255. ,tf.float32)
    return image, image2

In [None]:
# train_path = './Dataset/train'
# validation_path = './Dataset/validation/'
# validation_label_path = './Dataset/validation_label/'

train_path = './Dataset/denoise/LOL/train/low/'
validation_path = './Dataset/denoise/LOL/test/low/'
validation_label_path = './Dataset/denoise/LOL/test/high/'

size = 510

trainset = image_dataset_from_directory(train_path,
                                        labels=None,
                                        label_mode=None,
                                        color_mode='rgb',
                                        class_names=None,
                                        image_size=(size,size),
                                        batch_size=batch_size,
                                        shuffle=False)

valset = image_dataset_from_directory(validation_path,
                                      seed=1,
                                      labels=None,
                                      label_mode=None,
                                      color_mode='rgb',
                                      class_names=None,
                                      image_size=(size,size),
                                      batch_size=batch_size)

val_label = image_dataset_from_directory(validation_label_path,
                                         seed=1,
                                         labels=None,
                                         label_mode=None,
                                         color_mode='rgb',
                                         class_names=None,
                                         image_size=(size,size),
                                         batch_size=batch_size)

valset = tf.data.Dataset.zip((valset, val_label))
valset = valset.map(process2)
trainset = trainset.map(process)

In [None]:
iterators = len(trainset)

# Build Model

In [None]:
model = enhance_net(input_shape=(None,None,3), model_name=model_name)
model.enhancement_net.summary()
# model.summary()

In [None]:
model.compile(optimizer=AdamW(learning_rate=lr,
                              clipvalue=0.1,
                              weight_decay=0.0001))

In [None]:
from IPython.display import clear_output

def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

def show_predictions(dataset=None, num=1):
    clear_output(wait=True)
    for image in dataset.take(num):
        enhancemnet_image, _ = model.predict(image)
        display([image[0], enhancemnet_image[0]])

In [None]:
best_ssim = 0
best_psnr = 0

for epoch in range(epochs):
    if epoch != 0:
        print()
    print('Epoch:{0}/{1}'.format(epoch+1,epochs))
    
    strat = time.time()
    # 預設最多會有10個評估參數
    mean_loss = np.zeros(10)
    
    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(trainset):
        # 呼叫訓練
        dict = model.train_step(x_batch_train)  
        
        # 輸出訓練過程(Epoch、step、time、total loss等參數)
        mean_loss = training_process(step, mean_loss, dict, len(trainset), strat, mode=1)  
    training_process(step, mean_loss, dict, len(trainset), strat, mode=2)
    
    show_predictions(trainset)
    
    mean_ssim = 0
    mean_psnr = 0
    for val, label in valset:
        ssim, psnr = model.validation_step(val, label)
        mean_ssim += ssim
        mean_psnr += psnr
    mean_ssim /= len(val_label)
    mean_psnr /= len(val_label)
    
    print()
    
    if mean_ssim > best_ssim and mean_psnr > best_psnr:
        best_ssim = mean_ssim
        best_psnr = mean_psnr
        model.model_save(epoch, model_path)
        print('save_model', end=' ')
        print('ssim: {0:6f} - psnr: {1:6f}'.format(best_ssim, best_psnr))
    else:
        print('ssim: {0:6f} - psnr: {1:6f}'.format(mean_ssim, mean_psnr))
    
    if (epoch + 1) % 10 == 0:
        print()
        model.model_save(epoch, model_path)