# Transfer Learning

### Import Libraries

In [10]:
import pathlib
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
# import keras_tuner as kt
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.utils.class_weight import compute_class_weight


In [11]:
# import os
# import shutil

# def merge_folders(input_root, output_root):
#     # Create the output directory if it doesn't exist
#     if not os.path.exists(output_root):
#         os.makedirs(output_root)
    
#     # Loop through each fold (e.g. FOLD1, FOLD2, etc.)
#     for fold in os.listdir(input_root):
#         fold_path = os.path.join(input_root, fold)
#         fold_path = os.path.join(fold_path, "Test")
#         if not os.path.isdir(fold_path):
#             continue  # Skip files; expect only folders here

#         # For each category folder inside the fold
#         for category in os.listdir(fold_path):
#             category_path = os.path.join(fold_path, category)
#             if not os.path.isdir(category_path):
#                 continue  # Skip non-directory items
            
#             # Create the corresponding category folder in the output if it doesn't exist
#             output_category_path = os.path.join(output_root, category)
#             if not os.path.exists(output_category_path):
#                 os.makedirs(output_category_path)
            
#             # Copy each file from the category folder to the merged output only if it doesn't exist already
#             for file_name in os.listdir(category_path):
#                 file_path = os.path.join(category_path, file_name)
#                 if os.path.isfile(file_path):
#                     destination = os.path.join(output_category_path, file_name)
#                     if os.path.exists(destination):
#                         print(f"Skipped {file_path} because {destination} already exists.")
#                     else:
#                         shutil.copy(file_path, destination)
#                         print(f"Copied {file_path} to {destination}")

# # Example usage:
# # merge_folders("../data/MSLDV2/Augmented Images/Augmented Images/FOLDS_AUG", "../data/merged_MSLD/Train")
# merge_folders("../data/MSLDV2/Original Images/Original Images/FOLDS", "../data/merged_MSLD/Test")


## Train with Skin Cancer Dataset

### Hyperparameters

In [12]:
data_root = pathlib.Path("../data/merged_MSLD/Train")    # points to the folder containing the images that will be used for training

# hyperparameters
img_height = 224        # input image height
img_width = 224         # input image width
batch_size = 32         # size of the batch that will be fed to model

# folds = the amount of folds that will be created for cross-validation
# fine_tune_epochs = number of epochs after which we start fine-tuning
# fine_tune_at = layer number where we start unfreezing layers

# configurations that will be used in training
configs = [
    {"model_name": "mobilenet", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},
    # {"model_name": "mobilenet", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": True, "fine_tune_epochs": 25, "fine_tune_at": 148},
    # {"model_name": "mobilenet", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},
    # {"model_name": "mobilenet", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": True, "fine_tune_epochs": 25, "fine_tune_at": 148},

    {"model_name": "efficientnet", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},
    
    {"model_name": "densenet", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},

    {"model_name": "inceptionv3", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},

    {"model_name": "resnet50", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},

    {"model_name": "vgg16", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 15},

    # {"model_name": "xception", "learning_rate": 0.001, "batch_size": 32, "image_size" : 224, "optimizer": "adam", "epochs": 50, "save_metrics": True, "folds": 5, "fine_tune": False, "fine_tune_epochs": 25, "fine_tune_at": 150},
]

# Define the base path for saving models
save_dir = "../saved_models"
os.makedirs(save_dir, exist_ok=True)

## Training 

### Setup

In [13]:
# Load dataset without splitting
dataset = tf.keras.utils.image_dataset_from_directory(
    data_root,                                  # loads images from the data_root directory
    image_size=(img_height, img_width),         # resizes all images to (224, 224) pixels
    batch_size=batch_size,                      # set the batch size
    shuffle=True                                # shufle data when loaded
)

# test_ds = tf.keras.utils.image_dataset_from_directory(
#     "../data/merged_MSLD/Test",                                  # loads images from the data_root directory
#     image_size=(img_height, img_width),         # resizes all images to (224, 224) pixels
#     batch_size=batch_size,                      # set the batch size
#     shuffle=False,                              # shufle data when loaded
# )

# test_ds = test_ds.unbatch().batch(batch_size, drop_remainder=True)

class_names = np.array(dataset.class_names)     # get the class names for the data
num_classes = len(class_names)                  # get the number of classes in the dataset

# convert the dataset to a list of (image, label) pairs. This makes it easier to perform cross-validation
image_paths, labels = [], []
for image_batch, label_batch in dataset:
    image_paths.extend(image_batch.numpy())
    labels.extend(label_batch.numpy())

image_paths = np.array(image_paths)             # convert to numpy array to facilitate training
labels = np.array(labels)                       # convert to numpy array to facilitate training

# Split the dataset into training/validation and test sets
train_val_images, test_images, train_val_labels, test_labels = train_test_split(
    image_paths, labels, test_size=0.10, random_state=42, stratify=labels
)

# train_val_images = np.concatenate((train_val_images, test_images), axis=0)
# train_val_labels = np.concatenate((train_val_labels, test_labels), axis=0)


print(len(train_val_images))
print(len(train_val_labels))

def callbacks_setup(checkpoint_filepath):
    # EarlyStopping callback configuration
    early_stopping = EarlyStopping(
        monitor='val_loss',        # monitor validation loss
        patience=6,                # number of epochs with no improvement to stop training
        mode = 'min',              # want to minimize what it being monitored 
        min_delta=0.0003,
        restore_best_weights=False # don't restore in EarlyStopping, handled by ModelCheckpoint
    )

    model_checkpoint = ModelCheckpoint(
        filepath=checkpoint_filepath,   # path to save weights
        save_weights_only=True,         # only save weights instead of full model
        monitor='val_loss',             # monitor validation loss
        mode='min',                     # want to maximize what is being monitored
        save_best_only=True             # save the best weights
    )            

    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',      # monitor validation loss 
        factor=0.5,              # factor by which the learning rate will be reduced 
        patience=4,              # number of epochs with no improvement to stop training 
        mode='min',              # want to minimize what it being monitored 
        min_delta=0.0003,
        min_lr=1e-6              # lower bound on the learning rate 
    )            

    return early_stopping, model_checkpoint, reduce_lr

Found 10570 files belonging to 6 classes.
9513
9513


### Metrics

In [14]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from typing import Tuple, Dict, Any, List
from sklearn.metrics import (
    precision_score,
    classification_report,
    roc_auc_score,
    roc_curve,
    auc,
    recall_score,
    f1_score,
    confusion_matrix,
    ConfusionMatrixDisplay
)
from sklearn.preprocessing import label_binarize

def save_confusion_matrix(true_labels: np.ndarray, predicted_labels: np.ndarray, 
                          class_names: List[str], save_path: str) -> None:
    """
    Plots and saves the confusion matrix for multi-class classification.

    Args:
        true_labels (np.ndarray): Array of true class labels.
        predicted_labels (np.ndarray): Array of predicted class labels.
        class_names (List[str]): List of class names corresponding to class indices.
        save_path (str): Path to save the confusion matrix plot.
    """
    # Compute confusion matrix using sklearn
    cm = confusion_matrix(true_labels, predicted_labels)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)

    # Plot with adjustments
    fig, ax = plt.subplots(figsize=(8, 6))  # Adjust figure size
    disp.plot(cmap=plt.cm.Blues, ax=ax)

    ax.set_title("Confusion Matrix")
    ax.set_xlabel("Predicted label", fontsize=12)
    ax.set_ylabel("True label", fontsize=12)

    # Rotate x-axis labels for better readability and alignment
    plt.xticks(rotation=45, ha='right', rotation_mode='anchor', fontsize=10)
    plt.yticks(rotation=45, ha='right', rotation_mode='anchor', fontsize=10)

    # Prevent labels from being cut off
    plt.tight_layout()

    # Save and close plot
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

def save_loss_curve(history: Dict[str, Any], save_path: str) -> None:
    """
    Plots and saves the training and validation loss curves.

    Args:
        history (Dict[str, Any]): Dictionary containing training history (loss values).
        save_path (str): Path to save the loss curve plot.
    """
    plt.figure(figsize=(10, 6))
    plt.plot(history['loss'], label='Training Loss', color='blue')
    plt.plot(history['val_loss'], label='Validation Loss', color='orange')
    plt.title("Training and Validation Loss Over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def save_roc_auc(true_labels: np.ndarray, predicted_probs: np.ndarray, class_names: list, save_path: str = None):
    """
    Plots and saves the ROC AUC curve for multi-class classification.
    
    Args:
        true_labels (np.ndarray): True class labels.
        predicted_probs (np.ndarray): Predicted class probabilities.
        class_names (list): List of class names.
        save_path (str, optional): Path to save the ROC curve plot. Defaults to None.
    """
    plt.figure(figsize=(10, 6))
    for i, class_name in enumerate(class_names):
        fpr, tpr, _ = roc_curve(true_labels == i, predicted_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', label='Random Chance')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC AUC Curve')
    plt.legend(loc='lower right')
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

def save_evaluation_metrics(true_labels: np.ndarray, predicted_labels: np.ndarray, 
                            predicted_probs: np.ndarray, save_path: str) -> Dict[str, float]:
    """
    Computes evaluation metrics for multi-class classification and saves a bar chart.
    The metrics include accuracy, precision, recall, F1 score, and ROC AUC.

    Args:
        true_labels (np.ndarray): Array of true class labels.
        predicted_labels (np.ndarray): Array of predicted class labels.
        predicted_probs (np.ndarray): Array of predicted probabilities (shape: [n_samples, n_classes]).
        save_path (str): Path to save the evaluation metrics bar chart.

    Returns:
        Dict[str, float]: Dictionary containing computed metrics.
    """
    # Calculate accuracy by comparing predicted and true labels
    accuracy = np.mean(predicted_labels == true_labels)
    # Compute macro-averaged metrics for multi-class classification
    recall = recall_score(true_labels, predicted_labels, average='macro')
    precision = precision_score(true_labels, predicted_labels, average='macro')
    f1 = f1_score(true_labels, predicted_labels, average='macro')

    # For ROC AUC, first binarize the true labels to one-hot encoding
    n_classes = predicted_probs.shape[1]
    true_labels_binarized = label_binarize(true_labels, classes=list(range(n_classes)))
    # Compute ROC AUC with a one-vs-rest approach and macro average
    roc_auc = roc_auc_score(true_labels_binarized, predicted_probs, multi_class='ovr', average='macro')

    # Store metrics in a dictionary
    metrics = {
        "Accuracy": accuracy,
        "Precision": precision,
        "Sensitivity (Recall)": recall,
        "F1-Score": f1,
        "ROC AUC": roc_auc
    }

    # Plot metrics as a bar chart
    plt.figure(figsize=(10, 6))
    bars = plt.bar(metrics.keys(), metrics.values(), 
                   color=['darkturquoise', 'sandybrown', 'hotpink', 'limegreen', 'mediumpurple'])
    # Annotate each bar with its value
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2, yval, f'{yval:.4f}', ha='center', va='bottom')
    plt.title("Model Evaluation Metrics")
    plt.ylim([0, 1])
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.ylabel("Score")
    plt.savefig(save_path)
    plt.close()

    return metrics

def save_classification_report(true_labels: np.ndarray, predicted_labels: np.ndarray, 
                               class_names: List[str], save_path: str) -> None:
    """
    Saves the classification report to a text file for multi-class classification.

    Args:
        true_labels (np.ndarray): Array of true class labels.
        predicted_labels (np.ndarray): Array of predicted class labels.
        class_names (List[str]): List of class names.
        save_path (str): Path to save the classification report.
    """
    report = classification_report(true_labels, predicted_labels, target_names=class_names, digits=4)
    with open(save_path, "w") as f:
        f.write(report)

def calculate_metrics(true_labels: np.ndarray, predictions: np.ndarray) -> Tuple[float, float, float, float, float]:
    """
    Calculates evaluation metrics for multi-class classification.

    Args:
        true_labels (np.ndarray): Array of true class labels.
        predictions (np.ndarray): Array of predicted probabilities (shape: [n_samples, n_classes]).

    Returns:
        Tuple[float, float, float, float, float]: A tuple containing accuracy, precision, recall, 
            F1 score, and ROC AUC score.
    """
    # Convert predicted probabilities to predicted class labels using argmax
    predicted_labels = np.argmax(predictions, axis=1)
    accuracy = np.mean(predicted_labels == true_labels)
    precision = precision_score(true_labels, predicted_labels, average='macro')
    recall = recall_score(true_labels, predicted_labels, average='macro')
    f1 = f1_score(true_labels, predicted_labels, average='macro')

    # Binarize true labels for ROC AUC calculation
    n_classes = predictions.shape[1]
    true_labels_binarized = label_binarize(true_labels, classes=list(range(n_classes)))
    auc = roc_auc_score(true_labels_binarized, predictions, multi_class='ovr', average='macro')

    return accuracy, precision, recall, f1, auc

def save_best_model_visuals(history: tf.keras.callbacks.History, model: tf.keras.Model, 
                              val_ds: tf.data.Dataset, class_names: List[str], 
                              weights_path: str, fold: int) -> None:
    """
    Generates and saves evaluation visuals including confusion matrix, loss curve, evaluation 
    metrics bar chart, and classification report for the best performing model in a given fold.

    Args:
        history (tf.keras.callbacks.History): Training history object.
        model (tf.keras.Model): Trained model.
        val_ds (tf.data.Dataset): Validation dataset.
        class_names (List[str]): List of class names.
        weights_path (str): Directory path to save visuals.
        fold (int): Current fold number.
    """
    # Generate predictions (predicted probabilities) for the validation set
    val_predictions = model.predict(val_ds)
    # Convert predicted probabilities to class labels using argmax
    val_predicted_ids = np.argmax(val_predictions, axis=1)
    # Concatenate true labels from the validation dataset
    true_labels = np.concatenate([y for _, y in val_ds], axis=0)

    # Save the confusion matrix
    confusion_matrix_path = os.path.join(weights_path, f"confusion_matrix.png")
    save_confusion_matrix(true_labels, val_predicted_ids, class_names, confusion_matrix_path)

    # Save the loss curve using the training history
    loss_curve_path = os.path.join(weights_path, f"loss_curve.png")
    save_loss_curve(history.history, loss_curve_path)

    # Save the roc auc curve using the training history
    roc_auc_curve_path = os.path.join(weights_path, f"roc_auc_curve.png")
    save_roc_auc(true_labels, val_predictions, class_names, roc_auc_curve_path)

    # Save evaluation metrics bar chart (passing predicted probabilities for ROC AUC calculation)
    metrics_bar_chart_path = os.path.join(weights_path, f"evaluation_metrics.png")
    save_evaluation_metrics(true_labels, val_predicted_ids, val_predictions, metrics_bar_chart_path)

    # Save the classification report as a text file
    classification_report_path = os.path.join(weights_path, f"classification_report.txt")
    save_classification_report(true_labels, val_predicted_ids, class_names, classification_report_path)

def compute_cv_statistics(accuracies, precisions, recalls, f1_scores, use_std_error=False, output_filepath=None):
    """
    Compute the mean and variation (std dev or standard error) for each metric and optionally save the results to a file.
    
    Parameters:
        accuracies (list of float): Accuracy values for each fold.
        precisions (list of float): Precision values for each fold.
        recalls (list of float): Recall values for each fold.
        f1_scores (list of float): F1 score values for each fold.
        use_std_error (bool): If True, compute standard error (std / sqrt(n)); otherwise, compute standard deviation.
        output_filepath (str): If provided, the path to a text file where the metrics will be saved.
    
    The function prints each metric in the format:
        Metric: mean ± variation
    (Metrics are multiplied by 100 to display percentages.)
    """
    metrics = {
        "Accuracy": np.array(accuracies),
        "Precision": np.array(precisions),
        "Recall": np.array(recalls),
        "F1 Score": np.array(f1_scores)
    }
    
    output_lines = []
    for metric_name, values in metrics.items():
        mean_val = np.mean(values)
        # Use ddof=1 for sample standard deviation
        variation = np.std(values, ddof=1)
        if use_std_error:
            variation /= np.sqrt(len(values))
        # Format the output as percentages; adjust if your metrics are already in percentage
        output_lines.append(f"{metric_name}: {mean_val*100:.2f} ± {variation*100:.2f}\n")
    
    metrics_str = "".join(output_lines)
    
    # Print the overall metrics to the console
    print("\nOverall Cross-Validation Metrics:")
    print(metrics_str)
    
    # Save the metrics to a text file if an output file path is provided
    if output_filepath:
        with open(output_filepath, "w") as f:
            f.write("Overall Cross-Validation Metrics:\n")
            f.write(metrics_str)


In [15]:
def save_confusion_matrix_binary(true_labels: np.ndarray, predicted_labels: np.ndarray, 
                          save_path: str, mpox_index: int = 0) -> None:
    """
    Converts multi-class labels to binary (Mpox vs Other) and plots/saves the confusion matrix.

    Args:
        true_labels (np.ndarray): Array of true class labels (multi-class integers).
        predicted_labels (np.ndarray): Array of predicted class labels (multi-class integers).
        save_path (str): Path to save the confusion matrix plot.
        mpox_index (int): The index corresponding to Mpox. All other labels are considered "Other".
    """
    # Convert multi-class labels to binary: 1 if label equals mpox_index, else 0.
    binary_true = (true_labels == mpox_index).astype(int)
    binary_pred = (predicted_labels == mpox_index).astype(int)
    cm = confusion_matrix(binary_true, binary_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Other", "Mpox"])
    
    # Plot with adjustments
    fig, ax = plt.subplots(figsize=(8, 6))  # Adjust figure size
    disp.plot(cmap=plt.cm.Blues, ax=ax)

    ax.set_title("Confusion Matrix")
    ax.set_xlabel("Predicted label", fontsize=12)
    ax.set_ylabel("True label", fontsize=12)

    # Rotate x-axis labels for better readability and alignment
    plt.xticks(rotation=45, ha='right', rotation_mode='anchor', fontsize=10)
    plt.yticks(rotation=45, ha='right', rotation_mode='anchor', fontsize=10)

    # Prevent labels from being cut off
    plt.tight_layout()

    # Save and close plot
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

def save_loss_curve_binary(history: Dict[str, Any], save_path: str) -> None:
    """
    Plots and saves the training and validation loss curves.

    Args:
        history (Dict[str, Any]): Dictionary containing training history (loss values).
        save_path (str): Path to save the loss curve plot.
    """
    plt.figure(figsize=(10, 6))
    plt.plot(history['loss'], label='Training Loss', color='blue')
    plt.plot(history['val_loss'], label='Validation Loss', color='orange')
    plt.title("Training and Validation Loss Over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def save_evaluation_metrics_binary(true_labels: np.ndarray, predicted_labels: np.ndarray, 
                            predicted_probs: np.ndarray, save_path: str, 
                            mpox_index: int = 0) -> Dict[str, float]:
    """
    Computes and plots evaluation metrics for binary classification (Mpox vs Other).

    Args:
        true_labels (np.ndarray): Array of true class labels (multi-class integers).
        predicted_labels (np.ndarray): Array of predicted class labels (multi-class integers).
        predicted_probs (np.ndarray): Array of predicted probabilities for each class 
                                      (shape: [n_samples, n_classes]).
        save_path (str): Path to save the evaluation metrics bar chart.
        mpox_index (int): The index corresponding to Mpox.

    Returns:
        Dict[str, float]: Dictionary containing computed metrics.
    """
    # Convert to binary labels
    binary_true = (true_labels == mpox_index).astype(int)
    binary_pred = (predicted_labels == mpox_index).astype(int)
    # Use the probability for the Mpox class as the positive probability.
    mpox_probs = predicted_probs[:, mpox_index]

    accuracy = np.mean(binary_true == binary_pred)
    precision = precision_score(binary_true, binary_pred)
    recall = recall_score(binary_true, binary_pred)
    f1 = f1_score(binary_true, binary_pred)
    roc_auc = roc_auc_score(binary_true, mpox_probs)

    metrics = {
        "Accuracy": accuracy,
        "Precision": precision,
        "Sensitivity (Recall)": recall,
        "F1-Score": f1,
        "ROC AUC": roc_auc
    }

    # Plot metrics as a bar chart
    plt.figure(figsize=(10, 6))
    bars = plt.bar(metrics.keys(), metrics.values(), 
                   color=['darkturquoise', 'sandybrown', 'hotpink', 'limegreen', 'mediumpurple'])
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2, yval, f'{yval:.4f}', ha='center', va='bottom')
    plt.title("Evaluation Metrics (Mpox vs Other)")
    plt.ylim([0, 1])
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.ylabel("Score")
    plt.savefig(save_path)
    plt.close()

    return metrics

def save_classification_report_binary(true_labels: np.ndarray, predicted_labels: np.ndarray, 
                               save_path: str, mpox_index: int = 0) -> None:
    """
    Saves the classification report for binary classification (Mpox vs Other).

    Args:
        true_labels (np.ndarray): Array of true class labels (multi-class integers).
        predicted_labels (np.ndarray): Array of predicted class labels (multi-class integers).
        save_path (str): Path to save the classification report.
        mpox_index (int): The index corresponding to Mpox.
    """
    binary_true = (true_labels == mpox_index).astype(int)
    binary_pred = (predicted_labels == mpox_index).astype(int)
    report = classification_report(binary_true, binary_pred, target_names=["Other", "Mpox"], digits=4)
    with open(save_path, "w") as f:
        f.write(report)

def calculate_metrics_binary(true_labels: np.ndarray, predictions: np.ndarray, 
                      mpox_index: int = 0) -> Tuple[float, float, float, float, float]:
    """
    Calculates binary evaluation metrics for Mpox vs Other.
    The multi-class predictions are converted into binary predictions where the positive class 
    is Mpox (identified by mpox_index) and all other classes are negative.

    Args:
        true_labels (np.ndarray): Array of true class labels (multi-class integers).
        predictions (np.ndarray): Array of predicted probabilities (shape: [n_samples, n_classes]).
        mpox_index (int): The index corresponding to Mpox.

    Returns:
        Tuple[float, float, float, float, float]:
            Accuracy, Precision, Recall, F1 Score, and ROC AUC.
    """
    # Convert multi-class predictions to class indices
    predicted_labels_multi = np.argmax(predictions, axis=1)
    # Convert to binary: 1 if Mpox, 0 otherwise.
    binary_true = (true_labels == mpox_index).astype(int)
    binary_pred = (predicted_labels_multi == mpox_index).astype(int)
    mpox_probs = predictions[:, mpox_index]

    accuracy = np.mean(binary_true == binary_pred)
    precision = precision_score(binary_true, binary_pred)
    recall = recall_score(binary_true, binary_pred)
    f1 = f1_score(binary_true, binary_pred)
    auc = roc_auc_score(binary_true, mpox_probs)

    return accuracy, precision, recall, f1, auc

def save_best_model_visuals_binary(history: tf.keras.callbacks.History, model: tf.keras.Model, 
                              val_ds: tf.data.Dataset, weights_path: str, 
                              fold: int, mpox_index: int = 1) -> None:
    """
    Generates and saves evaluation visuals (confusion matrix, loss curve, metrics bar chart,
    and classification report) for binary classification (Mpox vs Other) for the best performing model.

    Args:
        history (tf.keras.callbacks.History): Training history object.
        model (tf.keras.Model): Trained model.
        val_ds (tf.data.Dataset): Validation dataset.
        weights_path (str): Directory path to save visuals.
        fold (int): Current fold number.
        mpox_index (int): The index corresponding to Mpox.
    """
    # Generate predictions for the validation set
    val_predictions = model.predict(val_ds)
    predicted_ids_multi = np.argmax(val_predictions, axis=1)
    true_labels = np.concatenate([y for _, y in val_ds], axis=0)
    
    # Save the confusion matrix
    confusion_matrix_path = os.path.join(weights_path, f"confusion_matrix_binary.png")
    save_confusion_matrix_binary(true_labels, predicted_ids_multi, confusion_matrix_path, mpox_index)
    
    # Save the loss curve
    loss_curve_path = os.path.join(weights_path, f"loss_curve_binary.png")
    save_loss_curve_binary(history.history, loss_curve_path)
    
    # Save evaluation metrics bar chart
    metrics_bar_chart_path = os.path.join(weights_path, f"evaluation_metrics_binary.png")
    save_evaluation_metrics_binary(true_labels, predicted_ids_multi, val_predictions, metrics_bar_chart_path, mpox_index)
    
    # Save the classification report
    classification_report_path = os.path.join(weights_path, f"classification_report_binary.txt")
    save_classification_report_binary(true_labels, predicted_ids_multi, classification_report_path, mpox_index)


### Automated Hyperparameter Tuning

In [16]:
# # Split data into training/validation set for hyperparameter tuning
# train_images_tuning, val_images_tuning, train_labels_tuning, val_labels_tuning = train_test_split(
#     image_paths, labels, test_size=0.1, random_state=42, stratify=labels
# )

# # Define the hypermodel for hyperparameter tuning
# def build_model(hp):
#     base_model = tf.keras.applications.MobileNetV2(
#         input_shape=(img_height, img_width, 3),
#         include_top=False,
#         weights='imagenet'
#     )
#     base_model.trainable = False  # Freeze layers initially
    
#     model = Sequential([
#         base_model,
#         layers.GlobalAveragePooling2D(),
#         layers.Dense(num_classes)
#     ])

#     # Tune hyperparameters
#     learning_rate = hp.Float('learning_rate', min_value=1e-4, max_value=1e-2, sampling='log')
#     optimizer = hp.Choice('optimizer', values=['adam', 'sgd'])

#     if optimizer == 'adam':
#         opt = tf.keras.optimizers.Adam(learning_rate=learning_rate)
#     else:
#         opt = tf.keras.optimizers.SGD(learning_rate=learning_rate)

#     model.compile(
#         optimizer=opt,
#         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#         metrics=['accuracy']
#     )
    
#     return model

# # Set up the tuner for hyperparameter tuning
# tuner = kt.RandomSearch(
#     build_model,
#     objective='val_accuracy',  # Optimize for validation accuracy
#     max_trials=10,             # Try 10 different hyperparameter combinations
#     executions_per_trial=1,    # Run each combination once
#     directory='hyperparameter_tuning',
#     project_name='best_hyperparams_tuning'
# )

# # Prepare TensorFlow datasets for training and validation
# train_ds = tf.data.Dataset.from_tensor_slices((train_images_tuning, train_labels_tuning)).batch(batch_size)
# val_ds = tf.data.Dataset.from_tensor_slices((val_images_tuning, val_labels_tuning)).batch(batch_size)

# # Perform the hyperparameter search on the validation set
# tuner.search(train_ds, validation_data=val_ds, epochs=10)

# # Get the best hyperparameters after the search
# best_hyperparams = tuner.get_best_hyperparameters(num_trials=1)[0]

# # Print the best hyperparameters
# print(f"Best Hyperparameters: {best_hyperparams.values}")


### Model creation and fine tuning

In [17]:
# Function to create and compile the model
def create_model(num_classes, config, fine_tune=None):
    # if you are not fine tuning the model, instantiate a new model 
    if(fine_tune == False):         
        # instantiate mobilenet (contains 154 layers)
        base_model = tf.keras.applications.MobileNetV2(
            input_shape=(img_height, img_width, 3),     # set the input it will receive
            include_top=False,                          # do not include top layer to perform transfer learning
            weights='imagenet'                          # load weights from imagenet dataset
        )
        base_model.trainable = False                    # Freeze the base model
        
        # add a layer in order to perform classification on our dataset
        model = Sequential([
            base_model,                         # use base_model as the start of your model
            layers.GlobalAveragePooling2D(),    # add a final layer to perform classification
            layers.Dense(num_classes)           # set the number of possible prediction to the num of classes in dataset
        ])
        
    # select optimizer and learning rate based on configuration
    if config["optimizer"] == "adam":
        optimizer = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"])
    elif config["optimizer"] == "sgd":
        optimizer = tf.keras.optimizers.SGD(learning_rate=config["learning_rate"])
    else:
        raise ValueError(f"Unsupported optimizer: {config['optimizer']}")

    # compile the model
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    
    return model

# fine tune model by unfreezing the layers after the first fine_tune_at layers
def fine_tune_model(base_model, fine_tune_at):
    # Unfreeze the layers starting from fine_tune_at index
    for layer in base_model.layers[:fine_tune_at]:
        layer.trainable = False
    for layer in base_model.layers[fine_tune_at:]:
        layer.trainable = True


### Training loop

In [18]:
fold_accuracies = []
fold_precisions = []
fold_recalls = []
fold_f1s = []


normalization_layer = layers.Rescaling(1.0 / 255)
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y)).batch(32)

for i, config in enumerate(configs):
    print(f"Training model {i + 1}/{len(configs)} with config: {config}")

    # K-fold Cross Validation
    kfold = StratifiedKFold(n_splits=config['folds'], shuffle=True, random_state=42)
    best_val_f1score = -float('inf')            # Initialize best F1 score with a very low value

    # Define the base path for saving models
    model_subdir = os.path.join(save_dir, f'model{i + 1}')
    os.makedirs(model_subdir, exist_ok=True)

    # Define the base path for saving checkpoints for model
    checkpoint_folder = os.path.join(model_subdir, 'checkpoints')
    os.makedirs(checkpoint_folder, exist_ok=True)

    # Define the base path for saving cthe model with the best f1-score
    best_f1_dir = os.path.join(model_subdir, 'best_f1score_fold')
    os.makedirs(best_f1_dir, exist_ok=True)
    
    # Training and validation loop for each fold
    fold = 1
    best_f1_score = 0
    for train_idx, val_idx in kfold.split(train_val_images, train_val_labels):
        print(f"\nFold {fold}/{config['folds']}...")

        checkpoint_filepath = os.path.join(checkpoint_folder, f'checkpoint_fold{fold}.weights.h5')

        # Create subset datasets for training and validation
        train_images, train_labels = train_val_images[train_idx], train_val_labels[train_idx]
        val_images, val_labels = train_val_images[val_idx], train_val_labels[val_idx]

        computed_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
        # Create a dictionary mapping each class index to its weight.
        class_weight_dict = {int(cls): weight for cls, weight in zip(np.unique(train_labels), computed_weights)}

        # Convert NumPy arrays back to TensorFlow datasets
        train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
        val_ds = tf.data.Dataset.from_tensor_slices((val_images, val_labels))

        # Normalize datasets 
        normalization_layer = layers.Rescaling(1./255)
        train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
        val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))

        # prefetch data to improve performance by overlapping data preprocessing and model execution and cache the dataset in memory and batch
        AUTOTUNE = tf.data.AUTOTUNE
        train_ds = train_ds.batch(batch_size).cache().prefetch(buffer_size=AUTOTUNE)
        val_ds = val_ds.batch(batch_size).cache().prefetch(buffer_size=AUTOTUNE)

        # Step 1: Train model with frozen layers
        print(f"Training with frozen base layers for {config['epochs']} epochs...")

        # Create and compile model for each fold
        model = create_model(num_classes, config, fine_tune=False) 

        # setup callbacks 
        early_stopping, model_checkpoint, reduce_lr = callbacks_setup(checkpoint_filepath)

        # train the model on the training set until the epochs specified
        history_frozen = model.fit(
            train_ds,                                       # dataset used for training
            validation_data=val_ds,                         # dataset used for validation
            epochs=config['epochs'],                        # epochs used for training
            callbacks=[early_stopping, model_checkpoint, reduce_lr],   # set early stopping to avoid overfitting
            class_weight = class_weight_dict,
            verbose=1
        )

        # load the best weights from ModelCheckpoint after training
        model.load_weights(checkpoint_filepath)

        if(config["fine_tune"] == True):
            # Step 2: Unfreeze layers and fine-tune
            print(f"Unfreezing layers starting from layer {config['fine_tune_at']} for fine-tuning...")
            fine_tune_model(model.layers[0], config['fine_tune_at'])      # fine tune model

            # re-compile the model with a lower learning rate for fine-tuning
            fine_tune_lr = config['learning_rate'] * 0.01

            model.compile(
                optimizer=tf.keras.optimizers.Adam(learning_rate=fine_tune_lr),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
            )
                
            print(f"Fine-tuning for {config['fine_tune_epochs']} epochs...")

            # setup callbacks again for fine-tuning phase with a unique checkpoint
            early_stopping, model_checkpoint = callbacks_setup(checkpoint_filepath)
            
            history_fine_tune = model.fit(
                train_ds,                                       # dataset used for training
                validation_data=val_ds,                         # dataset used for validation
                epochs=config['fine_tune_epochs'],                        # epochs used for training
                callbacks=[early_stopping, model_checkpoint],   # set early stopping to avoid overfitting
                verbose=1
            )

            # load weights after fine-tuning
            model.load_weights(checkpoint_filepath)

        # evaluate on validation set after training
        val_predictions = model.predict(val_ds)
        avg_val_loss = model.evaluate(val_ds, verbose=0)[0]
        avg_val_accuracy, avg_val_precision, avg_val_recall, avg_val_f1, avg_val_auc = calculate_metrics(
            np.concatenate([y for _, y in val_ds]), val_predictions
        )

        print(f"\nValidation: \tFold {fold} - Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_accuracy:.4f}, Precision: {avg_val_precision:.4f}, Recall: {avg_val_recall:.4f}, F1 Score: {avg_val_f1:.4f}, AUC Score: {avg_val_auc:.4f}")

        test_predictions = model.predict(test_ds)
        avg_test_accuracy, avg_test_precision, avg_test_recall, avg_test_f1, _ = calculate_metrics(
        np.concatenate([y for _, y in test_ds]), test_predictions
        )

        # Append the fold metrics to the lists
        fold_accuracies.append(avg_test_accuracy)
        fold_precisions.append(avg_test_precision)
        fold_recalls.append(avg_test_recall)
        fold_f1s.append(avg_test_f1)

        # -------------------- Optional: Evaluation on Test Dataset --------------------
        # If this fold produces the best F1 score so far, save the model and visuals
        if avg_test_f1 > best_f1_score:
            best_f1_score = avg_test_f1
            # Save the best model (using model.export for TensorFlow SavedModel format)
            model.export(best_f1_dir)
            print(f"Best model updated at Fold {fold} with F1 Score: {best_f1_score:.4f}")
            if config.get('save_metrics', False):
                save_best_model_visuals(history_frozen, model, test_ds, class_names, model_subdir, fold)
                save_best_model_visuals_binary(history_frozen, model, test_ds, model_subdir, fold, 5)

        fold += 1       # Move to the next fold

    output_file_path = os.path.join(model_subdir, "cv_metrics.txt")
    compute_cv_statistics(fold_accuracies, fold_precisions, fold_recalls, fold_f1s,
                        use_std_error=False, output_filepath=output_file_path)

# save metrics after training
# np.save(os.path.join(save_dir, 'train_metrics.npy'), train_metrics)
# np.save(os.path.join(save_dir, 'val_metrics.npy'), val_metrics)

Training model 1/6 with config: {'model_name': 'mobilenet', 'learning_rate': 0.001, 'batch_size': 32, 'image_size': 224, 'optimizer': 'adam', 'epochs': 50, 'save_metrics': True, 'folds': 5, 'fine_tune': False, 'fine_tune_epochs': 25, 'fine_tune_at': 150}

Fold 1/5...
Training with frozen base layers for 50 epochs...
Epoch 1/50
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 18ms/step - accuracy: 0.5134 - loss: 1.2286 - val_accuracy: 0.7546 - val_loss: 0.6893 - learning_rate: 0.0010
Epoch 2/50
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.7664 - loss: 0.5481 - val_accuracy: 0.7840 - val_loss: 0.5884 - learning_rate: 0.0010
Epoch 3/50
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.8132 - loss: 0.4317 - val_accuracy: 0.7987 - val_loss: 0.5362 - learning_rate: 0.0010
Epoch 4/50
[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.8419 - loss: 0.3673 - v

INFO:tensorflow:Assets written to: ../saved_models/model1/best_f1score_fold/assets


Saved artifact at '../saved_models/model1/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_4894')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131694436744656: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131698426687376: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131698426686992: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131698426687184: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131698426691408: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131698426685456: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694037138704: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694037139088: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131698426688528: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694037131024: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model1/best_f1score_fold/assets


Saved artifact at '../saved_models/model1/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_5210')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131694838034512: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492470032: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492470800: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693632633744: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693632630864: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492470608: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492472144: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492472336: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492471952: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693492471184: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model2/best_f1score_fold/assets


Saved artifact at '../saved_models/model2/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_5684')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693663921232: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459186320: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459185744: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459187856: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459187280: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459188432: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459190544: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459190736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459189008: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131694459189584: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model2/best_f1score_fold/assets


Saved artifact at '../saved_models/model2/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_6000')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693488309072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666007568: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666005840: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666007760: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666003152: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666007184: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666005072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666004880: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666005264: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666006416: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model3/best_f1score_fold/assets


Saved artifact at '../saved_models/model3/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_6474')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693665627472: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647373584: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647373392: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647368976: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647378000: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647374160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647375888: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647374736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647376080: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693647375120: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model3/best_f1score_fold/assets


Saved artifact at '../saved_models/model3/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_6632')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693663926992: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666001232: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666008720: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693663923728: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693663915088: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693665997776: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666005456: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693665993744: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666008336: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693666007952: TensorSpec(shape=(), dtype=tf.resource

2025-02-13 14:50:46.606895: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


INFO:tensorflow:Assets written to: ../saved_models/model3/best_f1score_fold/assets


INFO:tensorflow:Assets written to: ../saved_models/model3/best_f1score_fold/assets


Saved artifact at '../saved_models/model3/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_6790')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693663913552: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674014160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674020496: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674015696: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674006672: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674017040: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674020304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674011472: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674017808: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693674016464: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model3/best_f1score_fold/assets


Saved artifact at '../saved_models/model3/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_7106')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693647553808: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445887376: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445885264: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445885840: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445877392: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445875280: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445876240: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445881808: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445881040: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693445883920: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model4/best_f1score_fold/assets


Saved artifact at '../saved_models/model4/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_7264')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693445876048: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652703760: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652704528: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652704912: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652695696: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652698960: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652699536: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652703952: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652702224: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652693584: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model4/best_f1score_fold/assets


Saved artifact at '../saved_models/model4/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_7580')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693649125904: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649149008: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649143248: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649140560: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649144016: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649141136: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693487961104: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649147088: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649148240: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693649135184: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model5/best_f1score_fold/assets


Saved artifact at '../saved_models/model5/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_8054')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693244113232: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692990427600: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692990427984: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692990430480: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692990426448: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692990430864: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692874302864: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692874303056: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692874302672: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692874302096: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model5/best_f1score_fold/assets


Saved artifact at '../saved_models/model5/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_8370')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693665995472: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693477096016: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693477099088: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693633931024: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693477095248: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693650507920: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693678427600: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131695306813136: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693650501392: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693650495632: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model6/best_f1score_fold/assets


Saved artifact at '../saved_models/model6/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_8844')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131693448119504: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448128720: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448134096: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448129104: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448121424: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448118736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652698384: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693652696464: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448131792: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131693448120272: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model6/best_f1score_fold/assets


Saved artifact at '../saved_models/model6/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_9160')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131692871940816: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987743952: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987738192: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692990425104: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987733968: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987733008: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987734160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987734928: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987742032: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131691987730704: TensorSpec(shape=(), dtype=tf.resource

INFO:tensorflow:Assets written to: ../saved_models/model6/best_f1score_fold/assets


Saved artifact at '../saved_models/model6/best_f1score_fold'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_9476')
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  131692991720016: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991716176: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991715408: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991723664: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991709648: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991717328: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991721168: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991717904: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991713488: TensorSpec(shape=(), dtype=tf.resource, name=None)
  131692991725008: TensorSpec(shape=(), dtype=tf.resource

## Testing

In [None]:
# test_ds = tf.keras.utils.image_dataset_from_directory(
#     "../data/merged_MSLD/Test",                                  # loads images from the data_root directory
#     image_size=(img_height, img_width),         # resizes all images to (224, 224) pixels
#     batch_size=batch_size,                      # set the batch size
#     shuffle=False,                              # shufle data when loaded
# )

# test_ds = test_ds.unbatch().batch(batch_size, drop_remainder=True)

# test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

# for i, config in enumerate(configs):
#     # model = tf.keras.models.load_model(f'../saved_models/model{i+1}/best_f1score_fold')

#     # Load the SavedModel using TFSMLayer, treating it as a Keras layer
#     model_layer = tf.keras.layers.TFSMLayer(f'../saved_models/model{i+1}/best_f1score_fold', call_endpoint='serving_default')
    
#     # Wrap the TFSMLayer in a Sequential model for inference
#     model = tf.keras.Sequential([model_layer])

#     # once training is complete, evaluate on the held-out test set
#     print(f"Evaluating the best model for model{i+1} on the held-out test set...")
#     # test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
#     # test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y)).batch(batch_size)

#     test_predictions = model.predict(test_ds)

#     # Print the shape and type of predictions for debugging
#     # print(f"Predictions shape: {len(test_predictions['output_0'])}, type: {type(test_predictions)}")
#     # print(f"First few predictions: {test_predictions[:5]}")  # Check the first few predictions

#     # avg_test_loss = model.evaluate(test_ds, verbose=0)[0]
#     avg_test_accuracy, avg_test_precision, avg_test_recall, avg_test_f1, avg_test_auc = calculate_metrics(
#         np.concatenate([y for _, y in test_ds]), test_predictions['output_0']
#     )

#     print(f"\nTest Set Evaluation - Accuracy: {avg_test_accuracy:.4f}, Precision: {avg_test_precision:.4f}, Recall: {avg_test_recall:.4f}, F1 Score: {avg_test_f1:.4f}, AUC Score: {avg_test_auc:.4f}\n")
