In [None]:
#| default_exp datasets.kadik10k

# LIVE

> Building a `tf.data.Dataset` for Kadik10k.

In [None]:
#| hide
import os; os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [None]:
#| export
from pathlib import Path
from typing import List

import pandas as pd
import tensorflow as tf
import cv2

After setting up the path to the directory and loading the corresponding `.csv` file, we need to create a generator that will iterate over the dataframe, load and return a 3-tuple: `(Reference Image, Distorted Image, DMOS)`. When can the pass that generator into a `tf.data.Dataset.from_generator()` to build the `Dataset` object:

In [None]:
#| export
class KADIK10K():
    """Builder for the KADIK10K dataset"""

    def __init__(self,
                 path, # Path to the root directory of the dataset.
                 exclude_imgs: List[int] = None, # Image ID's to exclude.
                 exclude_dist: List[int] = None, # Distortion ID's to exclude.
                 exclude_ints: List[int] = None, # Distortion Intensities ID's to exclude.
                 ):
        self.path_root = Path(path) if isinstance(path, str) else path
        self.path_csv = self.path_root/"dmos.csv"
        self.path_images = self.path_root/"images"
        self.data = self.load_data(self.path_csv, exclude_imgs, exclude_dist, exclude_ints)

    @property
    def dataset(self):
        """tf.data.Dataset object built from the KADIK10K dataset."""
        return tf.data.Dataset.from_generator(
                self.data_gen,
                output_signature=(
                    tf.TensorSpec(shape=(384, 512, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(384, 512, 3), dtype=tf.float32),
                    tf.TensorSpec(shape=(), dtype=tf.float32)
                )
            ) 

    def data_gen(self):
        """Dataset generator to build the tf.data.Dataset."""
        for i, row in self.data.iterrows():
            ref, dist, dmos = row.ref_img, row.dist_img, row.dmos
            dist = cv2.imread(str(self.path_images/dist))
            dist = cv2.cvtColor(dist, cv2.COLOR_BGR2RGB)/255.0
            ref = cv2.imread(str(self.path_images/ref))
            ref = cv2.cvtColor(ref, cv2.COLOR_BGR2RGB)/255.0
            yield ref, dist, dmos

    def load_data(self,
                  path,
                  exclude_imgs,
                  exclude_dist,
                  exclude_ints,
                  ):
        data = pd.read_csv(self.path_csv)
        data = data[~data.Reference_ID.isin(exclude_imgs)] if exclude_imgs is not None else data
        data = data[~data.Reference_ID.isin(exclude_dist)] if exclude_dist is not None else data
        data = data[~data.Reference_ID.isin(exclude_ints)] if exclude_ints is not None else data
        return data


In [None]:
l = KADIK10K(path = Path("/media/disk/databases/BBDD_video_image/Image_Quality/KADIK10K"))

In [None]:
l.data.head()

Unnamed: 0,dist_img,ref_img,dmos,var
0,I01_01_01.png,I01.png,4.57,0.496
1,I01_01_02.png,I01.png,4.33,0.869
2,I01_01_03.png,I01.png,2.67,0.789
3,I01_01_04.png,I01.png,1.67,0.596
4,I01_01_05.png,I01.png,1.1,0.3


In [None]:
for a, b, c in l.dataset:
    break
assert a.shape == b.shape