In [None]:
# Imports all required libraries for data handling, visualization, and building a TensorFlow CNN for multiclass image classification
from sklearn.metrics import confusion_matrix
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import tensorflow as tf
import seaborn as sns
import pandas as pd
import numpy as np
import PIL
import os

In [None]:
# Initializes TPU or default strategy for distributed training. Prints number of replicas and TensorFlow version
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

print('Number of replicas:', strategy.num_replicas_in_sync)
print(tf.__version__)

In [None]:
# Sets main hyperparameters: batch size, image size, epochs, and enables data pipeline optimization
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [176, 208]
EPOCHS = 100

In [None]:
# Loads and splits the dataset into training and validation sets, applies resizing and batching for model input
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/kaggle/input/alzheimers-mri-images/Processed Dataset/training",
    validation_split = 0.2,
    subset = "training",
    seed = 1337,
    image_size = IMAGE_SIZE,
    batch_size = BATCH_SIZE,
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/kaggle/input/alzheimers-mri-images/Processed Dataset/training",
    validation_split = 0.2,
    subset = "validation",
    seed = 1337,
    image_size = IMAGE_SIZE,
    batch_size = BATCH_SIZE,
)

In [None]:
# Defines class names for Alzheimer's MRI dataset
class_names = ['Mild Demented', 'Moderate Demented', 'Non Demented', 'Very Mild Demented']

# Assigns class names to training and validation datasets
train_ds.class_names = class_names
val_ds.class_names = class_names

NUM_CLASSES = len(class_names)

In [None]:
# Displays a sample of images from the training set with their class labels for visual inspection
plt.figure(figsize = (10, 10))

for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(train_ds.class_names[labels[i]])
    plt.axis("off")

In [None]:
# Calculates and visualizes the number of images per class in the training set using a bar plot
train_dir = "/kaggle/input/alzheimers-mri-images/Processed Dataset/training"

# Counts the number of images in each class
class_counts = {cls: len(os.listdir(os.path.join(train_dir, cls))) for cls in os.listdir(train_dir)}

# Converts to DataFrame
df = pd.DataFrame(list(class_counts.items()), columns = ["Class", "Count"])

# Plots the number of images per class
plt.figure(figsize = (15, 8))
ax = sns.barplot(x = df["Class"], y = df["Count"], palette = "Set1")
ax.set_xlabel("Class", fontsize = 20)
ax.set_ylabel("Count", fontsize = 20)
plt.title("The Number Of Samples For Each Class", fontsize = 20)
plt.grid(True)
plt.xticks(rotation = 45)
plt.show()

In [None]:
# Applies one-hot encoding to the labels in the training and validation datasets for multiclass classification
def one_hot_label(image, label):
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

# Applies one-hot encoding to training and validation datasets
train_ds = train_ds.map(one_hot_label, num_parallel_calls = AUTOTUNE)
val_ds = val_ds.map(one_hot_label, num_parallel_calls = AUTOTUNE)

# Optimizes data pipeline with caching and prefetching for efficient training and validation
train_ds = train_ds.cache().prefetch(buffer_size = AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size = AUTOTUNE)

# Counts the number of images per class for further analysis or balancing
NUM_IMAGES = []

for label in class_names:
    dir_name = "/kaggle/input/alzheimers-mri-images/Processed Dataset/training/" + label[:-2] + 'ed'
    NUM_IMAGES.append(len([name for name in os.listdir(dir_name)]))

# Outputs the number of images per class for review
NUM_IMAGES

In [None]:
# Defines a reusable convolutional block for the CNN model architecture
def conv_block(filters):
    block = tf.keras.Sequential([
        tf.keras.layers.SeparableConv2D(filters, 3, activation = 'relu', padding = 'same'),
        tf.keras.layers.SeparableConv2D(filters, 3, activation = 'relu', padding = 'same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPool2D()
    ])
    return block

In [None]:
# Defines a reusable dense block for the CNN model architecture
def dense_block(units, dropout_rate):
    block = tf.keras.Sequential([
        tf.keras.layers.Dense(units, activation = 'relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate)
    ])
    return block

In [None]:
# Builds the full CNN model using convolutional and dense blocks for multiclass classification
def build_model():
    model = tf.keras.Sequential([
        tf.keras.Input(shape = (*IMAGE_SIZE, 3)),  # Input layer with image dimensions

        tf.keras.layers.Conv2D(16, 3, activation = 'relu', padding = 'same'),
        tf.keras.layers.Conv2D(16, 3, activation = 'relu', padding = 'same'),

        tf.keras.layers.MaxPool2D(),
        conv_block(32),
        conv_block(64),
        conv_block(128),
        tf.keras.layers.Dropout(0.2),

        conv_block(256),
        tf.keras.layers.Dropout(0.2),

        tf.keras.layers.Flatten(),
        dense_block(512, 0.7),
        dense_block(128, 0.5),
        dense_block(64, 0.3),

        tf.keras.layers.Dense(NUM_CLASSES, activation = 'softmax')
    ])
    return model

In [None]:
# Compiles the model within the chosen device strategy, specifying optimizer, loss, and metrics
with strategy.scope():
    model = build_model()
    METRICS = [tf.keras.metrics.AUC(name = 'auc')]
    model.compile(
        optimizer = 'adam',
        loss = tf.losses.CategoricalCrossentropy(),
        metrics = METRICS
    )

In [None]:
# Sets up learning rate scheduling, model checkpointing, and early stopping for robust training
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 ** (epoch / s)
    return exponential_decay_fn

# Creates an exponential decay function for the learning rate (initial lr=0.01, decay steps=20)
exponential_decay_fn = exponential_decay(0.01, 20)

# Callbacks to update the learning rate according to the exponential decay schedule
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)

# Callbacks to save the best model during training based on validation performance
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("tensorflow_alzheimer_model.keras", save_best_only = True)

# Callbacks to stop training early if validation performance does not improve for 10 epochs and restore best weights
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience = 10, restore_best_weights = True)

In [None]:
# Displays a summary and diagram of the model architecture for verification
model.summary()

tf.keras.utils.plot_model(model, to_file = 'model.png', show_shapes = True, show_layer_names = True, show_dtype = True, dpi = 120)

# Trains the model using the training and validation sets with callbacks for monitoring and early stopping
history = model.fit(
    train_ds,
    validation_data = val_ds,
    callbacks = [checkpoint_cb, early_stopping_cb, lr_scheduler],
    epochs = EPOCHS
)

In [None]:
# Plots training and validation AUC and loss to visualize model performance over epochs
fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()

for i, met in enumerate(['auc', 'loss']):
    ax[i].plot(history.history[met])
    ax[i].plot(history.history['val_' + met])
    ax[i].set_title('Model {}'.format(met))
    ax[i].set_xlabel('epochs')
    ax[i].set_ylabel(met)
    ax[i].legend(['train', 'val'])

In [None]:
# Loads and prepares the test dataset for evaluation, including one-hot encoding and performance optimizations
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/kaggle/input/alzheimers-mri-images/Processed Dataset/test",
    image_size = IMAGE_SIZE,
    batch_size = BATCH_SIZE,
)
test_ds = test_ds.map(one_hot_label, num_parallel_calls = AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size = AUTOTUNE)

# Evaluates the trained model on the test dataset and outputs the results
_ = model.evaluate(test_ds)

In [None]:
# Generates predictions for the test set, computes and visualizes the confusion matrix as a percentage heatmap
predictions = model.predict(test_ds)

# Converts predictions to class labels
y_pred = np.argmax(predictions, axis=1)

# Retrieves true labels
y_real = np.concatenate([y for x, y in test_ds], axis=0)
y_real = np.argmax(y_real, axis=1)  # Converts one-hot labels to class indices

# Computes confusion matrix
cm = confusion_matrix(y_real, y_pred)
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100  # Converts to percentage format

In [None]:
# Plots confusion matrix
plt.figure(figsize=(6,6))
sns.heatmap(cm_percent, annot=True, fmt=".2f", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix (%)")
plt.show()

In [None]:
# Displays predicted probabilities, predicted class, and actual class for each test sample, indicating correctness
for p, l in zip(predictions, y_real):
    probs_percent = [f"{prob*100:.2f}%" for prob in p]  # Converts probabilities to percentage
    predicted_class_idx = np.argmax(p)  # Index of the predicted class
    predicted_class_name = class_names[predicted_class_idx]

    print(f"Predictions: {probs_percent} -> Predicted class: {predicted_class_name} (Class {predicted_class_idx}), Actual Label: {class_names[l]}")

    if predicted_class_idx == l:
        print("Correct ✅\n")
    else:
        print("Incorrect ❌\n")