### This notebook explores the use of transfer learning with ResNet50 to classify Alzheimer's disease from anatomical MRI images.

Briefly, the pipeline involves the following steps and technical features:

- Data formating and quality check
- Transfer learning using ResNet50
- Hyperparameter tuning
- Final model application

### Import analysis and plotting libraries

In [0]:
# extract aws credentials from hidden table 
aws_keys_df = spark.read.format("csv").option("header", "true").option("sep", ",").load("/FileStore/tables/brad_databricks_personal_accessKeys_new.csv")

ACCESS_KEY = aws_keys_df.collect()[0][0]
SECRET_KEY = aws_keys_df.collect()[0][1]

# specify bucket and mount point
AWS_S3_BUCKET = "databricks-workspace-stack-brad-personal-bucket/AD_MRI_classification/"
MOUNT_NAME = "/mnt/AD_classification"
SOURCE_URL = f"s3a://{AWS_S3_BUCKET}"
EXTRA_CONFIGS = { "fs.s3a.access.key": ACCESS_KEY, "fs.s3a.secret.key": SECRET_KEY}

# mount bucket
# dbutils.fs.unmount(MOUNT_NAME)
# dbutils.fs.mount(SOURCE_URL, MOUNT_NAME, extra_configs = EXTRA_CONFIGS)

In [0]:
# "standard"
import numpy as np
import pandas as pd

# machine learning and statistics
import pyspark
from pyspark.sql import SparkSession
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from keras.utils import plot_model
from keras import regularizers
from keras import applications, layers, models, applications, callbacks, optimizers
import keras_tuner as kt
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from scipy.stats import false_discovery_control
from imblearn.over_sampling import SMOTE

# Parallel computing
import dask
from dask.distributed import Client, progress

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# misc
import cv2
import magic
from IPython.display import clear_output
clear_output(wait=False)

## Load and adapt ResNet50

In [0]:
test_lab_idx = np.asarray(test.iloc[:].label)

# Modify the ResNet50 model to accept grayscale images
res_model = applications.ResNet50(include_top=False, weights=None, input_shape=(128, 128, 3))

# Freeze all layers except the last block
for layer in res_model.layers[:-10]:
    layer.trainable = False

# Print layers to verify
for i, layer in enumerate(res_model.layers):
    print(f"Layer {i}: {layer.name}, Trainable: {layer.trainable}")

# Add a Conv2D layer to convert grayscale images to 3 channels
input_layer = layers.Input(shape=(128, 128, 1))
x = layers.Conv2D(3, (3, 3), padding='same')(input_layer)
x = res_model(x)
x = layers.Flatten()(x)
x = layers.Dense(4, activation='softmax')(x)  # Add Dense layer with number of classes

In [0]:
model = models.Model(inputs=input_layer, outputs=x)
std_learning_rate = 1e-4
model.compile(optimizer = keras.optimizers.Adam(learning_rate = std_learning_rate),
              loss = 'categorical_crossentropy',
              metrics = ['accuracy']
)

model.summary()

In [0]:
# Convert labels to categorical
train_labels_cat = to_categorical(train_lab_idx.astype('int8'), num_classes=4)
test_labels_cat = to_categorical(test_lab_idx.astype('int8'), num_classes=4)

# Train the model
history = model.fit(
    train_data, train_labels_cat, epochs=10, 
    validation_data=(test_data, test_labels_cat)
)

View results

In [0]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 6))

# Plot training and validation loss
ax1.set_xlabel('Epochs', fontsize=18)
ax1.set_ylabel('Loss', fontsize=18)
ax1.plot(history.history['loss'], color='green', label='Training Loss', linewidth=2.5)
ax1.plot(history.history['val_loss'], color='orange', linestyle='--', label='Validation Loss', linewidth=2.5)
ax1.legend(loc='upper right', fontsize=18)
ax1.tick_params(axis='both', which='major', labelsize=16)

# Plot training and validation accuracy
ax2.set_xlabel('Epochs', fontsize=18)
ax2.set_ylabel('Accuracy', fontsize=18)
ax2.plot(history.history['accuracy'], color='green', label='Training Accuracy', linewidth=2.5)
ax2.plot(history.history['val_accuracy'], color='orange', linestyle='--', label='Validation Accuracy', linewidth=2.5)
ax2.legend(loc='upper left', fontsize=18)
ax2.tick_params(axis='both', which='major', labelsize=16)

fig.tight_layout()
plt.show()

Fine tune the ResNet50

Since this is a relatively small dataset and since we have limited computational resources, we only unfreeze the final few layers of the network (rather than all) for retraining

In [0]:
# How many layers are in the base model
print("Number of layers in ResNet50: ", len(res_model.layers))

# For now, unfreeze last whole convolutional block for fine tuning
fine_tune_at = 143
for layer in res_model.layers[fine_tune_at:]:
  layer.trainable = True

# Print layers to verify
for i, layer in enumerate(res_model.layers):
    print(f"Layer {i}: {layer.name}, Trainable: {layer.trainable}")

# Add a Conv2D layer to convert grayscale images to 3 channels
input_layer = layers.Input(shape=(128, 128, 1))
x = layers.Conv2D(3, (3, 3), padding='same')(input_layer)
x = res_model(x)
x = layers.Flatten()(x)
x = layers.Dense(4, activation='softmax')(x)  # Add Dense layer with number of classes

model = models.Model(inputs=input_layer, outputs=x)
std_learning_rate = 1e-5 # use lower learning rate with more trainable layers
model.compile(optimizer = keras.optimizers.Adam(learning_rate = std_learning_rate),
              loss = 'categorical_crossentropy',
              metrics = ['accuracy']
)
model.summary()


In [0]:
# Convert labels to categorical
train_labels_cat = to_categorical(train_lab_idx.astype('int8'), num_classes=4)
test_labels_cat = to_categorical(test_lab_idx.astype('int8'), num_classes=4)

# Train the model
history = model.fit(
    train_data, train_labels_cat, epochs=10, 
    validation_data=(test_data, test_labels_cat)
)

In [0]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 6))

# Plot training and validation loss
ax1.set_xlabel('Epochs', fontsize=18)
ax1.set_ylabel('Loss', fontsize=18)
ax1.plot(history.history['loss'], color='blue', label='Training Loss', linewidth=2.5)
ax1.plot(history.history['val_loss'], color='red', linestyle='--', label='Validation Loss', linewidth=2.5)
ax1.legend(loc='upper right', fontsize=18)
ax1.tick_params(axis='both', which='major', labelsize=16)

# Plot training and validation accuracy
ax2.set_xlabel('Epochs', fontsize=18)
ax2.set_ylabel('Accuracy', fontsize=18)
ax2.plot(history.history['accuracy'], color='blue', label='Training Accuracy', linewidth=2.5)
ax2.plot(history.history['val_accuracy'], color='red', linestyle='--', label='Validation Accuracy', linewidth=2.5)
ax2.legend(loc='upper left', fontsize=18)
ax2.tick_params(axis='both', which='major', labelsize=16)

fig.tight_layout()
plt.show()