In [None]:
#| default_exp callbacks.evaluation

# Evaluation

> Simple callbacks to evaluate the current training model on a dataset at different training times.

In [None]:
#| export
from typing import Dict

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import Callback

We're going to build a class that stores a given dataset and calls evaluate on the training model when needed to obtain the evaluation metrics. We want it to be flexible in a way that we can specify a number of epochs or batches as the evaluation frequency. This could be solved by having a different callback for epoch and batches, but we probably can get away with using only one.

In [None]:
#| export
class EvaluateDataset(Callback):
    """Evaluates a given `tf.data.Dataset` at different training times."""

    def __init__(self,
                 dataset, # Dataset to be evaluated.
                 subset=None, # Layers interval to evaluate.
                 freq_epochs=None, # Number of epochs to wait between evaluations. `None` means not evaluating at an epoch interval.
                 freq_batches=None, # Number of batches to wait between evaluations. `None` means not evaluating at a batch interval.
                 append="", # Text to append to the metrics' names as an identifier.
                 ):
        self.dataset = dataset if isinstance(dataset, tf.data.Dataset) else self._convert_to_dataset(dataset)
        self.subset = subset
        self.freq_epochs = freq_epochs
        self.freq_batches = freq_batches
        self.append = append
        self.batches_seen, self.epochs_seen = 0, 0
        self._results_batches, self._results_epochs = [], []

    def _convert_to_dataset(self,
                            dataset, # Dataset to be converted.
                            ):
        """Tries to convert a dataset into a `tf.data.Dataset`."""
        return dataset

    def evaluate(self,
                 ) -> Dict: # Dictionary of evaluation results.
        """Calls the `.evaluate()` method of the given `model` on the stored `dataset`."""
        return {f"{name}{self.append}": value for name, value in self._model.evaluate(self.dataset, verbose=0, return_dict=True).items()}

    def on_train_begin(self,
                       logs=None,
                       ):
        self._model = self.model if self.subset is None else Sequential(self._model.layers[:self.subset])

    def on_train_batch_end(self,
                           batch, # Batch number in an epoch.
                           logs=None, # Training logs.
                           ):

        if self.freq_batches is None: return
        else:
            if self.batches_seen % self.freq_batches == 0: 
                results = self.evaluate()
                self._results_batches.append(results)
            self.batches_seen += 1
    
    def on_epoch_end(self,
                     batch, # Batch number in an epoch.
                     logs=None, # Training logs.
                     ):
        if self.freq_epochs is None: return
        else:
            if self.epochs_seen % self.freq_epochs == 0: 
                results = self.evaluate()
                self._results_epochs.append(results)
            self.epochs_seen += 1
    
    @staticmethod
    def _unpack_list_dicts(list_of_dicts):
        """Unpacks a list of dicts sharing keys into a dict with lists as values."""
        res = {}
        for result in list_of_dicts:
            for metric, value in result.items():
                if metric not in res.keys(): res[metric] = []
                res[metric].append(value)
        return res

    @property
    def results_batches(self):
        if len(self._results_batches) == 0: raise ValueError("No values stored yet.")
        return self._unpack_list_dicts(self._results_batches)

    @property
    def results_epochs(self):
        if len(self._results_epochs) == 0: raise ValueError("No values stored yet.")
        return self._unpack_list_dicts(self._results_epochs)

In [None]:
#| eval: false
from iqadatasets.datasets.tid2013 import TID2013
from iqadatasets.datasets.tid2008 import TID2008

ModuleNotFoundError: No module named 'iqadatasets'

In [None]:
#| eval: false
tid13 = TID2013("/media/disk/databases/BBDD_video_image/Image_Quality/TID/TID2013")
tid08 = TID2013("/media/disk/databases/BBDD_video_image/Image_Quality/TID/TID2008")

In [None]:
#| eval: false
from perceptnet.networks import PerceptNet
from perceptnet.pearson_loss import PearsonCorrelation

In [None]:
#| eval: false
model = PerceptNet(kernel_initializer="ones", gdn_kernel_size=1, learnable_undersampling=False)
model.compile(optimizer="adam",
              loss=PearsonCorrelation())

In [None]:
#| eval: false
cb_eval = EvaluateDataset(tid13.dataset.batch(16).take(4), freq_batches=5, append="_TID2013")
history = model.fit(tid08.dataset.batch(16).take(10), epochs=2, callbacks=[cb_eval])

In [None]:
#| eval: false
cb_eval.results_batches

{'loss_TID2013': [-0.9020472764968872,
  -0.9032243490219116,
  -0.9034193754196167,
  -0.9044590592384338]}