## Imports

In [1]:
import os
import cv2
import time
import functools
import skimage.io
import numpy as np
import skimage.filters
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.ndimage.filters as fi

from itertools import tee
from PIL import Image, ImageOps
from tensorflow.keras import Model
from skimage.transform import resize
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import Input, Conv2D, add

### Dataset Notes:

* Dataset length = 144,682 sequence
* Each sequence = 7 frames
* Each frame = 256 x 256 x 3 
* This is our Ground Truth dataset and it will be downsampled on the fly to train the model

### Dataset splitting: 
* Train = 90,682     
* Test = 27,000     
* Validation = 27,000

## Data loading and preprocessing

In [2]:
def load_sequence(sequence_path):
    GT = [Image.open(sequence_path + str(img) + '.png') for img in range(7)]
    return GT

In [3]:
def gkern(kernlen=13, nsig=1.6):
    inp = np.zeros((kernlen, kernlen))
    # set element at the middle to one, a dirac delta
    inp[kernlen // 2, kernlen // 2] = 1
    # gaussian-smooth the dirac, resulting in a gaussian filter mask
    return fi.gaussian_filter(inp, nsig)

def Gaussian_down_sample(sequence, scale):
    """
        sequence sahpe : [F, H, W, C]
    """
    assert scale in[2, 4], "Enter valid scale value [2, 4]"
    
    nsigma = scale * 0.4
    kernal_in = gkern()
    gauss_kernel_2d = tf.convert_to_tensor(kernal_in, dtype=tf.float32)
    gauss_kernel = tf.tile(gauss_kernel_2d[:, :, tf.newaxis, tf.newaxis], [1, 1, 3, 3]) # 13*13*3*3
    
    pad_w, pad_h = 6, 6   # Filter padding
    padded_sequence = tf.pad(sequence, [[0,0], [pad_w,pad_w], [pad_h,pad_h], [0,0]], mode='REFLECT')
    
    # Pointwise filter that does nothing
    pointwise_filter = tf.eye(3, batch_shape=[1, 1])
    LR_seq = tf.nn.separable_conv2d(padded_sequence, gauss_kernel, pointwise_filter,
                                    strides=[1, scale, scale, 1], padding=[[0,0],[0,0],[0,0],[0,0]])
    return LR_seq

In [4]:
def train_data_generator(train_dataset_path, batch_size, scale_factor):
    train_sequences_count = 90682
    train_batches_count = train_sequences_count // batch_size
    train_batches = np.random.choice(np.arange(train_sequences_count), (train_batches_count, batch_size), replace=False)


#     validation_sequences_count = 27000
#     validation_batches_count = validation_sequences_count // batch_size
#     validation_batches = np.random.choice(np.arange(validation_sequences_count), (validation_batches_count, batch_size), replace=False)
    
    sequence_len = 7
    
    # Now lets loop over the full epoch and process each batch befor return it to the network
    for batch in train_batches:
        GT_batch = []
        LR_batch = []
        for sequence in batch:
            sequences_path = train_dataset_path + str('%05d'%sequence) + '/'
            GT = load_sequence(sequences_path)
            GT = [np.asarray(img) for img in GT]
            GT = tf.cast(tf.convert_to_tensor(GT), tf.float32)
            LR = Gaussian_down_sample(GT, scale_factor)
            LR = tf.concat([LR[1:2, :, :, :], LR], axis=0)
            GT_batch.append(GT)
            LR_batch.append(LR)

        yield np.array(GT_batch), np.array(LR_batch), train_batches_count

## Subclassed Model

In [5]:
# Concatenate the Residual Blocks
def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return tf.keras.Sequential(layers)

In [6]:
# Residual Blocks 
class Residual_Blocks(tf.keras.Model):
    '''Residual block w/o BN
    -+-Conv-ReLU-Conv-+-
     |________________|
    '''
    def __init__(self, n_f):
        super(Residual_Blocks, self).__init__()
        self.conv1 = Conv2D(filters=n_f, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer=tf.keras.initializers.HeNormal())
        self.conv2 = Conv2D(filters=n_f, kernel_size=(3,3), padding='same', kernel_initializer=tf.keras.initializers.HeNormal())
        
    def call(self, x):
        identity = x
        out = self.conv1(x)
        out = self.conv2(out)
        return out + identity

In [7]:
# Hidden layers of the network
class hidden(tf.keras.Model):
    def __init__(self, n_f, n_b, scale):
        super(hidden, self).__init__()
        self.conv1 = Conv2D(filters=n_f, kernel_size=(3,3), activation='relu', padding='same', kernel_initializer=tf.keras.initializers.HeNormal())
        basic_block = functools.partial(Residual_Blocks, n_f=n_f)
        self.residual_blocks = make_layer(basic_block, n_b)
        self.conv_h = Conv2D(filters=n_f, kernel_size=(3,3), activation='relu', padding='same', name='hidden_state', kernel_initializer=tf.keras.initializers.HeNormal())
        self.conv_o = Conv2D(filters=scale*scale*3, kernel_size=(3,3), padding='same', name='output', kernel_initializer=tf.keras.initializers.HeNormal())
        
    def call(self, X, h, o):
        x_input = tf.concat([X, tf.cast(h, tf.float32), tf.cast(o, tf.float32)], axis=-1)
        x = self.conv1(x_input)
        x = self.residual_blocks(x)
        x_h = self.conv_h(x)
        x_o = self.conv_o(x)
        return x_h, x_o

In [8]:
# Down smaple the input
class PixelUnShuffle(tf.keras.Model):
    def __init__(self, scale):
        super(PixelUnShuffle, self).__init__()
        self.scale_factor = scale
    
    def call(self, x_o):
        x_o = tf.nn.space_to_depth(x_o, self.scale_factor)
        return x_o

In [9]:
# Main class
class RRN(tf.keras.Model):
    def __init__(self, n_f, n_b, scale):
        super(RRN, self).__init__()
        self.hidden = hidden(n_f, n_b, scale)
        self.scale = scale
        self.down = PixelUnShuffle(scale)
        self.n_f = n_f
        
    def call(self, x, x_h, x_o, init):
        f1 = x[:,0,:,:,:]
        f2 = x[:,1,:,:,:]
        h,w = f1.shape[1:3]
        x_input = tf.concat([f1, f2], axis=-1)
        if init:
            x_h, x_o = self.hidden(x_input, x_h, x_o)
        else:
            x_o = self.down(x_o)
            x_h, x_o = self.hidden(x_input, x_h, x_o)
        
        x_o = tf.image.resize(f2, (h*self.scale, w*self.scale)) + tf.nn.depth_to_space(x_o, self.scale)
        return x_h, x_o
    
    def train_step(self, GT, LR, loss_fn, optimizer):
        B,F,_,_,_ = LR.shape
        output = []
        
        with tf.GradientTape() as tape:
            for frame_index in range(F-1):
                if not bool(frame_index):
                    # Initialize frame[-1], hidden_state[-1] and prediction[-1]
                    init_frame = tf.zeros_like(LR[:,0,:,:,:])
                    prediction = tf.repeat(init_frame, repeats= self.scale**2, axis=3)
                    hidden_state = tf.repeat(init_frame[:,:,:,:1], repeats= self.n_f, axis=-1)

                hidden_state, prediction = self(LR[:,frame_index:frame_index+2,:,:,:], hidden_state, prediction, not bool(frame_index))
                output.append(prediction)
                
            output = tf.stack(output, axis=1)
            loss = tf.reduce_sum(loss_fn(GT, output))/(B*7)
            gradients = tape.gradient(loss, self.trainable_variables)
            optimizer.apply_gradients(zip(gradients, self.trainable_variables))
            
        return output, loss
    
    def plot_GT_pred(self, GT, pred):
        fig, axs = plt.subplots(1, 2, figsize=(15, 15))
        axs[0].imshow(tf.cast(GT[0,0], tf.uint8))
        axs[0].title.set_text('Ground Truth')
        axs[1].imshow(tf.cast(pred[0,0], tf.uint8))
        axs[1].title.set_text('Network output')
        plt.show()
        
    def fit(self, data_generator, epochs, batch_size, loss_fn, optimizer, load_from_check_point = False):
        if load_from_check_point:
            self.load_weights('check_point/model_weights')
            
        for epoch in range(epochs):
            data_generator, data_gen = tee(data_generator)
            for batch_index, data in enumerate(data_gen):
                t0 = time.time()
                
                GT, LR, total_batches_count = data[0], data[1], data[2]
                prediction, batch_loss = self.train_step(GT, LR, loss_fn, optimizer)
                
                # self.plot_GT_pred(GT, prediction)
                if batch_index%20 == 0:
                    self.save_weights('check_point/model_weights')  
                    
                t1 = time.time()
                
                print("Epoch[{}/{}],({}/{}): Batch_Loss: {:.4f} || Timer: {:.4f} sec.".
                format(epoch, epochs, (batch_index+1),total_batches_count, batch_loss, (t1-t0)))

## Train the model

In [12]:
# Enviroment variables
epochs = 5
batch_size = 32
scale_factor = 4
n_f = 128
n_b = 5
GT_path = 'Vimeo90k_256/train/'
input_shape = (64, 64, 118)

In [None]:
loss_fn = tf.keras.losses.mae
optimizer = tfa.optimizers.AdamW(learning_rate=1e-4, weight_decay=5e-4)
model = RRN(n_f, n_b, scale_factor)
data_gen = train_data_generator(GT_path, batch_size, scale_factor)
model.fit(data_gen, epochs, batch_size, loss_fn, optimizer)

Epoch[0/5],(1/2833): Batch_Loss: 5701515608064.0000 || Timer: 8.3632 sec.
Epoch[0/5],(2/2833): Batch_Loss: 894391156736.0000 || Timer: 8.3560 sec.
Epoch[0/5],(3/2833): Batch_Loss: 341150105600.0000 || Timer: 8.4944 sec.
Epoch[0/5],(4/2833): Batch_Loss: 177878106112.0000 || Timer: 8.5013 sec.
Epoch[0/5],(5/2833): Batch_Loss: 105108463616.0000 || Timer: 8.7005 sec.
Epoch[0/5],(6/2833): Batch_Loss: 65573650432.0000 || Timer: 8.5317 sec.
Epoch[0/5],(7/2833): Batch_Loss: 54557921280.0000 || Timer: 8.6683 sec.
Epoch[0/5],(8/2833): Batch_Loss: 36014542848.0000 || Timer: 8.4122 sec.
Epoch[0/5],(9/2833): Batch_Loss: 35948769280.0000 || Timer: 8.4894 sec.
Epoch[0/5],(10/2833): Batch_Loss: 24839538688.0000 || Timer: 8.6573 sec.
Epoch[0/5],(11/2833): Batch_Loss: 24162465792.0000 || Timer: 8.6831 sec.
Epoch[0/5],(12/2833): Batch_Loss: 20452900864.0000 || Timer: 8.7176 sec.
Epoch[0/5],(13/2833): Batch_Loss: 17302509568.0000 || Timer: 8.4940 sec.
Epoch[0/5],(14/2833): Batch_Loss: 13272921088.0000 || 