In [1]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras 

import os
import time
from datetime import datetime

from matplotlib import pyplot as plt
from IPython import display
import numpy as np
from tqdm import tqdm
import cv2
import albumentations as A

In [2]:
tf.__version__

'2.2.0'

## Create training dataset

In [69]:
USC_FOLDERS = [r"/shared/data1/Video_data_small/20201111out/aligned_B/B",
               r"/shared/data1/Video_data_small/20201111nig/aligned_B/B",
               r"/shared/data1/Video_data_small/20201110nig/aligned_B/B",
               r"/shared/data1/Video_data_small/20201111day/aligned_B/B",
               r"/shared/data1/Video_data_small/20201110day/aligned_B/B",
               r"/shared/data1/Video_data_small/20200824_indoor_video_flare/aligned/B",
               r"/shared/data1/Video_data_small/20200728_video/aligned/B",
               r"/shared/data1/Video_data_small/20200810_video_flare/aligned/B",
               r"/shared/data1/Video_data_small/20200824_outdoor_video_flare/aligned/B"]
GT_FOLDERS =  [r"/shared/data1/Video_data_small/20201111out/aligned_B/gt",
               r"/shared/data1/Video_data_small/20201111nig/aligned_B/gt",
               r"/shared/data1/Video_data_small/20201110nig/aligned_B/gt",
               r"/shared/data1/Video_data_small/20201111day/aligned_B/gt",
               r"/shared/data1/Video_data_small/20201110day/aligned_B/gt",
               r"/shared/data1/Video_data_small/20200824_indoor_video_flare/aligned/gt",
               r"/shared/data1/Video_data_small/20200728_video/aligned/gt",
               r"/shared/data1/Video_data_small/20200810_video_flare/aligned/gt",
               r"/shared/data1/Video_data_small/20200824_outdoor_video_flare/aligned/gt"]

In [3]:
from train_utils.utils  import dataset, image_io, image_transform
from train_utils.train  import losses, train_step

In [71]:
png_image_reader = image_io.ImageReaderPNG(transform_list=None)

In [72]:
usc_arrays = dataset.get_images_array(folders_list=USC_FOLDERS, image_reader=png_image_reader, images_extension='.png')

100%|██████████| 9/9 [00:00<00:00, 11.71it/s]
100%|██████████| 15/15 [00:00<00:00, 17.39it/s]
100%|██████████| 21/21 [00:01<00:00, 16.25it/s]
100%|██████████| 38/38 [00:02<00:00, 16.38it/s]
100%|██████████| 70/70 [00:04<00:00, 15.38it/s]
100%|██████████| 38/38 [00:02<00:00, 15.58it/s]
100%|██████████| 46/46 [00:02<00:00, 15.75it/s]
100%|██████████| 46/46 [00:02<00:00, 15.87it/s]
100%|██████████| 20/20 [00:01<00:00, 15.72it/s]


In [73]:
len(usc_arrays)

303

In [74]:
gt_arrays = dataset.get_images_array(folders_list=GT_FOLDERS, image_reader=png_image_reader, images_extension='.png')

100%|██████████| 9/9 [00:00<00:00, 14.67it/s]
100%|██████████| 15/15 [00:00<00:00, 15.26it/s]
100%|██████████| 21/21 [00:01<00:00, 14.91it/s]
100%|██████████| 38/38 [00:02<00:00, 15.48it/s]
100%|██████████| 70/70 [00:04<00:00, 14.91it/s]
100%|██████████| 38/38 [00:02<00:00, 15.33it/s]
100%|██████████| 46/46 [00:03<00:00, 15.29it/s]
100%|██████████| 46/46 [00:03<00:00, 14.81it/s]
100%|██████████| 20/20 [00:01<00:00, 14.85it/s]


In [75]:
len(gt_arrays)

303

In [11]:
from skimage.transform import rescale, resize

In [76]:
def img_resize(image):
    new_img = cv2.resize(image, (1920//4, 1080//4))
    #print('resized')
    return new_img

In [77]:
usc_arrays_downscaled = [img_resize(image) for image in tqdm(usc_arrays)]

100%|██████████| 303/303 [00:00<00:00, 1320.17it/s]


In [78]:
gt_arrays_downscaled = [img_resize(image) for image in tqdm(gt_arrays)]

100%|██████████| 303/303 [00:00<00:00, 1974.31it/s]


In [79]:
np.savez("usc_arrays_downscaled_small", usc_arrays_downscaled)
np.savez("gt_arrays_downscaled_small", gt_arrays_downscaled)

## DATASET LOADING

In [4]:
usc_arrays_downscaled = np.load("/home/p00536919/Flare_removal/Downscale_ref/usc_arrays_flare_only.npz")['arr_0']
gt_arrays_downscaled = np.load("/home/p00536919/Flare_removal/Downscale_ref/gt_arrays_flare_only.npz")['arr_0']

## GPU selection

In [5]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[1], True) 
tf.config.experimental.set_visible_devices(physical_devices[1], 'GPU')

In [6]:
tf.keras.backend.clear_session()

In [7]:
from guided_filter_tf.guided_filter import guided_filter

for num, (usc_image, gt_image) in tqdm(list(enumerate(zip(usc_arrays_downscaled, gt_arrays_downscaled)))):
    
    gt_image = guided_filter(x=tf.cast(gt_image[np.newaxis, :, :, :], tf.float32),
                             y=tf.cast(usc_image[np.newaxis, :, :, :], tf.float32), r=16, nhwc=True)
    gt_image = gt_image.numpy()[0]
    gt_arrays_downscaled[num] = gt_image

100%|██████████| 303/303 [00:05<00:00, 60.21it/s]


In [8]:
training_dataset = (usc_arrays_downscaled[:300], gt_arrays_downscaled[:300])

In [9]:
train_data_generator = dataset.DataGenerator(training_dataset, patch_size=224, batch_size=1, image_size=(1080//4, 1920//4), shuffle=True, pad=2)

### Clear logs

In [20]:
#!RMDIR /Q /S G:\logs\flare_removal
#!rm -rf /shared/p00536919/logs/flare_removal

## Build models

### Generator

In [10]:
def conv_block(growth_rate, filters, kernel_size, strides, x):
    x = tf.keras.layers.Conv2D(growth_rate * filters, kernel_size, padding='same', strides=strides, data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def dilated_conv_block(growth_rate, filters, kernel_size, dilation_rate, x):
    x = tf.keras.layers.Conv2D(growth_rate * filters, kernel_size, padding='same', dilation_rate=dilation_rate, data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def conv_skip_block(growth_rate, filters, kernel_size, x):
    x = tf.keras.layers.Conv2DTranspose(growth_rate * filters, kernel_size, padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def deconv_block(growth_rate, filters, kernel_size, strides, x):
    x = tf.keras.layers.Conv2DTranspose(growth_rate * filters, kernel_size, padding='same', strides=strides, data_format='channels_first')(x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 1, padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x

In [11]:
def Generator(height=None, width=None, input_channels=3, filters=32):

    inputs = tf.keras.Input(shape=[input_channels, height, width])

    x = dilated_conv_block(growth_rate=2, filters=filters, kernel_size=(9, 9), dilation_rate=16, x=inputs)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(5, 5), dilation_rate=8, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res1 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res2 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)

    x = deconv_block(growth_rate=4, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res2])
    x = conv_skip_block(growth_rate=4, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=2, x=x)

    x = deconv_block(growth_rate=2, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res1])
    x = conv_skip_block(growth_rate=2, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=1, filters=filters, kernel_size=(3, 3), dilation_rate=1, x=x)
    x = tf.keras.layers.Conv2D(3, (3, 3), padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.Activation('tanh')(x)
    x = tf.keras.layers.Subtract()([inputs, x])
    x = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same', data_format='channels_first')(x)
    
    _model = tf.keras.Model(inputs=inputs, outputs=x, name='derain_net')
    return _model  


In [12]:
generator = Generator()

In [13]:
generator.summary()

Model: "derain_net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 3, None, Non 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 64, None, Non 15616       input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 64, None, Non 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, None, No 204928      leaky_re_lu[0][0]                
_________________________________________________________________________________________

In [33]:
generator.save('deflare.h5')

In [34]:
import netron

In [35]:
netron.start('deflare.h5', port=8989)

Serving 'deflare.h5' at http://localhost:8989


In [28]:
netron.stop(port=8989)


Stopping http://localhost:8989


## Define losses

### Generator loss

In [14]:
from train_utils.train  import losses

In [15]:
vgg16_loss_1_3 = losses.make_VGG16_loss(blocks_dict={1:1, 3:1}, 
                                          weights_path=r"/home/p00536919/usc-image-enhancement/srs_refactor/refactored_pipeline/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5",
                                          loss_type='MAE')

In [16]:
def generator_loss(gen_output, target):
    
    gen_output = tf.transpose(gen_output, [0,2,3,1])
    target = tf.transpose(target, [0,2,3,1])
    
    #img_shape = tf.shape(gen_output)

    #upscale_factor = tf.random.uniform([1], minval=1, maxval=2.5)[0]
    
    #gen_output_upscaled = tf.image.resize(gen_output, [tf.cast(tf.cast(img_shape[1], tf.float32)*upscale_factor, tf.int32), tf.cast(tf.cast(img_shape[2], tf.float32)*upscale_factor, tf.int32)])
    #target_upscaled = tf.image.resize(target, [tf.cast(tf.cast(img_shape[1], tf.float32)*upscale_factor, tf.int32), tf.cast(tf.cast(img_shape[2], tf.float32)*upscale_factor, tf.int32)])
    
    l1_loss_val = losses.L1_loss(target, gen_output)# + losses.L1_loss(target_upscaled, gen_output_upscaled)
    
    #ms_ssim_loss_val = losses.MS_SSIM_loss(target, gen_output)
    ms_ssim_loss_val = (1 - tf.math.reduce_mean(tf.image.ssim_multiscale(target, gen_output, 1, power_factors=(0.4, 0.25, 0.25, 0.2363, 0.1333))))# + (1 - tf.math.reduce_mean(tf.image.ssim_multiscale(target_upscaled, gen_output_upscaled, 1, power_factors=(0.4, 0.25, 0.25, 0.2363, 0.1333))))

    
    vgg_loss_val = vgg16_loss_1_3(target, gen_output)# + vgg16_loss_1_3(target_upscaled, gen_output_upscaled)

    total_gen_loss_val = 0.5 * l1_loss_val + 1 * ms_ssim_loss_val  + 0.08 * vgg_loss_val

    return total_gen_loss_val, [l1_loss_val, ms_ssim_loss_val, vgg_loss_val]

## Define the Optimizers and Checkpoint-saver

In [17]:
"""decay_steps = 1000
initial_learning_rate_generator = 2e-4
initial_learning_rate_discriminator = 1e-4

step_decay = 1e-2

boundaries = [1000]

step_decay_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, [initial_learning_rate_generator, initial_learning_rate_generator*step_decay])"""

'decay_steps = 1000\ninitial_learning_rate_generator = 2e-4\ninitial_learning_rate_discriminator = 1e-4\n\nstep_decay = 1e-2\n\nboundaries = [1000]\n\nstep_decay_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, [initial_learning_rate_generator, initial_learning_rate_generator*step_decay])'

In [20]:
#generator_optimizer = tf.keras.optimizers.Adam(lr=2e-6, beta_1=0.25, beta_2=0.8)
#generator_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)
import tensorflow_addons as tfa

radam = tfa.optimizers.RectifiedAdam(lr=2e-4)
generator_optimizer = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)

In [21]:
checkpoint_dir = '/shared/p00536919/training_checkpoints/flare_removal/' + datetime.now().strftime("%Y%m%d-%H%M%S") + r"_flare_only_32f_dil_stack_activations_BS1_shift"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, generator=generator)

## Training

In [22]:
log_dir=r"/shared/p00536919/logs/flare_removal/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")+"_flare_only_32f_dil_stack_activations_BS1_shift")

file_writer_img = tf.summary.create_file_writer(log_dir + "fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")+"_flare_only_32f_dil_stack_activations_BS1_shift" + '/img')

In [23]:
from train_utils.train  import train_step 

In [24]:
train_step = train_step.make_train_step()

In [25]:
transform = A.Compose(
    [A.ShiftScaleRotate(scale_limit=[-0.01, 5], rotate_limit=180, interpolation=cv2.INTER_CUBIC, border_mode=cv2.BORDER_REFLECT_101, p=0.8),
     A.HueSaturationValue(hue_shift_limit=20/255, sat_shift_limit=20/255, val_shift_limit=30/255, always_apply=False, p=0.25),
     A.ChannelShuffle(p=0.25)],
    additional_targets={'image0': 'image'}
)

In [26]:
def random_alpha_blend(image, size=224, max_shift=100, seed=244, proba=0.5): 
    np.random.seed(seed)
    proba_gen = np.random.random()
    if proba_gen<proba:
        np.random.seed(seed)
        width_shift = np.random.randint(0, high=max_shift)
        np.random.seed(seed)
        height_shift = np.random.randint(0, high=max_shift)
        image_shifted = (cv2.copyMakeBorder((image*255).astype(np.uint8),0,height_shift,0,width_shift,cv2.BORDER_WRAP)/255.0)[height_shift:, width_shift:, :]
        image = np.where(image<image_shifted, image_shifted, image)
    else:
        pass
    return image

In [27]:
proba = 0.25

In [28]:
new_test_image = plt.imread(r"/shared/p00536919/000210_inp.png")
new_test_image = cv2.resize(new_test_image, (1920//4, 1080//4))[np.newaxis, :264, :480, :]
new_test_image = tf.cast(new_test_image, tf.float32)

In [29]:
def fit(epochs, run):
           
    for epoch in range(epochs):  
        epoch += int(epochs * run)
        start = time.time()
        display.clear_output(wait=True)
        
        if (epoch+1) % 2 == 0:
            generator.save_weights("./weights_train/log_weights0_32_epoch_{}.hdf5".format(epoch))
            with tf.device('/GPU:0'):
                tf.keras.backend.clear_session()
                test_generator = Generator()
                test_generator.load_weights("./weights_train/log_weights0_32_epoch_{}.hdf5".format(epoch))                
                #test_image = np.zeros((2448, 3264, 6))
                ing_index = np.random.randint(len(usc_arrays_downscaled))
                test_image = usc_arrays_downscaled[ing_index][np.newaxis, :264, :480, :]
                test_image = tf.cast(test_image, tf.float32)
                #test_image = tf.image.resize(test_image, [256,256])
                preds = test_generator(tf.transpose(test_image,[0,3,1,2]), training=True)  
                preds = tf.transpose(preds, [0,2,3,1])
                
                preds1 = test_generator(tf.transpose(new_test_image,[0,3,1,2]), training=True)  
                preds1 = tf.transpose(preds1, [0,2,3,1])
            
            with file_writer_img.as_default():                
                tf.summary.image("Input_image", usc_arrays_downscaled[ing_index][np.newaxis, :264, :480, :3], step=epoch)
                tf.summary.image("Ground truth", gt_arrays_downscaled[ing_index][np.newaxis, :264, :480, :3], step=epoch)
                tf.summary.image("Model results", preds.numpy()[:, :,:,:3], step=epoch)
                tf.summary.image("Input1", new_test_image.numpy()[:, :,:,:3], step=epoch)
                tf.summary.image("Model results 1", preds1.numpy()[:, :,:,:3], step=epoch)
                
                
        print("Epoch: ", epoch)
        train_data_generator.on_epoch_end()
        # Train
        gen_total_loss, l1_loss_val, ms_ssim_loss_val, vgg_loss_val  = 0, 0, 0, 0
        
        for n, (input_image, target) in enumerate(train_data_generator):
            print('.', end='')
            if (n+1) % 100 == 0:
                print('\n')
            input_image = np.array(input_image)            
            target = np.array(target)
            
            for pair_id in range(len(input_image)):
                seed = np.random.randint(0,1000)
                input_image[pair_id] = random_alpha_blend(input_image[pair_id], size=224, max_shift=100, seed=seed, proba=proba)
                target[pair_id] = random_alpha_blend(target[pair_id], size=224, max_shift=100, seed=seed, proba=proba)
                transformed = transform(image=input_image[pair_id], image0=target[pair_id])
            
                input_image[pair_id] = transformed['image']
                target[pair_id] = transformed['image0']

            input_image = tf.cast(input_image, tf.float32)
            target = tf.cast(target, tf.float32)

            
            losses = train_step(input_image=tf.transpose(tf.cast(input_image, tf.float32), [0,3,1,2]),
                                    target=tf.transpose(target,[0,3,1,2]),
                                    generator=generator,                                    
                                    generator_loss=generator_loss,
                                    generator_optimizer=generator_optimizer,
                                    training=True)
            
            gen_total_loss += losses[0]
            
            l1_loss_val += losses[1][0]
            ms_ssim_loss_val += losses[1][1]
            vgg_loss_val += losses[1][2]
            
            
            
        print()
        with summary_writer.as_default():
            tf.summary.scalar('gen_total_loss', gen_total_loss/len(train_data_generator), step=epoch)
            #tf.summary.scalar('gan_loss_val', gan_loss_val/len(train_data_generator), step=epoch)
            tf.summary.scalar('l1_loss_val', l1_loss_val/len(train_data_generator), step=epoch)
            tf.summary.scalar('ms_ssim_loss_val', ms_ssim_loss_val/len(train_data_generator), step=epoch)
            tf.summary.scalar('vgg_loss_val', vgg_loss_val/len(train_data_generator), step=epoch)
            #tf.summary.scalar('vgg_loss_val', vgg_loss_val/len(train_data_generator), step=epoch)
            #tf.summary.scalar('disc_loss ', disc_loss /len(train_data_generator), step=epoch)
            
        # saving (checkpoint) the model every 250 epochs
        if (epoch + 1) % 100 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            
        """if (epoch + 1) % 1000 == 0:
            generator.save_weights("avg_weights_gen_{}-1.hdf5".format(epoch))
            #discriminator.save_weights("avg_weights_disc_{}-1.hdf5".format(epoch))"""
            
        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                                time.time()-start))


In [30]:
EPOCHS = 1000

In [None]:
for i in range(100):
    fit(EPOCHS, i)

### Restore from checkpoint

In [66]:
checkpoint.restore("/shared/p00536919/training_checkpoints/flare_removal/20210301-164316_flare_only_32f_dil_stack_activations_BS4/ckpt-6")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd2a813e550>

### Save Models

In [51]:
generator.save_weights(r'./weights/trained_vanilla_pix2pix_hist_superlight_no_rescale_FINAL_gen_no_gamma_11_5_2020.hdf5')
discriminator.save_weights(r'./weights/trained_vanilla_pix2pix_hist_superlight_no_rescale_FINAL_disc_no_gamma_18_5_2020.hdf5')

In [39]:
generator.load_weights(r'F:\models\new_approach_gen_27_2_2020.hdf5')
discriminator.load_weights(r'F:\models\new_approach_disc_27_2_2020.hdf5')