# Checking if the datasets are equivalent to the others we have been using before

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
from glob import glob
from collections import namedtuple

In [2]:
import cv2
import tensorflow as tf
from iqadatasets.datasets.tid2013 import TID2013

## Old code

In [3]:
def filter_data_2008(img_ids='all', 
                dist_ids='all',
                dist_ints='all',
                exclude_img_ids=None,
                exclude_dist_ids=None,
                exclude_dist_ints=None):
    """
    Filters the data to only utilize a subset based on img_id.

    Parameters
    ----------
    img_ids: list[string]
        List of image IDs to use passed as strings.
    dist_ids: list[string]
        List of image IDs to use passed as strings.
    dist_int: list[string]
        List of image IDs to use passed as strings.
        As of now, the intensities go from 1 to 5.
    exclude_img_ids: list[string]
        List of image IDs to exclude passed as strings.
    exclude_dist_ids: list[string]
        List of image IDs to exclude passed as strings.
    exclude_dist_int: list[string]
        List of image IDs to exclude passed as strings.
        As of now, the intensities go from 1 to 5.

    Returns
    -------
    data: list[ImagePair]
        List of ImagePair objects containing the paths to the image pairs and 
        their corresponding metric.
    """
    ## It's not good practice to default a parameter as an empty list.
    ## The good practice is to default it as a None and then create the empty list.
    exclude_img_ids = [] if exclude_img_ids == None else exclude_img_ids
    exclude_dist_ids = [] if exclude_dist_ids == None else exclude_dist_ids
    exclude_dist_ints = [] if exclude_dist_ints == None else exclude_dist_ints
    data = []
    for img_path in glob(os.path.join(path_2008, 'reference_images', '*.BMP')):
        if img_ids != 'all': # We only want to skip images if any ids were specified
            if img_path.lower().split("/")[-1].split(".")[0][1:] not in img_ids:
                continue # Skips this loop iteration if the ids is not being selected
        elif len(exclude_img_ids)>0:
            if img_path.lower().split("/")[-1].split(".")[0][1:] in exclude_img_ids:
                continue
        for dist_img_path in glob(os.path.join(path_2008, 'distorted_images', f'{img_path.split("/")[-1].split(".")[0]}*')):
            dist_id, dist_int = dist_img_path.lower().split("/")[-1].split(".")[0][1:].split("_")[1:]
            
            if dist_ids!='all':
                if dist_id not in dist_ids:
                    continue
            elif len(exclude_dist_ids)>0:
                if dist_id in exclude_dist_ids:
                    continue
            if dist_ints!='all':
                if dist_int not in dist_ints:
                    continue
            elif len(exclude_dist_ints)>0:
                if dist_int in exclude_dist_ints:
                    continue
            
            data.append(ImagePair(img_path, dist_img_path, name_metric_2008[dist_img_path.split("/")[-1].split(".")[0]]))
    return data

def filter_data_2013(img_ids='all', 
                dist_ids='all',
                dist_ints='all',
                exclude_img_ids=None,
                exclude_dist_ids=None,
                exclude_dist_ints=None):
    """
    Filters the data to only utilize a subset based on img_id.

    Parameters
    ----------
    img_ids: list[string]
        List of image IDs to use passed as strings.
    dist_ids: list[string]
        List of image IDs to use passed as strings.
    dist_int: list[string]
        List of image IDs to use passed as strings.
        As of now, the intensities go from 1 to 5.
    exclude_img_ids: list[string]
        List of image IDs to exclude passed as strings.
    exclude_dist_ids: list[string]
        List of image IDs to exclude passed as strings.
    exclude_dist_int: list[string]
        List of image IDs to exclude passed as strings.
        As of now, the intensities go from 1 to 5.

    Returns
    -------
    data: list[ImagePair]
        List of ImagePair objects containing the paths to the image pairs and 
        their corresponding metric.
    """
    ## It's not good practice to default a parameter as an empty list.
    ## The good practice is to default it as a None and then create the empty list.
    exclude_img_ids = [] if exclude_img_ids == None else exclude_img_ids
    exclude_dist_ids = [] if exclude_dist_ids == None else exclude_dist_ids
    exclude_dist_ints = [] if exclude_dist_ints == None else exclude_dist_ints
    data = []
    for img_path in glob(os.path.join(path_2013, 'reference_images', '*.BMP')):
        if img_ids != 'all': # We only want to skip images if any ids were specified
            if img_path.lower().split("/")[-1].split(".")[0][1:] not in img_ids:
                continue # Skips this loop iteration if the ids is not being selected
        elif len(exclude_img_ids)>0:
            if img_path.lower().split("/")[-1].split(".")[0][1:] in exclude_img_ids:
                continue
        for dist_img_path in glob(os.path.join(path_2013, 'distorted_images', f'{img_path.lower().split("/")[-1].split(".")[0]}*')):
            dist_id, dist_int = dist_img_path.lower().split("/")[-1].split(".")[0][1:].split("_")[1:]
            
            if dist_ids!='all':
                if dist_id not in dist_ids:
                    continue
            elif len(exclude_dist_ids)>0:
                if dist_id in exclude_dist_ids:
                    continue
            if dist_ints!='all':
                if dist_int not in dist_ints:
                    continue
            elif len(exclude_dist_ints)>0:
                if dist_int in exclude_dist_ints:
                    continue
            
            data.append(ImagePair(img_path, dist_img_path, name_metric_2013[dist_img_path.split("/")[-1].split(".")[0]]))
    return data

In [None]:
path_2008 = '/media/disk/databases/BBDD_video_image/Image_Quality/TID/TID2008'
path_2013 = '/media/disk/databases/BBDD_video_image/Image_Quality/TID/TID2013'

name_metric_2008 = {}
with open(os.path.join(path_2008, 'mos_with_names.txt')) as f:
    for line in f.readlines():
        # remove last character to avoid \n
        metric, file_name = line[:-1].split(" ")
        name_metric_2008[file_name.upper().split(".")[0]] = float(metric)

name_metric_2013 = {}
with open(os.path.join(path_2013, 'mos_with_names.txt')) as f:
    for line in f.readlines():
        # remove last character to avoid \n
        metric, file_name = line[:-1].split(" ")
        name_metric_2013[file_name.lower().split(".")[0]] = float(metric)

ImagePair = namedtuple('ImagePair', 'img_path dist_img_path metric')

train_data = filter_data_2008(img_ids='all',
                            dist_ids='all',
                            dist_ints='all',
                            exclude_img_ids=['25'],
                    )
test_data = filter_data_2013(img_ids='all',
                        dist_ids='all',
                        dist_ints='all',
                        exclude_img_ids=['25'],
                    )

def train_gen():
    for sample in train_data:
        img = cv2.imread(sample.img_path)
        dist_img = cv2.imread(sample.dist_img_path)
        if True:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            dist_img = cv2.cvtColor(dist_img, cv2.COLOR_BGR2RGB)
        img = img/255.0
        dist_img = dist_img/255.0
        metric = sample.metric
        yield img, dist_img, metric

def test_gen():
    for sample in test_data:
        img = cv2.imread(sample.img_path)
        dist_img = cv2.imread(sample.dist_img_path)
        if True:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            dist_img = cv2.cvtColor(dist_img, cv2.COLOR_BGR2RGB)
        img = img/255.0
        dist_img = dist_img/255.0
        metric = sample.metric
        yield img, dist_img, metric

train_dataset = tf.data.Dataset.from_generator(train_gen,
                                                output_signature=(
                                                    tf.TensorSpec(shape=(384, 512, 3), dtype=tf.float32),
                                                    tf.TensorSpec(shape=(384, 512, 3), dtype=tf.float32),
                                                    tf.TensorSpec(shape=(), dtype=tf.float32)
                                                ))
test_dataset = tf.data.Dataset.from_generator(test_gen,
                                                output_signature=(
                                                    tf.TensorSpec(shape=(384, 512, 3), dtype=tf.float32),
                                                    tf.TensorSpec(shape=(384, 512, 3), dtype=tf.float32),
                                                    tf.TensorSpec(shape=(), dtype=tf.float32)
                                                ))

## Inspection

In [5]:
train_dataset, test_dataset

(<FlatMapDataset shapes: ((384, 512, 3), (384, 512, 3), ()), types: (tf.float32, tf.float32, tf.float32)>,
 <FlatMapDataset shapes: ((384, 512, 3), (384, 512, 3), ()), types: (tf.float32, tf.float32, tf.float32)>)

In [8]:
l = TID2013("/media/disk/databases/BBDD_video_image/Image_Quality/TID/TID2013", exclude_imgs=[25])

In [11]:
for i, _ in enumerate(test_dataset):pass
for j, _ in enumerate(l.dataset):pass
assert i==j

### Check if every element of the new dataset is in the old dataset

In [18]:
from fastprogress.fastprogress import progress_bar

In [None]:
# matches = 0
# for new in progress_bar(l.dataset, total=2880):
#     for old in test_dataset:
#         if (new[0].numpy()==old[0].numpy()).all() & (new[1].numpy()==old[1].numpy()).all() & (new[2]==old[2]): 
#             matches += 1
#             break
#         else: continue