In [None]:
import imageio
import random
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, GlobalAveragePooling2D
from notebooks.utils import show_images, gaussian_filter, image_normalization, rescale, image_shape
import imquality.datasets

In [None]:
print(f'tensorflow version {tf.__version__}')

## Dataset

The dataset that we are going to use to train and test this algorithm is [LiveIQA](https://live.ece.utexas.edu/research/quality/subjective.htm).
It is comprised of 30 reference images, and 5 different distortions with 5 severity levels each.

The first thing we need to do is to download the dataset. For this, I have created a couple of builders
for Image Quality datasets in the [image-quality](https://github.com/ocampor/image-quality) package. The builders
are an interface defined by tensorflow in [tensorflow-datasets](https://www.tensorflow.org/datasets) package. 

This process is going to take a couple of minutes because the dataset size is around 700 megabytes.

In [None]:
builder = imquality.datasets.LiveIQA()
builder.download_and_prepare()

After downloading and preparing the data, we can turn the builder as a dataset and shuffle it. 

In [None]:
ds = builder.as_dataset(shuffle_files=True)['train']
ds = ds.shuffle(1024).batch(1)

The output is a generator; therefore, we cannot access it unless we iterate in a for loop. In order to display an
image, I am iterating once to extract a sample. You can iterate this several times to understand the dataset.

In [None]:
for features in  ds.take(2):
    distorted_image = features['distorted_image']
    reference_image = features['reference_image']
    dmos = tf.round(features['dmos'][0], 2)
    distortion = features['distortion'][0]
    print(f'The distortion of the image is {dmos} with'
          f' a distortion {distortion} and shape {distorted_image.shape}')
    show_images([reference_image, distorted_image])

## Image Normalization

As pre-processing the image is turned into grayscale. As a second calculation, a low pass filter is applied
to the grayscale image. Finally, the low-pass filtered image is subtracted from the grayscale image. The
low frequency image is the result of blurring the image, downscaling by a factor of 1 / 4 and upscaling back
to the original size. 

\begin{align*}
\hat{I} = I_{gray} - I^{low}
\end{align*}

The main reasons for the image normalization are:
1. The Human Visual System (HVS) is not sensitive to changes in low frequency band.

2. Image distortions barely affect the low-frequency component of images. 

In [None]:
def image_preprocess(image: tf.Tensor) -> tf.Tensor:
    assert isinstance(image, tf.Tensor), 'The input must be a tf.Tensor'
    image = tf.cast(image, tf.float32)
    image = tf.image.rgb_to_grayscale(image)
    image_low = gaussian_filter(image, 16, 7 / 6)
    image_low = rescale(image_low, 1 / 4, method=tf.image.ResizeMethod.BICUBIC)
    image_low = tf.image.resize(image_low, size=image_shape(image), method=tf.image.ResizeMethod.BICUBIC)
    return image - tf.cast(image_low, image.dtype)

In [None]:
for features in ds.take(2):
    distorted_image = features['distorted_image']
    reference_image = features['reference_image']
    I_d = image_preprocess(distorted_image)
    I_d = tf.image.grayscale_to_rgb(I_d)
    I_d = image_normalization(I_d, 0, 1)
    show_images([reference_image, I_d])

## Objective Error Map

In the first stage of training, the objective error maps are used as proxy regression targets to get the effect of 
increasing data. The loss function is defined by the mean squared error between the predicted and ground-truth error
maps.

\begin{align*}
\mathbf{e}_{gt} = err(\hat{I}_r, \hat{I}_d)
\end{align*}

and $err(\cdot)$ is any error function. The authors decided to use

\begin{align*}
\mathbf{e}_{gt} = | \hat{I}_r -  \hat{I}_d | ^ p
\end{align*}

with $p=0.2$ in order to prevent that the values in the error map are small or close to zero.

In [None]:
def error_map(reference: tf.Tensor, distorted: tf.Tensor, p: float=0.2) -> tf.Tensor:
    assert reference.dtype == tf.float32 and distorted.dtype == tf.float32, 'dtype must be tf.float32'
    return tf.pow(tf.abs(reference - distorted), p)

In [None]:
for features in ds.take(3):
    reference_image = features['reference_image'] 
    I_r = image_preprocess(reference_image)
    I_d = image_preprocess(features['distorted_image'])
    e_gt = error_map(I_r, I_d, 0.2)
    I_d = image_normalization(tf.image.grayscale_to_rgb(I_d), 0, 1)
    e_gt = image_normalization(tf.image.grayscale_to_rgb(e_gt), 0, 1)
    show_images([reference_image, I_d, e_gt])

## Reliability Map Prediction

According to the author, the model is likely to fail to predict the objective error map of
homogeneous regions without having information of its pristine image. Thus, he proposes a 
reliability function. The assumption is that blurry regions have lower reliability than textured 
regions.

\begin{align*}
\mathbf{r} = \frac{2}{1 + exp(-\alpha|\hat{I}_d|)} - 1
\end{align*}

where α controls the saturation property of the reliability map. To assign sufficiently
large values to pixels with small values, the positive part of a sigmoid is used.

In [None]:
def reliability_map(distorted: tf.Tensor, alpha: float) -> tf.Tensor:
    assert distorted.dtype == tf.float32, 'The Tensor must by of dtype tf.float32'
    return 2 / (1 + tf.exp(- alpha * tf.abs(distorted))) - 1

Besides, to prevent the reliability map to directly affect the predicted score,
it is divided by its average

\begin{align*}
\mathbf{\hat{r}} = \frac{1}{\frac{1}{H_rW_r}\sum_{(i,j)}\mathbf{r}(i,j)}\mathbf{r}
\end{align*}

In [None]:
def average_reliability_map(distorted: tf.Tensor, alpha: float) -> tf.Tensor:
    r = reliability_map(distorted, alpha)
    return r / tf.reduce_mean(r)

In [None]:
for features in ds.take(2):
    reference_image = features['reference_image'] 
    I_d = image_preprocess(features['distorted_image'])
    r = average_reliability_map(I_d, 1)
    r = image_normalization(tf.image.grayscale_to_rgb(r), 0, 1)
    show_images([reference_image, r], cmap='gray')

## Loss function
The loss function is the mean square error of the product between the reliability map and the
error. The error is the difference between the predicted error map and the ground-truth error map.

\begin{align*}
\mathcal{L}_1(\hat{I}_d; \theta_f, \theta_g) = ||g(f(\hat{I}_d, \theta_f), \theta_g) - \mathbf{e}_{gt}) \odot \mathbf{\hat{r}}||^2_2
\end{align*}

## Read Files
We don't want to mix reference images in train and test because we want to test with completly unseen samples.

In [None]:
def calculate_error_map(features):
    I_d = image_preprocess(features['distorted_image'])
    I_r = image_preprocess(features['reference_image'])
    r = rescale(average_reliability_map(I_d, 0.2), 1 / 4)
    e_gt = rescale(error_map(I_r, I_d, 0.2), 1 / 4)
    return (I_d, e_gt)

In [None]:
train = ds.map(calculate_error_map)

In [None]:
input = tf.keras.Input(shape=(None, None, 1), batch_size=1, name='original_image')
f = Conv2D(48, (3, 3), name='Conv1', activation='relu', padding='same')(input)
f = Conv2D(48, (3, 3), name='Conv2', activation='relu', padding='same', strides=(2, 2))(f)
f = Conv2D(64, (3, 3), name='Conv3', activation='relu', padding='same')(f)
f = Conv2D(64, (3, 3), name='Conv4', activation='relu', padding='same', strides=(2, 2))(f)
f = Conv2D(64, (3, 3), name='Conv5', activation='relu', padding='same')(f)
f = Conv2D(64, (3, 3), name='Conv6', activation='relu', padding='same')(f)
f = Conv2D(128, (3, 3), name='Conv7', activation='relu', padding='same')(f)
f = Conv2D(128, (3, 3), name='Conv8', activation='relu', padding='same')(f)
g = Conv2D(1, (1, 1), name='Conv9', padding='same', activation='linear')(f)
objective_error_map = tf.keras.Model(input, g, name='objective_error_map')

In [None]:
def diqa_loss_1(weights):
    def loss(y_true, y_pred):
        return tf.reduce_mean(tf.square(y_true - y_pred) * weights)
    return loss

In [None]:
optimizer = tf.optimizers.Nadam(learning_rate=2 * 10 ** -4)
objective_error_map.compile(
    optimizer=optimizer,
    loss=tf.losses.MeanSquaredError(),
    metrics=[tf.metrics.MeanSquaredError()])

In [None]:
objective_error_map.summary()

In [None]:
history = objective_error_map.fit(x=train, epochs=1)

In [None]:
v = GlobalAveragePooling2D(data_format='channels_last')(f)
h = Dense(128, activation='relu')(v)
h = Dense(128, activation='relu')(h)
h = Dense(1)(h)
subjective_error = tf.keras.Model(input, h, name='subjective_error')

optimizer = tf.optimizers.Nadam(learning_rate=2 * 10 ** -4)
subjective_error.compile(
    optimizer=optimizer,
    loss=tf.losses.MeanSquaredError(),
    metrics=[tf.metrics.MeanSquaredError()])

In [None]:
subjective_error.summary()

In [None]:
def calculate_subjective_score(features):
    I_d = image_preprocess(features['distorted_image'])
    mos = features['dmos']
    return (I_d, mos)

In [None]:
train = ds.map(calculate_subjective_score)

In [None]:
history = subjective_error.fit(train, epochs=1)

In [None]:
sample = next(iter(ds))
I_d = image_preprocess(sample['distorted_image'])
target = sample['dmos'][0]
prediction = subjective_error.predict(I_d)[0][0]

print(f'the predicted value is: {prediction} and target is: {target}')