# Alzheimer's Disease classification from anatomical MRI

### This notebook explores the use of a custom CNN to classify Alzheimer's disease from anatomical MRI images.

Briefly, the pipeline involves the following steps and technical features:

- Data formating and quality check
- Custom CNN for classification
- Hyperparameter tuning
- Final model application

In [0]:
# "standard"
import numpy as np
import pandas as pd

# machine learning and statistics
import pyspark
from pyspark.sql import SparkSession
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from keras.utils import plot_model
import keras_tuner as kt
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from scipy.stats import false_discovery_control
from sklearn.metrics import confusion_matrix
from imblearn.over_sampling import SMOTE

# Parallel computing
import dask
from dask.distributed import Client, progress

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# misc
import cv2, magic, datetime, sys, os, wget
from IPython.display import clear_output

import sys
sys.path.append('/Workspace/Users/bjedelma@gmail.com/Alzheimers-MRI-Classification/src')
from img_preprocessing import dict_to_image

clear_output(wait=False)

### Mount AWS S3 bucket containing parquet data files

In [0]:

AWS_S3_BUCKET = "databricks-workspace-stack-brad-personal-bucket/AD_MRI_classification/"
KEY_FILE = "/FileStore/tables/brad_databricks_personal_accessKeys_new.csv"

# extract aws credentials from hidden table 
aws_keys_df = spark.read.format("csv").option("header", "true").option("sep", ",").load(KEY_FILE)

ACCESS_KEY = aws_keys_df.collect()[0][0]
SECRET_KEY = aws_keys_df.collect()[0][1]

# specify bucket and mount point
MOUNT_NAME = f"/mnt/{AWS_S3_BUCKET.split('/')[-2]}"
SOURCE_URL = f"s3a://{AWS_S3_BUCKET}"
EXTRA_CONFIGS = { "fs.s3a.access.key": ACCESS_KEY, "fs.s3a.secret.key": SECRET_KEY}

# mount bucket
if any(mount.mountPoint == MOUNT_NAME for mount in dbutils.fs.mounts()):
    print(f"{MOUNT_NAME} is already mounted.")
else:
    dbutils.fs.mount(SOURCE_URL, MOUNT_NAME, extra_configs = EXTRA_CONFIGS)
    print(f"{MOUNT_NAME} is now mounted.")

## Custom CNN for classification

In [0]:
def create_model():
    # three convolutional layers and one fully connected layer
    model = keras.Sequential([
        keras.Input(shape = (128, 128, 1)),
        keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_regularizer=keras.regularizers.l2(0.01)),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation = 'relu', kernel_regularizer=keras.regularizers.l2(0.01)),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(128, (3, 3), activation = 'relu', kernel_regularizer=keras.regularizers.l2(0.01)),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(256, activation = 'relu'), # fully connected layer
        keras.layers.Dense(4, activation = 'softmax')
    ])
    return model

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = create_model()
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

### Fit model

In [0]:
train_lab = to_categorical(train_lab_idx.astype('int8'))
history = model.fit(train_data, train_lab, epochs=50, batch_size=16)

### visualize model fit

In [0]:
fig, ax1 = plt.subplots()

ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss', color='tab:red')
ax1.plot(history.history['loss'], 'r', label='Loss')
ax1.tick_params(axis='y', labelcolor='tab:red')

ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy (%)', color='tab:blue')
ax2.plot(np.array(history.history['accuracy']) * 100, 'k', label='Accuracy (%)')
ax2.tick_params(axis='y', labelcolor='tab:blue')

fig.tight_layout()
plt.title('Training Loss and Accuracy')
plt.show()

### Predict test data

In [0]:
prob = model.predict(test_data)
predict_classes=np.argmax(prob,axis=1)
predict_classes

Evaluate accuracy and visualize

In [0]:
# Generate and plot confusion matrix
test_lab_idx = np.asarray(test.iloc[:].label)
conf_matrix = confusion_matrix(test_lab_idx, predict_classes)
conf_matrix_normalized = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

fig, ax = plt.subplots(1, 2, figsize=(12, 6))

sns.heatmap(conf_matrix_normalized, annot=True, fmt='.2f', cmap='Blues', cbar_kws={'label': 'Accuracy (%)'}, 
            xticklabels=['No AD', 'Mild AD', 'Moderate AD', 'Severe AD'], 
            yticklabels=['No AD', 'Mild AD', 'Moderate AD', 'Severe AD'], 
            ax=ax[0], linewidths=1, linecolor='black')
ax[0].set_xlabel('Predicted Labels')
ax[0].set_ylabel('True Labels')
ax[0].set_title('Confusion Matrix')

# Bar plot for the distribution of the test set
train_label_counts = train['label'].value_counts().sort_index()
ax[1].bar(train_label_counts.index, train_label_counts.values, color = ['#aec7e8', '#ffbb78', '#98df8a', '#ff9896'])
ax[1].set_xlabel('Class Label')
ax[1].set_ylabel('Frequency')
ax[1].set_title('Distribution of Training Set')
ax[1].set_xticks(train_label_counts.index)
ax[1].set_xticklabels(['No AD', 'Mild AD', 'Moderate AD', 'Severe AD'])

plt.tight_layout()
plt.show()

Overall we do not see a direct link between the test set class accuracies and the number of measurements in the training set. In fact the class with the lowest number of training measurements achieved the second highest accuracy. Therefore, even though it would be ideal to have completely (or moreso) balanced training classes, this is not a major bias in the current model.

### Hyperparameter tuning

Define parameters to tune and the corresponding parameter space

In [0]:
# Same architecture as before, but with hyperparameter ranges
def build_model(hp):

    model = keras.Sequential([
    keras.Input(shape = (128, 128, 1)),  
        
    keras.layers.Conv2D(
        filters = hp.Int('conv_1_filter', min_value = 32, max_value = 128, step = 32), 
        kernel_size = hp.Choice('conv_1_kernel', values = [3,3]), 
        activation = 'relu',
        kernel_regularizer=keras.regularizers.l2(0.01)), 
    keras.layers.MaxPooling2D((2, 2)),

    keras.layers.Conv2D(
        filters = hp.Int('conv_2_filter', min_value = 64, max_value = 128, step = 32),
        kernel_size = hp.Choice('conv_2_kernel', values = [3,3]),
        activation = 'relu',
        kernel_regularizer=keras.regularizers.l2(0.01)), 
    keras.layers.MaxPooling2D((2, 2)),

    keras.layers.Conv2D(
        filters = hp.Int('conv_3_filter', min_value = 96, max_value = 128, step = 32),
        kernel_size = hp.Choice('conv_3_kernel', values = [3,3]),
        activation = 'relu',
        kernel_regularizer=keras.regularizers.l2(0.01)),
    keras.layers.MaxPooling2D((2, 2)),
        
    keras.layers.Flatten(),
    keras.layers.Dense(
        units=hp.Int('dense_1_units', min_value = 128, max_value = 256, step = 32),
        activation='relu',
        kernel_regularizer=keras.regularizers.l2(0.01)),
        
    keras.layers.Dropout(0.5),
        
    keras.layers.Dense(4, activation = 'softmax')
    ])
    
    hp_learning_rate = hp.Choice('learning_rate', values=[1e-3, 1e-4])
    model.compile(optimizer = keras.optimizers.Adam(learning_rate = hp_learning_rate),
                  loss = 'categorical_crossentropy',
                  metrics = ['accuracy'])
    
    return model

### Initiate tuner

In [0]:
tuner = kt.Hyperband(build_model, objective = 'val_accuracy', max_epochs = 10, factor = 3, directory = 'my_dir', project_name = 'AD_class')
stop_early = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5)

### Run search

In [0]:
# need numpy arrays rather than tensors for tuner
train_images = np.asarray(list(train_ds.map(lambda x, y: x)))
train_labels = np.asarray(list(train_ds.map(lambda x, y: y)))
train_labels = to_categorical(train_labels.astype('int8'))

test_lab_idx = np.asarray(test.iloc[:].label)
test_ds = tf.data.Dataset.from_tensor_slices((test_data, test_lab_idx))
test_images = np.asarray(list(test_ds.map(lambda x, y: x)))
test_labels = np.asarray(list(test_ds.map(lambda x, y: y)))
test_labels = to_categorical(test_labels.astype('int8'))

tuner.search(train_images, train_labels, epochs = 10, callbacks = [stop_early],
             validation_data = (test_images, test_labels))

In [0]:
# Optimal hyperparameters
best_hps = tuner.get_best_hyperparameters(num_trials = 1)[0]

print(f"""
Optimal parameters are as follows:

Filter 1 output dim: {best_hps.get('conv_1_filter')}
Filter 2 output dim: {best_hps.get('conv_2_filter')}
Filter 2 output dim: {best_hps.get('conv_3_filter')}

Dense layer units: {best_hps.get('units')}

Learning Rate: {best_hps.get('learning_rate')}
""")

Optimal # of epochs

In [0]:
# model = tuner.hypermodel.build(best_hps)
# history = model.fit(train_images, train_labels, epochs = 50, validation_split = 0.2)

# val_acc_per_epoch = history.history['val_accuracy']
# best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
# print('Best epoch: %d' % (best_epoch,))