In [4]:
import numpy as np
from keras.src.layers import Dropout
from sklearn.utils import class_weight
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import tensorflow as tf

In [5]:
dst_path = "/Users/sosen/UniProjects/eng-thesis/data/old/temp/CROPPED/NON-STED"

# Define parameters
batch_size = 32

img_height = 260
img_width = 260

In [6]:
# Load the Data
train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
    directory=dst_path,
    labels='inferred',
    subset="both",
    label_mode='categorical',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    validation_split=0.2,
    seed=123,
)
class_names = train_ds.class_names
num_classes = len(class_names)

Found 2921 files belonging to 6 classes.
Using 2337 files for training.
Using 584 files for validation.


2024-12-08 22:53:55.104831: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2024-12-08 22:53:55.104872: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2024-12-08 22:53:55.104883: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
2024-12-08 22:53:55.104921: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-12-08 22:53:55.104950: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [7]:
def copy_red_to_green_and_blue(image,label):
    """
    This function takes an image and replaces the green and blue channels 
    with the values from the red channel.
    """
    # Repeat the red channel across the RGB channels
    # image[..., 0] is the red channel of the image
    red_channel = image[..., 0:1]  # Extract only the red channel, shape (H, W, 1)
    new_image = tf.concat([red_channel, red_channel, red_channel], axis=-1)
    return new_image, label

In [8]:
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
train_ds = train_ds.map(copy_red_to_green_and_blue)
val_ds = val_ds.map(copy_red_to_green_and_blue)

In [9]:
from tensorflow.keras import backend as K


def f1(y_true, y_pred):
    def precision(y_true, y_pred):
        '''
        Precision metric. Only computes a batch-wise average of precision. Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        '''
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    def recall(y_true, y_pred):
        '''
        Recall metric. Only computes a batch-wise average of recall. Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        '''
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    y_pred = K.round(y_pred)
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

In [10]:
def normalize(x):
    # mean = [0.485, 0.456, 0.406]
    # std = [0.229, 0.224, 0.225]
    mean = [0.24155269834305526,0.24155269834305526,0.24155269834305526]
    std = [0.2809975193618667, 0.2809975193618667, 0.2809975193618667]
    # mean = tf.constant([61.59593603488316, 61.59593603488316, 61.59593603488316])
    # std = tf.constant([71.65436657360661, 71.65436657360661, 71.65436657360661])
    return (x - mean) / std

In [11]:
data_augmentation = tf.keras.Sequential([
    # tf.keras.layers.Lambda(copy_red_to_green_and_blue),
    # tf.keras.layers.Rescaling(1. / 255),  # Scale images to [0, 1]
    # tf.keras.layers.Lambda(normalize),  # Normalize images
    # tf.keras.layers.Rescaling(scale=2, offset=-1),  # Scale images to [-1, 1]
    # You can add more augmentations if needed
    # tf.keras.layers.RandomZoom(0.15),
    # tf.keras.layers.RandomWidth(0.1),
    # tf.keras.layers.RandomHeight(0.1),
    # tf.keras.layers.RandomContrast(0.05),
    tf.keras.layers.RandomRotation(0.1),
    tf.keras.layers.RandomFlip("horizontal"),
], name="data_augmentation")

In [12]:
from tensorflow.keras import layers

inputs = layers.Input(shape=(img_height, img_width, 3))
x = data_augmentation(inputs)

model = tf.keras.applications.EfficientNetV2B0(weights='imagenet', include_top=False, input_tensor=x, pooling=None, include_preprocessing=True)
model.trainable = False

AVG_POOL_TOP = True
if AVG_POOL_TOP:
    # Rebuild top
    x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
    x = layers.BatchNormalization()(x)

    top_dropout_rate = 0.2
    x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)
    outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)
else:
    x = Flatten()(model.output)
    x = Dense(1000, activation='relu')(x)
    outputs = Dense(num_classes, activation='softmax')(x)

# Compile
model = tf.keras.Model(inputs, outputs, name="EfficientNet")

# Compile the model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy', f1])

print(model.summary())

Model: "EfficientNet"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 260, 260, 3)]        0         []                            
                                                                                                  
 data_augmentation (Sequent  (None, 260, 260, 3)          0         ['input_1[0][0]']             
 ial)                                                                                             
                                                                                                  
 rescaling (Rescaling)       (None, 260, 260, 3)          0         ['data_augmentation[0][0]']   
                                                                                                  
 normalization (Normalizati  (None, 260, 260, 3)          0         ['rescaling[0][0]']

In [13]:
# Count occurrences of each class in the training dataset
labels = np.concatenate([y for x, y in train_ds], axis=0)
label_indices = np.argmax(labels, axis=1)

# Compute class weights
class_weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(label_indices),
    y=label_indices
)

train_class_weights = dict(enumerate(class_weights))

early_stopping = EarlyStopping(
    monitor='f1',  # specify the F1 score for early stopping
    patience=4,
    mode='max',  #
    restore_best_weights=True
)

# Train the model
history = model.fit(
    train_ds,
    epochs=20,
    validation_data=val_ds,
    class_weight=train_class_weights,
    callbacks=[early_stopping]
)

Epoch 1/20


2024-12-08 22:54:02.628044: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2024-12-08 22:54:03.118679: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


 1/74 [..............................] - ETA: 6:42 - loss: 3.5931 - accuracy: 0.1250 - f1: 0.0851

2024-12-08 22:54:05.337501: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


 3/74 [>.............................] - ETA: 15s - loss: 2.9799 - accuracy: 0.1667 - f1: 0.1157

2024-12-08 22:54:05.550182: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:05.742640: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


 5/74 [=>............................] - ETA: 14s - loss: 2.4984 - accuracy: 0.1500 - f1: 0.0950

2024-12-08 22:54:05.941337: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:06.134277: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


 6/74 [=>............................] - ETA: 13s - loss: 2.6188 - accuracy: 0.1510 - f1: 0.0927

2024-12-08 22:54:06.322214: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


 8/74 [==>...........................] - ETA: 13s - loss: 2.3776 - accuracy: 0.1797 - f1: 0.1250

2024-12-08 22:54:06.569320: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:06.769471: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


 9/74 [==>...........................] - ETA: 13s - loss: 2.3642 - accuracy: 0.1979 - f1: 0.1435

2024-12-08 22:54:06.951176: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


10/74 [===>..........................] - ETA: 13s - loss: 2.3373 - accuracy: 0.1969 - f1: 0.1422

2024-12-08 22:54:07.210822: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


12/74 [===>..........................] - ETA: 13s - loss: 2.2576 - accuracy: 0.2057 - f1: 0.1628

2024-12-08 22:54:07.438558: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:07.647958: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


14/74 [====>.........................] - ETA: 12s - loss: 2.1839 - accuracy: 0.2143 - f1: 0.1689

2024-12-08 22:54:07.896021: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


15/74 [=====>........................] - ETA: 12s - loss: 2.1408 - accuracy: 0.2271 - f1: 0.1831

2024-12-08 22:54:08.104745: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:08.297851: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


17/74 [=====>........................] - ETA: 12s - loss: 2.0524 - accuracy: 0.2537 - f1: 0.2047

2024-12-08 22:54:08.512906: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:08.713571: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:08.933030: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:09.123649: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:09.307433: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:09.523154: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:09.719649: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:09.912629: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:10.113344: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:10.323644: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:10.528226: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:10.749818: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:10.941145: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:11.158612: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:11.417268: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:11.589663: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:11.804593: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:12.038293: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:12.232062: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:12.440658: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:12.667204: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:12.855638: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:13.058717: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:13.280275: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:13.479146: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:13.687331: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:13.908059: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:14.106753: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:14.328192: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:14.569080: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:14.759974: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:14.961718: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:15.213048: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:15.409592: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:15.602641: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:15.815860: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:15.978041: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:16.176721: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:16.378801: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:16.587622: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:16.786218: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:16.988791: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:17.222743: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:17.434302: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:17.638750: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:17.873717: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:18.123037: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:18.320955: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:18.534088: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:18.723877: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:18.933914: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:19.178142: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:19.350125: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:19.524710: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:19.727468: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:19.921488: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.




2024-12-08 22:54:20.120688: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2024-12-08 22:54:20.319672: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import classification_report
import numpy as np
import tensorflow as tf

# Assuming model is already trained and compiled

# Predict on validation data
predictions = []
y_true = []

# Iterate over the validation dataset to collect true labels and predictions
for images, labels in val_ds:
    preds = model.predict(images)
    predictions.extend(np.argmax(preds, axis=1))
    y_true.extend(np.argmax(labels.numpy(), axis=1))  # Convert one-hot to integer labels

# Convert lists to numpy arrays
y_pred = np.array(predictions)
y_true = np.array(y_true)

print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Normalize the confusion matrix by dividing each row (true condition) by the sum of the row
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plotting the normalized confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Normalized Confusion Matrix')
plt.show()

In [None]:

fig, axs = plt.subplots(3, figsize=(10, 15))

# Plot training & validation accuracy values
axs[0].plot(history.history['accuracy'])
axs[0].plot(history.history['val_accuracy'])
axs[0].set_title('Model Accuracy')
axs[0].set_ylabel('Accuracy')
axs[0].set_xlabel('Epoch')
axs[0].legend(['Train', 'Val'], loc='upper left')
axs[0].set_xticks(range(0, len(history.history['accuracy']), max(1, len(history.history['accuracy']) // 10)))

# Plot training & validation loss values
axs[1].plot(history.history['loss'])
axs[1].plot(history.history['val_loss'])
axs[1].set_title('Model Loss')
axs[1].set_ylabel('Loss')
axs[1].set_xlabel('Epoch')
axs[1].legend(['Train', 'Val'], loc='upper left')
axs[1].set_xticks(range(0, len(history.history['loss']), max(1, len(history.history['loss']) // 10)))

# Plot training & validation F1 score values
axs[2].plot(history.history['f1'])
axs[2].plot(history.history['val_f1'])
axs[2].set_title('Model F1 Score')
axs[2].set_ylabel('F1 Score')
axs[2].set_xlabel('Epoch')
axs[2].legend(['Train', 'Val'], loc='upper left')
axs[2].set_xticks(range(0, len(history.history['f1']), max(1, len(history.history['f1']) // 10)))

plt.tight_layout()  # Adjust layout to prevent overlap
plt.show()

plt.show()

In [None]:
from tensorflow.keras.optimizers import Adam


def unfreeze_model(original_model):
    # Clone the original model's architecture
    # new_model = clone_model(original_model)
    new_model = original_model
    # Copy the weights from the original model to the new model
    new_model.set_weights(original_model.get_weights())

    # Unfreeze the top 10 layers, except for BatchNormalization layers
    for layer in new_model.layers[-20:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True

    # Compile the new modeled with modified optimization settings
    optimizer = Adam(learning_rate=1e-4)
    new_model.compile(
        optimizer=optimizer,
        loss="categorical_crossentropy",
        metrics=['accuracy', f1]  # ensure f1_metric is defined or imported
    )

    return new_model

# Example usage:
# Assuming `model` is your pre-trained model
new_model = unfreeze_model(model)

early_stopping_unfreeze = EarlyStopping(
    monitor='f1',  # specify the F1 score for early stopping
    patience=4,
    mode='max',  #
    restore_best_weights=True
)
epochs = 10  # @param {type: "slider", min:4, max:10}
history_finetune = new_model.fit(train_ds, epochs=epochs, validation_data=val_ds, callbacks=[early_stopping])

In [None]:
fig, axs = plt.subplots(3, figsize=(10, 15))

# Plot training & validation accuracy values
axs[0].plot(history_finetune.history['accuracy'])
axs[0].plot(history_finetune.history['val_accuracy'])
axs[0].set_title('Model accuracy')
axs[0].set_ylabel('Accuracy')
axs[0].set_xlabel('Epoch')
axs[0].legend(['Train', 'Val'], loc='upper left')

# Plot training & validation loss values
axs[1].plot(history_finetune.history['loss'])
axs[1].plot(history_finetune.history['val_loss'])
axs[1].set_title('Model loss')
axs[1].set_ylabel('Loss')
axs[1].set_xlabel('Epoch')
axs[1].legend(['Train', 'Val'], loc='upper left')
# Plot training & validation F1 score values
axs[2].plot(history_finetune.history['f1'])
axs[2].plot(history_finetune.history['val_f1'])
axs[2].set_title('Model F1 Score')
axs[2].set_ylabel('F1 Score')
axs[2].set_xlabel('Epoch')
axs[2].legend(['Train', 'Val'], loc='upper left')

In [None]:
predictions = []
y_true = []

# Iterate over the validation dataset to collect true labels and predictions
for images, labels in val_ds:
    preds = new_model.predict(images)
    predictions.extend(np.argmax(preds, axis=1))
    y_true.extend(np.argmax(labels.numpy(), axis=1))  # Convert one-hot to integer labels

# Convert lists to numpy arrays
y_pred = np.array(predictions)
y_true = np.array(y_true)

print(classification_report(y_true, y_pred, target_names=class_names))