#### Libraries:

In [4]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from glob import glob
import math
import tensorflow as tf
#!pip install tensorflow-addons
from tensorflow_addons import image as tfa_image
import tensorflow as tf

#### Data Merging and Augmentation Functions:
##### Functions in this cell section are based on research done in: "How to Train Neural Networks for Flare Removal"
##### LINK: https://github.com/google-research/google-research/tree/master/flare_removal

In [5]:
def remove_background(im
    EPS = 1e-7
    im_min = tf.reduce_min(im, axis=(-3, -2), keepdims=True)
    im_max = tf.reduce_max(im, axis=(-3, -2), keepdims=True)
    return (im - im_min) * im_max / (im_max - im_min + EPS)

def _center_transform(t, height, width):
    center_to_origin = tfa_image.translations_to_projective_transforms([-width / 2, -height / 2])
    origin_to_center = tfa_image.translations_to_projective_transforms([width / 2, height / 2])
    t = tfa_image.compose_transforms([center_to_origin, t, origin_to_center])
    return t

def shears_to_projective_transforms(shears, height, width):
    shears = tf.convert_to_tensor(shears)
    if tf.rank(shears) == 1:
        shears = shears[None, :]
    shears_x = tf.reshape(tf.tan(shears[:, 0]), (-1, 1))
    shears_y = tf.reshape(tf.tan(shears[:, 1]), (-1, 1))
    ones = tf.ones_like(shears_x)
    zeros = tf.zeros_like(shears_x)
    transform = tf.concat([ones, shears_x, zeros, shears_y, ones, zeros, zeros, zeros], axis=-1)
    return _center_transform(transform, height, width)

def scales_to_projective_transforms(scales, height, width):
    scales = tf.convert_to_tensor(scales)
    if tf.rank(scales) == 1:
        scales = scales[None, :]
    scales_x = tf.reshape(scales[:, 0], (-1, 1))
    scales_y = tf.reshape(scales[:, 1], (-1, 1))
    zeros = tf.zeros_like(scales_x)
    transform = tf.concat([scales_x, zeros, zeros, zeros, scales_y, zeros, zeros, zeros], axis=-1)
    return _center_transform(transform, height, width)

def apply_affine_transform(image, rotation = 0., shift_x = 0., shift_y = 0., shear_x = 0., shear_y = 0., scale_x = 1., scale_y = 1., interpolation = 'bilinear'):
    height, width = image.shape[0:2]
    rotation = tfa_image.angles_to_projective_transforms(rotation, height, width)
    shear = shears_to_projective_transforms([shear_x, shear_y], height, width)
    scaling = scales_to_projective_transforms([scale_x, scale_y], height, width)
    translation = tfa_image.translations_to_projective_transforms([shift_x, shift_y])
    t = tfa_image.compose_transforms([rotation, shear, scaling, translation])
    image = tf.expand_dims(image, axis=0)
    transformed = tfa_image.transform(image, t, interpolation=interpolation)
    return transformed

def normalize_white_balance(im):
    EPS = 1e-7
    channel_mean = tf.reduce_mean(im, axis=(-3, -2), keepdims=True)
    max_of_mean = tf.reduce_max(channel_mean, axis=(-3, -2, -1), keepdims=True)
    normalized = max_of_mean * im / (channel_mean + EPS)
    return normalized

def _gaussian_kernel(kernel_size, sigma, n_channels, dtype):
    x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=dtype)
    g = tf.math.exp(-(tf.pow(x, 2) / (2 * tf.pow(tf.cast(sigma, dtype), 2))))
    g_norm2d = tf.pow(tf.reduce_sum(g), 2)
    g_kernel = tf.tensordot(g, g, axes=0) / g_norm2d
    g_kernel = tf.expand_dims(g_kernel, axis=-1)
    return tf.expand_dims(tf.tile(g_kernel, (1, 1, n_channels)), axis=-1)

def apply_blur(im, sigma):
    blur = _gaussian_kernel(21, sigma, im.shape[-1], im.dtype)
    im = tf.expand_dims(im, axis=0)
    im = tf.nn.depthwise_conv2d(im, blur, [1, 1, 1, 1], 'SAME')
    im = tf.squeeze(im)
    return im

def quantize_8(image):
    q8 = tf.image.convert_image_dtype(image, tf.uint8, saturate=True)
    return tf.cast(q8, tf.float32) * (1.0 / 255.0)

def add_flare(scene, flare, noise, count, flare_max_gain = 5.0, apply_affine = True, training_res = 512):
    
    batch_size = 1
    flare_input_height, flare_input_width, _ = flare.shape
    gamma = tf.random.uniform([], 1.8, 2.2)
    flare_linear = tf.image.adjust_gamma(flare, gamma)
    # Remove DC background in flare.
    flare_linear = remove_background(flare_linear)
                      
    rotation = tf.random.uniform([1], minval=-math.pi, maxval=math.pi)
    shift = tf.random.normal([1], mean=0.0, stddev=10.0)
    shear = tf.random.uniform([1], minval=-math.pi / 9, maxval=math.pi / 9)
    scale = tf.random.uniform([1], minval=0.8, maxval=1.2)
    flare_linear = apply_affine_transform(flare_linear, rotation=rotation, shift_x=shift[0], shift_y=shift[0], shear_x=shear[0], shear_y=shear[0], scale_x=scale[0], scale_y=scale[0])
    
    flare_linear = tf.clip_by_value(flare_linear, 0.0, 1.0)
    flare_linear = tf.image.crop_to_bounding_box(flare_linear, offset_height=(flare_input_height - training_res) // 2,
                                                  offset_width=(flare_input_width - training_res) // 2,
                                                  target_height=training_res, target_width=training_res)
    flare_linear = tf.image.random_flip_left_right(tf.image.random_flip_up_down(flare_linear))
    
    #Normalize the white balance and then apply random white balance.
    flare_linear = normalize_white_balance(flare_linear)
    rgb_gains = tf.random.uniform([1], 0, flare_max_gain, dtype=tf.float32) # If random is set to 1, then all flare are white tones.
    flare_linear *= rgb_gains

    #Augmentation on flare patterns: random blur and DC offset.
    blur_size = tf.random.uniform([], 0.1, 3)
    flare_linear = apply_blur(flare_linear, blur_size)
    offset = tf.random.uniform([], -0.02, 0.02)
    flare_linear = tf.clip_by_value(flare_linear + offset, 0.0, 1.0)
    flare_srgb = tf.image.adjust_gamma(flare_linear, 1.0 / gamma)

    #Random crop and flips.
    scene_linear = tf.image.adjust_gamma(scene, gamma)
    scene_linear = tf.image.random_crop(scene_linear, flare_linear.shape)
    scene_linear = tf.image.random_flip_left_right(tf.image.random_flip_up_down(scene_linear))

    #Add additive Gaussian noise.
    sigma = tf.abs(tf.random.normal([], 0, noise))
    noise = tf.random.normal(scene_linear.shape, 0, sigma)
    scene_linear += noise
    
    #Add random digital gain.
    gain = tf.random.uniform([], 0, 1.2)  # varying the intensity scale
    scene_linear = tf.clip_by_value(gain * scene_linear, 0.0, 1.0)
    scene_srgb = tf.image.adjust_gamma(scene_linear, 1.0 / gamma)

    #Combine the flare-free scene with a flare pattern to produce a synthetic training example.
    combined_linear = scene_linear + flare_linear
    combined_srgb = tf.image.adjust_gamma(combined_linear, 1.0 / gamma)
    combined_srgb = tf.clip_by_value(combined_srgb, 0.0, 1.0)

    #Define paths from where the modified scene, modified flare, and merge image are written into.
    #It is still important to save the modified scene and modified flare so that in a later step they can be used to 
    #generate the saturated pixel ground-truth.
    path_combined = "/PATH/Merged_images/img_combined_" + str(count) + ".jpg"
    path_scene = "/PATH/Scene_images/img_scene_" + str(count) + ".jpg"
    path_flare = "/PATH/Flare_images/img_flare_" + str(count) + ".jpg"
    tf.keras.utils.save_img(path_combined, (combined_srgb), data_format=None, scale=True)
    tf.keras.utils.save_img(path_scene, (scene_srgb), data_format=None, scale=True)
    tf.keras.utils.save_img(path_flare, (flare_srgb), data_format=None, scale=True)

    return (quantize_8(scene_srgb), quantize_8(flare_srgb), quantize_8(combined_srgb), gamma)

#### Read Scenes and Flare Dataset and start processing images.

In [6]:
def load_img(path_read_img, path_read_flares):
    #Obtain all the file paths for the input images and output targets. 
    images = sorted(glob(os.path.join(path_read_img, "*")))
    flares = sorted(glob(os.path.join(path_read_flares, "*")))

    images_files = tf.data.Dataset.list_files(images, shuffle = True, seed=0)
    flares_files = tf.data.Dataset.list_files(flares, shuffle = True, seed=0)

    return images_files, flares_files

def _parser_img(file_name):
    blob = tf.io.read_file(file_name)
    image = tf.io.decode_image(blob, dtype=tf.float32)
    image.set_shape((640,640,3))
    return image

def _parser_flare(file_name):
    blob = tf.io.read_file(file_name)
    image = tf.io.decode_image(blob, dtype=tf.float32)
    image.set_shape((752,1008,3))
    return image

#Define paths from where the flare and scene images are going to be read from.
path_read_flares = "/PATH/Flares_dataset/"
path_read_img = "/PATH/Scenes_dataset/"
images_files, flares_files = load_img(path_read_img, path_read_flares)

#Read location of each image in dataset.
scenes = images_files.map(_parser_img, num_parallel_calls=tf.data.AUTOTUNE, deterministic= not True)
flares = flares_files.map(_parser_flare, num_parallel_calls=tf.data.AUTOTUNE, deterministic= not True)

#Process dataset.
count=0
for scene, flare in tf.data.Dataset.zip((scenes, flares)):
    print(count)
    scene_srgb, flare_srgb, combined_srgb, gamma = add_flare(scene, flare, 0.01, count, flare_max_gain = 10.0, apply_affine = True, training_res = 512)
    count = count + 1

#### Generate Ground-truth Images with Saturated pixels only

In [8]:
IMG_SIZE = 256
#Dataset size.
num_flares = 31783
#This process uses the modified flare, scene and merged flared image to generate the saturated pixel one.
path_scene =  "/PATH/Scene_images/img_scene_"
path_flare =  "/PATH/Flare_images/img_flare_"
path_combined =  "/PATH/Merged_images/img_combined_"

def read_image_combined(path):
    #Read image from path using OpenCV.
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #Resize image to 255x255.
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    #Normalize image.
    img = img/255.0
    return img

def read_image(path):
    #Read image from path using OpenCV.
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #Resize image to 255x255.
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    #Normalize image.
    img = img/255.0
    return img

def read_flare(path):
    #Read image in Grayscale from path using OpenCV.
    img = cv2.imread(path,0)
    #Resize image to 255x255.
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    #Normalize image.
    img = img/255.0
    #Triplicate the last dimension of grayscale values to 3 channels to evaluate saturation in all of them the same way.
    a = np.ones((256, 256, 3))
    img_flare = a * img[..., np.newaxis]
    return img_flare

for i in range (num_flares):
    print("Image #", i)
    #Read Images.
    path_scene_element = path_scene + str(i) + ".jpg"
    path_flare_element = path_flare + str(i) + ".jpg"
    path_combined_element = path_combined + str(i) + ".jpg"
    img_scene = read_image(path_scene_element)
    img_flare = read_flare(path_flare_element)
    img_combined = read_flare(path_combined_element)
    #Convert to img to tensor.
    tensor_img_scene = tf.convert_to_tensor(img_scene)
    tensor_img_flare = tf.convert_to_tensor(img_flare)
    tensor_img_combined = tf.convert_to_tensor(img_combined)
    #To have 4 dimensions for Flatten to work fine.
    tensor_img_scene = tf.expand_dims(tensor_img_scene, axis=0)
    tensor_img_flare = tf.expand_dims(tensor_img_flare, axis=0)
    tensor_img_combined = tf.expand_dims(tensor_img_combined, axis=0)
    #To cast the tensorst to float32.
    tensor_img_scene = tf.cast(tensor_img_scene, tf.float32)
    tensor_img_flare = tf.cast(tensor_img_flare,tf.float32)
    tensor_img_combined = tf.cast(tensor_img_flare,tf.float32)

    #Evaluate if pixels in Flare position not oversaturated (<0.97), then give back pixels from ideal scene, else give back saturated pixels from Flare.
    threshold = tf.constant(.97)
    masked_scene = tf.where(tf.less(tensor_img_flare, threshold), tensor_img_scene, tensor_img_combined)
    #Get rid of batch dimension
    scene_only_oversat = tf.squeeze(masked_scene)
    #Write scene_only_oversat image to drive file.
    path_masked_scene = "/PATH/Merged_images_saturated/scene_oversat_" + str(i) + ".jpg"
    tf.keras.utils.save_img(path_masked_scene, (scene_only_oversat), data_format=None, scale=True)
    

Image # 0
Image # 1
Image # 2
Image # 3
Image # 4
Image # 5
Image # 6
Image # 7
Image # 8
Image # 9
Image # 10
Image # 11
Image # 12
Image # 13
Image # 14
Image # 15
Image # 16
Image # 17
Image # 18
Image # 19
Image # 20
Image # 21
Image # 22
Image # 23
Image # 24
Image # 25
Image # 26
Image # 27
Image # 28
Image # 29
Image # 30
Image # 31
Image # 32
Image # 33
Image # 34
Image # 35
Image # 36
Image # 37
Image # 38
Image # 39
Image # 40
Image # 41
Image # 42
Image # 43
Image # 44
Image # 45
Image # 46
Image # 47
Image # 48
Image # 49
Image # 50
Image # 51
Image # 52
Image # 53
Image # 54
Image # 55
Image # 56
Image # 57
Image # 58
Image # 59
Image # 60
Image # 61
Image # 62
Image # 63
Image # 64
Image # 65
Image # 66
Image # 67
Image # 68
Image # 69
Image # 70
Image # 71
Image # 72
Image # 73
Image # 74
Image # 75
Image # 76
Image # 77
Image # 78
Image # 79
Image # 80
Image # 81
Image # 82
Image # 83
Image # 84
Image # 85
Image # 86
Image # 87
Image # 88
Image # 89
Image # 90
Image # 9

error: OpenCV(4.6.0) C:\b\abs_74oeeuevib\croots\recipe\opencv-suite_1664548340488\work\modules\imgproc\src\color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'
