In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.utils import to_categorical
import pickle as pkl
import time
print("TF  Version",tf.__version__)

TF  Version 2.2.0


Using TensorFlow backend.


In [2]:
# check and set path before loading modules
INPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
OUTPUT_DIR = "/tf/notebooks/schnemau/xAI_stroke_3d/"
if os.getcwd() != OUTPUT_DIR:
    os.chdir(OUTPUT_DIR)
    
import functions_model_definition as md
import functions_read_data as rdat

from functions.augmentation3d import zoom, rotate, flip, shift

In [3]:
# Define Version
#version = "CIBLSX" # one of:
version = "CIB" # one of:

# Define Model Version
model_version = 4

# Select naming convention (for CIBLSX model_version >= 3 should be False)
comp_mode = False # if True: use old naming convention

# define paths
DATA_DIR, WEIGHT_DIR, DATA_OUTPUT_DIR, PIC_OUTPUT_DIR, pic_save_name = rdat.dir_setup(
    INPUT_DIR, OUTPUT_DIR, version, model_version, 
    compatibility_mode=comp_mode)

save_csv = False

In [4]:
## load images and ids
(X_in, pat_ids, id_tab, all_results_tab, pat_orig_tab, pat_norm_tab, num_models) = rdat.version_setup(
    DATA_DIR = DATA_DIR, 
    version = version, 
    model_version = model_version,
    compatibility_mode=comp_mode)

Results Table does not exist for CIB M66. Returning None for all_results_tab.


In [5]:
# define model
(input_dim_img, output_dim, LOSS, layer_connection, last_activation) = md.model_setup(version)

model_3d = md.model_init(
    version = version, 
    output_dim = output_dim,
    LOSS = LOSS,
    layer_connection = layer_connection,
    last_activation = last_activation,
    C = 2,
    learning_rate = 5*1e-5,
    batch_size = 6,
    input_dim = input_dim_img,
    input_dim_tab = pat_norm_tab.drop(columns=["p_id"]).shape[1] if "LSX" in version else None,
)

In [6]:
# Define Model Name
generate_model_name = md.set_generate_model_name(
    model_version = model_version, 
    layer_connection = layer_connection, 
    last_activation = last_activation, 
    path = WEIGHT_DIR,
    compatability_mode=comp_mode)  

In [7]:
model_nrs = list(range(num_models)) #num of ensembles
which_splits = list(range(0,10)) # 10 Fold

In [8]:
# function for augmentation
if pat_norm_tab is not None:
    def train_preprocessing(data, label):
        volume = data[0]
        tabular = data[1]
        volume = zoom(volume)
        volume = rotate(volume)
        volume = shift(volume)
        volume = flip(volume)
        return (volume, tabular), label
else: 
    def train_preprocessing(data, label):
        volume = data  
        volume = zoom(volume)
        volume = rotate(volume)
        volume = shift(volume)
        volume = flip(volume)
        return (volume), label

In [9]:
start0 = time.time()
for which_split in which_splits:
    start1 = time.time()
    print("\n\n\n\n################################################################################")
    print("Split " + str(which_split))
    print("################################################################################\n\n\n\n")
    
    data_split = rdat.split_data(id_tab, X_in, which_split, X_tab = pat_norm_tab)
    
    #Images
    X_valid = np.expand_dims(data_split["X"]["valid"], axis=-1)
    X_train = np.expand_dims(data_split["X"]["train"], axis=-1)
    
    #Outcomes    
    Y_valid = to_categorical(data_split["y"]["valid"])
    Y_train = to_categorical(data_split["y"]["train"])

    #Tabular data

    if pat_norm_tab is not None:
        X_tab_train = data_split["X_tab"]["train"]    
        X_tab_valid = data_split["X_tab"]["valid"]
        train_data = tf.data.Dataset.from_tensor_slices((X_train, X_tab_train))
        valid_data = tf.data.Dataset.from_tensor_slices((X_valid, X_tab_valid))
    else:
        X_tab_test = None
        X_tab_valid = None
        train_data = tf.data.Dataset.from_tensor_slices((X_train))
        valid_data = tf.data.Dataset.from_tensor_slices((X_valid))

    valid_labels = tf.data.Dataset.from_tensor_slices((Y_valid))
    valid_loader = tf.data.Dataset.zip((valid_data, valid_labels))
    valid_dataset = (valid_loader.batch(6, drop_remainder = True)) 

    train_labels = tf.data.Dataset.from_tensor_slices((Y_train))
    train_loader = tf.data.Dataset.zip((train_data, train_labels))
    train_dataset = (train_loader.shuffle(len(X_train)).map(train_preprocessing).batch(6, drop_remainder=True))

    for model_nr in model_nrs:
        start2 = time.time()

        if pat_norm_tab is not None:
            model_name = ("3D_CNN_avg_layer_binary_outcome_CIBLSX_split" + str(which_split) + 
                      "_ens" + str(model_nr) + "_M" + str(model_version) + ".h5")
        else:
            model_name = ("3D_CNN_avg_layer_binary_outcome_CIB_split" + str(which_split) + 
                      "_ens" + str(model_nr) + "_M" + str(model_version) + ".h5")

        checkpoint_cb = keras.callbacks.ModelCheckpoint(
            filepath = WEIGHT_DIR + model_name,
            verbose = (1 if which_split == 0 and model_nr == 0 else 0),
            save_weights_only = True,
            monitor = "val_loss", 
            mode = 'min',
            save_best_only = True)

        early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True)
               
        hist = model_3d.fit(
            train_dataset,
            validation_data=valid_dataset,
            epochs=500,
            shuffle=True,
            verbose=(1 if which_split == 0 and model_nr == 0 else 0), 
            callbacks=[checkpoint_cb, early_stopping_cb])

        pkl.dump(hist.history, open(WEIGHT_DIR + "hist_" + model_name[:-2] + "pkl", "wb"), protocol=4)     
        
        end2 = time.time()
        print(" ")   
        print("Duration of Training: " + str(end2-start2))  
        
    end1 = time.time()
    print(" ")   
    print("Duration of Split: " + str(end1-start1))  
        
end0 = time.time()
print(" ")
print("Duration of Everything: " + str(end0-start0))  





################################################################################
Split 0
################################################################################




Epoch 1/500
Epoch 00001: val_loss improved from inf to 0.44791, saving model to /tf/notebooks/schnemau/xAI_stroke_3d/weights/10Fold_CIB/3D_CNN_avg_layer_binary_outcome_CIB_split0_ens0_M66.h5
Epoch 2/500

KeyboardInterrupt: 