In [None]:

from loguru import logger
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
    f1_score,
    precision_score,
    recall_score,
)
from tensorflow.data import AUTOTUNE
from typing import List

from src import (
    Dataset as WSI_Dataset,
    ModelContext,
    ModelFactory,
)

import matplotlib.pyplot as plt
import numpy as np
import os
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"


In [None]:

VALIDATION_SPLIT = 0.2


In [None]:

contexts: List[ModelContext] = ModelFactory.models()
logger.info("Available models\n" + "\n".join([c.filename for c in contexts]))


In [None]:

_, _, _raw_test_ds = WSI_Dataset.get(validation_split=VALIDATION_SPLIT)

class_names = _raw_test_ds.class_names

logger.info(f"Raw test set with {len(_raw_test_ds)} samples and {len(_raw_test_ds.class_names)} of classes, which are {', '.join(_raw_test_ds.class_names)}")

test_ds = _raw_test_ds \
    .batch(1) \
    .cache() \
    .prefetch(buffer_size=AUTOTUNE)

logger.info(f"Test set with {len(test_ds)} samples")


In [None]:

for c in contexts:

    logger.info(f"Model: {c.filename}")
    c.model.summary()

    predictions = c.model.predict(
        test_ds,
        verbose=1,
    )

    actual = np.array([l.numpy() for _, l in test_ds])
    predicted = np.argmax(predictions, axis=-1)

    logger.info(f"Accuracy: {accuracy_score(actual, predicted)}")
    logger.info(f"Precision: {precision_score(actual, predicted, average='micro')}")
    logger.info(f"Sensitivity recall: {recall_score(actual, predicted, average='micro')}")
    logger.info(f"Specificity: {recall_score(actual, predicted, pos_label=0, average='micro')}")
    logger.info(f"F1 score: {f1_score(actual, predicted, average='micro')}")

    cm = confusion_matrix(actual, predicted)

    cm_display = ConfusionMatrixDisplay(
        confusion_matrix=cm,
        display_labels=class_names,
    )

    cm_display.plot(
        cmap="Blues",
        ax=plt.subplots(figsize=(9, 9))[1]
    )

    plt.show()
