In [None]:
import os

os.environ["KERAS_BACKEND"] = "torch"
import keras
from keras import layers

IMAGE_SIZE = (224, 224)
BATCH_SIZE = 64

In [None]:
_data_dir = os.path.join("data", "fish_lizard_monkey_snake", "images")
TRAIN_DS, VAL_DS = keras.utils.image_dataset_from_directory(
    _data_dir,
    labels="inferred",
    label_mode="categorical",
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    shuffle=True,
    interpolation="bilinear",
    pad_to_aspect_ratio=True,
    validation_split=0.2,
    subset="both",
    seed=42,
)

In [None]:
early_stop_cb = keras.callbacks.EarlyStopping(
    min_delta=0.015,
    monitor="val_accuracy",
    mode="max",
    patience=8,
    restore_best_weights=True,
    verbose=1,
)
checkpointing_cb = keras.callbacks.ModelCheckpoint(
    "bird_dog_classifier.keras",
    monitor="val_accuracy",
    mode="max",
    verbose=1,
    save_best_only=True,
    save_weights_only=False,
    save_freq="epoch",
    initial_value_threshold=None,
)
lr_cb = keras.callbacks.ReduceLROnPlateau(
    min_delta=0.015,
    monitor="val_accuracy",
    mode="max",
    patience=6,
    factor=0.5,
    min_lr=0.00001,
    verbose=1,
)
METRICS = [
    keras.metrics.Precision(name="precision"),
    keras.metrics.Recall(name="recall"),
    keras.metrics.F1Score(name="f1"),
    keras.metrics.CategoricalAccuracy(name="accuracy"),
]

In [None]:
# Load model and remove last layer
transfer_model = keras.models.load_model("bird_dog_classifier.keras")
transfer_model.pop()

# Add new output layer
new_layer = layers.Dense(4, name="new_output_v01")
transfer_model.add(new_layer)

# Freeze all layers except the last one
for layer in transfer_model.layers[:-1]:  # Skip the last layer
    layer.trainable = False

transfer_model.summary()

transfer_model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=METRICS,
)

transfer_history = transfer_model.fit(
    TRAIN_DS,
    validation_data=VAL_DS,
    epochs=50,
    callbacks=[early_stop_cb, checkpointing_cb, lr_cb],
)