### Importing Necessary Packages and Checking GPU Compatability

In [None]:
import random
import numpy as np
import tensorflow as tf
random.seed(123)
np.random.seed(123)
tf.random.set_seed(123)
tf.config.experimental.enable_op_determinism()

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, LeakyReLU, ReLU
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  ## To turn off debugging information 
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) 

from tensorflow.keras import mixed_precision

### Creating Image Generators for Each Data Split

In [5]:
from tensorflow.keras.applications.resnet50 import preprocess_input

# Create ImageDataGenerator for ResNet50 (with its specific preprocessing)
datagen_resnet = ImageDataGenerator(preprocessing_function=preprocess_input)

# Train generator (3-channel)
train_generator_resnet = datagen_resnet.flow_from_directory(
    '/Your/Directory/Train,val,test/train',
    target_size=(96, 96),
    batch_size=32,
    class_mode="categorical",
    color_mode="rgb"   # <-- Make it 3-channel
)

# Validation generator (3-channel)
val_generator_resnet = datagen_resnet.flow_from_directory(
    '/Your/Directory/Train,val,test/val',
    target_size=(96, 96),
    batch_size=32,
    class_mode="categorical",
    color_mode="rgb"
)

# Test generator (3-channel)
test_generator_resnet = datagen_resnet.flow_from_directory(
    '/Your/Directory/Train,val,test/test',
    target_size=(96, 96),
    batch_size=32,
    class_mode="categorical",
    color_mode="rgb"
)

# Print class labels
print("Class Labels training set:", train_generator_resnet.class_indices)
print("Class Labels validation set:", val_generator_resnet.class_indices)
print("Class Labels testing set:", test_generator_resnet.class_indices)

Found 2592 images belonging to 4 classes.
Found 648 images belonging to 4 classes.
Found 360 images belonging to 4 classes.
Class Labels training set: {'Mild Dementia': 0, 'Moderate Dementia': 1, 'Non Demented': 2, 'Very mild Dementia': 3}


### Model Architecture and Training

In [6]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

# Ensure reproducibility
import random
import numpy as np

random.seed(123)
np.random.seed(123)
tf.random.set_seed(123)

# Load ResNet50 pre-trained on ImageNet, no top (i.e., no final dense layers), 
base_resnet = ResNet50(weights='imagenet',
                       include_top=False,
                       input_shape=(96, 96, 3))

x = base_resnet.output

# Either flatten or use global average pooling
x = GlobalAveragePooling2D()(x)

# Fully connected dense layer
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)

# Final output layer for 4 classes
outputs = Dense(4, activation='softmax')(x)

model_resnet = Model(inputs=base_resnet.input, outputs=outputs)

# Compiling model
model_resnet.compile(optimizer=Adam(learning_rate=1e-4),
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])

# Setup callbacks
reduce_lr = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-7)
early_stopping = EarlyStopping(monitor="val_accuracy", patience=4, restore_best_weights=True)

# Train
history_resnet = model_resnet.fit(
    train_generator_resnet,
    validation_data=val_generator_resnet,
    epochs=30,
    callbacks=[reduce_lr, early_stopping]
)

Epoch 1/30


  self._warn_if_super_not_called()


[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 1s/step - accuracy: 0.3913 - loss: 1.9086 - val_accuracy: 0.4213 - val_loss: 2.8420 - learning_rate: 1.0000e-04
Epoch 2/30
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 1s/step - accuracy: 0.7800 - loss: 0.5843 - val_accuracy: 0.5417 - val_loss: 2.3267 - learning_rate: 1.0000e-04
Epoch 3/30
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 1s/step - accuracy: 0.8813 - loss: 0.3129 - val_accuracy: 0.7253 - val_loss: 0.9440 - learning_rate: 1.0000e-04
Epoch 4/30
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 1s/step - accuracy: 0.9228 - loss: 0.1941 - val_accuracy: 0.7824 - val_loss: 0.8047 - learning_rate: 1.0000e-04
Epoch 5/30
[1m37/81[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m1:00[0m 1s/step - accuracy: 0.9681 - loss: 0.1050

KeyboardInterrupt: 

### Model Evaluation

In [None]:
# Evaluate Baseline CNN Model
test_loss_cnn, test_acc_cnn = model.evaluate(test_generator, verbose=0)
print(f"Baseline CNN Test Accuracy: {test_acc_cnn:.4f}")

# Evaluate ResNet50 Model
test_loss_resnet, test_acc_resnet = model_resnet.evaluate(test_generator_resnet, verbose=0)
print(f"ResNet50 Test Accuracy: {test_acc_resnet:.4f}")

# Compare
print("Baseline CNN Test Accuracy: ", test_acc_cnn)
print("ResNet50   Test Accuracy: ", test_acc_resnet)

### Plotting Training Accuracy & Validation Accuracy

In [None]:
# Compare training accuracy
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='CNN Training Acc')
plt.plot(history.history['val_accuracy'], label='CNN Validation Acc')

plt.plot(history_resnet.history['accuracy'], label='ResNet50 Training Acc')
plt.plot(history_resnet.history['val_accuracy'], label='ResNet50 Validation Acc')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training & Validation Accuracy Comparison')
plt.legend()
plt.show()

### Confusion Matrix

In [None]:
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, f1_score

# Get true labels and predictions
y_true = test_generator_resnet.classes  # True labels from test set
y_pred = model_resnet.predict(test_generator_resnet)
y_pred_classes = np.argmax(y_pred, axis=1)  # Convert probabilities to class labels

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)

# Plot confusion matrix using seaborn
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
            xticklabels=list(test_generator_resnet.class_indices.keys()), 
            yticklabels=list(test_generator_resnet.class_indices.keys()))
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix for ResNet50 Model")
plt.show()

### Classification Report

In [None]:

# Compute classification report (includes F1-score for each class)
class_report = classification_report(y_true, y_pred_classes, target_names=list(test_generator_resnet.class_indices.keys()))

# Compute overall F1-score (macro and weighted)
f1_macro = f1_score(y_true, y_pred_classes, average="macro")
f1_weighted = f1_score(y_true, y_pred_classes, average="weighted")

# Display results
print("Classification Report:\n", class_report)
print(f"Macro-Averaged F1-score: {f1_macro:.4f}")
print(f"Weighted-Averaged F1-score: {f1_weighted:.4f}")

### Displaying Metrics

In [None]:
from sklearn.metrics import recall_score, classification_report, confusion_matrix

# Get true labels and predictions
y_true = test_generator_resnet.classes  # True labels from test set
y_pred = model_resnet.predict(test_generator_resnet)
y_pred_classes = np.argmax(y_pred, axis=1)  # Convert probabilities to class labels

# Compute recall for each class
recall_per_class = recall_score(y_true, y_pred_classes, average=None)

# Compute overall recall scores
recall_macro = recall_score(y_true, y_pred_classes, average="macro")  # Equal weight for each class
recall_weighted = recall_score(y_true, y_pred_classes, average="weighted")  # Weighted by sample count

# Print recall per class
class_labels = list(test_generator_resnet.class_indices.keys())
recall_results = {class_labels[i]: recall_per_class[i] for i in range(len(class_labels))}

# Print results
print("Recall per class:")
for class_name, recall_value in recall_results.items():
    print(f"{class_name}: {recall_value:.4f}")

print(f"\nMacro-Averaged Recall: {recall_macro:.4f}")
print(f"Weighted-Averaged Recall: {recall_weighted:.4f}")

### Plotting Recall Values

In [None]:
# Plot recall values
plt.figure(figsize=(8,5))
sns.barplot(x=class_labels, y=recall_per_class, palette="Blues")
plt.xlabel("Class")
plt.ylabel("Recall")
plt.title("Recall per Class for ResNet50 Model")
plt.ylim(0,1)
plt.show()