In [None]:
from loguru import logger
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
    f1_score,
    precision_score,
    recall_score,
)
from tensorflow.data import (
    AUTOTUNE,
    Dataset,
)
from tensorflow.keras.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    TensorBoard,
)
from tensorflow.keras.layers import (
    RandomFlip,
    RandomRotation,
    RandomZoom,
)
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from typing import List

from src import (
    Dataset as WSI_Dataset,
    ModelContext,
    ModelFactory,
)

import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

%load_ext tensorboard


In [None]:
MODEL_FILENAME = ""

BATCH_SIZE = 4
VALIDATION_SPLIT = 0.2
DATA_AUGMENTATION = True

MODEL_RECOMPILE = True

EPOCHS = 100
LEARNING_RATE = 0.00001
OPTIMIZER = Adam(learning_rate=LEARNING_RATE)
LOSS_FUNCTION = SparseCategoricalCrossentropy(from_logits=True)

FIT_CALLBACKS = [
    EarlyStopping(
        monitor="val_loss",
        patience=10,
        verbose=1,
        restore_best_weights=True,
    ),
    TensorBoard(
        log_dir=f"logs/fit/{MODEL_FILENAME}",
        histogram_freq=1,
        profile_batch=0,
    ),
]


In [None]:
contexts: List[ModelContext] = ModelFactory.models()

model_list: List[ModelContext] = list(
    filter(
        lambda c: c.filename == MODEL_FILENAME,
        contexts
    )
)

if len(model_list) < 1:
    logger.error(f"Model {MODEL_FILENAME} not found")
    exit(1)

context = model_list[0]


In [None]:
_raw_train_ds, _raw_val_ds, _raw_test_ds = WSI_Dataset.get(validation_split=VALIDATION_SPLIT)

class_names = _raw_train_ds.class_names

logger.info(f"Raw train set with {len(_raw_train_ds)} samples and {len(_raw_train_ds.class_names)} of classes, which are {', '.join(_raw_train_ds.class_names)}")
logger.info(f"Raw validation set with {len(_raw_val_ds)} samples and {len(_raw_val_ds.class_names)} of classes, which are {', '.join(_raw_val_ds.class_names)}")
logger.info(f"Raw test set with {len(_raw_test_ds)} samples and {len(_raw_test_ds.class_names)} of classes, which are {', '.join(_raw_test_ds.class_names)}")


In [None]:
def _process_ds(ds: Dataset, batch: int, shuffle: bool) -> Dataset:
    ds = ds.batch(batch)

    if shuffle:
        ds.shuffle(buffer_size=500, reshuffle_each_iteration=True)

    ds = ds.cache()
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds

def _augment(ds: Dataset) -> Dataset:
    data_augmentation = Sequential(
        [
            RandomFlip("horizontal"),
            RandomRotation(0.1),
            RandomZoom(0.1),
        ]
    )

    result = ds.map(
        lambda x, y: (data_augmentation(x, training=True), y),
        num_parallel_calls=AUTOTUNE,
    )

    return result


if DATA_AUGMENTATION:
    _raw_train_ds = _augment(_raw_train_ds)


train_ds = _process_ds(_raw_train_ds, batch=BATCH_SIZE, shuffle=True)
val_ds = _process_ds(_raw_val_ds, batch=BATCH_SIZE, shuffle=False)
test_ds = _process_ds(_raw_test_ds, batch=1, shuffle=False)

logger.info(f"Batched train set with {len(train_ds)} samples")
logger.info(f"Batched validation set with {len(val_ds)} samples")
logger.info(f"Test set with {len(test_ds)} samples")


In [None]:
if MODEL_RECOMPILE:
    context.model.compile(
        optimizer=OPTIMIZER,
        loss=LOSS_FUNCTION,
        metrics=["accuracy"],
    )

context.model.summary()


In [None]:
%tensorboard --logdir logs/fit

context.model.fit(
    train_ds,
    callbacks=FIT_CALLBACKS,
    validation_data=val_ds,
    epochs=EPOCHS,
    verbose=1,
)


In [None]:
predictions = context.model.predict(
    test_ds,
    verbose=1,
)

logger.debug(f"Predictions shape: {predictions.shape}")
logger.debug(f"Predictions\n{predictions}")


actual = np.array([l.numpy() for _, l in test_ds])
predicted = np.argmax(predictions, axis=-1)

logger.debug(f"Actual shape: {actual.shape}")
logger.debug(f"Actual values\n{actual}")

logger.debug(f"Predicted shape: {predicted.shape}")
logger.debug(f"Predicted values\n{predicted}")


cm = confusion_matrix(actual, predicted)

logger.debug(f"Confusion Matrix\n{cm}")


cm_display = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=class_names,
)

cm_display.plot(
    cmap="Blues",
    ax=plt.subplots(figsize=(9, 9))[1]
)

plt.show()


logger.info(f"Accuracy: {accuracy_score(actual, predicted)}")
logger.info(f"Precision: {precision_score(actual, predicted, average='micro')}")
logger.info(f"Sensitivity recall: {recall_score(actual, predicted, average='micro')}")
logger.info(f"Specificity: {recall_score(actual, predicted, pos_label=0, average='micro')}")
logger.info(f"F1 score: {f1_score(actual, predicted, average='micro')}")


In [None]:
context.model.summary()
context.save()
