# Alzheimer's Disease classification from anatomical MRI

### This notebook explores the use of transfer learning with ResNet50 to classify Alzheimer's disease from anatomical MRI images.

Briefly, the pipeline involves the following steps and technical features:

- Data formating and quality check
- Transfer learning using ResNet50
- Hyperparameter tuning
- Final model application

### Import analysis and plotting libraries

In [0]:
# extract aws credentials from hidden table 
aws_keys_df = spark.read.format("csv").option("header", "true").option("sep", ",").load("/FileStore/tables/brad_databricks_personal_accessKeys_new.csv")

ACCESS_KEY = aws_keys_df.collect()[0][0]
SECRET_KEY = aws_keys_df.collect()[0][1]

# specify bucket and mount point
AWS_S3_BUCKET = "databricks-workspace-stack-brad-personal-bucket/AD_MRI_classification/"
MOUNT_NAME = "/mnt/AD_classification"
SOURCE_URL = f"s3a://{AWS_S3_BUCKET}"
EXTRA_CONFIGS = { "fs.s3a.access.key": ACCESS_KEY, "fs.s3a.secret.key": SECRET_KEY}

# mount bucket
# dbutils.fs.unmount(MOUNT_NAME)
# dbutils.fs.mount(SOURCE_URL, MOUNT_NAME, extra_configs = EXTRA_CONFIGS)

In [0]:
# "standard"
import numpy as np
import pandas as pd

# machine learning and statistics
import pyspark
from pyspark.sql import SparkSession
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from keras.utils import plot_model
import keras_tuner as kt
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from scipy.stats import false_discovery_control

# Parallel computing
import dask
from dask.distributed import Client, progress

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# misc
import cv2
import magic
from IPython.display import clear_output

### Load and format training data

In [0]:
Lab = ['Mild', 'Moderate', 'None', 'Very Mild']

train = pd.read_parquet("/dbfs/mnt/AD_classification/train-00000-of-00001-c08a401c53fe5312.parquet")
train.head()

#### Convert data to readable format

In [0]:
def dict_to_image(image_dict):
    if isinstance(image_dict, dict) and 'bytes' in image_dict:
        byte_string = image_dict['bytes']
        nparr = np.frombuffer(byte_string, np.uint8)
        img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
        return img
    else:
        raise TypeError(f"Expected dictionary with 'bytes' key, got {type(image_dict)}")

train['img_arr'] = train['image'].apply(dict_to_image)
train.drop("image", axis=1, inplace=True)
train.head()

### Load and format test data

In [0]:
test = pd.read_parquet("/dbfs/mnt/AD_classification/test-00000-of-00001-44110b9df98c5585.parquet")

test['img_arr'] = test['image'].apply(dict_to_image)
test.drop("image", axis=1, inplace=True)
test.head()

## Explore structure and visualization of the data

### Distribution of the datasets (are all classes represented equally?)

In [0]:
f, ax = plt.subplots(1,3)
unique, counts = np.unique(np.asarray(train.iloc[:].label), return_counts = True)
ax[0].bar(unique, counts)
ax[0].set_xticks(unique)
ax[0].set_xticklabels(Lab, rotation = 45)
ax[0].set_title('Training')
ax[0].set_xlabel('Class')
ax[0].set_ylabel('# of images')

ax[1].axis('off')

unique, counts = np.unique(np.asarray(test.iloc[:].label), return_counts = True)
ax[2].bar(unique, counts)
ax[2].set_xticks(unique)
ax[2].set_xticklabels(Lab, rotation = 45)
ax[2].set_title('Testing')
ax[2].set_xlabel('Class')
ax[0].set_ylabel('# of images')

# An obvious imbalance across classes, but each class seems to be balanced across training/testing sets

## Visually inspect data

In [0]:
train_lab_idx = np.asarray(train.iloc[:].label)

f, ax = plt.subplots(4,4)
for lab in range(4):
    for ex in range(4):
    
        class_lab = np.argwhere(train_lab_idx == 1)
        current_idx = np.random.randint(len(class_lab)-1,size = 1)
        current_idx = np.asarray(current_idx)
        
        ax[ex, lab].axis('off')
        ax[ex, lab].imshow(train.iloc[class_lab[current_idx[0]][0]].img_arr, cmap = "gray")
        if ex == 0: ax[ex, lab].set_title(Lab[lab])

# Clearly, images show different slices within the brain, which may be a major confound...

## Load and adapt ResNet50

### Format images (stored in local memory)

In [0]:
# training data
train_data = np.empty((len(train), 128, 128))
for i in range(len(train)):
   train_data[i, :, :] = train.iloc[i].img_arr

# test data
test_data = np.empty((len(test), 128, 128))
for i in range(len(test)):
   test_data[i, :, :] = test.iloc[i].img_arr

# format in 3D shape that keras likes
train_data = np.expand_dims(train_data, axis = 3)
test_data = np.expand_dims(test_data, axis = 3)
train_data.shape

### Data augmentation

In [0]:
# Resize and rescale functions
IMG_SIZE = 128
resize_and_rescale = keras.Sequential([
  layers.Resizing(IMG_SIZE, IMG_SIZE),
  layers.Rescaling(1./255)
])

data_augmentation = keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.2),
])

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle = False, augment =False):
  # Resize and rescale all datasets.
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y),
              num_parallel_calls = AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

  # Batch all datasets.
  ds = ds.batch(batch_size)

  # Use data augmentation only on the training set.
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training = True), y),
                num_parallel_calls = AUTOTUNE)

  # Use buffered prefetching on all datasets.
  return ds.prefetch(buffer_size = AUTOTUNE)

### Augment training dataset only

#### Verify augmentation on first image

In [0]:
# sample dataset
train_data_tmp = train_data[0]
train_data_tmp = np.expand_dims(train_data_tmp, axis = 0)

# apply augmentation
train_ds = tf.data.Dataset.from_tensor_slices((train_data_tmp, train_lab_idx[:1]))
train_ds = prepare(train_ds, shuffle=True, augment=True)

# view augmented images
train_ds = train_ds.unbatch()
images = np.asarray(list(train_ds.map(lambda x, y: x)))

f, ax = plt.subplots(1,2)
ax[0].axis('off')
ax[0].imshow(train_data[0], cmap = "gray",)
ax[0].set_title('Original', fontsize = 10)

ax[1].axis('off')
ax[1].imshow(images[0,:,:,:], cmap = "gray",)
ax[1].set_title('Augmented', fontsize = 10)
plt.tight_layout()


### Augment multiple times to generate larger training set

In [0]:
train_ds = tf.data.Dataset.from_tensor_slices((train_data, train_lab_idx))
train_ds_for_aug = train_ds

for i in range(2):
    train_aug = prepare(train_ds_for_aug, shuffle = True, augment = True)
    train_aug = train_aug.unbatch()
    train_aug = train_aug.map(lambda x, y: (tf.cast(x, tf.float64), y))
    train_ds = train_ds.concatenate(train_aug)

del train_aug, train_ds_for_aug
images = np.asarray(list(train_ds.map(lambda x, y: x)))
labels = np.asarray(list(train_ds.map(lambda x, y: y)))
images.shape

### Load pre-trained ResNet50 model and adjust for current dataset

In [0]:
from keras import regularizers
from keras import applications, layers, models

res_model = applications.ResNet50(include_top=False, weights='imagenet', input_shape=(128, 128, 3))

# Freeze all layers except for the last block
for layer in res_model.layers[:-10]:
    layer.trainable = False

# Verify
for i, layer in enumerate(res_model.layers):
    print(i, layer.name, layer.trainable)

# Connect pretrained ResNet50 model with new layers
x = res_model.output
x = layers.Flatten()(x)
x = layers.Dense(1024, activation='relu', kernel_regularizer=regularizers.l2(0.01))(x)
x = layers.Dropout(0.5)(x)
predictions = layers.Dense(4, activation='softmax', kernel_regularizer=regularizers.l2(0.01))(x)
model = models.Model(inputs=res_model.input, outputs=predictions)

In [0]:
test_lab_idx = np.asarray(test.iloc[:].label)

checkpoint_cb = keras.callbacks.ModelCheckpoint("best_model.keras", save_best_only=True, monitor="val_accuracy", mode="max")

model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adam(learning_rate=1e-6), metrics=['accuracy'])

history = model.fit(
    train_data, 
    to_categorical(train_lab_idx.astype('int8')), 
    epochs=10, 
    validation_data=(test_data, to_categorical(test_lab_idx.astype('int8'))), 
    callbacks=[checkpoint_cb]
)

### Compile model

In [0]:
model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
# plot_model(model, to_file='simple_CNN.png', show_shapes = True, show_layer_names = True)
model.summary()

### Fit model

In [0]:
history = model.fit(train_data, to_categorical(train_lab_idx.astype('int8')), epochs = 10)

### visualize model fit

In [0]:
plt.plot(history.history.get('loss'),'r')
plt.plot(np.array(history.history.get('accuracy'))*100,'k')

### Predict test data

In [0]:
prob = model.predict(test_data)
predict_classes=np.argmax(prob,axis=1)
predict_classes

### evaluate accuracy and visualize

In [0]:
test_lab_idx = np.asarray(test.iloc[:].label)
test_lab_idx

plt.plot(predict_classes[0:100],'k')
plt.plot(test_lab_idx[0:100],'r')

### Hyperparameter tuning

#### Define parameters to tune and the corresponding parameter space
##### (same architecture as before, but with hyperparameter ranges)

In [0]:
def build_model(hp):

    model = keras.Sequential([
    keras.Input(shape = (128, 128, 1)),  
        
    keras.layers.Conv2D(
        filters = hp.Int('conv_1_filter', min_value = 32, max_value = 128, step = 16), # adding filter 
        kernel_size = hp.Choice('conv_1_kernel', values = [3,3]), # adding filter size or kernel size
        activation = 'relu'), # activation function
    keras.layers.MaxPooling2D((2, 2)),

    keras.layers.Conv2D(
        filters = hp.Int('conv_2_filter', min_value = 32, max_value = 128, step = 16),
        kernel_size = hp.Choice('conv_2_kernel', values = [3,3]),
        activation = 'relu'), 
    keras.layers.MaxPooling2D((2, 2)),

    keras.layers.Conv2D(
        filters = hp.Int('conv_3_filter', min_value = 32, max_value = 128, step = 16),
        kernel_size = hp.Choice('conv_3_kernel', values = [3,3]),
        activation = 'relu'),
    keras.layers.MaxPooling2D((2, 2)),
        
    keras.layers.Flatten(),
    keras.layers.Dense(
        units=hp.Int('dense_1_units', min_value = 128, max_value = 512, step = 32),
        activation='relu'),
        
    # output layer    
    keras.layers.Dense(4, activation = 'softmax')
    ])
    
    # compile  model
    hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
    model.compile(optimizer = keras.optimizers.Adam(learning_rate = hp_learning_rate),
                  loss = 'categorical_crossentropy',
                  metrics = ['accuracy'])
    
    return model

### Initiate tuner

In [0]:
tuner = kt.Hyperband(build_model, objective = 'val_accuracy', max_epochs = 10, factor = 3, directory = 'my_dir', project_name = 'AD_class')
stop_early = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5)

### Run search

In [0]:
# need numpy arrays rather than tensors for tuner
train_images = np.asarray(list(train_ds.map(lambda x, y: x)))
train_labels = np.asarray(list(train_ds.map(lambda x, y: y)))
train_labels = to_categorical(train_labels.astype('int8'))

test_ds = tf.data.Dataset.from_tensor_slices((test_data, test_lab_idx))
test_images = np.asarray(list(test_ds.map(lambda x, y: x)))
test_labels = np.asarray(list(test_ds.map(lambda x, y: y)))
test_labels = to_categorical(test_labels.astype('int8'))

tuner.search(train_images, train_labels, epochs = 20, callbacks = [stop_early],
             validation_data = (test_images, test_labels))

In [0]:
# Get the optimal hyperparameters
best_hps = tuner.get_best_hyperparameters(num_trials = 1)[0]

print(f"""
Optimal parameters are as follows:

Filter 1 output dim: {best_hps.get('conv_1_filter')}
Filter 2 output dim: {best_hps.get('conv_2_filter')}
Filter 2 output dim: {best_hps.get('conv_3_filter')}

Dense layer units: {best_hps.get('units')}

Learning Rate: {best_hps.get('learning_rate')}
""")

### Optimal # of epochs

In [0]:
model = tuner.hypermodel.build(best_hps)
history = model.fit(train_images, train_labels, epochs = 50, validation_split = 0.2)

val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))

## Visualize feature maps from

In [0]:
t = test_images[0,:,:,:]
t = np.expand_dims(t, axis = 0)
t.shape

In [0]:
feature_map = model.predict(t)
feature_map[0]