In [None]:
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.optimizers import Adam

In [None]:
from model2 import enhance_net
from utils.training_process import training_process

# Load Data

In [None]:
epochs = 10
batch_size = 6
# 各模型極限batch size(3080Ti-12GB)： 10 , 20, 7, 
model_name = ['DCE_Denoised_noactivate']
model_name = model_name[0]
add_noise = True
lr = 1e-4
# input_shape = (512,512,3)
model_path = './model/'

In [None]:
def process(image):
    image = tf.cast(image/255. ,tf.float32)
    return image

def process2(image, image2):
    image = tf.cast(image/255. ,tf.float32)
    image2 = tf.cast(image2/255. ,tf.float32)
    return image, image2

In [None]:
train_path = './Dataset/train'
validation_path = './Dataset/validation/'
validation_label_path = './Dataset/validation_label/'

trainset = image_dataset_from_directory(train_path,
                                        labels=None,
                                        label_mode=None,
                                        color_mode='rgb',
                                        class_names=None,
                                        image_size=(512,512),
                                        batch_size=batch_size)

valset = image_dataset_from_directory(validation_path,
                                      seed=1,
                                      labels=None,
                                      label_mode=None,
                                      color_mode='rgb',
                                      class_names=None,
                                      image_size=(512,512),
                                      batch_size=batch_size)

val_label = image_dataset_from_directory(validation_path,
                                         seed=1,
                                         labels=None,
                                         label_mode=None,
                                         color_mode='rgb',
                                         class_names=None,
                                         image_size=(512,512),
                                         batch_size=batch_size)

valset = tf.data.Dataset.zip((valset, val_label))
valset = valset.map(process2)
trainset = trainset.map(process)

In [None]:
iterators = len(trainset)

# Build Model

In [None]:
model = enhance_net(input_shape=(None,None,3), model_name=model_name, add_noise=add_noise)
model.enhancement_net.summary()
model.summary()

In [None]:
model.compile(optimizer=Adam(learning_rate=lr, clipvalue=1.0))

In [None]:
best_ssim = 0
best_psnr = 0
for epoch in range(epochs):
    if epoch != 0:
        print()
    print('Epoch:{0}/{1}'.format(epoch+1,epochs))
    
    strat = time.time()
    # 預設最多會有10個評估參數
    mean_loss = np.zeros(10)
    
    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(trainset):
        # 呼叫訓練
        dict = model.train_step(x_batch_train, epoch, epochs)  
        
        # 輸出訓練過程(Epoch、step、time、total loss等參數)
        mean_loss = training_process(step, mean_loss, dict, len(trainset), strat, mode=1)  
    training_process(step, mean_loss, dict, len(trainset), strat, mode=2)
    
    mean_ssim = 0
    mean_psnr = 0
    for val, label in valset:
        ssim, psnr = model.validation_step(val, label)
        mean_ssim += ssim
        mean_psnr += psnr
    mean_ssim /= len(val_label)
    mean_psnr /= len(val_label)
    
    print()
    
    if mean_ssim > best_ssim and mean_psnr > best_psnr:
        best_ssim = mean_ssim
        best_psnr = mean_psnr
        model.model_save(epoch, model_path)
        print('save_model', end=' ')
        print('ssim: {0:6f} - psnr: {1:6f}'.format(best_ssim, best_psnr))
    else:
        print('ssim: {0:6f} - psnr: {1:6f}'.format(mean_ssim, mean_psnr))

In [None]:
model.enhancement_net.save_weights(model_path + model_name + '/enhancment/weights/finish/') 