In [None]:
import imageio
import random
import tensorflow as tf
from tensorflow_core.python.keras.layers.pooling import GlobalAveragePooling2D
from tensorflow_core.python.layers.convolutional import Conv2D
from tensorflow_core.python.layers.core import Dense
from notebooks.utils import show_images, gaussian_filter, image_normalization, rescale, read_image
import imquality.datasets

In [None]:
print(f'tensorflow version {tf.__version__}')

## Dataset

The dataset that we are going to use to train and test this algorithm is [LiveIQA](https://live.ece.utexas.edu/research/quality/subjective.htm).
It is comprised of 30 reference images, and 5 different distortions with 5 severity levels each.

The first thing we need to do is to download the dataset. For this, I have created a couple of builders
for Image Quality datasets in the [image-quality](https://github.com/ocampor/image-quality) package. The builders
are an interface defined by tensorflow in [tensorflow-datasets](https://www.tensorflow.org/datasets) package. 

This process is going to take a couple of minutes because the dataset size is around 700 megabytes.

In [None]:
builder = imquality.datasets.LiveIQA()
builder.download_and_prepare()

After downloading and preparing the data, we can turn the builder as a dataset and shuffle it. 

In [None]:
ds = builder.as_dataset()['train']
ds = ds.shuffle(1024).batch(1).prefetch(1)

The output is a generator; therefore, we cannot access it unless we iterate in a for loop. In order to display an
image, I am iterating once to extract a sample. You can iterate this several times to understand the dataset.

In [None]:
for features in  ds.take(1):
    distorted_image = features['distorted_image']
    reference_image = features['reference_image']
    dmos = tf.round(features['dmos'][0], 2)
    distortion = features['distortion'][0]
    print(f'The distortion of the image is {dmos} with'
          f' a distortion {distortion}')
    show_images([reference_image, distorted_image])

## Image Normalization

As pre-processing the image is turned into grayscale. As a second calculation, a low pass filter is applied
to the grayscale image. Finally, the low-pass filtered image is subtracted from the grayscale image. The
low frequency image is the result of blurring the image, downscaling by a factor of 1 / 4 and upscaling back
to the original size. 

\begin{align*}
\hat{I} = I_{gray} - I^{low}
\end{align*}

The main reasons for the image normalization are:
1. The Human Visual System (HVS) is not sensitive to changes in low frequency band.

2. Image distortions barely affect the low-frequency component of images. 

In [None]:
def image_preprocess(image: tf.Tensor) -> tf.Tensor:
    assert isinstance(image, tf.Tensor), 'The input must be a tf.Tensor'
    image = tf.cast(image, tf.float32)
    image = tf.image.rgb_to_grayscale(image)
    image_low = gaussian_filter(image, 16, 7 / 6)
    image_low = rescale(image_low, 1 / 4, method=tf.image.ResizeMethod.BICUBIC)
    image_low = rescale(image_low, 4, method=tf.image.ResizeMethod.BICUBIC)
    return image - tf.cast(image_low, image.dtype)

In [None]:
for features in ds.take(1):
    distorted_image = features['distorted_image']
    reference_image = features['reference_image']
    I_d = image_preprocess(distorted_image)
    I_d = tf.image.grayscale_to_rgb(I_d)
    I_d = image_normalization(I_d, 0, 1)

show_images([reference_image, I_d])

## Objective Error Map

In the first stage of training, the objective error maps are used as proxy regression targets to get the effect of 
increasing data. The loss function is defined by the mean squared error between the predicted and ground-truth error
maps.

\begin{align*}
\mathbf{e}_{gt} = err(\hat{I}_r, \hat{I}_d)
\end{align*}

and $err(\cdot)$ is any error function. The authors decided to use

\begin{align*}
\mathbf{e}_{gt} = | \hat{I}_r -  \hat{I}_d | ^ p
\end{align*}

with $p=0.2$ in order to prevent that the values in the error map are small or close to zero.

In [None]:
def error_map(reference: tf.Tensor, distorted: tf.Tensor, p: float=0.2) -> tf.Tensor:
    assert reference.shape == distorted.shape, 'Both images must be of the same size'
    assert reference.dtype == tf.float32 and distorted.dtype == tf.float32, 'dtype must be tf.float32'
    return tf.pow(tf.abs(reference - distorted), p)

In [None]:
I = tf.convert_to_tensor(imageio.imread(get_image_url(2, 11, 0)))
I_r = image_preprocess(I)
results = []
for severity in (1, 3, 5):
    I = tf.convert_to_tensor(imageio.imread(get_image_url(2, 11, severity)))
    I_d = image_preprocess(I)
    e_gt = error_map(I_r, I_d, 0.2)
    e_gt = tf.image.grayscale_to_rgb(e_gt)
    e_gt = image_normalization(e_gt, 0, 1)
    results.append(e_gt)

show_images(results, cmap='gray')

## Reliability Map Prediction

According to the author, the model is likely to fail to predict the objective error map of
homogeneous regions without having information of its pristine image. Thus, he proposes a 
reliability function. The assumption is that blurry regions have lower reliability than textured 
regions.

\begin{align*}
\mathbf{r} = \frac{2}{1 + exp(-\alpha|\hat{I}_d|)} - 1
\end{align*}

where α controls the saturation property of the reliability map. To assign sufficiently
large values to pixels with small values, the positive part of a sigmoid is used.

In [None]:
def reliability_map(distorted: tf.Tensor, alpha: float) -> tf.Tensor:
    assert distorted.dtype == tf.float32, 'The Tensor must by of dtype tf.float32'
    return 2 / (1 + tf.exp(- alpha * tf.abs(distorted))) - 1

Besides, to prevent the reliability map to directly affect the predicted score,
it is divided by its average

\begin{align*}
\mathbf{\hat{r}} = \frac{1}{\frac{1}{H_rW_r}\sum_{(i,j)}\mathbf{r}(i,j)}\mathbf{r}
\end{align*}

In [None]:
def average_reliability_map(distorted: tf.Tensor, alpha: float) -> tf.Tensor:
    r = reliability_map(distorted, alpha)
    return r / tf.reduce_mean(r)

In [None]:
results = []
for severity in (1, 3, 5):
    I = tf.convert_to_tensor(imageio.imread(get_image_url(2, 11, severity)))
    I_d = image_preprocess(I)
    r = average_reliability_map(I_d, 1)
    r = tf.image.grayscale_to_rgb(r)
    results.append(image_normalization(r, 0, 1))

show_images(results, cmap='gray')

## Loss function
The loss function is the mean square error of the product between the reliability map and the
error. The error is the difference between the predicted error map and the ground-truth error map.

\begin{align*}
\mathcal{L}_1(\hat{I}_d; \theta_f, \theta_g) = ||g(f(\hat{I}_d, \theta_f), \theta_g) - \mathbf{e}_{gt}) \odot \mathbf{\hat{r}}||^2_2
\end{align*}

## Read Files
We don't want to mix reference images in train and test because we want to test with completly unseen samples.

In [None]:
random.seed(1)

idx = list(range(1, 25))
random.shuffle(idx)

train_idx = idx[0:22]
test_idx = idx[21:]

In [None]:
def get_paths(idxs: list, base_uri: str) -> list:
    return [
    (idx, get_image_url(idx, distortion, severity, base_uri))
    for idx in idxs
    for distortion in range(1, 24)
    for severity in range(1, 5)
]

In [None]:
base_uri = '/Users/ricardoocampo/Data/tid2013'
train_uris = get_paths(train_idx, base_uri)
test_uris = get_paths(test_idx, base_uri)

In [None]:
def load_and_preproces_image(uri):
    image = read_image(uri)
    return image_preprocess(image)

In [None]:
train_images = [
    load_and_preproces_image(filepath)
    for _, filepath in train_uris]

In [None]:
train = tf.stack(train_images, axis=0)

In [None]:
def calculate_y(idx, train_uris, train, base_uri):
    e_gt = error_map(load_and_preproces_image(get_image_url(train_uris[idx][0], None, 0, base_uri)), train[idx])
    return tf.image.resize(e_gt, (int(384/4), int(512/4)))

In [None]:
def calculate_r(distorted):
    r = average_reliability_map(distorted, 0.2)
    return tf.image.resize(r, (int(384/4), int(512/4)))

In [None]:
train_y = [
    calculate_y(idx, train_uris, train, base_uri)
    for idx in range(len(train))]

In [None]:
r = [
    calculate_r(t)
    for t in train]

In [None]:
train_y = tf.stack(train_y, axis=0)

In [None]:
train_r = tf.stack(r, axis=0)

In [None]:
input = tf.keras.Input(shape=(None, None, 1), batch_size=50, name='original_image')
f = Conv2D(48, (3, 3), name='Conv1', activation='relu', padding='same')(input)
f = Conv2D(48, (3, 3), name='Conv2', activation='relu', padding='same', strides=(2, 2))(f)
f = Conv2D(64, (3, 3), name='Conv3', activation='relu', padding='same')(f)
f = Conv2D(64, (3, 3), name='Conv4', activation='relu', padding='same', strides=(2, 2))(f)
f = Conv2D(64, (3, 3), name='Conv5', activation='relu', padding='same')(f)
f = Conv2D(64, (3, 3), name='Conv6', activation='relu', padding='same')(f)
f = Conv2D(128, (3, 3), name='Conv7', activation='relu', padding='same')(f)
f = Conv2D(128, (3, 3), name='Conv8', activation='relu', padding='same')(f)
g = Conv2D(1, (1, 1), name='Conv9', padding='same', activation='linear')(f)
objective_error_map = tf.keras.Model(input, g, name='objective_error_map')

In [None]:
def diqa_loss_1(weights):
    def loss(y_true, y_pred):
        return tf.reduce_mean(tf.square(y_true - y_pred) * weights)
    return loss

In [None]:
optimizer = tf.optimizers.Nadam(learning_rate=2 * 10 ** -4)
objective_error_map.compile(
    optimizer=optimizer,
    loss=tf.losses.MeanSquaredError(),
    metrics=[tf.metrics.MeanSquaredError()])

In [None]:
objective_error_map.summary()

In [None]:
history = objective_error_map.fit(train, train_y,
                    batch_size=50,
                    epochs=1,
                    validation_split=0.2)

In [None]:
images = []
idx = 230
test_x = load_and_preproces_image(test_uris[idx][1])
images.append(image_normalization(tf.squeeze(test_x), 0, 1))
x = objective_error_map.predict(test_x[tf.newaxis, :, :, :])
images.append(image_normalization(tf.squeeze(x), 0, 1))
reference = load_and_preproces_image(get_image_url(test_uris[idx][0], None, 0, base_uri))
e_gt = error_map(reference, test_x)
images.append(image_normalization(tf.squeeze(e_gt), 0, 1))
show_images(images, cmap='gray')

In [None]:
mos = open('/Users/ricardoocampo/Data/tid2013/mos_with_names.txt', 'r').readlines()

In [None]:
mos = [x.split(' ') for x in mos]
mos = {y.lower().replace('\n', ''):x for x, y in mos}

In [None]:
def get_file_name(x):
    return x.split('/')[-1]

In [None]:
mos_arr = [float(mos[get_file_name(train_uri)]) for _, train_uri in train_uris]

In [None]:
mos_y = tf.convert_to_tensor(mos_arr, dtype=tf.float32)

In [None]:
v = GlobalAveragePooling2D(data_format='channels_last')(f)
h = Dense(128, activation='relu')(v)
h = Dense(128, activation='relu')(h)
h = Dense(1)(h)
subjective_error = tf.keras.Model(input, h, name='subjective_error')

optimizer = tf.optimizers.Nadam(learning_rate=2 * 10 ** -4)
subjective_error.compile(
    optimizer=optimizer,
    loss=tf.losses.MeanSquaredError(),
    metrics=[tf.metrics.MeanSquaredError()])

In [None]:
subjective_error.summary()

In [None]:
history = subjective_error.fit(train, mos_y,
                    batch_size=50,
                    epochs=1,
                    validation_split=0.2)

In [None]:
images = []
idx = 170
test_x = load_and_preproces_image(test_uris[idx][1])
images.append(image_normalization(tf.squeeze(test_x), 0, 1))
prediction = subjective_error.predict(test_x[tf.newaxis, :, :, :])[0][0]
target = float(mos[get_file_name(test_uris[idx][1])])

print(f'the predicted value is: {prediction} and target is: {target}')