**Pre-req:** Confirm that you are actually using a GPU instance by running `nvidia-smi` from terminal.  
**Important Note:** You will not see GPU being usable unless those modules have been loaded when the session was created.  

This notebook uses tensorboard API to aid with hyperparameter tuning, which can be used in visualization process post the tuning process.

In [None]:
# Env variable to optimize memory usage
import os
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [None]:
import datetime
import tensorflow as tf
from tensorflow.keras import datasets, layers, models,preprocessing
from tensorboard.plugins.hparams import api as hp
import numpy as np
import matplotlib.pyplot as plt
import pickle
from sklearn.preprocessing import MinMaxScaler,StandardScaler
from sklearn.metrics import f1_score, confusion_matrix,precision_score,recall_score,accuracy_score,log_loss
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold

In [None]:
# Define base path where files will be stored.
# This is unpacked from the pickle file created in Step 0.

with open('pickledHomeScratchShared.pickle', "rb") as f:
    baseHomePath,baseScratchPath,baseSharedPath = pickle.load(f)

In [None]:
# Set memory growth on GPUs (another step for memory optimization)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
  except RuntimeError as e:
    print("Error: "+e)

#Print GPU devices, if any, that are available
physical_devices = tf.config.list_physical_devices('GPU')
print("Physical GPU Devices: ", physical_devices)

logical_devices = tf.config.list_logical_devices('GPU')
print("Logical GPU Devices: ", logical_devices)


### Choose view here by selecting one of the options in `viewChoices` and one of the `colorChoices`

In [None]:
# Enter choice of view by selecting 'choice' variable:
viewChoices = ['Transverse','Coronal','Sagittal']
choice = viewChoices[2]

# File choices
fileChoices = {
'608' : 'This dataset contains 16-stitiched frames using 235 MRIs from OASIS-1("N/A" CDR rows were filtered) + 373 MRIs from OASIS-2',
'809' : 'This dataset contains 16-stitiched frames using 436 MRIs from OASIS-1("N/A" CDR rows assumed as CDR=0) + 373 MRIs from OASIS-2',
'508' : 'This dataset contains 16-stitiched frames using 135 MRIs from OASIS-1(Only CDR=0 rows were kept) + 373 MRIs from OASIS-2',
'Processed' : 'This dataset contains 436 MRIs from OASIS-1 taken from the "PROCESSED" folder. This is NOT stitched.'
}
fileChoice = list(fileChoices.keys())[3]

# Define file names here to make it easier to substitute in next code cell

if fileChoice == '608':
    transverseFile = 'skip_120_stitched_imgs_t_all_608.pickle'
    sagittalFile = 'skip_56_stitched_imgs_s_all_608.pickle'
    coronalFile = 'skip_120_stitched_imgs_c_all_608.pickle'
    labelsFile = 'all_labels_allImgs_608.pickle'
    mriIDFile = 'all_mri_id_allImgs_608.pickle'
elif fileChoice == '809':
    transverseFile = 'skip_120_stitched_imgs_t_all_809.pickle'
    sagittalFile = 'skip_56_stitched_imgs_s_all_809.pickle'
    coronalFile = 'skip_120_stitched_imgs_c_all_809.pickle'
    labelsFile = 'all_labels_allImgs_809.pickle'
    mriIDFile = 'all_mri_id_allImgs_809.pickle'   
elif fileChoice == '508':
    transverseFile = 'skip_120_stitched_imgs_t.pickle'
    sagittalFile = 'skip_56_stitched_imgs_s.pickle'
    coronalFile = 'skip_120_stitched_imgs_c.pickle'
    labelsFile = 'all_labels.pickle'
    mriIDFile = 'all_mri_id.pickle'   
elif fileChoice == 'Processed':
    transverseFile = 'processed_img_t.pickle'
    sagittalFile = 'processed_img_s.pickle'
    coronalFile = 'processed_img_c.pickle'
    if choice == 'Transverse':
        labelsFile = 'all_labels_processed_t.pickle'
        mriIDFile = 'all_mri_id_processed_t.pickle'
    elif choice == 'Sagittal':
        labelsFile = 'all_labels_processed_s.pickle'
        mriIDFile = 'all_mri_id_processed_s.pickle'
    elif choice == 'Coronal':
        labelsFile = 'all_labels_processed_c.pickle'
        mriIDFile = 'all_mri_id_processed_c.pickle'     
  
     

print('You have chosen to train with the view of "{}"'.format(choice))
print(fileChoices[fileChoice])

In [None]:
# Load objects from serialized files

if choice == 'Transverse':
    with open("{}/{}".format(baseSharedPath,transverseFile), "rb") as f:
        img_16frames = pickle.load(f)
elif choice == 'Coronal':
    with open("{}/{}".format(baseSharedPath,coronalFile), "rb") as f:
        img_16frames = pickle.load(f)
elif choice == 'Sagittal':
    with open("{}/{}".format(baseSharedPath,sagittalFile), "rb") as f:
        img_16frames = pickle.load(f)

            
with open("{}/{}".format(baseSharedPath,labelsFile), "rb") as f:
    all_labels = pickle.load(f)
    
with open("{}/{}".format(baseSharedPath,mriIDFile), "rb") as f:
    all_mri_id = pickle.load(f)
    
# 'Stack' the images  since it currently is represented in form of a scalar's shape (508,)
# Needs to be more on the lines of (508,x,y)
img_16frames = np.stack(img_16frames, axis=0)
# Reshape the images to 4 dimensional (1st dim is number of images, 2nd is height, 3rd is width and 4th is scalar)
img_16frames = img_16frames.reshape(-1, img_16frames.shape[1], img_16frames.shape[2],1)


print('You have chosen to train with the view of "{}"'.format(choice))
print('Shape of dataset : {}'.format(img_16frames.shape))
print('Shape of labels : {}'.format(all_labels.shape))


In [None]:
# Standardize data using StandardScaler()

# Flatten the array along the last dimension
img_16frames_flat = img_16frames.reshape(-1, img_16frames.shape[-1])
# Standardize the flattened array
scaler = StandardScaler()
img_16frames_flat_scaled = scaler.fit_transform(img_16frames_flat)
# Reshape the standardized array to its original shape and reassign to 'img_16frames'
img_16frames = img_16frames_flat_scaled.reshape(img_16frames.shape)

print(img_16frames.shape)


In [None]:
# Split to use portion of dataset
"""
img_16frames = img_16frames[:235,:,:,:]
all_labels = all_labels[:235]
"""

print('Shape of split dataset : {}'.format(img_16frames.shape))
print('Shape of split labels : {}'.format(all_labels.shape))

In [None]:
X_train, X_test, y_train, y_test  = train_test_split(img_16frames, all_labels
                                                     ,test_size=0.2, random_state=13 , stratify=all_labels)

In [None]:
# Show a sample image to see if expected view

def plot_stitched_img(stitched_img):
    # takes arrays from get_mri_array function.
    # returns a sample of the image.
    plt.close();
    plt.figure(figsize=(50,30)) 
    plt.imshow(stitched_img, cmap=plt.cm.gray_r, interpolation="nearest") 
    plt.show()

plot_stitched_img(img_16frames[0][:,:,0])

In [None]:
# Define the hyperparameters to tune
HP_FILTER_SIZE = hp.HParam('filter_size', hp.Discrete([3, 5, 7]))
HP_NUM_FILTERS = hp.HParam('num_filters', hp.Discrete([8, 32, 64]))
HP_DROPOUT = hp.HParam('dropout', hp.RealInterval(0.2, 0.4))
#HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))

# Define the metrics to log
METRIC_ACCURACY = 'accuracy'
METRIC_LOSS = 'loss'
METRIC_RECALL = 'recall'
METRIC_PRECISION = 'precision'
METRIC_F1 = 'f1_score'

In [None]:
# Define the model function
def cnn_model(hparams,fileChoice):
    if fileChoice in ['608','809','508']: #Model catering to stitched images
        model = models.Sequential()
        model.add(layers.Conv2D(hparams[HP_NUM_FILTERS],hparams[HP_FILTER_SIZE], activation='relu',
                    input_shape=(img_16frames.shape[1], img_16frames.shape[2], 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
            ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu'))
        model.add(layers.Dropout(hparams[HP_DROPOUT]))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=128, kernel_size= 3
            ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu'))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(2))
        tf.random.set_seed(42)
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])
    elif fileChoice == 'Processed': # Model catering to processed images
        model = models.Sequential()
        model.add(layers.Conv2D(filters=hparams[HP_NUM_FILTERS], kernel_size= hparams[HP_FILTER_SIZE]
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu'
                        ,input_shape=(img_16frames.shape[1], img_16frames.shape[2], 1), name = "C_2d_1"))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "C_2d_2"))
        model.add(layers.Dropout(hparams[HP_DROPOUT]))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "C_2d_3"))
        model.add(layers.Dropout(0.15))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "C_2d_4"))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "Dense_1"))
        model.add(layers.Dense(2))
        tf.random.set_seed(42)
        model.compile(tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
    return model

# Define the training function
def train_cnn(hparams, log_dir):
    model = cnn_model(hparams,fileChoice)
    checkpoint_filepath = "{}/{}".format(os.path.dirname(baseSharedPath),'tmp') # 'tmp' folder in shared space
    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
        ,hp.KerasCallback(log_dir, hparams)
        ,tf.keras.callbacks.ModelCheckpoint( # Saves best model on each epoch, which can be loaded and used
                filepath=checkpoint_filepath, save_weights_only=True,monitor='val_accuracy',mode='max',save_best_only=True)
    ]
    history = model.fit(X_train,  y_train, epochs=15, 
                    validation_data=(X_test, y_test), callbacks=callbacks)
    model.load_weights(checkpoint_filepath) # Load best model that is available from ModelCheckpoint callback data
    # Calculate precision, recall and f1_score as well
    y_pred = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred, axis=1)
    f1 = f1_score(y_test, y_pred_classes, average='macro')
    precision = precision_score(y_test, y_pred_classes, average='macro')
    recall = recall_score(y_test, y_pred_classes, average='macro')
    accuracy = accuracy_score(y_test_binary, y_pred_binary)
    loss = log_loss(y_test_binary, y_pred_binary)
    return (accuracy,loss,f1,precision,recall)

# Define the main function for hyperparameter tuning
def run(run_dir, hparams):
    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hparams)
        accuracy,loss,f1,precision,recall = train_cnn(hparams, run_dir)
        tf.summary.scalar(METRIC_ACCURACY, accuracy, step=1)
        tf.summary.scalar(METRIC_LOSS, loss, step=1)
        tf.summary.scalar(METRIC_F1, f1, step=1)
        tf.summary.scalar(METRIC_PRECISION, precision, step=1)
        tf.summary.scalar(METRIC_RECALL, recall, step=1)

In [None]:
# Define the hyperparameters grid search space
runDateTime = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
session_num = 0
for filter_size in HP_FILTER_SIZE.domain.values:
    for num_filters in HP_NUM_FILTERS.domain.values:
        for dropout_rate in (HP_DROPOUT.domain.min_value, HP_DROPOUT.domain.max_value):
            hparams = {
                HP_FILTER_SIZE: filter_size,
                HP_NUM_FILTERS: num_filters,
                HP_DROPOUT: dropout_rate
            }
            run_name = f"run-{session_num}"
            print(f"--- Starting trial: {run_name}")
            print({h.name: hparams[h] for h in hparams})
            mirrored_strategy = tf.distribute.MirroredStrategy()
            with mirrored_strategy.scope():
                argForRun = f"../data/logs/hparam_tuning/{choice}/{runDateTime}-{fileChoice}/{run_name}"
                run(argForRun,hparams)
            session_num += 1