### Importing key libraries

In [None]:
import os, json, warnings
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

warnings.filterwarnings('ignore')

### TF configuration

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)
session = tf.compat.v1.Session(config=config)

### Hyperparameters

In [None]:
IMG_SIZE = 224
LEARNING_RATE = 0.0001
BATCH_SIZE = 32
DROPOUT_PROB = 0.3
NUM_EPOCHS = 200
# EarlyStopping
PATIENCE = 5 # no. of epochs with no improvement after which training will be stopped
# ReduceLROnPlateau
LR_PATIENCE = 5 # no. of epochs with no improvement after which learning rate will be reduced
LR_FACTOR = 0.1 # factor by which the learning rate will be reduced
MIN_LR = 1e-6 # lower bound on the learning rate

### Dataset splitting (one-time)

In [None]:
output_dir = r'splitted_data' # Intentionally declared in separate cell

In [None]:
# # !pip install -q split-folders
# import splitfolders

# dataset_dir = r'dataset'

# # Make output dir if doesn't exist, else skip
# os.makedirs(output_dir, exist_ok=True)

# splitfolders.ratio(dataset_dir, output=output_dir, seed=1337, ratio=(0.75, 0.15, 0.1))
# print(f"Dataset successfully splitted into: {os.listdir(output_dir)}")

### Dataset config

In [None]:
def helper_ds(partition, shuffle_status=True):
    return tf.keras.utils.image_dataset_from_directory(
        directory=os.path.join(output_dir, partition),
        image_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        seed=1337,
        labels='inferred',
        label_mode='binary',
        shuffle=shuffle_status
    )

train_ds = helper_ds('train')
test_ds = helper_ds('test', shuffle_status=False)
val_ds = helper_ds('val')

### Data visualization

In [None]:
plt.figure(figsize=(10, 8))
for images, labels in train_ds.take(1):
    for i in range(12):
        ax = plt.subplot(3, 4, i + 1)
        plt.imshow(np.array(images[i]).astype("uint8"))
        state = "active" if int(labels[i]) else "drowsy"
        plt.title(state)
        plt.axis("off")
        plt.tight_layout()

In [None]:
class_names = train_ds.class_names
print(f"Class names: {class_names}")

# Saving the class names
with open('class_names.json', 'w') as f:
    json.dump(class_names, f)

# Sanity check
print(f"Class names successfully saved.") if os.path.exists('class_names.json') else print(f"Failed to save the file.")

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

# Normalize pixel values
# normalization_layer = tf.keras.layers.Rescaling(1./255)
# train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
# val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)
# test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y), num_parallel_calls=AUTOTUNE)

# Enable prefetching and shuffling
train_ds = train_ds.shuffle(buffer_size=10000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)

### Data augmentation example (optional)

In [None]:
data_aug = [
    tf.keras.layers.RandomFlip("horizontal", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.2),
    tf.keras.layers.RandomContrast(0.2),
    tf.keras.layers.RandomBrightness([-0.3, 0.1])
]

def data_augmentation(images):
    for layer in data_aug:
        images = layer(images)
    return images

plt.figure(figsize=(10, 6))
for images, _ in train_ds.take(1):
    for i in range(8):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(2, 4, i + 1)
        plt.imshow(np.array(augmented_images[0]).astype("uint8"))
        plt.title(f"augmented {i+1}")
        plt.axis("off")
        plt.tight_layout()

In [None]:
# Retain class names after applying map
train_ds.class_names = class_names
print(train_ds.class_names)

In [None]:
# Double checking if train_ds exists and has class_names
if hasattr(train_ds, 'class_names'):
    print("Class Names:", train_ds.class_names)
else:
    print("Error: train_ds does not have class_names. Verify dataset creation.")

### Stacking models

In [None]:
def stacked_model(input_shape=[IMG_SIZE, IMG_SIZE, 3]):

    d_net = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', input_shape=input_shape)
    e_net = tf.keras.applications.EfficientNetV2B2(include_top=False, weights='imagenet', input_shape=input_shape)
    m_net = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=input_shape)

    # Freeze the models
    d_net.trainable=False; e_net.trainable=False; m_net.trainable=False

    # Input layer
    inputs = tf.keras.Input(shape=input_shape)

    # Data augmentation
    data_augmentation = tf.keras.Sequential([
        tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.2),
        # tf.keras.layers.RandomZoom(0.2),
        tf.keras.layers.RandomContrast(0.2),
        tf.keras.layers.RandomBrightness([-0.3, 0.1]),
        # tf.keras.layers.Rescaling(1.0 / 255),
    ])

    augmented_inputs = data_augmentation(inputs)

    # Extracted features
    d_net_features = d_net(augmented_inputs)
    e_net_features = e_net(augmented_inputs)
    m_net_features = m_net(augmented_inputs)

    # Global average pooling
    d_net_pooling = tf.keras.layers.GlobalAveragePooling2D()(d_net_features)
    e_net_pooling = tf.keras.layers.GlobalAveragePooling2D()(e_net_features)
    m_net_pooling = tf.keras.layers.GlobalAveragePooling2D()(m_net_features)

    # Combine outputs
    combined_outputs = tf.keras.layers.concatenate([d_net_pooling, e_net_pooling])
    # REFERENCE: https://stackoverflow.com/a/71170687/23011800
    outputs_pre = tf.keras.layers.Dense(1024, activation='relu')(combined_outputs)
    outputs_pre = tf.keras.layers.Dropout(DROPOUT_PROB)(outputs_pre)
    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(outputs_pre)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                 loss='binary_crossentropy',
                 metrics=['accuracy', tf.keras.metrics.AUC(), tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall')])
    return model

### Defining callbacks

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=LR_FACTOR, patience=LR_PATIENCE, min_lr=MIN_LR)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint("checkpoints/best_model.keras", monitor='val_loss', save_best_only=True, verbose=1)

# Custom training curve callback
from IPython.display import clear_output

train_losses=[]; val_losses=[]; precision_scores=[]; recall_scores=[]

class TrainingCurveCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        train_losses.append(logs['loss'])
        val_losses.append(logs['val_loss'])
        precision_scores.append(logs['precision'])
        recall_scores.append(logs['recall'])

        clear_output(wait=True) # clear output before plotting

        fig, ax1 = plt.subplots(figsize=(10, 5)) # create figure; will contain loss/accuracy curves
        fig.suptitle('Training Curves')

        # loss curve (ax1 - left y axis)
        ax1.plot(train_losses, label='Train Loss')
        ax1.plot(val_losses, label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_xlim(left=0)
        ax1.set_ylabel('Loss function')

        # determine left y axis range from min/max loss values
        y_ax1_min = min(min(train_losses), min(val_losses)) # find lowest loss value across both curves
        y_ax1_min = max(0, y_ax1_min - 0.1) # add some padding to the bottom of the plot. lower bound can't be less than 0
        y_ax1_max = max(max(train_losses), max(val_losses)) + 0.1 # find highest loss value across both curves, add some padding to the top of the plot
        ax1.set_ylim(y_ax1_min, y_ax1_max) # set y axis limits

        # Best epoch (smallest loss)
        best_epoch = int(np.argmin(val_losses))
        best_loss = val_losses[best_epoch]

        # add veritcal line for best epoch
        ax1.vlines(best_epoch, ymin=0, ymax=y_ax1_max, linestyles='dashed', colors='black',
                   label=f'best epoch={best_epoch}\nval loss={best_loss:.3f}')
        ax1.legend(loc='upper left')

        # create right y axis for precision/recall curves
        ax2 = ax1.twinx()

        # Plot precision/recall (right y-axis)
        ax2.plot(precision_scores, label=f'Precision\n{precision_scores[best_epoch]:.3f} @ {best_epoch}', color='red')
        ax2.plot(recall_scores, label=f'Recall\n{recall_scores[best_epoch]:.3f} @ {best_epoch}', color='green')
        ax2.set_ylabel('Precision / Recall')
        ax2.set_ylim(0, 1)
        ax2.legend(loc='upper right')
        plt.tight_layout()
        plt.show()

### Initializing the model

In [None]:
ensemble_model = stacked_model() # Intentionally declared in separate cell

### Model summary (optional)

In [None]:
# ensemble_model.summary()

### Training the model

In [None]:
history = ensemble_model.fit(train_ds, validation_data=val_ds, epochs=NUM_EPOCHS,
                             callbacks=[early_stopping, reduce_lr, model_checkpoint, TrainingCurveCallback()], shuffle=True)

### Saving the model

In [None]:
h5_filename = 'final.h5' # Intentionally declared in separate cell

In [None]:
ensemble_model.save(h5_filename)

# Sanity check
print(f"Model successfully saved as {h5_filename}") if os.path.exists(h5_filename) else print(f"Failed to save the model.")

##### ******* Can start from here once training is finished (no need to run prev cells) *******

### Load model and JSON file

In [None]:
trained_ensemble = tf.keras.models.load_model(h5_filename)

# Load class names from JSON file
with open('class_names.json', 'r') as f:
    class_names = json.load(f)

### Model evaluation (w/ val dataset)

In [None]:
final_val_loss, final_val_accuracy, *rest = trained_ensemble.evaluate(val_ds)
print(f"Ensemble Model - Validation Loss: {final_val_loss:.4f}, Validation Accuracy: {final_val_accuracy:.4f}")

### Model testing (w/ test dataset)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_pred = []; y_true = []

# Collecting predictions and true labels
for images, labels in test_ds:
    preds = trained_ensemble.predict(images)
    y_pred.extend(np.round(preds))
    y_true.extend(labels.numpy())

# Ensuring class names match the number of classes
if len(class_names) < len(set(y_true)):
    raise ValueError("Number of class names does not match the number of unique classes in the labels.")

# Computing confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
plt.figure(figsize=(8,8))
disp.plot(cmap=plt.cm.Blues, values_format='d')
# Rotating x-axis labels for better readability
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=45, ha='right')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

### Model testing 2 (w/ external data)

In [None]:
# Define the external dataset path
external_data_dir = 'path_here'
predictions = []

plt.figure(figsize=(15, 15))

counter = 1
print("Predictions for Test Dataset:")
for img_file in os.listdir(external_data_dir):
    if img_file.endswith(('.png', '.jpg', '.jpeg')):
        img_path = os.path.join(external_data_dir, img_file)

        # Load and preprocess the image
        img = tf.keras.preprocessing.image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
        img_array = tf.keras.preprocessing.image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)

        preds = trained_ensemble.predict(img_array)
        predicted_class_index = np.argmax(preds, axis=1)[0]

        # Store the prediction along with the image filename
        predictions.append((img_file, class_names[predicted_class_index]))

        # Display the image and prediction
        plt.subplot(4, 4, counter)
        plt.imshow(img)
        plt.title(f"Predicted: {class_names[predicted_class_index]}")
        plt.axis('off')
        
        counter += 1

plt.tight_layout()
plt.show()