In [2]:
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.optimizers import Adam

In [4]:
from denoise_model import denoising_net
from utils.training_process import training_process

# Config

In [10]:
Lambda1 = 1
Lambda2 = 2
epochs = 1000
batch_size = 16
# 各模型極限batch size(3080Ti-12GB)： 10 , 20, 7, 
model_name = 'Small_Denoised_sigmoid_LOL_Lambda2_2_guassion25'
add_noise = True
lr = 1e-4
# input_shape = (512,512,3)
model_path = './model/'

mode = ['Large', 'Small']
mode = mode[0]

# Load Data

In [6]:
def process(image):
    image = tf.cast(image/255. ,tf.float32)
    return image

def process2(image, image2):
    image = tf.cast(image/255. ,tf.float32)
    image2 = tf.cast(image2/255. ,tf.float32)
    return image, image2

In [7]:
train_path = './Dataset/denoise/train/'
# train_path = './Dataset/train/'
validation_path = './Dataset/denoise/train_label/'

trainset = image_dataset_from_directory(train_path,
                                        labels=None,
                                        label_mode=None,
                                        color_mode='rgb',
                                        class_names=None,
                                        image_size=(256,256),
                                        batch_size=batch_size,
                                        crop_to_aspect_ratio=True)

valset = image_dataset_from_directory(train_path,
                                      seed=1,
                                      labels=None,
                                      label_mode=None,
                                      color_mode='rgb',
                                      class_names=None,
                                      image_size=(256,256),
                                      batch_size=batch_size,
                                      crop_to_aspect_ratio=True)

val_label = image_dataset_from_directory(train_path,
                                        seed=1,
                                        labels=None,
                                        label_mode=None,
                                        color_mode='rgb',
                                        class_names=None,
                                        image_size=(256,256),
                                        batch_size=batch_size,
                                        crop_to_aspect_ratio=True)

valset = tf.data.Dataset.zip((valset, val_label))
valset = valset.map(process2)
trainset = trainset.map(process)

Found 485 files belonging to 1 classes.
Found 485 files belonging to 1 classes.
Found 485 files belonging to 1 classes.


In [8]:
iterators = len(trainset)

# Build Model

In [11]:
model = denoising_net(input_shape=(None,None,3),
                        add_noise=add_noise,
                        model_name=model_name,
                        Lambda1=Lambda1,
                        Lambda2=Lambda2,
                        mode=mode)
model.denoising_net.summary()
# model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv2d_27 (Conv2D)             (None, None, None,   1344        ['input_3[0][0]']                
                                48)                                                               
                                                                                                  
 leaky_re_lu_26 (LeakyReLU)     (None, None, None,   0           ['conv2d_27[0][0]']              
                                48)                                                         

# Compile

In [9]:
ratio = epochs / 100

def lr_callable(ratio, epoch, gamma=0.5):
    print()
    if epoch == int(20 * ratio) -1:
        new_lr = model.optimizer.learning_rate * gamma
        model.optimizer.learning_rate = new_lr
        print('upgard lr!!!')
    elif epoch == int(40 * ratio) -1:
        new_lr = model.optimizer.learning_rate * gamma
        model.optimizer.learning_rate = new_lr
        print('upgard lr!!!')
    elif epoch == int(60 * ratio) -1:
        new_lr = model.optimizer.learning_rate * gamma
        model.optimizer.learning_rate = new_lr
        print('upgard lr!!!')
    elif epoch == int(80 * ratio) -1:
        new_lr = model.optimizer.learning_rate * gamma
        model.optimizer.learning_rate = new_lr
        print('upgard lr!!!')
        
model.compile(optimizer=Adam(learning_rate=lr, clipvalue=1.0))

# Trianing

In [10]:
best_ssim = 0
best_psnr = 0
for epoch in range(epochs):
    if epoch != 0:
        print()
    print('Epoch:{0}/{1}'.format(epoch+1,epochs))
    
    strat = time.time()
    # 預設最多會有10個評估參數
    mean_loss = np.zeros(10)
    
    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(trainset):
        # 呼叫訓練
        dict = model.train_step(x_batch_train, epoch, epochs)  
        
        # 輸出訓練過程(Epoch、step、time、total loss等參數)
        mean_loss = training_process(step, mean_loss, dict, len(trainset), strat, mode=1)  
    training_process(step, mean_loss, dict, len(trainset), strat, mode=2)

    lr_callable(ratio, epoch, gamma=0.5)
    
    mean_ssim = 0
    mean_psnr = 0
    for val, label in valset:
        ssim, psnr = model.validation_step(val, label)
        mean_ssim += ssim
        mean_psnr += psnr
    mean_ssim /= len(valset)
    mean_psnr /= len(valset)
    
    print()
    
    if mean_ssim > best_ssim and mean_psnr > best_psnr:
        best_ssim = mean_ssim
        best_psnr = mean_psnr
        model.denoising_net.save_weights(model_path + model_name +\
             '/weights/epoch{0}_ssim{1:6f}_psnr{2:6f}/'.format(epoch+1, best_ssim, best_psnr))
        print('save_model', end=' ')
        print('ssim: {0:6f} - psnr: {1:6f}'.format(best_ssim, best_psnr))
    else:
        print('ssim: {0:6f} - psnr: {1:6f}'.format(mean_ssim, mean_psnr))

Epoch:1/1000
31/31 [████████████████████] 7.24s  - train_loss:0.087435 - l_rec1:0.087086 - l_rec2:0.000174

save_model ssim: 0.336793 - psnr: 11.904207

Epoch:2/1000
31/31 [████████████████████] 4.57s  - train_loss:0.070142 - l_rec1:0.069585 - l_rec2:0.000278

save_model ssim: 0.348873 - psnr: 13.319426

Epoch:3/1000
31/31 [████████████████████] 4.60s  - train_loss:0.050046 - l_rec1:0.049453 - l_rec2:0.000297

save_model ssim: 0.357869 - psnr: 14.803411

Epoch:4/1000
31/31 [████████████████████] 4.69s  - train_loss:0.046262 - l_rec1:0.045534 - l_rec2:0.000364

save_model ssim: 0.360669 - psnr: 15.010697

Epoch:5/1000
31/31 [████████████████████] 4.63s  - train_loss:0.043397 - l_rec1:0.042546 - l_rec2:0.000425

save_model ssim: 0.361939 - psnr: 15.020649

Epoch:6/1000
31/31 [████████████████████] 4.61s  - train_loss:0.043541 - l_rec1:0.042520 - l_rec2:0.000510

save_model ssim: 0.369406 - psnr: 15.177604

Epoch:7/1000
31/31 [████████████████████] 4.61s  - train_loss:0.045057 - l_rec1:0.

In [11]:
model.save_weights(model_path + model_name + '/weights/finish/') 