In [None]:
import tensorflow as tf
from sklearn import metrics
import numpy as np

In [5]:
# EarlyStopping based on F1 score
class EarlyStoppingByF1(tf.keras.callbacks.Callback):
    def __init__(self, patience=5, delta=0.01, monitor='val_f1_score'):
        super(EarlyStoppingByF1, self).__init__()
        self.patience = patience
        self.delta = delta
        self.monitor = monitor
        self.bestScore = -np.inf
        self.wait = 0

    def on_epoch_end(self, epoch, logs=None):
        currentScore = logs.get(self.monitor)
        if currentScore is None:
            return

        if currentScore > self.bestScore + self.delta:
            self.bestScore = currentScore
            self.wait = 0
        else:
            self.wait += 1

        if self.wait >= self.patience:
            print(f'\nEarly stopping triggered at epoch {epoch + 1} due to lack of improvement in {self.monitor}')
            self.model.stop_training = True

In [7]:
# Precision, Recall, F1-Score Calculation Callback
class SklearnMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, validationData=None, earlyStoppingMonitor=None, batchSize=32):
        super(SklearnMetricsCallback, self).__init__()
        self.validationData = validationData
        self.earlyStoppingMonitor = earlyStoppingMonitor
        self.batchSize = batchSize

    def on_epoch_end(self, epoch, logs=None):
        # Use validation data generator to calculate precision, recall, and f1 score
        if self.validationData:
            # Get a batch of data from the validation generator
            valData, valLabels = next(self.validationData)
            # Make predictions on the batch
            yPred = self.model.predict(valData, batch_size=self.batchSize)
            yPredClasses = np.argmax(yPred, axis=1)

            # Here val_labels are sparse (integer encoded)
            yTrue = valLabels

            # Calculate precision, recall, and f1 score using sklearn metrics
            precision = metrics.precision_score(yTrue, yPredClasses, average='macro')
            recall = metrics.recall_score(yTrue, yPredClasses, average='macro')
            f1 = metrics.f1_score(yTrue, yPredClasses, average='macro')

            # Print the metrics for this epoch
            print(f'\nEpoch {epoch + 1} Metrics: Precision = {precision:.4f}, Recall = {recall:.4f}, F1-Score = {f1:.4f}')
            
            # Log them into the `logs` dictionary to record them
            logs['val_precision'] = precision
            logs['val_recall'] = recall
            logs['val_f1_score'] = f1

            # If EarlyStopping is set, trigger early stopping here based on F1 score
            if self.earlyStoppingMonitor:
                self.earlyStoppingMonitor.on_epoch_end(epoch, logs)
