**Pre-req:** Confirm that you are actually using a GPU instance by running `nvidia-smi` from terminal.  
**Important Note:** You will not see GPU being usable unless those modules have been loaded when the session was created.  

In [None]:
# Env variable to optimize memory usage
import os
import datetime
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [None]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models,preprocessing,Model
from tensorflow.keras.applications.inception_v3 import InceptionV3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from sklearn.preprocessing import MinMaxScaler,StandardScaler
from sklearn.metrics import f1_score, confusion_matrix,precision_score,recall_score,accuracy_score,log_loss
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from IPython.display import display_html 

In [None]:
# Define base path where files will be stored.
# This is unpacked from the pickle file created in Step 0.

with open('pickledHomeScratchShared.pickle', "rb") as f:
    baseHomePath,baseScratchPath,baseSharedPath = pickle.load(f)

In [None]:
# Set memory growth on GPUs (another step for memory optimization)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
  except RuntimeError as e:
    print("Error: "+e)

#Print GPU devices, if any, that are available
physical_devices = tf.config.list_physical_devices('GPU')
print("Physical GPU Devices: ", physical_devices)

logical_devices = tf.config.list_logical_devices('GPU')
print("Logical GPU Devices: ", logical_devices)


### Choose view here by selecting one of the options in `viewChoices` and one of the `colorChoices`

In [None]:
# File choices
fileChoices = {
'608' : 'This dataset contains 16-stitiched frames using 235 MRIs from OASIS-1("N/A" CDR rows were filtered) + 373 MRIs from OASIS-2',
'809' : 'This dataset contains 16-stitiched frames using 436 MRIs from OASIS-1("N/A" CDR rows assumed as CDR=0) + 373 MRIs from OASIS-2',
'508' : 'This dataset contains 16-stitiched frames using 135 MRIs from OASIS-1(Only CDR=0 rows were kept) + 373 MRIs from OASIS-2',
'Processed' : 'This dataset contains 436 MRIs from OASIS-1 taken from the "PROCESSED" folder. This is NOT stitched.',
'Pre-trained' : 'Pre-trained InceptionV3 model applied on top of "Processed" dataset (i.e. 436 Processed MRIs from OASIS-1)'
}
# IMPORTANT NOTE: Make choice here, which will be used for rest of notebook.
fileChoice = list(fileChoices.keys())[1]
print(fileChoice," : ",fileChoices[fileChoice])

# Choice of view (No need to select since it will loop through all three)
viewChoices = ['Transverse','Coronal','Sagittal']

# Choose seed value used in remaining part of code
seedValueOptions = [14,42,696]
seedValue = seedValueOptions[2]
print('Seed value used throughout this notebook is {}'.format(seedValue))

In [None]:
# Code to help choose files based on file choice and view choice
def chooseFiles(viewChoice,fileChoice):
    if fileChoice == '608':
        transverseFile = 'skip_120_stitched_imgs_t_all_608.pickle'
        sagittalFile = 'skip_56_stitched_imgs_s_all_608.pickle'
        coronalFile = 'skip_120_stitched_imgs_c_all_608.pickle'
        labelsFile = 'all_labels_allImgs_608.pickle'
        mriIDFile = 'all_mri_id_allImgs_608.pickle'
    elif fileChoice == '809':
        transverseFile = 'skip_120_stitched_imgs_t_all_809.pickle'
        sagittalFile = 'skip_56_stitched_imgs_s_all_809.pickle'
        coronalFile = 'skip_120_stitched_imgs_c_all_809.pickle'
        labelsFile = 'all_labels_allImgs_809.pickle'
        mriIDFile = 'all_mri_id_allImgs_809.pickle'   
    elif fileChoice == '508':
        transverseFile = 'skip_120_stitched_imgs_t.pickle'
        sagittalFile = 'skip_56_stitched_imgs_s.pickle'
        coronalFile = 'skip_120_stitched_imgs_c.pickle'
        labelsFile = 'all_labels.pickle'
        mriIDFile = 'all_mri_id.pickle'   
    elif fileChoice in ['Processed','Pre-trained']:
        transverseFile = 'processed_img_t.pickle'
        sagittalFile = 'processed_img_s.pickle'
        coronalFile = 'processed_img_c.pickle'
        if viewChoice == 'Transverse':
            labelsFile = 'all_labels_processed_t.pickle'
            mriIDFile = 'all_mri_id_processed_t.pickle'
        elif viewChoice == 'Sagittal':
            labelsFile = 'all_labels_processed_s.pickle'
            mriIDFile = 'all_mri_id_processed_s.pickle'
        elif viewChoice == 'Coronal':
            labelsFile = 'all_labels_processed_c.pickle'
            mriIDFile = 'all_mri_id_processed_c.pickle'  
    # Load objects from serialized files
    if viewChoice == 'Transverse':
        with open("{}/{}".format(baseSharedPath,transverseFile), "rb") as f:
            img_16frames = pickle.load(f)
    elif viewChoice == 'Coronal':
        with open("{}/{}".format(baseSharedPath,coronalFile), "rb") as f:
            img_16frames = pickle.load(f)
    elif viewChoice == 'Sagittal':
        with open("{}/{}".format(baseSharedPath,sagittalFile), "rb") as f:
            img_16frames = pickle.load(f)
    with open("{}/{}".format(baseSharedPath,labelsFile), "rb") as f:
        all_labels = pickle.load(f)
    with open("{}/{}".format(baseSharedPath,mriIDFile), "rb") as f:
        all_mri_id = pickle.load(f)
        
    # 'Stack' the images  since it currently is represented in form of a scalar's shape (x,). Needs to be (x,y,z)
    img_16frames = np.stack(img_16frames, axis=0)

    if fileChoice == 'Pre-trained': # InceptionV3 needs 3 channels, hence have to adjust
        img_16frames = np.repeat(img_16frames[..., np.newaxis], 3, -1)
        print('Adjusting image shape for Pre-trained model to {}'.format(img_16frames.shape))
    else: # Reshape the images to 4 dimensional (1st dim is number of images, 2nd is height, 3rd is width and 4th is scalar)
        img_16frames = img_16frames.reshape(-1, img_16frames.shape[1], img_16frames.shape[2],1)

    return(img_16frames,all_labels,all_mri_id)

# Code to standardize dataset
def standardizeImg(img_16frames):
    # Flatten the array along the last dimension
    img_16frames_flat = img_16frames.reshape(-1, img_16frames.shape[-1])
    # Standardize the flattened array
    scaler = StandardScaler()
    img_16frames_flat_scaled = scaler.fit_transform(img_16frames_flat)
    # Reshape the standardized array to its original shape and reassign to 'img_16frames'
    img_16frames = img_16frames_flat_scaled.reshape(img_16frames.shape)
    return img_16frames


In [None]:
# No longer needed since this is being done inside the KFold code below.
# X_train, X_test, y_train, y_test  = train_test_split(img_16frames, all_labels
# ,test_size=0.2, random_state=seedValue , stratify=all_labels)

In [None]:
# Show sample image of one of the images based on the view that was picked

def plot_stitched_img(stitched_img):
    # takes arrays from get_mri_array function.
    # returns a sample of the image.
    plt.close();
    plt.figure(figsize=(50,30)) 
    plt.imshow(stitched_img, cmap=plt.cm.gray_r, interpolation="nearest") 
    plt.show()

#plot_stitched_img(img_16frames[0][:,:,0])

### Choose model hyperparameters and define re-usuable code to be able to generate model over each of the cross validation folds.

In [None]:
# Set values of hyperparameters here and substitute in code below.
conv2d_1_numfilters = 32
conv2d_1_kernelSize = 3
dropput_1 = 0.4
learning_rate = 0.0001

# Re-usable code for generating model using hyperparameters in global space
def generateModel(fileChoice):
    if fileChoice in ['608','809','508']: #Model catering to stitched images
        model = models.Sequential()
        model.add(layers.Conv2D(filters=conv2d_1_numfilters, kernel_size= conv2d_1_kernelSize, activation='relu',
                    input_shape=(img_16frames.shape[1], img_16frames.shape[2], 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3, kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu'))
        model.add(layers.Dropout(dropput_1))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=128, kernel_size= 3, kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu'))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(2))
        tf.random.set_seed(seedValue)
        model.compile(tf.keras.optimizers.Adam(learning_rate=learning_rate),
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])
    elif fileChoice == 'Processed': # Model catering to processed images
        model = models.Sequential()
        model.add(layers.Conv2D(filters=conv2d_1_numfilters, kernel_size= conv2d_1_kernelSize
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu'
                        ,input_shape=(img_16frames.shape[1], img_16frames.shape[2], 1), name = "C_2d_1"))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "C_2d_2"))
        model.add(layers.Dropout(0.15))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "C_2d_3"))
        model.add(layers.Dropout(0.15))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(filters=64, kernel_size= 3
                        ,kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "C_2d_4"))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, kernel_regularizer = tf.keras.regularizers.L2(0.01), activation='relu', name = "Dense_1"))
        model.add(layers.Dense(2))
        tf.random.set_seed(seedValue)
        model.compile(tf.keras.optimizers.Adam(learning_rate=learning_rate),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              #loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              #loss='binary_crossentropy',
              #metrics=[tf.keras.metrics.Recall()]
              metrics=['accuracy']
                     )
    
    elif fileChoice == 'Pre-trained':
        Inceptionv3_model = InceptionV3(input_shape = (img_16frames.shape[1], img_16frames.shape[2], 3)
                                , weights = 'imagenet', include_top = False)
        Inceptionv3_last_output = Inceptionv3_model.output
        Inceptionv3_maxpooled_output = layers.Flatten()(Inceptionv3_last_output)
        Inceptionv3_x = layers.Dense(1024, kernel_regularizer = tf.keras.regularizers.L2(0.01),
                                     activation = 'relu')(Inceptionv3_maxpooled_output)
        Inceptionv3_x = layers.Dropout(0.2)(Inceptionv3_x)
        Inceptionv3_x = layers.Dense(8, kernel_regularizer = tf.keras.regularizers.L2(0.01),
                                     activation = 'softmax')(Inceptionv3_x)
        model = Model(inputs = Inceptionv3_model.input, outputs = Inceptionv3_x)
        model.compile(optimizer = tf.keras.optimizers.SGD(lr = 0.0001, momentum = 0.9),
                                          loss = 'sparse_categorical_crossentropy',
                                          metrics = ['accuracy'])
    return model

# Reusable code to generate f1_score, precision, recall and confusion matrix for a pair of X_test,y_test against a model
def generateF1Score(model,X_test,y_test):
    y_pred = model.predict(X_test)
    y_pred = y_pred.round(1)
    y_pred_binary = [0 if x[0] > x[1] else 1 for x in y_pred]
    y_test_binary = list(y_test.reshape(1,-1)[0])
    f1Score = f1_score(y_test_binary, y_pred_binary, average='macro')
    precision = precision_score(y_test_binary, y_pred_binary)
    recall = recall_score(y_test_binary, y_pred_binary)
    accuracy = accuracy_score(y_test_binary, y_pred_binary)
    conf_matx = pd.DataFrame(confusion_matrix(y_test_binary, y_pred_binary), index = ['Actual-Neg', 'Actual-Pos'], columns = ['Pred-Neg', 'Pred-Pos'])
    return (f1Score,precision,recall,accuracy,conf_matx)


In [None]:
# If 'Pre-trained', have to change img_16frames

In [None]:
# Main code cell for training over stratified k-fold data

allModelsScores = dict() # Initialize a dictionary that stores all the information of all the models
n_splits = 5

for viewChoice in viewChoices:
    print('The choice of view is {}'.format(viewChoice))
    img_16frames,all_labels,all_mri_id = chooseFiles(viewChoice,fileChoice) # Generate dataset
    img_16frames = standardizeImg(img_16frames) # Standardize dataset
    modelCVScores = [] # Empty list to store scores from each iteration
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seedValue) # Define the Stratified KFold object
    # Train for chosen view over k-folds using distributed strategy
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        # Iterate through each fold
        for i, (train_index, test_index) in enumerate(skf.split(img_16frames,all_labels)):
            print('Iteration: {} for {}'.format(i+1,viewChoice))
            # Generate model from scratch 
            model = generateModel(fileChoice)
            checkpoint_filepath = "{}/{}".format(os.path.dirname(baseSharedPath),'tmp') # 'tmp' folder in shared space
            callbacks = [
                tf.keras.callbacks.ModelCheckpoint( # Saves best model on each epoch, which can be loaded and used
                    filepath=checkpoint_filepath, save_weights_only=True,monitor='val_accuracy',mode='max',save_best_only=True)
            ]
            # Generate new datasets for X_train, X_test, y_train, y_test on each run
            X_train, X_test =  img_16frames[train_index],img_16frames[test_index]
            y_train, y_test = all_labels[train_index],all_labels[test_index]
            ActualNegativesInTrain ,ActualPositivesInTrain = len(y_train[y_train==0]),len(y_train[y_train==1])
            ActualNegativesInTest, ActualPositivesInTest= len(y_test[y_test==0]),len(y_test[y_test==1])
            model.fit(X_train,  y_train, epochs=20, 
                            validation_data=(X_test, y_test),callbacks=callbacks,verbose=2)
            history = model.load_weights(checkpoint_filepath) # Load best model that is available from ModelCheckpoint callback data
            # Store the evaluation metrics for this fold
            modelCVScores.append([ActualNegativesInTrain ,ActualPositivesInTrain
                                  ,ActualNegativesInTest, ActualPositivesInTest
                                  ,*generateF1Score(model,X_test,y_test)])
        allModelsScores[viewChoice] = modelCVScores

In [None]:
# Display all of the confusion matrices one over the other

conf_mat_list = [[allModelsScores[key][i][-1] for key in allModelsScores] for i in range(n_splits)]
i = 0
for iteration in conf_mat_list:
    i += 1
    confT,confC,confS = iteration
    confT_styler = confT.style.set_table_attributes("style='display:inline'").set_caption('Iteration {}-{}'.format(i,'Transvere'))
    confC_styler = confC.style.set_table_attributes("style='display:inline'").set_caption('Iteration {}-{}'.format(i,'Coronal'))
    confS_styler = confS.style.set_table_attributes("style='display:inline'").set_caption('Iteration {}-{}'.format(i,'Sagittal'))
    space = "\xa0" * 10
    display_html(confT_styler._repr_html_()+space+confC_styler._repr_html_()+space+confS_styler._repr_html_()
             , raw=True)

In [None]:
# Unpack all of the models from allModelsScores and display
print('Below are results for a {} K-Fold validation training using the {} dataset'.format(n_splits,fileChoice))
modelDfT,modelDfC,modelDfS = [pd.DataFrame([item[0:-1] for item in allModelsScores[key]]
             ,columns= ["ActualNegativesInTrain" ,"ActualPositivesInTrain"
                              ,"ActualNegativesInTest", "ActualPositivesInTest"
                        ,"f1Score","precision","recall","accuracy"]) for key in allModelsScores]

modelDfT_styler = modelDfT.style.set_table_attributes("style='display:inline'").set_caption('Transvere')
modelDfC_styler = modelDfC.style.set_table_attributes("style='display:inline'").set_caption('Coronal')
modelDfS_styler = modelDfS.style.set_table_attributes("style='display:inline'").set_caption('Sagittal')

space = "\xa0" * 10
display_html(modelDfT_styler._repr_html_()+space+modelDfC_styler._repr_html_()+space+modelDfS_styler._repr_html_()
             , raw=True)


In [None]:
# Print average across all the models
pd.DataFrame([modelDfT.mean(),modelDfC.mean(),modelDfS.mean()],index=viewChoices)

In [None]:
# Pickle results with appropriate name so that it can be loaded and used if needed
runDateTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
with open("{}/all_model_scores_{}_{}fold_seed{}_{}.pickle".format(baseSharedPath+'/model_scores',fileChoice,n_splits,seedValue,runDateTime), "wb") as f:
    pickle.dump(allModelsScores, f)