<a href="https://colab.research.google.com/github/alexandrufalk/tensorflow/blob/Master/FEQE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Add, ReLU
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG19
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Feature Extraction Module VGG19

In [2]:
def feature_extractor(input_shape=(256, 256, 3)):
    input_img = Input(shape=input_shape)

    # Load VGG19 without the top classification layers
    vgg = VGG19(weights='imagenet', include_top=False, input_tensor=input_img)

    # Select intermediate layers for feature extraction
    # You can choose layers based on the level of features you need
    layers = ['block1_conv2', 'block3_conv4', 'block5_conv4']
    outputs = [vgg.get_layer(name).output for name in layers]

    # Create the feature extraction model
    model = Model(inputs=input_img, outputs=outputs)

    # Freeze the pretrained weights
    model.trainable = False

    return model

#Custom Feature Extraction Layers

In [5]:
def custom_feature_extractor(input_shape=(256, 256, 3)):
    input_img = Input(shape=input_shape)

    x = Conv2D(64, (3, 3), padding='same', activation='relu')(input_img)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)

    model = Model(inputs=input_img, outputs=x)

    return model

#Quality Enhancement Module

In [3]:
def quality_enhancement_module(features, scale=2):
    # Upsample the features
    x = UpSampling2D(size=(scale, scale), interpolation='bicubic')(features)

    # Apply convolutional layers to refine the upsampled image
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)

    # Final output layer
    output = Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)

    return output

# Combine Modules into FEQE Model

In [4]:
def feqe_model(input_shape=(256, 256, 3), scale=2, use_pretrained=True):
    input_img = Input(shape=input_shape)

    if use_pretrained:
        extractor = feature_extractor(input_shape)
    else:
        extractor = custom_feature_extractor(input_shape)

    features = extractor(input_img)
    enhanced_img = quality_enhancement_module(features, scale)

    model = Model(inputs=input_img, outputs=enhanced_img)

    return model