# Deep Neural Network for Artefact Suppression

There are two main networks:
* Suppression Network
* VQA Network

In the suppression network:
* Spatial Suppression Network
* Temporal Suppression Network

Both suppression networks use U-net architecture.

### Imports
Main imports: tensorflow, numpy, subprocess, pandas

In [None]:
%matplotlib inline
import os
# os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers
from keras.utils.vis_utils import plot_model
import numpy as np
import subprocess
import matplotlib.pyplot as plt
import pandas as pd
import datetime

#### Enable GPU to dynamically allocate more memory if needed

In [None]:
# computeDevices = tf.config.list_physical_devices('GPU')
# print(computeDevices)
# for device in computeDevices:
#     try:
#         tf.config.experimental.set_memory_growth(device, True)
#     except:
#         print(f"{device} cannot be set")
#         pass


### YUV420 Functions
Helper functions to read and write YUV420 files

In [None]:
def readYUV420(name: str, resolution: tuple, upsampleUV: bool = False):
    height = resolution[0]
    width = resolution[1]
    bytesY = int(height * width)
    bytesUV = int(bytesY/4)
    Y = []
    U = []
    V = []
    with open(name,"rb") as yuvFile:
        while (chunkBytes := yuvFile.read(bytesY + 2*bytesUV)):
            Y.append(np.reshape(np.frombuffer(chunkBytes, dtype=np.uint8, count=bytesY, offset = 0), (width, height)))
            U.append(np.reshape(np.frombuffer(chunkBytes, dtype=np.uint8, count=bytesUV, offset = bytesY),  (width//2, height//2)))
            V.append(np.reshape(np.frombuffer(chunkBytes, dtype=np.uint8, count=bytesUV, offset = bytesY + bytesUV), (width//2, height//2)))
    Y = np.stack(Y)
    U = np.stack(U)
    V = np.stack(V)
    if upsampleUV:
        U = U.repeat(2, axis=1).repeat(2, axis=2)
        V = V.repeat(2, axis=1).repeat(2, axis=2)
    return Y, U, V


def readYUV420Range(name: str, resolution: tuple, range: tuple, upsampleUV: bool = False):
    height = resolution[0]
    width = resolution[1]
    bytesY = int(height * width)
    bytesUV = int(bytesY/4)
    Y = []
    U = []
    V = []
    with open(name,"rb") as yuvFile:
        startLocation = range[0]
        endLocation = range[1] + 1
        startLocationBytes = startLocation * (bytesY + 2*bytesUV)
        endLocationBytes = endLocation * (bytesY + 2*bytesUV)
        data = np.fromfile(yuvFile, np.uint8, endLocationBytes-startLocationBytes, offset=startLocationBytes).reshape(-1,bytesY + 2*bytesUV)
        Y = np.reshape(data[:, :bytesY], (-1, width, height))
        U = np.reshape(data[:, bytesY:bytesY+bytesUV], (-1, width//2, height//2))
        V = np.reshape(data[:, bytesY+bytesUV:bytesY+2*bytesUV], (-1, width//2, height//2))
    if upsampleUV:
        U = U.repeat(2, axis=1).repeat(2, axis=2)
        V = V.repeat(2, axis=1).repeat(2, axis=2)
    return Y, U, V


def writeYUV420(name: str, Y, U, V, downsample=True):
    towrite = bytearray()
    if downsample:
        U = U[:, ::2, ::2]
        V = V[:, ::2, ::2]
    for i in range(Y.shape[0]):
        towrite.extend(Y[i].tobytes())
        towrite.extend(U[i].tobytes())
        towrite.extend(V[i].tobytes())
    with open(name, "wb") as destination:
        destination.write(towrite)

### Data Generator
Data generator that feeds in data of a specific batch size to Suppression Networks

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, referencePaths, degradedPaths, frameRanges, batch_size, dim, shuffle=True):
        self.referencePaths = referencePaths
        self.degradedPaths = degradedPaths
        self.batch_size = batch_size
        self.dim = dim
        self.shuffle = shuffle
        self.frameRanges = frameRanges
        self.on_epoch_end()
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.referencePaths))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
    
    def __len__(self):
        return int(np.floor(len(self.referencePaths)/self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        referencePaths_temp = [] 
        degradedPaths_temp = []
        frameRanges_temp = []
        for i in indexes:
            referencePaths_temp.append(self.referencePaths[i])
            degradedPaths_temp.append(self.degradedPaths[i])
            frameRanges_temp.append(self.frameRanges[i])
        
        X, y = self.__data_generation(referencePaths_temp, degradedPaths_temp, frameRanges_temp)
        
        return X,y 
    
    def __data_generation(self, referencePaths_temp, degradedPaths_temp, frameRanges_temp):
        X = np.empty([self.batch_size, 5,*self.dim])
        y = np.empty([self.batch_size, 3,*self.dim])
        
        for i, (degradedPath, frameRange) in enumerate(zip(degradedPaths_temp, frameRanges_temp)):
            Ydeg, Udeg, Vdeg = readYUV420Range(degradedPath, (1920,1080), frameRange, upsampleUV = True)
            YUVdeg = np.stack([Ydeg, Udeg, Vdeg], axis=-1)
            X[i] = YUVdeg
        
        for i, (referencePath, frameRange) in enumerate(zip(referencePaths_temp, frameRanges_temp)):
            Yref, Uref, Vref = readYUV420Range(referencePath, (1920,1080), (frameRange[0]+1, frameRange[1]-1), upsampleUV = True)
            YUVref = np.stack([Yref,Uref,Vref], axis=-1)
            y[i] = YUVref
        
        X = (X/255).astype(np.float64)
        y = (y/255).astype(np.float64)
        
        return X, y
        



### Loss functions
Loss functions for MSE and perceptual losses 

In [None]:
def mse(y, x):
    loss = tf.reduce_mean(tf.square(tf.subtract(y,x)))
    return loss

def perceptCompute(ref, pred, path='/home/ramsookd/ArtefactReduction/processingTemp/'):
    ref = ref * 255
    pred = tf.round(tf.clip_by_value(pred, 0, 1) * 255)
    
    ref = ref.astype(np.uint8)
    pred = pred.numpy().astype(np.uint8)
    
    refFile = f'{path}ref_temp.yuv'
    predFile = f'{path}pred_temp.yuv'
    
    writeYUV420(f'{refFile}', ref[:,:,:,0], ref[:,:,:,1], ref[:,:,:,2], downsample=True)
    writeYUV420(f'{predFile}', pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2], downsample=True)
    
    commandVMAF = f"vmaf --width 1920 --height 1080 -p 420 -b 8 -m version=vmaf_v0.6.1 -o {path}CSV.csv --csv -r {refFile} -d {predFile}"
    runVMAF = subprocess.Popen(commandVMAF,stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=True, cwd='/')
    processOut, processErr = runVMAF.communicate() 

    vmafDF = pd.read_csv(f"{path}CSV.csv")
    vmafScore = vmafDF['vmaf'].mean()
    return vmafScore

def l2Loss(x, tensor):
    return tf.reduce_mean((tf.square(tensor-x)))

def perceptualLoss(alpha, ref_frames, pred_frames, vqa_pred):

    batch_size = ref_frames.shape[0]
    loss = 0
    for i in range(batch_size):
        Y_frames_ref = ref_frames[i, :, :, :, 0]
        UV_frames_ref = ref_frames[i, :, :, :, 1:]
        
        Y_frames_pred = pred_frames[i, :, :, :, 0]
        UV_frames_pred = pred_frames[i, :, :, :, 1:]
        
        print(vqa_pred[i])
        loss = (alpha*(l2Loss(100,vqa_pred[i])) + (1-alpha)*mse(Y_frames_ref, Y_frames_pred)) + (mse(UV_frames_ref,UV_frames_pred))
        loss += loss

    loss = loss/batch_size
    return loss

def vqaLoss(ref_frames, pred_frames, pred_vmaf):
    batch_size = ref_frames.shape[0]
    actual_vmaf = []
    for i in range(batch_size):
        actual_vmaf.append(perceptCompute(ref_frames[i], pred_frames[i]))
    actual_vmaf = tf.convert_to_tensor(actual_vmaf)
    print(actual_vmaf)
    print(pred_vmaf)
    return(mse(pred_vmaf, actual_vmaf))

In [None]:
class CNNBNReluDown(layers.Layer):
    def __init__(self, numFilters, size, strides, bn=False, **kwargs):
        super(CNNBNReluDown, self).__init__()
        self.numFilters = numFilters
        self.size = size
        self.bn = bn
        self.strides = strides
        self.convLayer = layers.Conv2D(self.numFilters, self.size, strides=self.strides ,padding="same")
        if self.bn:
            self.bnLayer = layers.BatchNormalization()
        self.reluLayer = layers.LeakyReLU()
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'numFilters' : self.numFilters,
            'size' : self.size,
            'strides' : self.strides,
            'bn' : self.bn
        })
        return config
    
    def call(self, inputs, training=False):
        x = self.convLayer(inputs)
        if self.bn:
            x = self.bnLayer(x, training=training)
        x = self.reluLayer(x)
        return x


class CNNBNReluUp(layers.Layer):
    def __init__(self, numFilters, size, strides, dropout=False, **kwargs):
        super(CNNBNReluUp, self).__init__()
        self.numFilters = numFilters
        self.size = size
        self.dropout = dropout
        self.strides = strides
        self.convLayer = layers.Conv2DTranspose(self.numFilters, self.size, strides=self.strides, padding="same")
        if self.dropout:
            self.dropoutLayer = layers.Dropout(0.3)
        self.reluLayer = layers.LeakyReLU()
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'numFilters' : self.numFilters,
            'size' : self.size,
            'strides' : self.strides,
            'dropout' : self.dropout
        })
        return config
    
    def call(self, inputs, training=False):
        x = self.convLayer(inputs)
        if self.dropout:
            x = self.dropoutLayer(x, training=training)
        x = self.reluLayer(x)
        return x

In [None]:
class SpatialSuppression(keras.Model):
    def __init__(self, encoder, decoder, outChannels = 3):
        super(SpatialSuppression, self).__init__()
        self.numEncoderBlocks = len(encoder)
        self.numDecoderBlocks = len(decoder)
        self.encoder = []
        self.decoder = []
        for encoder_opts in encoder:
            self.encoder.append(CNNBNReluDown(encoder_opts[0],encoder_opts[1],encoder_opts[2]))
        
        for decoder_opts in decoder:
            self.decoder.append(CNNBNReluUp(decoder_opts[0],decoder_opts[1],decoder_opts[2]))
        self.lastConv = layers.Conv2DTranspose(outChannels, 4, strides=2, padding="same")
    
    def call(self, x, training=False):
        skips = []
        for encoderLayer in self.encoder:
            x = encoderLayer(x, training=training)
            skips.append(x)
        
        skips = reversed(skips[:-1])
        
        for decoderLayer, skip in zip(self.decoder, skips):
            x = decoderLayer(x, training=training)
            x = layers.Concatenate()([x, skip])
    
        x = self.lastConv(x)
        
        return x
    
    def model(self):
        x = keras.Input(shape=(1080,1920,3))
        return keras.Model(inputs=[x], outputs=self.call(x))
    

class CNN3D_BNReluDown(layers.Layer):
    def __init__(self, numFilters, size, strides, bn=False,  **kwargs):
        super(CNN3D_BNReluDown, self).__init__()
        self.numFilters = numFilters
        self.size = size
        self.bn = bn
        self.strides = strides
        self.conv3DLayer = layers.Conv3D(self.numFilters, self.size, strides=self.strides ,padding="same")
        if self.bn:
            self.bnLayer = layers.BatchNormalization()
        self.reluLayer = layers.LeakyReLU()
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'numFilters' : self.numFilters,
            'size' : self.size,
            'strides' : self.strides,
            'bn' : self.bn
        })
        return config
    
    def call(self, inputs, training=False):
        x = self.conv3DLayer(inputs)
        if self.bn:
            x = self.bnLayer(x, training=training)
        x = self.reluLayer(x)
        return x


class CNN3D_BNReluUp(layers.Layer):
    def __init__(self, numFilters, size, strides, dropout=False,  **kwargs):
        super(CNN3D_BNReluUp, self).__init__()
        self.numFilters = numFilters
        self.size = size
        self.dropout = dropout
        self.strides = strides
        self.conv3DLayer = layers.Conv3DTranspose(self.numFilters, self.size, strides=self.strides, padding="same")
        if self.dropout:
            self.dropoutLayer = layers.Dropout(0.3)
        self.reluLayer = layers.LeakyReLU()
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'numFilters' : self.numFilters,
            'size' : self.size,
            'strides' : self.strides,
            'dropout' : self.dropout
        })
        return config
    
    def call(self, inputs, training=False):
        x = self.conv3DLayer(inputs)
        if self.dropout:
            x = self.dropoutLayer(x, training=training)
        x = self.reluLayer(x)
        return x

class TemporalSuppression(keras.Model):
    def __init__(self, encoder, decoder, outChannels = 3):
        super(TemporalSuppression, self).__init__()
        self.numEncoderBlocks = len(encoder)
        self.numDecoderBlocks = len(decoder)
        self.encoder = []
        self.decoder = []
        with tf.device('/device:GPU:1'):
            for encoder_opts in encoder:
                self.encoder.append(CNN3D_BNReluDown(encoder_opts[0],encoder_opts[1],encoder_opts[2],encoder_opts[3]))
            
        with tf.device('/device:GPU:2'):
            for decoder_opts in decoder:
                self.decoder.append(CNN3D_BNReluUp(decoder_opts[0],decoder_opts[1],decoder_opts[2],decoder_opts[3]))
            self.lastConv = layers.Conv3DTranspose(outChannels, 4, strides=(1,2,2), padding="same")
            self.lastReLU = layers.LeakyReLU()
            self.lastConv3D = layers.Conv3D(outChannels,3,strides=(3,1,1), padding="same")
            self.lastReLUFinal = layers.LeakyReLU()

    
    def call(self, x, training=False):
        
        # skips = []

        with tf.device('/device:GPU:1'):
            for encoderLayer in self.encoder:
                x = encoderLayer(x, training=training)
                # skips.append(x)
            
            # skips = reversed(skips[:-1])
    
        with tf.device('/device:GPU:2'):

            # for decoderLayer, skip in zip(self.decoder, skips):
            for decoderLayer in self.decoder:
                x = decoderLayer(x, training=training)
                # x = layers.Concatenate()([x, skip])
        
            x = self.lastConv(x, training=training)
            x = self.lastReLU(x, training=training)
            x = self.lastConv3D(x, training=training)
            x = self.lastReLUFinal(x, training=training)
            x = tf.clip_by_value(x, 0, 1)
            return x
    
    def model(self):
        x = keras.Input(shape=(3,1080,1920,3))
        return keras.Model(inputs=[x], outputs=self.call(x))

class VideoQualityAssessment(keras.Model):
    def __init__(self, spatialBlocks, temporalBlock, finalBlock, denseBlock):
        super(VideoQualityAssessment, self).__init__()
        self.spatialBlocks = []
        self.temporalBlock = []
        self.finalBlock = []
        self.denseBlock = []
        
        for spatial in spatialBlocks:
            self.spatialBlocks.append(CNNBNReluDown(spatial[0],spatial[1],spatial[2], spatial[3]))
        
        for temporal in temporalBlock:
            self.temporalBlock.append(CNNBNReluDown(temporal[0],temporal[1],temporal[2], temporal[3]))
        
        for final in finalBlock:
            self.finalBlock.append(CNNBNReluDown(final[0],final[1],final[2], final[3]))
        
        for dense in denseBlock:
            self.denseBlock.append(layers.Dense(dense))
    
    def call(self, x_ref_min1, x_ref, x_ref_pl1, x_dist_min1, x_dist, x_dist_pl1, training = False):
        for spatial in self.spatialBlocks:
            x_ref_min1 = spatial(x_ref_min1, training=training)
        for spatial in self.spatialBlocks:
            x_ref = spatial(x_ref, training=training)
        for spatial in self.spatialBlocks:
            x_ref_pl1 = spatial(x_ref_pl1, training=training)
            
        x_ref = layers.Concatenate()([x_ref_min1, x_ref, x_ref_pl1])
        
        for spatial in self.spatialBlocks:
            x_dist_min1 = spatial(x_dist_min1, training=training)
        for spatial in self.spatialBlocks:
            x_dist = spatial(x_dist, training=training)
        for spatial in self.spatialBlocks:
            x_dist_pl1 = spatial(x_dist_pl1, training=training)
        
        x_dist = layers.Concatenate()([x_dist_min1, x_dist, x_dist_pl1])
                
        for temporal in self.temporalBlock:
            x_ref = temporal(x_ref, training=training)
        for temporal in self.temporalBlock:
            x_dist = temporal(x_dist, training=training)
            
        x = layers.Concatenate()([x_ref, x_dist])
        
        for final in self.finalBlock:
            x = final(x, training=training)
        
        x = layers.Flatten()(x)
        
        for dense in self.denseBlock:
            x = dense(x, training=training)
            
        return x
    
    def model(self):
        x_ref_min1 = keras.Input(shape=(1080,1920,3))
        x_ref = keras.Input(shape=(1080,1920,3))
        x_ref_pl1 = keras.Input(shape=(1080,1920,3))

        x_dist_min1 = keras.Input(shape=(1080,1920,3))
        x_dist = keras.Input(shape=(1080,1920,3))
        x_dist_pl1 = keras.Input(shape=(1080,1920,3))
        
        return keras.Model(inputs=[x_ref_min1, x_ref, x_ref_pl1, x_dist_min1, x_dist, x_dist_pl1], 
                           outputs=self.call(x_ref_min1, x_ref, x_ref_pl1, x_dist_min1, x_dist, x_dist_pl1))

In [None]:
spatialBlock = [
    (64, 3, 2, True),
    (64, 3, 2, False),
    (128, 3, 2, False),
    (128, 3, 2, False),
]

temporalBlock = [
    (128, 3, 2, False),
    (128, 3, 2, False),
    (256, 3, 2, False),
]

finalBlock = [
    (256, 3, 2, False),
    (256, 3, 2, False),
    (512, 3, 2, False),
]

denseBlock = [
    1024,
    512,
    128,
    1
]

with tf.device('/device:GPU:3'):
    vqaModel = VideoQualityAssessment(spatialBlock,temporalBlock, finalBlock, denseBlock).model()



encoderTemporal = [
    (64, 3, (1,2,2), True),
    (64, 3, (1,2,2), False),
    (128, 3, (1,2,2), False),
    (128, 5, (1,5,5), False),
    (256, 3, (1,3,3), False),
    (256, 3, (1,3,2), False),
    (256, 3, (1,3,2), False),
    (512, 3, (1,2,2), False),
]

decoderTemporal = [
    (256, 3, (1,1,2), True),
    (256, 3, (1,3,2), True),
    (256, 3, (1,3,2), True),
    (128, 3, (1,3,3), False),
    (128, 5, (1,5,5), False),
    (64, 3, (1,2,2), False),
    (64, 3, (1,2,2), False),
]

with tf.device('/device:GPU:2'):
    temporalModel = TemporalSuppression(encoderTemporal, decoderTemporal).model()

encoderSpatial = [
    (64, 4, 2, True),
    (64, 4, 2, False),
    (128, 4, 2, False),
    (128, 5, 5, False),
    (256, 4, 3, False),
    (256, 4, (3,2), False),
    (512, 3, (3,2), False),
    (512, 3, 2, False),
]

decoderSpatial = [
    (512, 3, (1,2), True),
    (256, 4, (3,2), True),
    (256, 4, (3,2), True),
    (128, 4, (3,3), False),
    (128, 5, 5, False),
    (64, 4, 2, False),
    (64, 4, 2, False),
]

with tf.device('/device:GPU:1'):
    spatialModel = SpatialSuppression(encoderSpatial, decoderSpatial).model()

In [None]:
## Data prep
videoDF = pd.read_csv("/home/ramsookd/ArtefactReduction/data/windowedDataset.csv")
refFiles = videoDF['cleanPath'].tolist()[1001:1002]
degFiles = videoDF['degradedPath'].tolist()[1001:1002]
frameStart = videoDF['StartFrame'].tolist()[1001:1002]
frameEnd = videoDF['EndFrame'].tolist()[1001:1002]
frames = [(x, y) for x, y in zip(frameStart, frameEnd)]

In [None]:
## Training Loop
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
vqaTrainLoss = tf.keras.metrics.Mean('vqaTrainLoss', dtype=tf.float32)
suppressionTrainLoss = tf.keras.metrics.Mean('suppressionTrainLoss', dtype=tf.float32)


EPOCHS = 5 
batch_size = 1
dims = (1080, 1920, 3)
alpha = 1e-5

opt_vqa = keras.optimizers.Adam(1e-4)
opt_suppresion = keras.optimizers.Adam(1e-2)

for epoch in range(EPOCHS):
    dataGen = DataGenerator(refFiles, degFiles, frames, batch_size, dims, True)
    batchSteps = dataGen.__len__()
    for i in range(batchSteps):
        X, y = dataGen.__getitem__(i)
        
        ## Train suppressionNet
        with tf.GradientTape() as suppressionTape:
            with tf.device('/device:GPU:1'):
                XSpa = spatialModel(tf.reshape(X, [batch_size*5, *dims]), training=True)
                del X
                XSpa = tf.reshape(XSpa, [batch_size, 5, *dims])
                tempIn = []
                for bs in range(batch_size):
                    for frameCenter in range(1,4):
                        if frameCenter == 4:
                            tempIn.append(XSpa[bs,frameCenter-1:,:,:,:])
                        if frameCenter != 4:
                            tempIn.append(XSpa[bs,frameCenter-1:frameCenter+2,:,:,:])
                del XSpa
                tempIn = tf.stack(tempIn, axis=0)
                
            with tf.device('/device:GPU:2'):
                tempOut = temporalModel(tempIn, training=True)
                tempOut = tf.reshape(tempOut, [batch_size, 3, *dims])
            with tf.device('/device:GPU:3'):
                vqaPred = vqaModel([y[:,0,:,:,:], y[:,1,:,:,:], y[:,2,:,:,:], tempOut[:,0,:,:,:], tempOut[:,1,:,:,:], tempOut[:,2,:,:,:]],training=False)
            lossSuppression = perceptualLoss(alpha,y,tempOut,vqaPred)
        grads = suppressionTape.gradient(lossSuppression, spatialModel.trainable_weights+temporalModel.trainable_weights)
        opt_suppresion.apply_gradients(zip(grads, spatialModel.trainable_weights+temporalModel.trainable_weights))

        ## Train vqaNet
        with tf.GradientTape() as vqaTape:
            with tf.device('/device:GPU:3'):
                vqaPred = vqaModel([y[:,0,:,:,:], y[:,1,:,:,:], y[:,2,:,:,:], tempOut[:,0,:,:,:], tempOut[:,1,:,:,:], tempOut[:,2,:,:,:]],training=True)
            lossVQA = vqaLoss(y, tempOut, vqaPred)
        grads = vqaTape.gradient(lossVQA, vqaModel.trainable_weights)
        opt_vqa.apply_gradients(zip(grads, vqaModel.trainable_weights))



        print(f"Epoch: {epoch}, step: {i}, VQA Loss: {lossVQA.numpy()}, Suppression Loss: {lossSuppression.numpy()}")
