## Inference pipeline for Deep feature **video enhancement** solution with integrated color correction solution

Consists of models initialization available in two ways: architecture decalration with weights loading and checkpoint load. 

Pipeline allows to process videos already decomposed into frame sequences: each video in separated folder. Paths perpresented in lists of corresponding folders. Pipeline reads frames, yields downscaled output, fullres guidance images and fullres result. 

When result is accumulated it is being written by cv2 video writer in .avi format.  

**Attention:** not suitable for large videos as all data stored in RAM. For large videos rewrite pipeline for online mode.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras 

import os
import time
import glob
from datetime import datetime
import shutil

from matplotlib import pyplot as plt
from IPython import display
import numpy as np
from tqdm import tqdm
import cv2
from skimage.transform import rescale, resize
import skimage.filters

## GPU selection

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True) 
tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU')

In [None]:
tf.keras.backend.clear_session()

## Initialize models

In [4]:
def conv_block(growth_rate, filters, kernel_size, strides, x):
    x = tf.keras.layers.Conv2D(growth_rate * filters, kernel_size, padding='same', strides=strides, data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def dilated_conv_block(growth_rate, filters, kernel_size, dilation_rate, x):
    x = tf.keras.layers.Conv2D(growth_rate * filters, kernel_size, padding='same', dilation_rate=dilation_rate, data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def conv_skip_block(growth_rate, filters, kernel_size, x):
    x = tf.keras.layers.Conv2DTranspose(growth_rate * filters, kernel_size, padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def deconv_block(growth_rate, filters, kernel_size, strides, x):
    x = tf.keras.layers.Conv2DTranspose(growth_rate * filters, kernel_size, padding='same', strides=strides, data_format='channels_first')(x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 1, padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


In [5]:
def get_downscale_generator(height=None, width=None, input_channels=3, filters=32):

    inputs = tf.keras.Input(shape=[input_channels, height, width])

    x = dilated_conv_block(growth_rate=2, filters=filters, kernel_size=(9, 9), dilation_rate=16, x=inputs)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(5, 5), dilation_rate=8, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res1 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res2 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)

    x = deconv_block(growth_rate=4, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res2])
    x = conv_skip_block(growth_rate=4, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=2, x=x)

    x = deconv_block(growth_rate=2, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res1])
    x = conv_skip_block(growth_rate=2, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=1, filters=filters, kernel_size=(3, 3), dilation_rate=1, x=x)
    x = tf.keras.layers.Conv2D(3, (3, 3), padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.Activation('tanh')(x)
    x = tf.keras.layers.Subtract()([inputs, x])
    x = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same', data_format='channels_first')(x)
    
    _model = tf.keras.Model(inputs=inputs, outputs=x, name='downscale_net')
    return _model  

In [6]:
def get_fullres_generator(height=None, width=None, input_channels=6, filters=16):

    inputs = tf.keras.Input(shape=[input_channels, height, width])

    x = dilated_conv_block(growth_rate=2, filters=filters, kernel_size=(5, 5), dilation_rate=8, x=inputs)
    res1 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    res2 = x

    x = dilated_conv_block(growth_rate=4, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = tf.keras.layers.AveragePooling2D((2, 2), 2, padding='same', data_format='channels_first')(x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)
    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=4, x=x)

    x = deconv_block(growth_rate=4, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res2])
    x = conv_skip_block(growth_rate=4, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=8, filters=filters, kernel_size=(3, 3), dilation_rate=2, x=x)

    x = deconv_block(growth_rate=2, filters=filters, kernel_size=(4, 4), strides=2, x=x)

    x = tf.keras.layers.Concatenate(axis=1)([x, res1])
    x = conv_skip_block(growth_rate=2, filters=filters, kernel_size=(1, 1), x=x)

    x = dilated_conv_block(growth_rate=1, filters=filters, kernel_size=(3, 3), dilation_rate=1, x=x)
    x = tf.keras.layers.Conv2D(3, (3, 3), padding='same', data_format='channels_first')(x)
    x = tf.keras.layers.Subtract()([inputs[:,:3,:,:], x])
    
    _model = tf.keras.Model(inputs=inputs, outputs=x, name='fullres_net')
    return _model       

In [7]:
downscale_model = get_downscale_generator(height=None, width=None, input_channels=3, filters=32)
full_res_model = get_fullres_generator(height=None, width=None, input_channels=6, filters=16)

In [8]:
downscale_model_weights = r"20210303_dilation_stack_activations_shift.hdf5"
full_res_model_weights = r"20210317_full_res_shift_optimized.hdf5"

In [9]:
downscale_model.load_weights(downscale_model_weights)
full_res_model.load_weights(full_res_model_weights)

In [10]:
#fake_downscale_optimizer = tf.keras.optimizers.Adam(beta_1=0.5)
#fake_full_ref_optimizer = tf.keras.optimizers.Adam(beta_1=0.5)

In [11]:
#checkpoint_downscaled = tf.train.Checkpoint(generator_optimizer=fake_downscale_optimizer, generator=downscale_model)
#checkpoint_full_res = tf.train.Checkpoint(generator_optimizer=fake_full_ref_optimizer, generator=full_res_model)

In [12]:
#checkpoint_downscaled.restore(r"/shared/p00536919/training_checkpoints/flare_removal/20210212-193037_flare_only_32f/ckpt-29")
#checkpoint_full_res.restore(r"/shared/p00536919/training_checkpoints/flare_removal/20201127-151914/ckpt-58")

In [10]:
def get_color_correction_generator(filters=8, data_format='channels_first'):
    
    axis = -3
    
    inputs = tf.keras.layers.Input(shape=[3,None,None])

    out = tf.keras.layers.AvgPool2D(pool_size=(4, 8), padding='same', name='avg_pool1', data_format=data_format)(inputs)

    out = tf.keras.layers.Conv2D(filters, 5, strides=4, padding='same', name='conv1', data_format=data_format)(out)

    out = tf.keras.layers.LeakyReLU(name='activation1')(out)
    out = tf.keras.layers.Conv2D(filters * 2, 3, strides=2, padding='same', name='conv2', data_format=data_format)(out)

    out = tf.keras.layers.LeakyReLU(name='activation2')(out)
    out = tf.keras.layers.Conv2D(filters * 4, 3, strides=2, padding='same', name='conv3', data_format=data_format)(out)

    out = tf.keras.layers.LeakyReLU(name='activation3')(out)
    out = tf.keras.layers.Conv2D(filters * 8, 3, strides=2, padding='same', name='conv4', data_format=data_format)(out)

    out = tf.keras.layers.LeakyReLU(name='activation4')(out)
    out = tf.expand_dims(tf.keras.layers.GlobalAveragePooling2D( data_format=data_format)(out), axis=-1)
    out = tf.expand_dims(out, axis=-1)

    out = tf.keras.layers.Conv2D(filters * 16, 1, strides=1, padding='same', name='conv5', data_format=data_format)(out)

    out = tf.keras.layers.LeakyReLU(name='activation5')(out)
    out = tf.keras.layers.Conv2D(filters * 32, 1, strides=1, padding='same', name='conv6', data_format=data_format)(out)
   

    out = tf.keras.layers.LeakyReLU(name='activation7')(out)
   

    out22 = tf.keras.layers.Flatten(name='flatten')(out)
    

    
    out = tf.tile(tf.expand_dims(out22, axis=1), multiples=[1, 1, 1080*1920])


    out = tf.keras.layers.Reshape((1080, 1920,filters*32))(out)
    out = tf.transpose(out,[0,3,1,2])

    bread = out
    out = tf.keras.layers.Concatenate(axis=axis)([out, x])
    out = tf.keras.layers.Conv2D(256*1, 1, padding='same', data_format=data_format, name='conv8')(out)
    out = tf.keras.layers.LeakyReLU(name='activation8')(out)
   
    out = tf.keras.layers.Conv2D(128, 1, padding='same', data_format=data_format, name='conv92')(out)
    out = tf.keras.layers.LeakyReLU(name='activation92')(out)
    
    out = tf.keras.layers.Conv2D(64, 1, padding='same', data_format=data_format, name='conv10')(out)
    out = tf.keras.layers.LeakyReLU(name='activation10')(out)
    out = tf.keras.layers.Conv2D(32, 1, padding='same', data_format=data_format, name='conv11')(out)
    
    out = tf.keras.layers.LeakyReLU(name='activation100')(out)
    out = tf.keras.layers.Conv2D(16, 1, padding='same', data_format=data_format, name='conv112')(out)
    
    out = tf.keras.layers.LeakyReLU(name='activation101')(out)
    out = tf.keras.layers.Conv2D(8, 1, padding='same', data_format=data_format, name='conv1121')(out)

    out = tf.keras.layers.LeakyReLU(name='activation10011')(out)
    out = tf.keras.layers.Conv2D(4, 1, padding='same', data_format=data_format, name='conv1122')(out)
    
    out = tf.keras.layers.LeakyReLU(name='activation12')(out)
    
    out = tf.keras.layers.Conv2D(3, 1, padding='same', data_format=data_format, activation='sigmoid', name='conv14')(out)
    
    return tf.keras.Model(inputs=inputs, outputs=out)

In [None]:
cc_net = get_color_correction_generator()
cc_net.load_weights(r'/home/p00536919/Flare_removal/Downscale_ref/log_colorcor_videodata256augmaecropnn.hdf5')

## Create pipeline

In [None]:
IMAGE_WIDTH = 1080
IMAGE_HEIGHT = 1920

In [10]:
def prepare_full_res_guidance(usc_downscaled_image, cnn):
    # Paddings needed if frame size does not contain required order of "2"
    paddings = tf.constant([[0,0],[0, 0], [0, 0],[0,0]])
    padded_image = tf.cast(usc_downscaled_image[np.newaxis,:,:,:], tf.float32)
    padded_image = tf.pad(padded_image, paddings, "SYMMETRIC")
    cnn_out = cnn(tf.transpose(padded_image,[0,3,1,2]))
    np_out = tf.transpose(cnn_out,[0,2,3,1]).numpy()[0]
    np_out = np.clip(np_out, 0, 1)
    np_out = cv2.resize(np_out, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation = cv2.INTER_LINEAR)
    return np_out

def get_full_res_output(image, cnn):
    test_image = image[np.newaxis, :, :, :]
    test_image = tf.cast(test_image, tf.float32)
    preds = cnn(tf.transpose(test_image,[0,3,1,2]), training=True)  
    preds = tf.transpose(preds, [0,2,3,1])
    preds = preds.numpy()[0]
    return preds

def img_resize(image, factor=4):
    new_img = cv2.resize(image, (IMAGE_WIDTH//factor, IMAGE_HEIGHT//factor))
    return new_img

### Batch processing

In [16]:
#USC_TEST_FOLDERS_LIST = glob.glob(r"/shared/data1/shared/data1/Video_data_custom/1223_23/outside"+"/*")
#USC_TEST_FOLDERS_LIST = [r"/shared/data1/VID_20200806_143342"]
USC_TEST_FOLDERS_LIST = [r"/shared/data1/20210318_USC_Tablet_data/Test_data/sample_4_W/VID_20210323_170053",
                         r"/shared/data1/20210318_USC_Tablet_data/Test_data/sample_4_W/VID_20210323_170453",
                        ]
GT_TEST_FOLDERS_LIST =  [r"/shared/data1/test_flare_videos/gt/VID_20200806_143508",
                        r"/shared/data1/test_flare_videos/gt/VID_20200806_143559",
                        ]

OUTPUT_FOLDER = r"./test_output/"
VIDEO_OUT_FOLDER = r"./video_out/"

In [17]:
from train_utils.utils  import dataset, image_io, image_transform

In [None]:
shutil.rmtree(OUTPUT_FOLDER)
os.mkdir(OUTPUT_FOLDER)

In [19]:
for gt_test_folder, usc_test_folder in zip(GT_TEST_FOLDERS_LIST, USC_TEST_FOLDERS_LIST):
    
    GT_TEST_FOLDERS = [gt_test_folder]
    USC_TEST_FOLDERS = [usc_test_folder]
    png_image_reader = image_io.ImageReaderPNG(transform_list=None)
    usc_filenames = dataset.get_filenames_array(folders_list=USC_TEST_FOLDERS, images_extension='.png')
    
    usc_arrays = dataset.get_images_array(folders_list=USC_TEST_FOLDERS, image_reader=png_image_reader, images_extension='.png')
    gt_arrays = dataset.get_images_array(folders_list=GT_TEST_FOLDERS, image_reader=png_image_reader, images_extension='.png')
    
    usc_arrays_downscaled = [img_resize(image) for image in tqdm(usc_arrays)]
    usc_fullres_guidnance_arrays = [prepare_full_res_guidance(image, downscale_model) for image in tqdm(usc_arrays_downscaled)]
    usc_concatenated_fullres_arrays = [np.dstack([usc_image, usc_guidance]) for 
                                       (usc_image, usc_guidance) in tqdm(list(zip(usc_arrays, usc_fullres_guidnance_arrays)))]
    full_res_output_arrays = [get_full_res_output(image, full_res_model) for image in tqdm(usc_concatenated_fullres_arrays)] 
    full_res_output_arrays = [get_full_res_output(image, cc_net) for image in tqdm(full_res_output_arrays)]
    
    #for i in tqdm(list(range(len(usc_fullres_guidnance_arrays)))):
        #plt.imsave(os.path.join(OUTPUT_FOLDER, "{}.png".format(str(i).zfill(4))), np.clip(full_res_output_arrays[i],0,1))
        
    video_name = USC_TEST_FOLDERS[0].split('/')[-1] + '.avi'
    
    print(os.path.join(VIDEO_OUT_FOLDER, video_name))
    
    out = cv2.VideoWriter(os.path.join(VIDEO_OUT_FOLDER, video_name), cv2.VideoWriter_fourcc(*'DIVX'), 24, (IMAGE_WIDTH*3, IAMGE_HEIGHT))
    
    for gt_image, usc_image, out_image in tqdm(list(zip(gt_arrays, usc_arrays, full_res_output_arrays))):
        
        
        out_image = np.clip(out_image,0,1)
        out_image = out_image*255
        #out_image = cv2.resize(out_image, (1920//4, 1080//4), interpolation = cv2.INTER_LINEAR)
        out_image = out_image.astype(np.uint8)
        
        usc_image = np.clip(usc_image,0,1)
        usc_image = usc_image*255
        #usc_image = cv2.resize(usc_image, (1080//4, 1920//4), interpolation = cv2.INTER_LINEAR)
        usc_image = usc_image.astype(np.uint8)  
        
        gt_image = np.clip(gt_image,0,1)
        gt_image = gt_image*255
        #gt_image = cv2.resize(gt_image, (1080//4, 1920//4), interpolation = cv2.INTER_LINEAR)
        gt_image = gt_image.astype(np.uint8)
        
        
        triple_img = np.hstack([gt_image, np.transpose(out_image[:,::-1,:],[1,0,2]), usc_image])
        out.write(triple_img[:,:,::-1])
        
    out.release()
    time.sleep(5)
    print("\n" + video_name + r" released" + "\n")

100%|██████████| 1422/1422 [00:00<00:00, 692556.93it/s]
100%|██████████| 1422/1422 [00:26<00:00, 54.65it/s]
100%|██████████| 1422/1422 [00:00<00:00, 2125.80it/s]
100%|██████████| 1422/1422 [01:19<00:00, 17.82it/s]
100%|██████████| 1422/1422 [00:24<00:00, 57.34it/s]
100%|██████████| 1422/1422 [01:47<00:00, 13.21it/s]
  0%|          | 2/1422 [00:00<01:59, 11.89it/s]

/home/p00536919/Flare_removal/Downscale_ref/video_out/VID_20210323_170053.avi


100%|██████████| 1422/1422 [01:52<00:00, 12.69it/s]
100%|██████████| 178/178 [00:00<00:00, 416157.25it/s]
  2%|▏         | 4/178 [00:00<00:04, 39.73it/s]


VID_20210323_170053.avi released



100%|██████████| 178/178 [00:03<00:00, 50.45it/s]
100%|██████████| 178/178 [00:00<00:00, 1991.91it/s]
100%|██████████| 178/178 [00:09<00:00, 17.93it/s]
100%|██████████| 178/178 [00:02<00:00, 59.96it/s]
100%|██████████| 178/178 [00:13<00:00, 13.53it/s]
  1%|          | 2/178 [00:00<00:13, 13.38it/s]

/home/p00536919/Flare_removal/Downscale_ref/video_out/VID_20210323_170453.avi


100%|██████████| 178/178 [00:13<00:00, 12.92it/s]



VID_20210323_170453.avi released

