<a href="https://colab.research.google.com/github/alexandrufalk/tensorflow/blob/Master/FEQE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [88]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, Add, Input, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG19
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import os
from PIL import Image
import matplotlib.pyplot as plt

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
root_dir = "/content/drive/MyDrive/Datasets/BSDS500/images/"

#Parameters

In [120]:
# Parameters
BATCH_SIZE = 16
SCALE_FACTOR = 2
IMG_HEIGHT = 256
IMG_WIDTH = 256
CHANNELS = 3
AUTOTUNE = tf.data.AUTOTUNE

#Paths

In [121]:
# Paths
BASE_DIR = '/content/drive/MyDrive/Datasets/BSDS500/images/'
TRAIN_DIR = os.path.join(BASE_DIR, 'training')
VALIDATION_DIR = os.path.join(BASE_DIR, 'validation')
TEST_DIR = os.path.join(BASE_DIR, 'test')

#Preprocessing functions

In [122]:
# Define Preprocessing and Utility Functions
def load_and_preprocess_image(file_path):
    """
    Loads an image from a file, preprocesses it by resizing and scaling,
    and generates a low-resolution version.
    """
    try:
        # Read the image file
        image = tf.io.read_file(file_path)

        # Decode the image to RGB format (forces 3 channels)
        image = tf.image.decode_image(image, channels=3, expand_animations=False)

        # Convert to float32 and scale to [0, 1]
        image = tf.image.convert_image_dtype(image, tf.float32)

        # Resize the image to ensure consistent dimensions
        image = tf.image.resize(image, [256, 256], method='bicubic')

        # Generate low-resolution image by downsampling
        lr_image = tf.image.resize(image, [128, 128], method='bicubic')

        return lr_image, image  # (LR, HR)
    except tf.errors.InvalidArgumentError:
        # Handle corrupted images by returning zero tensors or skipping
        return tf.zeros([128, 128, 3]), tf.zeros([256, 256, 3])

def augment(lr, hr):
    """
    Applies random horizontal and vertical flips to the LR and HR images.
    """
    # Random horizontal flip
    if tf.random.uniform(()) > 0.5:
        lr = tf.image.flip_left_right(lr)
        hr = tf.image.flip_left_right(hr)

    # Random vertical flip
    if tf.random.uniform(()) > 0.5:
        lr = tf.image.flip_up_down(lr)
        hr = tf.image.flip_up_down(hr)

    return lr, hr

def create_dataset(directory, batch_size, augment_data=False):
    """
    Creates a TensorFlow dataset from image files in a directory.
    """
    # Supported extensions
    extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff']

    # Create a dataset of file paths
    list_ds = tf.data.Dataset.list_files(
        [os.path.join(directory, ext) for ext in extensions],
        shuffle=True
    )

    # Load and preprocess images
    dataset = list_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

    # Apply data augmentation if specified
    if augment_data:
        dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)

    # Batch and prefetch
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

def postprocess_image(image):
    """
    Converts image tensor from [0, 1] to [0, 255] and casts to uint8.
    """
    image = tf.clip_by_value(image, 0.0, 1.0)  # Ensure values are within [0, 1]
    image = image * 255.0
    return tf.cast(image, tf.uint8).numpy()

def save_image(image, filename):
    """
    Saves a numpy array as an image file.
    """
    img = Image.fromarray(image)
    img.save(filename)

In [123]:
# Create datasets
train_dataset = create_dataset(TRAIN_DIR, BATCH_SIZE, augment_data=True)
validation_dataset = create_dataset(VALIDATION_DIR, BATCH_SIZE, augment_data=False)
test_dataset = create_dataset(TEST_DIR, BATCH_SIZE, augment_data=False)

In [138]:
# Define Custom DepthToSpaceLayer
class DepthToSpaceLayer(Layer):
    def __init__(self, scale, **kwargs):
        super(DepthToSpaceLayer, self).__init__(**kwargs)
        self.scale = scale

    def call(self, inputs):
        return tf.nn.depth_to_space(inputs, self.scale)

    def get_config(self):
        config = super(DepthToSpaceLayer, self).get_config()
        config.update({'scale': self.scale})
        return config

In [139]:
# Define Residual Block (if needed)
def residual_block(x):
    skip = x
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(64, (3, 3), padding='same')(x)
    x = Add()([x, skip])
    return x

# Feature Extraction Module VGG19

In [140]:
# Define Custom Feature Extractor without Downsampling
def feature_extractor(input_shape=(128, 128, 3)):
    input_img = Input(shape=input_shape)

    # Simple convolutional layers without pooling
    x = Conv2D(64, (3,3), padding='same', activation='relu')(input_img)
    x = Conv2D(128, (3,3), padding='same', activation='relu')(x)
    x = Conv2D(256, (3,3), padding='same', activation='relu')(x)

    # Create the feature extraction model
    model = Model(inputs=input_img, outputs=x)

    # Freeze the pretrained weights if any
    model.trainable = False

    return model

In [141]:
# Define Quality Enhancement Module with Single Upsampling Step

def quality_enhancement_module(features, scale=2):
    x = features
    upsample_steps = int(tf.math.log(float(scale)) / tf.math.log(2.0))  # For scale=2, steps=1

    for _ in range(upsample_steps):
        # Upsample by a factor of 2 with 'bilinear' interpolation
        x = UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
        # Apply convolution to refine the upsampled feature maps
        x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)

    # Final output layer to get 3 channels
    output = Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)

    return output

#Custom Feature Extraction Layers

In [142]:
def custom_feature_extractor(input_shape=(256, 256, 3)):
    input_img = Input(shape=input_shape)

    x = Conv2D(64, (3, 3), padding='same', activation='relu')(input_img)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)

    model = Model(inputs=input_img, outputs=x)

    return model

In [143]:
# Combine into FEQE Model
def feqe_model(input_shape=(128, 128, 3), scale=2, use_pretrained=True):
    input_img = Input(shape=input_shape)

    if use_pretrained:
        extractor = feature_extractor(input_shape)
    else:
        extractor = custom_feature_extractor(input_shape)  # Define if needed

    features = extractor(input_img)  # Single tensor of shape (None, 128, 128, 256)
    enhanced_img = quality_enhancement_module(features, scale)

    model = Model(inputs=input_img, outputs=enhanced_img)

    return model

In [144]:
# Inspect a batch from the training dataset
for lr_batch, hr_batch in train_dataset.take(1):
    print(f"LR batch shape: {lr_batch.shape}")  # Expected: (BATCH_SIZE, 128, 128, 3)
    print(f"HR batch shape: {hr_batch.shape}")  # Expected: (BATCH_SIZE, 256, 256, 3)
    break

LR batch shape: (16, 128, 128, 3)
HR batch shape: (16, 256, 256, 3)


In [145]:
# Define Custom Metrics
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim_metric(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

#Quality Enhancement Module

# Combine Modules into FEQE Model

#Compile and Train the FEQE Model

In [146]:
# Choose whether to use a pretrained feature extractor
USE_PRETRAINED = True  # Set to False to use custom feature extractor

# Instantiate the FEQE model
feqe = feqe_model(scale=SCALE_FACTOR, use_pretrained=USE_PRETRAINED)

# Display the model summary
feqe.summary()


In [147]:
# Compile the Model
feqe.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='mse',
    metrics=[psnr_metric, ssim_metric]
)

#Define Callbacks

In [148]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# Define callbacks
checkpoint = ModelCheckpoint(
    'feqe_best_model.keras',
    monitor='val_psnr_metric',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stop = EarlyStopping(
    monitor='val_psnr_metric',
    patience=10,
    mode='max',
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_psnr_metric',
    factor=0.5,
    patience=5,
    mode='max',
    verbose=1
)

In [149]:
# Train the Model
history = feqe.fit(
    train_dataset,
    epochs=100,
    validation_data=validation_dataset,
    callbacks=[checkpoint, early_stop, reduce_lr],
    verbose=1  # Enable verbose logging
)


Epoch 1/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - loss: 0.0689 - psnr_metric: 12.0647 - ssim_metric: 0.3358   
Epoch 1: val_psnr_metric improved from -inf to 12.31636, saving model to feqe_best_model.keras
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 6s/step - loss: 0.0686 - psnr_metric: 12.0824 - ssim_metric: 0.3353 - val_loss: 0.0665 - val_psnr_metric: 12.3164 - val_ssim_metric: 0.3049 - learning_rate: 1.0000e-04
Epoch 2/100
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 554ms/step - loss: 0.0622 - psnr_metric: 12.6070 - ssim_metric: 0.3405
Epoch 2: val_psnr_metric improved from 12.31636 to 12.46258, saving model to feqe_best_model.keras
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 725ms/step - loss: 0.0623 - psnr_metric: 12.5977 - ssim_metric: 0.3405 - val_loss: 0.0645 - val_psnr_metric: 12.4626 - val_ssim_metric: 0.3123 - learning_rate: 1.0000e-04
Epoch 3/100
[1m13/13[0m [32m━━━━━━━━━━━