<a href="https://colab.research.google.com/github/alexandrufalk/tensorflow/blob/Master/FEQE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Add, ReLU
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG19
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
root_dir = "/content/drive/MyDrive/Datasets/BSDS500/images/"

#Parameters

In [9]:
# Parameters
BATCH_SIZE = 16
SCALE_FACTOR = 2
IMG_HEIGHT = 256
IMG_WIDTH = 256
CHANNELS = 3
AUTOTUNE = tf.data.AUTOTUNE

#Paths

In [10]:
# Paths
BASE_DIR = '/content/drive/MyDrive/Datasets/BSDS500/images/'
TRAIN_DIR = os.path.join(BASE_DIR, 'training')
VALIDATION_DIR = os.path.join(BASE_DIR, 'validation')
TEST_DIR = os.path.join(BASE_DIR, 'test')

#Preprocessing functions

In [11]:
def load_and_preprocess_image(file_path):
    """
    Loads an image from a file, preprocesses it by resizing and scaling.
    """
    # Read the image file
    image = tf.io.read_file(file_path)

    # Decode the image to RGB format
    image = tf.image.decode_image(image, channels=CHANNELS, expand_animations=False)

    # Convert to float32 and scale to [0, 1]
    image = tf.image.convert_image_dtype(image, tf.float32)

    # Resize the image
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH], method='bicubic')

    return image

def create_dataset(directory, batch_size, augment=False):
    """
    Creates a TensorFlow dataset from image files in a directory.
    """
    # Supported extensions
    extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff']

    # Create a dataset of file paths
    list_ds = tf.data.Dataset.list_files(
        [os.path.join(directory, ext) for ext in extensions],
        shuffle=True
    )

    # Load and preprocess images
    dataset = list_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

    if augment:
        # Define data augmentation
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomFlip("vertical"),
            tf.keras.layers.RandomRotation(0.1),
        ])
        dataset = dataset.map(lambda x: data_augmentation(x), num_parallel_calls=AUTOTUNE)

    # Batch and prefetch
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTOTUNE)

    return dataset

# Feature Extraction Module VGG19

In [2]:
def feature_extractor(input_shape=(256, 256, 3)):
    input_img = Input(shape=input_shape)

    # Load VGG19 without the top classification layers
    vgg = VGG19(weights='imagenet', include_top=False, input_tensor=input_img)

    # Select intermediate layers for feature extraction
    # You can choose layers based on the level of features you need
    layers = ['block1_conv2', 'block3_conv4', 'block5_conv4']
    outputs = [vgg.get_layer(name).output for name in layers]

    # Create the feature extraction model
    model = Model(inputs=input_img, outputs=outputs)

    # Freeze the pretrained weights
    model.trainable = False

    return model

#Custom Feature Extraction Layers

In [5]:
def custom_feature_extractor(input_shape=(256, 256, 3)):
    input_img = Input(shape=input_shape)

    x = Conv2D(64, (3, 3), padding='same', activation='relu')(input_img)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)

    model = Model(inputs=input_img, outputs=x)

    return model

#Quality Enhancement Module

In [3]:
def quality_enhancement_module(features, scale=2):
    # Upsample the features
    x = UpSampling2D(size=(scale, scale), interpolation='bicubic')(features)

    # Apply convolutional layers to refine the upsampled image
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)

    # Final output layer
    output = Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)

    return output

# Combine Modules into FEQE Model

In [4]:
def feqe_model(input_shape=(256, 256, 3), scale=2, use_pretrained=True):
    input_img = Input(shape=input_shape)

    if use_pretrained:
        extractor = feature_extractor(input_shape)
    else:
        extractor = custom_feature_extractor(input_shape)

    features = extractor(input_img)
    enhanced_img = quality_enhancement_module(features, scale)

    model = Model(inputs=input_img, outputs=enhanced_img)

    return model