In [None]:
import os

import numpy as np
import pandas as pd
import tensorflow as tf

from .utils._logger import logger
from .utils._validation import config_args

In [None]:
# Data paths
train_dir: str = os.path.join(config_args.base_dir, "train")
test_dir: str = os.path.join(config_args.base_dir, "test")

# Create dataset
train_dataset = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    shuffle=True,
    batch_size=config_args.batch_size,
    image_size=tuple(config_args.image_size),
    seed=config_args.seed,
    label_mode="categorical",
)
test_dataset = tf.keras.utils.image_dataset_from_directory(
    test_dir,
    shuffle=False,
    batch_size=config_args.batch_size,
    image_size=tuple(config_args.image_size),
    seed=config_args.seed,
    label_mode="categorical",
)

In [None]:
CLASS_NAMES = train_dataset.class_names
NUM_CLASS: int = len(train_dataset.class_names)

In [None]:
# Cache setup
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.cache().prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
# Base model setup
base_model = tf.keras.applications.EfficientNetV2B0(
    include_top=False,
    weights="imagenet",
    input_shape=tuple(config_args.image_size) + (3,),
    pooling="avg",
)
base_model.trainable = False

In [None]:
# Augmentation
data_augmentation = tf.keras.Sequential(
    [
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.4),
        tf.keras.layers.RandomHeight(0.4),
        tf.keras.layers.RandomWidth(0.4),
        tf.keras.layers.RandomZoom(0.4),
    ]
)

In [None]:
# Model structure
inputs = tf.keras.Input(shape=tuple(config_args.image_size) + (3,))
x = data_augmentation(inputs)
x = base_model(x, training=False)
x = tf.keras.layers.Dense(512, activation="relu")(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(NUM_CLASS, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)

In [None]:
# Compiling model
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=3, verbose=1, restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=config_args.output_dir + "checkpoint.keras",
        monitor="val_loss",
        verbose=1,
        save_best_only=True,
    ),
]


model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config_args.base_learning_rate),
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)

In [None]:
# Training
r = model.fit(
    train_dataset, epochs=20, validation_data=test_dataset, callbacks=[callbacks]
)

In [None]:
# Evaluation
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

model = tf.keras.models.load_model(config_args.output_dir + "checkpoint.keras")
test_preds = model.predict(test_dataset)
test_ds_labels = np.concatenate([labels for images, labels in test_dataset], axis=0)

test_preds_labels = test_preds.argmax(axis=-1)
test_ds_labels_argmax = test_ds_labels.argmax(axis=-1)

# Create a confusion matrix
confusion_matrix_preds = confusion_matrix(
    y_true=test_ds_labels_argmax, y_pred=test_preds_labels
)

# Create a confusion matrix plot
confusion_matrix_display = ConfusionMatrixDisplay(
    confusion_matrix=confusion_matrix_preds, display_labels=CLASS_NAMES
)
fig, ax = plt.subplots(figsize=(6, 6))
confusion_matrix_display.plot(
    xticks_rotation="vertical", cmap="Blues", colorbar=False, ax=ax
)

In [None]:
test_preds = model.predict(test_dataset)
test_preds_labels = test_preds.argmax(axis=-1)
test_ds_labels = np.concatenate([labels for images, labels in test_dataset], axis=0)
test_ds_labels_argmax = test_ds_labels.argmax(axis=-1)

test_pred_probs_max = tf.reduce_max(test_preds, axis=-1).numpy()

test_results_df = pd.DataFrame(
    {
        "test_pred_label": test_preds_labels,
        "test_pred_prob": test_pred_probs_max,
        "test_pred_class_name": [
            CLASS_NAMES[test_pred_label] for test_pred_label in test_preds_labels
        ],
        "test_truth_label": test_ds_labels_argmax,
        "test_truth_class_name": [
            CLASS_NAMES[test_truth_label] for test_truth_label in test_ds_labels_argmax
        ],
    }
)

test_results_df["correct"] = (
    test_results_df["test_pred_class_name"] == test_results_df["test_truth_class_name"]
)

test_results_df.head()

In [None]:
accuracy_per_class = test_results_df.groupby("test_truth_class_name")["correct"].mean()

accuracy_per_class_df = (
    pd.DataFrame(accuracy_per_class)
    .reset_index()
    .sort_values("correct", ascending=False)
)
accuracy_per_class_df.head(), accuracy_per_class_df.tail()

In [None]:
def predict_custom(img):
    """Predict over given custom image."
    Args:
        img: Path of image."""
    try:
        model = tf.keras.models.load_model(config_args.output_dir + "checkpoint.keras")
        img = tf.keras.utils.load_img(
            path=img, color_mode="rgb", target_size=(224, 224)
        )
        img = tf.keras.utils.img_to_array(img)
        prediction = model.predict(np.expand_dims(img, axis=0))
        return round(np.max(prediction[0], axis=-1), 2), CLASS_NAMES[
            np.argmax(prediction[0], axis=-1)
        ]
    except Exception as e:
        logger.error(f"Predicting custom image failed: {e}")