In [None]:
import os
from random import randint

import numpy as np
import tensorflow as tf
from keras.src.layers import Dropout, GlobalAveragePooling2D
from sklearn.utils import class_weight
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import layers
from tensorflow.keras.models import Model


In [None]:
dst_path = "/Users/sosen/UniProjects/eng-thesis/data/data-uncompressed/2D-tiff-brightest-areas-sted-128"

# Define parameters
batch_size = 64
img_height = 128
img_width = 128

In [None]:
def copy_red_to_green_and_blue(image, label):
    """
    This function takes an image and replaces the green and blue channels 
    with the values from the red channel.
    """
    # Repeat the red channel across the RGB channels
    # image[..., 0] is the red channel of the image
    red_channel = image[..., 0:1]  # Extract only the red channel, shape (H, W, 1)
    new_image = tf.concat([red_channel, red_channel, red_channel], axis=-1)
    return new_image, label

In [None]:

train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
    directory=dst_path,
    labels='inferred',
    subset="both",
    label_mode='categorical',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    validation_split=0.3,
    seed=123,
)

In [None]:

# Retrieve number of classes
class_names = train_ds.class_names
num_classes = len(class_names)

# Data augmentation function
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1. / 255),
    tf.keras.layers.RandomRotation(0.1),
    # You can add more augmentations if needed
    # tf.keras.layers.RandomZoom(0.15),
    # tf.keras.layers.RandomWidth(0.2),
    # tf.keras.layers.RandomHeight(0.2),
    # tf.keras.layers.RandomShear(0.15),
    tf.keras.layers.RandomFlip("horizontal"),
])

# Augment the datasets

train_ds = train_ds.map(lambda x, y: (tf.image.per_image_standardization(x), y))
val_ds = val_ds.map(lambda x, y: (tf.image.per_image_standardization(x), y))

# Modify the datasets to use the red channel for all
train_ds = train_ds.map(copy_red_to_green_and_blue)
val_ds = val_ds.map(copy_red_to_green_and_blue)

train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))
val_ds = val_ds.map(lambda x, y: (data_augmentation(x, training=False), y))

# Optimize dataset performance
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
# x_sample, y_sample = next(train_ds)
# print("Shape of input batch: ", x_sample.shape)
# print("Shape of labels batch: ", y_sample.shape)

In [None]:
from tensorflow.keras import backend as K


def f1(y_true, y_pred):
    def precision(y_true, y_pred):
        '''
        Precision metric. Only computes a batch-wise average of precision. Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        '''
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    def recall(y_true, y_pred):
        '''
        Recall metric. Only computes a batch-wise average of recall. Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        '''
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    y_pred = K.round(y_pred)
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

In [None]:


base_model = tf.keras.applications.EfficientNetB0(weights='imagenet', include_top=False,
                                                  input_shape=(img_width, img_height, 3))

base_model.trainable = False
x = base_model.output
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.2, name="top_dropout")(x)
# x = layers.Dense(512, activation='relu')(x)
predictions = layers.Dense(num_classes, activation='softmax', name="pred")(x)  # Replace num_classes with your actual number of classes

# Define the final model
model = Model(inputs=base_model.input, outputs=predictions, name="EfficientNet")

# Compile the model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy', f1])


In [None]:

# Count occurrences of each class in the training dataset
labels = np.concatenate([y for x, y in train_ds], axis=0)
label_indices = np.argmax(labels, axis=1)

# Compute class weights
class_weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(label_indices),
    y=label_indices
)

train_class_weights = dict(enumerate(class_weights))


early_stopping = EarlyStopping(
    monitor='val_f1',  # specify the F1 score for early stopping
    patience=3,
    mode='max',  # since higher F1 scores are better
    restore_best_weights=True
)

# Train the model
history = model.fit(
    train_ds,
    epochs=20,
    validation_data=val_ds,
    class_weight=train_class_weights,
    callbacks=[early_stopping]  # add the early stopping callback

)

In [None]:
from sklearn.metrics import classification_report
import numpy as np
import tensorflow as tf

# Assuming model is already trained and compiled

# Predict on validation data
predictions = []
y_true = []

# Iterate over the validation dataset to collect true labels and predictions
for images, labels in val_ds:
    preds = model.predict(images)
    predictions.extend(np.argmax(preds, axis=1))
    y_true.extend(np.argmax(labels.numpy(), axis=1))  # Convert one-hot to integer labels

# Convert lists to numpy arrays
y_pred = np.array(predictions)
y_true = np.array(y_true)

print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(3, figsize=(10, 15))

# Plot training & validation accuracy values
axs[0].plot(history.history['accuracy'])
axs[0].plot(history.history['val_accuracy'])
axs[0].set_title('Model accuracy')
axs[0].set_ylabel('Accuracy')
axs[0].set_xlabel('Epoch')
axs[0].legend(['Train', 'Val'], loc='upper left')

# Plot training & validation loss values
axs[1].plot(history.history['loss'])
axs[1].plot(history.history['val_loss'])
axs[1].set_title('Model loss')
axs[1].set_ylabel('Loss')
axs[1].set_xlabel('Epoch')
axs[1].legend(['Train', 'Val'], loc='upper left')
# Plot training & validation F1 score values
axs[2].plot(history.history['f1'])
axs[2].plot(history.history['val_f1'])
axs[2].set_title('Model F1 Score')
axs[2].set_ylabel('F1 Score')
axs[2].set_xlabel('Epoch')
axs[2].legend(['Train', 'Val'], loc='upper left')

plt.show()