In [None]:
import shutil
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, utils
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint
from pathlib import Path
from sklearn.metrics import ConfusionMatrixDisplay

from python import classes, metrics, data_loader

In [None]:
notebook_name = "main"
notebook_classification = ""
notebook_cv = 0

image_size = 128
batch_size = 64
max_epochs = 200
early_stopping_patience = 5
total_cv = 5

data_dir = "../../data"
tensorboard_dir = "../../out/logs"
metrics_dir = "../../out/metrics"
models_dir = "../../out/keras"
weights_dir = "../../out/weights"

In [None]:
assert notebook_name != "", "notebook_name must be provided"
assert notebook_classification in ['models', 'types'], "notebook_classification must be one of ['models', 'types']"
assert notebook_cv != 0, "notebook_cv must be provided"
assert notebook_cv <= total_cv, "notebook_cv must not be greater than total_cv"

In [None]:
class_names = classes.class_names[notebook_classification]
classes_num = len(class_names)
metrics.classes_num = classes_num

notebook_model = "simple"
data_dir = Path(data_dir) / f"{notebook_classification}"
tensorboard_dir = Path(tensorboard_dir) / f"{notebook_model}/{notebook_classification}/{notebook_name}/cv{notebook_cv}"
metrics_file = Path(metrics_dir) / f"{notebook_model}/{notebook_classification}/{notebook_name}.json"
model_file = Path(models_dir) / f"{notebook_model}/{notebook_classification}/{notebook_name}/cv{notebook_cv}.keras"
weights_file = Path(weights_dir) / f"{notebook_model}/{notebook_classification}/{notebook_name}/cv{notebook_cv}.weights.h5"

In [None]:
assert not model_file.is_file(), "This model already exists"

In [None]:
shutil.rmtree(tensorboard_dir, ignore_errors=True)

tensorboard_dir.mkdir(parents=True, exist_ok=True)
metrics_file.parent.mkdir(parents=True, exist_ok=True)
model_file.parent.mkdir(parents=True, exist_ok=True)
weights_file.parent.mkdir(parents=True, exist_ok=True)

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
train_data, val_data = data_loader.load_data(
    data_dir=data_dir,
    val_fold=notebook_cv,
    total_folds=total_cv,
    class_names=class_names,
    batch_size=batch_size,
    image_size=image_size,
    buffer_size=10000
)

In [None]:
model = models.Sequential()
model.add(layers.Input(shape=(image_size, image_size, 3)))
model.add(layers.Rescaling(1./255)),

model.add(layers.Conv2D(32, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Dropout(0.2))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Dropout(0.2))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Dropout(0.2))
model.add(layers.Conv2D(32, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dropout(0.4))
model.add(layers.Dense(64, activation='relu'))

model.add(layers.Dense(classes_num, activation='softmax'))

model.summary()

In [None]:
model.compile(
    optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
    metrics=["accuracy", metrics.f1_m, metrics.precision_m, metrics.recall_m]
)

In [None]:
checkpoint = ModelCheckpoint(weights_file, save_best_only=True, save_weights_only=True)
early_stop = EarlyStopping(monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True)
tensorboard = TensorBoard(log_dir=tensorboard_dir)

model.fit(
    train_data, epochs=max_epochs, validation_data=val_data, 
    callbacks=[
        checkpoint, early_stop, tensorboard
    ]
);

In [None]:
y_true, y_pred = metrics.model_predict(model, val_data)

In [None]:
fold_metrics = metrics.evaluate_metrics(y_true, y_pred)
metrics.save_metrics(fold_metrics, metrics_file, notebook_cv)
fold_metrics

In [None]:
fig, ax = plt.subplots(figsize=(max(classes_num, 6), max(classes_num, 6)))
cmp = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize="true", display_labels=class_names, colorbar=False, xticks_rotation='vertical', ax=ax)

In [None]:
model.save(model_file)