In [None]:
import keras
from keras.saving import register_keras_serializable

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay
from functools import partial
import seaborn as sns


In [None]:
# Import model
TEST_DATADIR = "../data/test_directory"

test_ds = keras.utils.image_dataset_from_directory(
    TEST_DATADIR,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="rgb",
    batch_size=1,
    image_size=(224, 224),
    shuffle=False,
)

class_names = test_ds.class_names

In [None]:


DefaultConv2D = partial(keras.layers.Conv2D, kernel_size=3, strides=1,
                        padding="same", kernel_initializer="he_normal",
                        use_bias=False)


@register_keras_serializable(package='Custom', name='ResidualUnit')
class ResidualUnit(keras.layers.Layer):
    """A custom residual unit layer for ResNet models.

    This layer implements a residual unit, which is a building block for ResNet models.
    It consists of a series of convolutional layers, batch normalization, and skip connections.

    Args:
        filters (int): The number of filters in the convolutional layers.
        strides (int, optional): The stride value for the convolutional layers. Defaults to 1.
        activation (str or callable, optional): The activation function to use. Defaults to "relu".

    """

    def __init__(self, filters, strides=1, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.strides = strides
        self.filters = filters
        self.activation = keras.activations.get(activation)
        self.main_layers = [
            DefaultConv2D(filters, strides=strides),
            keras.layers.BatchNormalization(),
            self.activation,
            DefaultConv2D(filters),
            keras.layers.BatchNormalization()
        ]
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                DefaultConv2D(filters, kernel_size=1, strides=strides),
                keras.layers.BatchNormalization()
            ]

    def build(self, input_shape):
        super().build(input_shape)

    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(Z + skip_Z)

    def get_config(self):
        config = super().get_config()
        config.update({
                "filters": self.filters,
                "strides": self.strides,
                "activation": keras.activations.serialize(self.activation)
            })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [None]:
model = keras.models.load_model("../models/best_model_resnet18.keras")

In [None]:
def evaluate_model_performance(model, test_ds, class_labels):
    """
    Evaluate the model's performance on the validation set and print the classification report.

    Args:
        model: A trained Keras model.
        test_ds: A tf.data.Dataset object containing the test set.
        class_labels: A list of class labels.
    """

    
    # Initialize a list to hold all labels
    true_labels = []

    # Iterate over the dataset
    for _ , labels in test_ds:
        true_labels.append(np.argmax(labels.numpy()))  # Extract labels and convert to NumPy arrays
 
    #  To get the predicted labels, we predict using the model  
    predictions = model.predict(test_ds, steps=len(test_ds))
    
    # Take the argmax to get the predicted class indices.
    predicted_labels = np.argmax(predictions, axis=1)


    # Classification report
    report = classification_report(true_labels, predicted_labels, target_names=class_labels)
    print(report)
    print('\n')
    
    # Define a custom colormap
    colors = ["white", "royalblue"]
    cmap_cm = LinearSegmentedColormap.from_list("cmap_cm", colors)

    # Confusion Matrix
    cm = confusion_matrix(true_labels, predicted_labels)

    # Plotting confusion matrix using seaborn
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, cmap=cmap_cm, fmt='d', xticklabels=class_labels, yticklabels=class_labels)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()


In [None]:
evaluate_model_performance(model, test_ds, class_names)