This notebook implements a custom convolutional neural network (CNN) to classify Alzheimer's disease from anatomical MRI images. It explores hyperparameter tuning to optimize model performance and applies a distributed training strategy using TensorFlow's MirroredStrategy to accelerate training.

In [0]:
# "standard"
import numpy as np
import pandas as pd

# machine learning and statistics
import pyspark
from pyspark.sql import SparkSession
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from keras.callbacks import EarlyStopping
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from keras.utils import plot_model
import keras_tuner as kt
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from scipy.stats import false_discovery_control
from sklearn.metrics import confusion_matrix

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# misc
import cv2, magic, datetime, sys, os, wget, pickle
from IPython.display import clear_output

clear_output(wait=False)

Mount AWS S3 bucket containing processed data

In [0]:
ACCESS_KEY = dbutils.secrets.get(scope="brad-aws", key="access_key")
SECRET_KEY= dbutils.secrets.get(scope="brad-aws", key="secret_key")

# specify bucket and mount point
AWS_S3_BUCKET = "databricks-workspace-stack-brad-personal-bucket/AD_MRI_classification/raw/"
MOUNT_NAME = f"/mnt/{AWS_S3_BUCKET.split('/')[-2]}"
SOURCE_URL = f"s3a://{AWS_S3_BUCKET}"
EXTRA_CONFIGS = { "fs.s3a.access.key": ACCESS_KEY, "fs.s3a.secret.key": SECRET_KEY}

# mount bucket
if any(mount.mountPoint == MOUNT_NAME for mount in dbutils.fs.mounts()):
    print(f"{MOUNT_NAME} is already mounted.")
else:
    dbutils.fs.mount(SOURCE_URL, MOUNT_NAME, extra_configs = EXTRA_CONFIGS)
    print(f"{MOUNT_NAME} is now mounted.")

/mnt/raw is already mounted.


In [0]:
# Load data file and unpack contents
s3_file_path = 'AD_MRI_classification/preprocessed/data_pre.pkl'
file_path = f"/dbfs/mnt/{s3_file_path}"

with open(file_path, 'rb') as f:
    train_data, train_lab, test_data, test_lab = pickle.load(f)

Define custom CNN and distributed training strategy

In [0]:
def create_model():
    # three convolutional layers and one fully connected layer
    model = keras.Sequential([
        keras.Input(shape = (128, 128, 1)),

        keras.layers.Conv2D(
            filters=32, 
            kernel_size=(3, 3), 
            activation='relu', 
            kernel_regularizer=keras.regularizers.l2(0.01)),
        keras.layers.MaxPooling2D((2, 2)),

        keras.layers.Conv2D(
            filters=64, 
            kernel_size=(3, 3), 
            activation='relu', 
            kernel_regularizer=keras.regularizers.l2(0.01)),
        keras.layers.MaxPooling2D((2, 2)),
        
        keras.layers.Conv2D(
            filters=128, 
            kernel_size=(3, 3), 
            activation='relu', 
            kernel_regularizer=keras.regularizers.l2(0.01)),
        keras.layers.MaxPooling2D((2, 2)),

        keras.layers.Flatten(),
        keras.layers.Dense(256, activation='relu'), # fully connected layer
        keras.layers.Dense(4, activation='softmax')
    ])
    return model

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = create_model()
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

clear_output(wait=False)

Fit model

In [0]:
# Ensure labels are categorical and have the correct shape
train_lab = to_categorical(train_lab.astype('int8'), num_classes=4)
test_lab = to_categorical(test_lab.astype('int8'), num_classes=4)

early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
validation_data = (test_data, test_lab)

history = model.fit(
    train_data, 
    train_lab, 
    epochs=25, 
    batch_size=16, 
    validation_data=validation_data, 
    callbacks=[early_stopping]
)

2025-01-09 08:09:51.622914: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 676003840 exceeds 10% of free system memory.
2025-01-09 08:10:07.042457: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 676003840 exceeds 10% of free system memory.
2025-01-09 08:10:08.637815: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 1/25
[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m21:45[0m 2s/step - accuracy: 0.3750 - loss: 13.9409[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:39[0m 342ms/step - accuracy: 0.3125 - loss: 98.3246[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:01[0m 375ms/step - accuracy: 0.2917 - loss: 134.6029[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:55[0m 367ms/step - accuracy: 0.2930 - loss: 153.1266[1m  5/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:48[0m 357ms/step - accuracy: 0.2969 - loss: 159.8931[1m  6/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:46[0m 355ms/s

2025-01-09 08:14:59.022095: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 447ms/step - accuracy: 0.4568 - loss: 10.3613

2025-01-09 08:15:00.730073: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
2025-01-09 08:15:07.452822: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m303s[0m 468ms/step - accuracy: 0.4569 - loss: 10.3505 - val_accuracy: 0.4328 - val_loss: 1.6820
Epoch 2/25
[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m15:33[0m 1s/step - accuracy: 0.6250 - loss: 1.5574[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m6:01[0m 562ms/step - accuracy: 0.5469 - loss: 1.6055[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:56[0m 462ms/step - accuracy: 0.5451 - loss: 1.6039[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:48[0m 450ms/step - accuracy: 0.5534 - loss: 1.5903[1m  5/645[0m [37m━━━━━━━━━

2025-01-09 08:19:17.055563: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m250s[0m 386ms/step - accuracy: 0.5186 - loss: 1.5482 - val_accuracy: 0.5672 - val_loss: 1.3764
Epoch 3/25


2025-01-09 08:19:22.479278: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]


[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m6:01[0m 562ms/step - accuracy: 0.6250 - loss: 1.3071[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:49[0m 451ms/step - accuracy: 0.6562 - loss: 1.2828[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:17[0m 400ms/step - accuracy: 0.6597 - loss: 1.2767[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:04[0m 381ms/step - accuracy: 0.6549 - loss: 1.2818[1m  5/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:09[0m 390ms/step - accuracy: 0.6590 - loss: 1.2837[1m  6/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:14[0m 398ms/step - accuracy: 0.6550 

2025-01-09 08:23:32.656358: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m258s[0m 400ms/step - accuracy: 0.5409 - loss: 1.3832 - val_accuracy: 0.5805 - val_loss: 1.2462
Epoch 4/25
[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m7:02[0m 656ms/step - accuracy: 0.4375 - loss: 1.5032[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m5:43[0m 534ms/step - accuracy: 0.4844 - loss: 1.4544[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:51[0m 454ms/step - accuracy: 0.5104 - loss: 1.4177[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:32[0m 426ms/step - accuracy: 0.5195 - loss: 1.4011[1m  5/645[0m [37m━━━━━━━

2025-01-09 08:28:03.232950: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m270s[0m 418ms/step - accuracy: 0.5818 - loss: 1.2115 - val_accuracy: 0.6187 - val_loss: 1.1172
Epoch 5/25


2025-01-09 08:28:10.055366: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]


[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m5:15[0m 491ms/step - accuracy: 0.5625 - loss: 1.2793[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:11[0m 391ms/step - accuracy: 0.5625 - loss: 1.2335[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:30[0m 421ms/step - accuracy: 0.5486 - loss: 1.2053[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:31[0m 423ms/step - accuracy: 0.5521 - loss: 1.1844[1m  5/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:30[0m 422ms/step - accuracy: 0.5517 - loss: 1.1784[1m  6/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:30[0m 424ms/step - accuracy: 0.5535 

2025-01-09 08:33:00.275490: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m299s[0m 463ms/step - accuracy: 0.5987 - loss: 1.1406 - val_accuracy: 0.6313 - val_loss: 1.0322
Epoch 6/25
[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m9:25[0m 878ms/step - accuracy: 0.6875 - loss: 0.8800[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:39[0m 434ms/step - accuracy: 0.7031 - loss: 0.8846[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:30[0m 422ms/step - accuracy: 0.6979 - loss: 0.8926[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4:42[0m 441ms/step - accuracy: 0.6875 - loss: 0.9101[1m  5/645[0m [37m━━━━━━━

2025-01-09 08:37:54.503489: W tensorflow/core/framework/dataset.cc:993] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


[1m645/645[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m300s[0m 464ms/step - accuracy: 0.6798 - loss: 0.9408 - val_accuracy: 0.6961 - val_loss: 0.8806
Epoch 7/25
[1m  1/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m13:22[0m 1s/step - accuracy: 0.8125 - loss: 0.7499[1m  2/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m8:16[0m 771ms/step - accuracy: 0.8125 - loss: 0.7570[1m  3/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m8:50[0m 827ms/step - accuracy: 0.8194 - loss: 0.7611[1m  4/645[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m8:55[0m 835ms/step - accuracy: 0.8177 - loss: 0.7698[1m  5/645[0m [37m━━━━━━━━━━━

Visualize model fit

In [0]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 6))

# Plot training and validation loss
ax1.set_xlabel('Epochs', fontsize=18)
ax1.set_ylabel('Loss', fontsize=18)
ax1.plot(history.history['loss'], color='green', label='Training Loss', linewidth=2.5)
ax1.plot(history.history['val_loss'], color='orange', linestyle='--', label='Validation Loss', linewidth=2.5)
ax1.legend(loc='upper right', fontsize=18)
ax1.tick_params(axis='both', which='major', labelsize=16)

# Plot training and validation accuracy
ax2.set_xlabel('Epochs', fontsize=18)
ax2.set_ylabel('Accuracy', fontsize=18)
ax2.plot(history.history['accuracy'], color='green', label='Training Accuracy', linewidth=2.5)
ax2.plot(history.history['val_accuracy'], color='orange', linestyle='--', label='Validation Accuracy', linewidth=2.5)
ax2.legend(loc='upper left', fontsize=18)
ax2.tick_params(axis='both', which='major', labelsize=16)

fig.tight_layout()
plt.show()

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data

Predict test data, evaluate accuracy and visualize

In [0]:
# Predict test data
prob=model.predict(test_data)
predict_classes=np.argmax(prob,axis=1)
predict_classes

# Generate and plot confusion matrix
conf_matrix = confusion_matrix(test_lab, predict_classes)
conf_matrix_normalized = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

fig, ax = plt.subplots(1, 2, figsize=(12, 6))

sns.heatmap(conf_matrix_normalized, annot=True, fmt='.2f', cmap='Blues', cbar_kws={'label': 'Accuracy (%)'}, 
            xticklabels=['No AD', 'Mild AD', 'Moderate AD', 'Severe AD'], 
            yticklabels=['No AD', 'Mild AD', 'Moderate AD', 'Severe AD'], 
            ax=ax[0], linewidths=1, linecolor='black')
ax[0].set_xlabel('Predicted Labels')
ax[0].set_ylabel('True Labels')
ax[0].set_title('Confusion Matrix')

# Bar plot for the distribution of the test set
train_label_counts = train['label'].value_counts().sort_index()
ax[1].bar(train_label_counts.index, train_label_counts.values, color = ['#aec7e8', '#ffbb78', '#98df8a', '#ff9896'])
ax[1].set_xlabel('Class Label')
ax[1].set_ylabel('Frequency')
ax[1].set_title('Distribution of Training Set')
ax[1].set_xticks(train_label_counts.index)
ax[1].set_xticklabels(['No AD', 'Mild AD', 'Moderate AD', 'Severe AD'])

plt.tight_layout()
plt.show()

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data

Overall we do not see a direct link between the test set class accuracies and the number of measurements in the training set. In fact the class with the lowest number of training measurements achieved the second highest accuracy. Therefore, even though it would be ideal to have completely (or moreso) balanced training classes, this is not a major bias in the current model.

Define hyperparameters to tune and the corresponding parameter space

In [0]:
# Same architecture as before, but with hyperparameter ranges
def build_model(hp):

    model = keras.Sequential([
    keras.Input(shape = (128, 128, 1)),  
        
    keras.layers.Conv2D(
        filters = hp.Int('conv_1_filter', min_value = 32, max_value = 128, step = 32), 
        kernel_size = hp.Choice('conv_1_kernel', values = [3,3]), 
        activation = 'relu',
        kernel_regularizer=keras.regularizers.l2(0.01)), 
    keras.layers.MaxPooling2D((2, 2)),

    keras.layers.Conv2D(
        filters = hp.Int('conv_2_filter', min_value = 64, max_value = 128, step = 32),
        kernel_size = hp.Choice('conv_2_kernel', values = [3,3]),
        activation = 'relu',
        kernel_regularizer=keras.regularizers.l2(0.01)), 
    keras.layers.MaxPooling2D((2, 2)),

    keras.layers.Conv2D(
        filters = hp.Int('conv_3_filter', min_value = 96, max_value = 128, step = 32),
        kernel_size = hp.Choice('conv_3_kernel', values = [3,3]),
        activation = 'relu',
        kernel_regularizer=keras.regularizers.l2(0.01)),
    keras.layers.MaxPooling2D((2, 2)),
        
    keras.layers.Flatten(),
    keras.layers.Dense(
        units=hp.Int('dense_1_units', min_value = 128, max_value = 256, step = 32),
        activation='relu',
        kernel_regularizer=keras.regularizers.l2(0.01)),
        
    keras.layers.Dropout(0.5),
        
    keras.layers.Dense(4, activation = 'softmax')
    ])
    
    hp_learning_rate = hp.Choice('learning_rate', values=[1e-3, 1e-4])
    model.compile(optimizer = keras.optimizers.Adam(learning_rate = hp_learning_rate),
                  loss = 'categorical_crossentropy',
                  metrics = ['accuracy'])
    
    return model

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data

Initiate tuner

In [0]:
tuner = kt.Hyperband(build_model, objective = 'val_accuracy', max_epochs = 20, factor = 3, directory = 'my_dir', project_name = 'AD_class')
stop_early = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5)

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data

Run search

In [0]:
# tune parameters on subset of data to conserve time/memory
subset = []
for label in np.unique(train_lab):
    label_indices = np.where(train_lab == label)[0]
    np.random.shuffle(label_indices)
    subset.extend(label_indices[:int(0.2 * len(label_indices))])
subset = np.array(subset)
import matplotlib.pyplot as plt

# Plot bar graph of label distribution in subset
label_counts = np.bincount(train_lab_tune.argmax(axis=1))
labels = np.arange(len(label_counts))

plt.bar(labels, label_counts, tick_label=labels)
plt.xlabel('Labels')
plt.ylabel('Count')
plt.title('Label Distribution in Subset')
plt.show()
train_data_tune = train_data[subset,:,:,:]
train_lab_tune = train_lab[subset]
train_lab_tune = to_categorical(train_lab_tune.astype('int8'))

subset = []
for label in np.unique(test_lab):
    label_indices = np.where(test_lab == label)[0]
    np.random.shuffle(label_indices)
    subset.extend(label_indices[:int(0.2 * len(label_indices))])
subset = np.array(subset)
test_data_tune = test_data[subset,:,:]
test_lab_tune = test_lab[subset]
test_lab_tune = to_categorical(test_lab_tune.astype('int8'))

tuner.search(train_data_tune, train_lab_tune, epochs = 10, callbacks = [stop_early],
             validation_data = (test_data_tune, test_lab_tune))

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data

In [0]:
# Optimal hyperparameters
best_hps = tuner.get_best_hyperparameters(num_trials = 1)[0]

print(f"""
Optimal parameters are as follows:

Filter 1 output dim: {best_hps.get('conv_1_filter')}
Filter 2 output dim: {best_hps.get('conv_2_filter')}
Filter 2 output dim: {best_hps.get('conv_3_filter')}

Dense layer units: {best_hps.get('units')}

Learning Rate: {best_hps.get('learning_rate')}
""")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data

Optimal # of epochs

In [0]:
# model = tuner.hypermodel.build(best_hps)
# history = model.fit(train_images, train_labels, epochs = 50, validation_split = 0.2)

# val_acc_per_epoch = history.history['val_accuracy']
# best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
# print('Best epoch: %d' % (best_epoch,))

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:138)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:717)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:458)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:706)
	at com.data