#### Libraries

In [1]:
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization, Lambda, concatenate, Add
from tensorflow.keras.layers import UpSampling2D, SeparableConv2D, Input, Concatenate, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import Sequential
from tensorflow.keras.layers.experimental import preprocessing
from skimage.metrics import structural_similarity
from tensorflow import keras
#Check TensorFlow version:
print("TensorFlow Version: ", tf.__version__)
#Check if GPU is being used:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

TensorFlow Version:  2.6.0
Num GPUs Available:  1


#### Process Dataset

In [3]:
IMG_SIZE = 256
path_read_scene_and_flares =  "/PATH/Merged_images/"
path_read_scenes =  "/PATH/Merged_images_saturated/"

EPOCHS = 300
BATCH = 12
LR = 1e-3

def load_img(path_img, path_target, split):

    #Obtain all the file paths for the input images and output targets. 
    images = sorted(glob(os.path.join(path_img, "*")))
    target = sorted(glob(os.path.join(path_target, "*")))
    #Randomly select 10% of the entire Dataset as Validation data.  
    train_x, valid_x = train_test_split(images, test_size=int(split * len(images)), random_state=42)
    train_y, valid_y = train_test_split(target, test_size=int(split * len(images)), random_state=42)
    #Randomly select 10% of the entire Dataset as Testing data.
    train_x, test_x = train_test_split(train_x, test_size=int(split * len(images)), random_state=42)
    train_y, test_y = train_test_split(train_y, test_size=int(split * len(images)), random_state=42)
    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

def read_image(path):
    #Get path from file image.
    path = path.decode()
    #Read image from path using OpenCV.
    img = cv2.imread(path)
    #Resize image to 255x255.
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    #Normalize image.
    img = img/255.0
    return img

def read_target(path):
    #Get path from file target.
    path = path.decode()
    #Read image from path as greyscale using OpenCV.
    target = cv2.imread(path)
    #Resize target to 255x255.
    target = cv2.resize(target, (IMG_SIZE, IMG_SIZE))
    #Define target at floating point.
    target = target/255.0
    return target

def call_convert(img, msk):
    def _convert(img, msk):
        #Call data processing functions.
        img = read_image(img)
        msk = read_target(msk)
        return img, msk
        
    img, msk = tf.numpy_function(_convert, [img, msk], [tf.float64, tf.float64])
    img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    msk.set_shape([IMG_SIZE, IMG_SIZE, 3])

    return img, msk

def parse_dataset(img, msk, BATCH):
    
    data_set = tf.data.Dataset.from_tensor_slices((img, msk))
    data_set = data_set.map(call_convert)
    data_set = (data_set
    .shuffle(BATCH*100)
    .batch(BATCH)
    .map(lambda i, j: (i, j), num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE))
    data_set = data_set.repeat()
    return data_set
  
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_img(path_read_scene_and_flares, path_read_scenes, split=0.1)

print("Training data: ", len(train_x))
print("Validation data: ", len(valid_x))
print("Testing data: ", len(test_x))

Training data:  25426
Validation data:  3178
Testing data:  3178


#### Define FlareNet-TL and FlareNet-simple architectures.

In [7]:
def FlareNet_TL():
    #Define input layer with size and name. 
    inputs = Input(shape=(256, 256, 3), name="input_image")
    #Load MobileNetV2:
      #input layer is assigned to input of model.
      #weights are pre-training on ImageNet
      #do not include fully-connected layer at the top.
      #alpha > 1 will proportionally increases the number of filters in each layer. 
    encoder = MobileNetV2(input_tensor=inputs, weights="imagenet", include_top=False, alpha=0.35)
    #Get MobileNetV2 specific layer output.
    encoder_output = encoder.get_layer("block_13_expand_relu").output
    #Get MobileNetV2 specific layer output as skip connection.
    x_skip = encoder.get_layer("block_6_expand_relu").output
    x = UpSampling2D((2, 2))(encoder_output)
    x = Concatenate()([x, x_skip])
    x = SeparableConv2D(64, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = SeparableConv2D(64, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
     #Get MobileNetV2 specific layer output as skip connection.
    x_skip = encoder.get_layer("block_3_expand_relu").output
    x = UpSampling2D((2, 2))(x)
    x = Concatenate()([x, x_skip])
    x = SeparableConv2D(48, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    #x = Conv2D(48, (3, 3), padding="same")(x)
    x = SeparableConv2D(48, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
     #Get MobileNetV2 specific layer output as skip connection.
    x_skip = encoder.get_layer("block_1_expand_relu").output
    x = UpSampling2D((2, 2))(x)
    x = Concatenate()([x, x_skip])
    x = SeparableConv2D(32, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = SeparableConv2D(32, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    #Get MobileNetV2 specific layer output as skip connection.
    x_skip = encoder.get_layer("input_image").output
    x = UpSampling2D((2, 2))(x)
    x = Concatenate()([x, x_skip])
    x = SeparableConv2D(16, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = SeparableConv2D(16, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)   
    #Last layer.
    x = SeparableConv2D(3, (3, 3), padding="same")(x)
    x = Activation("sigmoid")(x)

    FlareNet_TL = Model(inputs, x)

    return FlareNet_TL

model = FlareNet_TL()

In [8]:
def FlareNet_simple():

    inputs = Input((256, 256, 3))
    x1_skip = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same', name="inputs")(inputs)
    x2 = MaxPooling2D((2, 2))(x1_skip)
    x3 = SeparableConv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding="same")(x2) 
    x4 = MaxPooling2D((2, 2))(x3)
    x5_skip = SeparableConv2D(48, (3, 3), activation='relu', kernel_initializer='he_normal', padding="same")(x4)
    x6 = MaxPooling2D((2, 2))(x5_skip)
    x7 = SeparableConv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding="same")(x6)
    x8 = MaxPooling2D((2, 2))(x7)
    
    x9 = Conv2DTranspose(64, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal', strides=2)(x8)
    x10 = Conv2DTranspose(48, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal', strides=2)(x9)

    x_skip1 = Add()([x5_skip, x10])
    x_skip1 = Activation("relu")(x_skip1) 
    x14 = Conv2DTranspose(32, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal', strides=2)(x_skip1)
    x15 = Conv2DTranspose(16, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal', strides=2)(x14)

    x_skip2 = Add()([x1_skip, x15])
    x_skip2 = Activation("relu")(x_skip2)   
    outputs = Conv2D(3, (1, 1), activation='sigmoid')(x_skip2)

    FlareNet_simple = Model(inputs, outputs)

    return FlareNet_simple

model = FlareNet_simple()

#### Model Training.

In [None]:
#Read training and validation dataset.
train_dataset = parse_dataset(train_x, train_y, BATCH)
valid_dataset = parse_dataset(valid_x, valid_y, BATCH)

#Define Loss function:
def SSIMLoss(y_true, y_pred):
    SSIMLoss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
    return SSIMLoss

#Define optimizer, metric and callback functions:
opt = tf.keras.optimizers.Nadam(LR)
metrics = [SSIMLoss]
model.compile(loss=SSIMLoss, optimizer = opt, metrics=metrics)
callbacks = [ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10), EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)]

train_steps = len(train_x)//BATCH
if len(train_x) % BATCH != 0: train_steps += 1
valid_steps = len(valid_x)//BATCH
if len(valid_x) % BATCH != 0: valid_steps += 1
 
history = model.fit(train_dataset, validation_data=valid_dataset, epochs=EPOCHS, steps_per_epoch=train_steps, validation_steps=valid_steps, callbacks=callbacks)

plt.plot(history.history["loss"], label="Training Loss")
plt.plot(history.history["val_loss"], label="Validation Loss")
plt.legend()
plt.show()

#### Validate Model with Testing Dataset.

In [None]:
import time
from scipy import stats
mae_predict = 0
mse_predict = 0
mae_input = 0
mse_input = 0

ssmi_input_vs_original = 0
ssmi_input_vs_predicted = 0
num_test = len(test_x)

#Process data.
def read_image(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = cv2.resize(x, (IMG_SIZE, IMG_SIZE))
    x = x/255.0
    return x

def read_target(path):
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = cv2.resize(x, (IMG_SIZE, IMG_SIZE))
    x = x/255.0
    return x

for i, (x, y) in enumerate(zip(test_x[:num_test], test_y[:num_test])):
    x = read_image(x)
    y = read_target(y)
    start = time.process_time()    
    y_pred = model.predict(np.expand_dims(x, axis=0))[0]
    print(time.process_time() - start)

    #If required to print input image vs ground-truth vs prediction.
    h, w, _ = x.shape
    white_line = np.ones((h, 20, 3))
    all_images = [x, white_line, y, white_line, y_pred]

    ####### Calculate Metrics #######
    (score, diff) = structural_similarity(y, x, full=True, multichannel=True)
    diff = (diff * 255).astype("uint8")
    ssmi_input_vs_original = ssmi_input_vs_original + score

    (score, diff) = structural_similarity(y, y_pred, full=True, multichannel=True)
    diff = (diff * 255).astype("uint8")
    ssmi_input_vs_predicted = ssmi_input_vs_predicted + score

    mae_predict = mae_predict + np.mean(np.abs(y_pred - y))
    mse_predict = mse_predict + np.mean((y_pred - y) ** 2)
    
    mae_input = mae_input + np.mean(np.abs(x - y))
    mse_input = mse_input + np.mean((x - y) ** 2)

    #image = np.concatenate(all_images, axis=1)
    #fig = plt.figure(figsize=(12, 12))
    #a = fig.add_subplot(1, 1, 1)
    #imgplot = plt.imshow(image)

print("Num Test:", num_test)

print("Average SSIM Input vs Original:", ssmi_input_vs_original/num_test)
print("Average SSIM Input vs Predicted:", ssmi_input_vs_predicted/num_test)

print("Average MAE Input vs Original:", mae_input/num_test)
print("Average MAE Input vs Predicted:", mae_predict/num_test)

print("Average MSE Input vs Original", mse_input/num_test)
print("Average MSE Input vs Predicted", mse_predict/num_test)

In [18]:
model.save("/PATH/FlareNet_XXX.h5")