# Train an X Fold CV Binary Model based on 3D Image Data

The model is based on the following paper https://arxiv.org/abs/2206.13302. However, only binary classification is performed.  
In order to have each patient once in the test set the data is not split in the same way as in the paper.

### Import Libraries and Modules

In [None]:
!pip install statsmodels
!pip install seaborn

In [None]:
%matplotlib inline

import os
import h5py
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import zipfile
import random
import pickle as pkl

from sklearn.model_selection import KFold
from sklearn import preprocessing
from sklearn import metrics
from scipy import ndimage

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential, Model
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical

In [None]:
# check and set path before loading modules
print(os.getcwd())
DIR = "/tf/notebooks/brdd/xAI_stroke_3d/"
if os.getcwd() != DIR:
    os.chdir(DIR)

import functions_metrics as fm
import functions_model_definition as md
import functions_read_data as rdat

print("TF  Version",tf.__version__)

In [None]:
# tf.config.experimental.get_memory_usage("GPU:0")

### Overview of all Runs

Following an overview of all runs and their corresponding version name. Each version is saved in a folder with the same name. Before a new versionn is trained, the folders must be created manually.

- 10Fold_sigmoid_V0 (old name: 10Fold_sigmoid): 10 stratifed (with outcome mrs > 2 or mrs <= 2) Folds trained with the last layer beeing activated with sigmoid (5 ensembles per split)
- 10Fold_softmax_V0: same Folds as 10Fold_sigmoid but last layer activated with softmax (5 ensembles per split)
- 10Fold_softmax_V1: new 10 Fold stratified (with mrs) and last layer activated with softmax (10 ensembles per split)
- 10Fold_sigmoid_V1: same Folds as 10Fold_softmax_V1 and last layer activated with sigmoid (10 ensembles per split)
- 10Fold_sigmoid_V2: 10 Fold binary stratified (mrs > or <= 2) other seed than V0, and last layer activated with sigmoid (5 ensembles per split)
- 10Fold_sigmoid_V2f: same as 10Fold_sigmoid_V2 but with flatten Layer
- 10Fold_signoid_V3: 10 Fold binary stratified (mrs > or <= 2) without TIA patients, other seed than V0 and V2 and last layer activated wih sigmoid (5 ensembles per split)

10Fold_sigmoid_V0 was trained twice, therefore 2 model_versions exist (1 and 2), for all other versions only 1 model_version exists. Both model_versions are saved in the same folder.

### Load Data & Model


In [None]:
# Define the path + output path:
print(os.getcwd())
DATA_DIR = DIR + "data/"
OUTPUT_DIR = DIR + "weights/10Fold_sigmoid_V0/"
# check model name in kfold loop


In [None]:
# define the path of the data and the model

id_tab = pd.read_csv(DATA_DIR + "10Fold_ids_V0.csv", sep=",") # for V0
# id_tab = pd.read_csv(DATA_DIR + "10Fold_ids_V1.csv", sep=",") # for V1
# id_tab = pd.read_csv(DATA_DIR + "10Fold_ids_V2.csv", sep=",") # for V2 and V2f
# id_tab = pd.read_csv(DATA_DIR + "10Fold_ids_V3.csv", sep=",") # for V3
X = np.load(DATA_DIR + "prepocessed_dicom_3d.npy")

model_version = 2 # define the model version

print(id_tab.shape)
print(X.shape)

In [None]:
# define parameters of the model

train = False # if False then the weigths are loaded and only the output is calculated

# Define Model
layer_connection = "globalAveragePooling" # globalAveragePooling, flatten
last_activation = "sigmoid" # sigmoid, softmax

if last_activation == "sigmoid":
    LOSS = "binary_crossentropy"
elif last_activation == "softmax":
    LOSS = tf.keras.losses.categorical_crossentropy

In [None]:
num_splits = 10   # 10 for all models
num_models = 5    # see overview of all models (above)

batch_size = 6    # 6 for all models
epochs = 250      # 250 for all models

In [None]:
input_dim = np.expand_dims(X, axis = -1).shape[1:]

if last_activation == "sigmoid":
    output_dim = 1
elif last_activation == "softmax":
    output_dim = 2

# call model
model_3d = md.stroke_binary_3d(input_dim = input_dim,
                               output_dim = output_dim,
                               layer_connection = layer_connection,
                               last_activation = last_activation)

model_3d.summary()

In [None]:
def validation_preprocessing(volume, label):
    """Process validation data by only adding a channel."""
    volume = tf.expand_dims(volume, axis=3)
    return volume, label


### Train Model

In [None]:
test_auc = []
test_nll = []
test_sens = []
test_spez = []

In [None]:
# loop over splits (kFold)

start0 = time.time()
for i in range(num_splits):
    start1 = time.time()
    print("\n\n\n\n################################################################################")
    print("Split " + str(i))
    print("################################################################################\n\n\n\n")
    
    (X_train, X_valid, X_test), (y_train, y_valid, y_test) = rdat.split_data(id_tab, X, i)
    
    if last_activation == "softmax":
        y_train_enc = to_categorical(y_train)
        y_valid_enc = to_categorical(y_valid)
        y_test_enc = to_categorical(y_test)
    else:
        y_train_enc = y_train
        y_valid_enc = y_valid
        y_test_enc = y_test
    
    # loop over model instances (ensembling)
    for j in range(num_models):
        start2 = time.time()
        print("\n\n#######################################################")
        print("Split " + str(i) + " Model " + str(j))
        print("#######################################################\n\n")
        
        if layer_connection == "globalAveragePooling":
            model_name = ("3d_cnn_binary_model_split" + str(i) + 
                          "_unnormalized_avg_layer_paper_model_" + last_activation + "_activation_" + str(model_version) + str(j) + ".h5")
        elif layer_connection == "flatten":
            model_name = ("3d_cnn_binary_model_split" + str(i) + 
                          "_unnormalized_flat_layer_paper_model_" + last_activation + "_activation_" + str(model_version) + str(j) + ".h5")

        # call model
        model_3d = md.stroke_binary_3d(input_dim = input_dim,
                                       output_dim = output_dim,
                                       layer_connection = layer_connection,
                                       last_activation = last_activation)
        
        # Define data loaders.
        train_loader = tf.data.Dataset.from_tensor_slices((X_train, y_train_enc))
        validation_loader = tf.data.Dataset.from_tensor_slices((X_valid, y_valid_enc))

        # data augmentation
        datagen = ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            zoom_range=0.15,
            shear_range=0.15,
            fill_mode="nearest")
        datagen.fit(X_train)

        validation_dataset = (
            validation_loader.shuffle(len(X_valid))
            .map(validation_preprocessing)
            .batch(batch_size)
            .prefetch(2)
        )

        #compile
        model_3d.compile(
            loss=LOSS,
            optimizer=keras.optimizers.Adam(learning_rate=5*1e-5),
            metrics=["acc", tf.keras.metrics.AUC(name = "auc")]
        )

        # Define callbacks.
        # checkpoint_cb = keras.callbacks.ModelCheckpoint(
        #     OUTPUT_DIR + model_name, 
        #     save_best_only=True
        # )
        checkpoint_cb = keras.callbacks.ModelCheckpoint(
            filepath = OUTPUT_DIR + model_name,
            verbose = (1 if i == 0 and j == 0 else 0),
            save_weights_only = True,
            monitor = "val_loss", #'val_acc',
            mode = 'min',
            save_best_only = True)

        early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_loss", patience=100, restore_best_weights=True)
        
        
        # Train the model, doing validation at the end of each epoch
        if train:
            hist = model_3d.fit(
                datagen.flow(X_train, y_train_enc, batch_size=batch_size, shuffle=True),
                validation_data=validation_dataset,
                epochs=epochs,
                shuffle=True,
                verbose=(1 if i == 0 and j == 0 else 0),
                callbacks=[checkpoint_cb, early_stopping_cb]
            ) 
            pkl.dump(hist.history, open(OUTPUT_DIR + "hist_" + model_name[:-2] + "pkl", "wb"), protocol=4)
            histplt = hist.history
            
        if not train:
            histplt = pkl.load(open(OUTPUT_DIR + "hist_" + model_name[:-2] + "pkl", "rb"))
            
        # plot training history
        plt.figure(figsize = (30,10))
        plt.subplot(1,3,1)
        plt.plot(histplt['loss'], label = "loss")
        plt.plot(histplt['val_loss'],label = "val_loss")
        plt.legend()
        plt.subplot(1,3,2)
        plt.plot(histplt['acc'], label = "acc")
        plt.plot(histplt['val_acc'],label = "val_acc")
        plt.legend()
        plt.subplot(1,3,3)
        plt.plot(histplt['auc'], label = "auc")
        plt.plot(histplt['val_auc'],label = "val_auc")
        plt.legend()
        plt.show()
            
            
        # Model evaluation
        model_3d.load_weights(OUTPUT_DIR + model_name)

        model_3d.evaluate(x=X_test, y=y_test_enc, verbose = 0)
        
        (AUC, NLL, sens, spec) = fm.bin_class_report(
            X_test, 
            y_test_enc, 
            model = model_3d)
        
        if last_activation == "sigmoid":
            y_pred = model_3d.predict(X_test)
        elif last_activation == "softmax":
            y_pred = model_3d.predict(X_test)[:,1]

        fpr, tpr, threshold = metrics.roc_curve(y_test, (y_pred))
        roc_auc = metrics.auc(fpr, tpr)

        # method I: plt
        plt.title('Receiver Operating Characteristic')
        plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
        plt.legend(loc = 'lower right')
        plt.plot([0, 1], [0, 1],'r--')
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')
        plt.show()
        
        test_auc.append(AUC)
        test_nll.append(NLL)
        test_sens.append(sens)
        test_spez.append(spec)
        
        end2 = time.time()
        print(" ")   
        print("Duration of Training: " + str(end2-start2))  
        
    end1 = time.time()
    print(" ")   
    print("Duration of Split: " + str(end1-start1))  
        
end0 = time.time()
print(" ")
print("Duration of Everything: " + str(end0-start0))  