In [1]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras 

import os
import time
from datetime import datetime

from matplotlib import pyplot as plt
from IPython import display
import numpy as np
from tqdm import tqdm
import cv2

## Create training dataset

In [2]:
USC_FOLDERS = [r"/shared/data1/Video_data_small/usc"]
GT_FOLDERS = [r"/shared/data1/Video_data_small/gt"]

In [3]:
from train_utils.utils  import dataset, image_io, image_transform
from train_utils.train  import losses, train_step

In [5]:
png_image_reader = image_io.ImageReaderPNG(transform_list=None)

In [53]:
usc_arrays = dataset.get_images_array(folders_list=USC_FOLDERS, image_reader=png_image_reader, images_extension='.png')

100%|██████████| 303/303 [00:18<00:00, 16.36it/s]


In [54]:
gt_arrays = dataset.get_images_array(folders_list=GT_FOLDERS, image_reader=png_image_reader, images_extension='.png')

100%|██████████| 303/303 [00:18<00:00, 16.08it/s]


In [55]:
def img_resize(image):
    new_img = cv2.resize(image, (1920//4, 1080//4))
    #print('resized')
    return new_img

In [56]:
usc_arrays_downscaled = [img_resize(image) for image in tqdm(usc_arrays)]

100%|██████████| 303/303 [00:00<00:00, 687.24it/s]


In [57]:
gt_arrays_downscaled = [img_resize(image) for image in tqdm(gt_arrays)]

100%|██████████| 303/303 [00:00<00:00, 2735.47it/s]


In [4]:
usc_arrays_downscaled = np.load("/home/p00536919/Flare_removal/Downscale_ref/usc_arrays_flare_only.npz")['arr_0']
gt_arrays_downscaled = np.load("/home/p00536919/Flare_removal/Downscale_ref/gt_arrays_flare_only.npz")['arr_0']

usc_arrays = np.load("/home/p00536919/Flare_removal/Downscale_ref/usc_arrays_flare_only_full.npz")['arr_0']
gt_arrays = np.load("/home/p00536919/Flare_removal/Downscale_ref/gt_arrays_flare_only_full.npz")['arr_0']

In [12]:
#training_dataset = (usc_arrays_downscaled, gt_arrays_downscaled)

In [13]:
#train_data_generator = dataset.DataGenerator(training_dataset, patch_size=256, batch_size=1, image_size=(2448//8, 3264//8), shuffle=True, pad=2)

## GPU selection

In [8]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True) 
tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU')

In [9]:
tf.keras.backend.clear_session()

In [7]:
from guided_filter_tf.guided_filter import guided_filter

for num, (usc_image, gt_image) in tqdm(list(enumerate(zip(usc_arrays, gt_arrays)))):
    
    gt_image = guided_filter(x=tf.cast(gt_image[np.newaxis, :, :, :], tf.float32),
                             y=tf.cast(usc_image[np.newaxis, :, :, :], tf.float32), r=100, nhwc=True)
    gt_image = gt_image.numpy()[0]
    gt_arrays[num] = gt_image

100%|██████████| 303/303 [00:19<00:00, 15.77it/s]


In [9]:
#training_dataset = (usc_arrays_downscaled, gt_arrays_downscaled)

In [10]:
#train_data_generator = dataset.DataGenerator(training_dataset, patch_size=224, batch_size=1, image_size=(1080//4, 1920//4), shuffle=True, pad=2)

### Clear logs

In [40]:
#!RMDIR /Q /S G:\logs\flare_removal
#!rm -rf /shared/p00536919/logs/flare_removal

## Build models

### Generator

In [10]:
def conv_block(growth_rate, filters, kernel_size, strides, x):
    x = tf.keras.layers.Conv2D(growth_rate * filters, kernel_size, padding='same', strides=strides, data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def dilated_conv_block(growth_rate, filters, kernel_size, dilation_rate, x):
    x = tf.keras.layers.Conv2D(growth_rate * filters, kernel_size, padding='same', dilation_rate=dilation_rate, data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def conv_skip_block(growth_rate, filters, kernel_size, x):
    x = tf.keras.layers.Conv2DTranspose(growth_rate * filters, kernel_size, padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def deconv_block(growth_rate, filters, kernel_size, strides, x):
    x = tf.keras.layers.Conv2DTranspose(growth_rate * filters, kernel_size, padding='same', strides=strides, data_format='channels_first')(x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 1, padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x

In [11]:
def Generator(height=None, width=None, input_channels=3, filters=16):

    inputs = tf.keras.Input(shape=[input_channels, height, width])

    x = dilated_conv_block(growth_rate=2, filters=filters, kernel_size=(5, 5), dilation_rate=8, x=inputs)
    res1 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res2 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)

    x = deconv_block(growth_rate=4, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res2])
    x = conv_skip_block(growth_rate=4, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=2, x=x)

    x = deconv_block(growth_rate=2, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res1])
    x = conv_skip_block(growth_rate=2, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=1, filters=filters, kernel_size=(3, 3), dilation_rate=1, x=x)
    x = tf.keras.layers.Conv2D(3, (3, 3), padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.Subtract()([inputs[:,:3,:,:], x])
    
    _model = tf.keras.Model(inputs=inputs, outputs=x, name='derain_net')
    return _model    


In [12]:
def Generator_downscale(height=None, width=None, input_channels=3, filters=16):

    inputs = tf.keras.Input(shape=[input_channels, height, width])

    x = dilated_conv_block(growth_rate=2, filters=filters, kernel_size=(9, 9), dilation_rate=16, x=inputs)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(5, 5), dilation_rate=8, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res1 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res2 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)

    x = deconv_block(growth_rate=4, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res2])
    x = conv_skip_block(growth_rate=4, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=2, x=x)

    x = deconv_block(growth_rate=2, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res1])
    x = conv_skip_block(growth_rate=2, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=1, filters=filters, kernel_size=(3, 3), dilation_rate=1, x=x)
    x = tf.keras.layers.Conv2D(3, (3, 3), padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.Activation('tanh')(x)
    x = tf.keras.layers.Subtract()([inputs, x])
    x = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same', data_format='channels_first')(x)
    
    _model = tf.keras.Model(inputs=inputs, outputs=x, name='derain_net')
    return _model 

In [13]:
generator_full = Generator(height=None, width=None, input_channels=6, filters=16)

In [44]:
#generator_full.summary()

In [45]:
generator_full.save('deflare.h5')

In [46]:
import netron

In [47]:
netron.start('deflare.h5', port=8989)

Serving 'deflare.h5' at http://localhost:8989


In [48]:
netron.stop(port=8989)


Stopping http://localhost:8989


## Guidance data preparation

In [14]:
fake_optimizer = tf.keras.optimizers.Adam(beta_1=0.5)
guidance_cnn = Generator_downscale(input_channels=3, filters=32)
guidance_cnn.load_weights(r"/home/p00536919/Flare_removal/Downscale_ref/20210303_dilation_stack_activations_shift.hdf5")

In [57]:
checkpoint = tf.train.Checkpoint(generator_optimizer=fake_optimizer, generator=guidance_cnn)

In [58]:
checkpoint.restore(r"/shared/p00536919/training_checkpoints/flare_removal/20210211-164544_flare_only/ckpt-33")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7ef9345b2070>

In [15]:
def prepare_full_res_guidance(usc_downscaled_image, cnn):
    paddings = tf.constant([[0,0],[0, 2], [0, 0],[0,0]])
    padded_image = tf.cast(usc_downscaled_image[np.newaxis,:,:,:], tf.float32)
    padded_image = tf.pad(padded_image, paddings, "SYMMETRIC")
    cnn_out = cnn(tf.transpose(padded_image,[0,3,1,2]))
    np_out = tf.transpose(cnn_out,[0,2,3,1]).numpy()[0][:-2]
    np_out = np.clip(np_out, 0, 1)
    np_out = cv2.resize(np_out, (1920, 1080), interpolation = cv2.INTER_LINEAR)
    return np_out

In [16]:
usc_fullres_guidnance_arrays = [prepare_full_res_guidance(image, guidance_cnn) for image in tqdm(usc_arrays_downscaled)] 

100%|██████████| 303/303 [00:37<00:00,  8.02it/s]


In [17]:
usc_concatenated_fullres_arrays = [np.dstack([usc_image, usc_guidance]) for 
                                   (usc_image, usc_guidance) in tqdm(zip(usc_arrays, usc_fullres_guidnance_arrays))]

303it [00:13, 22.83it/s]


In [18]:
training_dataset = (usc_concatenated_fullres_arrays, gt_arrays)

In [19]:
train_data_generator = dataset.DataGenerator(training_dataset, patch_size=896, batch_size=1, image_size=(1080, 1920), shuffle=True, pad=11)

## Define losses

### Generator loss

In [20]:
from train_utils.train  import losses

In [21]:
vgg16_loss_1_3 = losses.make_VGG16_loss(blocks_dict={1:1, 3:1}, 
                                          weights_path=r"/home/p00536919/usc-image-enhancement/srs_refactor/refactored_pipeline/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5",
                                          loss_type='MSE')

In [22]:
def generator_loss(gen_output, target):
    
    gen_output = tf.transpose(gen_output, [0,2,3,1])
    target = tf.transpose(target, [0,2,3,1])
    
    l1_loss_val = losses.L1_loss(target, gen_output)
    
    ms_ssim_loss_val = (1 - tf.math.reduce_mean(tf.image.ssim_multiscale(target, gen_output, 1, power_factors=(0.4, 0.25, 0.25, 0.2363, 0.1333))))# + (1 - tf.math.reduce_mean(tf.image.ssim_multiscale(target_upscaled, gen_output_upscaled, 1, power_factors=(0.4, 0.25, 0.25, 0.2363, 0.1333))))

    
    vgg_loss_val = vgg16_loss_1_3(tf.image.resize(target, [224,224]), tf.image.resize(gen_output, [224,224]))

    total_gen_loss_val = 0.5 * l1_loss_val + 1 * ms_ssim_loss_val  + 0.08 * vgg_loss_val

    return total_gen_loss_val, [l1_loss_val, ms_ssim_loss_val, vgg_loss_val]

## Define the Optimizers and Checkpoint-saver

In [62]:
#generator_optimizer = tf.keras.optimizers.Adam(lr=2e-4, beta_1=0.25, beta_2=0.75)
import tensorflow_addons as tfa

radam = tfa.optimizers.RectifiedAdam(lr=2e-6)
generator_optimizer = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)

In [24]:
checkpoint_dir = '/shared/p00536919/training_checkpoints/flare_removal/' + datetime.now().strftime("%Y%m%d-%H%M%S") + "_full_res_video"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, generator=generator_full)

## Training

In [25]:
log_dir=r"/shared/p00536919/logs/flare_removal/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.now().strftime("%Y%m%d-%H%M%S") + "_full_res_video")

file_writer_img = tf.summary.create_file_writer(log_dir + "fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")  + "_full_res_video" + '/img')

In [35]:
import albumentations as A

In [46]:
transform = A.Compose(
    [A.ShiftScaleRotate(scale_limit=[-0.01, 5], rotate_limit=180, interpolation=cv2.INTER_CUBIC, border_mode=cv2.BORDER_REFLECT_101, p=0.8),
     A.HueSaturationValue(hue_shift_limit=20/255, sat_shift_limit=20/255, val_shift_limit=30/255, always_apply=False, p=0.25),
     A.ChannelShuffle(p=0.25)],
    additional_targets={'image0': 'image', 'image1': 'image'}
)

In [63]:
from train_utils.train  import train_step 

In [64]:
train_step = train_step.make_train_step()

In [65]:
def fit(epochs, run):
           
    for epoch in range(epochs):  
        epoch += int(epochs * run)
        start = time.time()
        display.clear_output(wait=True)
        
        if (epoch+1) % 50 == 0:
            generator_full.save_weights("log_weights2_full.hdf5")
            with tf.device('/GPU:0'):
                tf.keras.backend.clear_session()
                test_generator = Generator(height=None, width=None, input_channels=6, filters=16)
                test_generator.load_weights('log_weights2_full.hdf5')                
                ing_index = np.random.randint(len(usc_concatenated_fullres_arrays))
                test_image = usc_concatenated_fullres_arrays[ing_index][np.newaxis, :, :, :]
                test_image = tf.cast(test_image, tf.float32)
                preds = test_generator(tf.transpose(test_image,[0,3,1,2]), training=True)  
                preds = tf.transpose(preds, [0,2,3,1])
            
            with file_writer_img.as_default():                
                tf.summary.image("Input_image", usc_arrays[ing_index][np.newaxis, :, :, :], step=epoch)
                tf.summary.image("Ground truth", gt_arrays[ing_index][np.newaxis, :, :, :], step=epoch)
                tf.summary.image("Guidance", usc_fullres_guidnance_arrays[ing_index][np.newaxis, :, :, :], step=epoch)
                tf.summary.image("Model results", preds.numpy()[:, :,:,:3], step=epoch)
                
                
        print("Epoch: ", epoch)
        train_data_generator.on_epoch_end()
        # Train
        gen_total_loss, l1_loss_val, ms_ssim_loss_val, vgg_loss_val  = 0, 0, 0, 0
        
        for n, (input_image, target) in enumerate(train_data_generator):
            print('.', end='')
            if (n+1) % 100 == 0:
                print('\n')
            input_image = np.array(input_image)            
            target = np.array(target)
            
            for pair_id in range(len(input_image)):
                #seed = np.random.randint(0,1000)
                #input_image[pair_id] = random_alpha_blend(input_image[pair_id], size=224, max_shift=100, seed=seed, proba=proba)
                #target[pair_id] = random_alpha_blend(target[pair_id], size=224, max_shift=100, seed=seed, proba=proba)
                transformed = transform(image=input_image[pair_id][:,:,:3], image0=input_image[pair_id][:,:,3:6], image1=target[pair_id])
            
                input_image[pair_id] = np.dstack([transformed['image'], transformed['image0']])
                target[pair_id] = transformed['image1']

            input_image = tf.cast(input_image, tf.float32)
            target = tf.cast(target, tf.float32)
            losses = train_step(input_image=tf.transpose(tf.cast(input_image, tf.float32), [0,3,1,2]),
                                    target=tf.transpose(target,[0,3,1,2]),
                                    generator=generator_full,                                    
                                    generator_loss=generator_loss,
                                    generator_optimizer=generator_optimizer,
                                    training=True)
            
            gen_total_loss += losses[0]
            
            l1_loss_val += losses[1][0]
            ms_ssim_loss_val += losses[1][1]
            vgg_loss_val += losses[1][2]
            
            
            
        print()
        with summary_writer.as_default():
            tf.summary.scalar('gen_total_loss', gen_total_loss/len(train_data_generator), step=epoch)
            #tf.summary.scalar('gan_loss_val', gan_loss_val/len(train_data_generator), step=epoch)
            tf.summary.scalar('l1_loss_val', l1_loss_val/len(train_data_generator), step=epoch)
            tf.summary.scalar('ms_ssim_loss_val', ms_ssim_loss_val/len(train_data_generator), step=epoch)
            tf.summary.scalar('vgg_loss_val', vgg_loss_val/len(train_data_generator), step=epoch)
            #tf.summary.scalar('vgg_loss_val', vgg_loss_val/len(train_data_generator), step=epoch)
            #tf.summary.scalar('disc_loss ', disc_loss /len(train_data_generator), step=epoch)
            
        # saving (checkpoint) the model every 250 epochs
        if (epoch + 1) % 100 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            
        """if (epoch + 1) % 1000 == 0:
            generator.save_weights("avg_weights_gen_{}-1.hdf5".format(epoch))
            #discriminator.save_weights("avg_weights_disc_{}-1.hdf5".format(epoch))"""
            
        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                                time.time()-start))


In [66]:
EPOCHS = 1000

In [67]:
for i in range(100):
    fit(EPOCHS, i+1.324)

Epoch:  3243
....................................................................................................

.......................

KeyboardInterrupt: 

### Restore from checkpoint

In [69]:
checkpoint.restore("/shared/p00536919/training_checkpoints/flare_removal/20200114-114537/ckpt-10")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f87bc185290>

### Save Models

In [51]:
generator.save_weights(r'./weights/trained_vanilla_pix2pix_hist_superlight_no_rescale_FINAL_gen_no_gamma_11_5_2020.hdf5')
discriminator.save_weights(r'./weights/trained_vanilla_pix2pix_hist_superlight_no_rescale_FINAL_disc_no_gamma_18_5_2020.hdf5')

In [33]:
generator.load_weights(r'/home/p00536919/Flare_removal/Downscale_ref/log_weights1.hdf5')
#discriminator.load_weights(r'F:\models\new_approach_disc_27_2_2020.hdf5')

In [68]:
plt.imsave(r"usc1.png", usc_arrays_downscaled[495][:296, :408, :])

In [69]:
plt.imsave(r"gt1.png", gt_arrays_downscaled[495][:296, :408, :])

In [72]:
test_image = usc_arrays_downscaled[495][np.newaxis, :296, :408, :]
test_image = tf.cast(test_image, tf.float32)
#test_image = tf.image.resize(test_image, [256,256])
preds = generator(tf.transpose(test_image,[0,3,1,2]), training=True)  
preds = tf.transpose(preds, [0,2,3,1])
plt.imsave(r"preds1.png", np.clip(preds.numpy(),0,1)[0])