In [None]:
#---------------------------------------------------------------------------------
# CREATED BY: Arka Bhowmik and Sarah Eskreis-Winkler, Memorial Sloan Kettering Cancer Center, NY (2022)
#
# --------------------------------------------------------------------------------
# THIS IS A MAIN PROGRAM TO PERFORMS TRAINING OF CLASSIFICATION ALGORITHM (VGG-16) 
#
# 
# THE MODELS INCLUDES VARIOUS SUB-FUNCTIONS
# INSTRUCTIONS: ENSURE ALL SUBFUNCTIONS AND MAIN PROGRAM IN THE SAME FOLDER 
#               OR CHANGE THE PATH INSIDE ACCORDINGLY
#---------------------------------------------------------------------------------
# IMPORT IMPORTANT TENSORFLOW AND KERAS LIBRARIES TO RUN DEEP CNN
#
import tensorflow as tf
import tensorflow_addons as tfa
#
# SUPPORTING LIBRARIES
import os
import os.path
import sys
import numpy as np
import matplotlib.pyplot as plt
import time
import tqdm
#
# STOPS WARNING AND CHECKS FOR TENSORFLOW VERSION
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
#
#

In [None]:
#
#******************************************************************
#                  STEP 1: USER MODEL INPUTS
#******************************************************************
# Add path for suppoting functions
sys.path.append('/mri_triage_normal/support_function/')
#
# CHANGE ALL INPUT PARAMETERS ONLY IN config_vgg.PY (can be found in support_func)
import config_vgg as config
#
from read_and_split_vgg import read_and_split           # Calls the function for spliting the dataset into training and valid set
from user_input_balancing import user_input_balancing   # Calls the function for class balancing (manual oversampling)
from create_dictionary_vgg import create_dictionary
from Data_gen_vgg import DataGenerator                  # Calls the function for datageneration in case of full image
#

In [None]:
#
#-----------------------------------------------------------------------------
#                  STEP 2: DATA PRE-PROCESSING
#-----------------------------------------------------------------------------
#
# THE DATA PRE-PROCESSING STEP READ AND SPLIT DATA INTO TRAIN AND VALID SETS
# (IF TRAIN AND VALID ALREADY EXISTS IN THE PATH--> Put split_flg='nosplit'
# 
#******************************************************************************
# USER INSTRUCTION:- THIS STEP WILL REPLACE THE EARLIER TRAIN AND VALID FILES
# (If split_flg = 'split')
#******************************************************************************
#
counter_tr, counter_vl = read_and_split(config.BASE_PATH,config.CSV_NOTSPLIT,config.split_type,
                                        config.split_ID,config.train_ratio,config.validation_ratio,
                                        config.split_flg)
#
#
#--------------------------------------------------------------------------------
# FUNCTION FOR RANDOM OVERSAMPLING OF WEAKER POPULATION (IF UNBALANCED)
#--------------------------------------------------------------------------------
#
user_input_balancing(counter_tr, counter_vl)
#

In [None]:
#
#-----------------------------------------------------------------------------
#                          STEP 3: DATA GENERATION
#-------------------------------------------------------------------------------
#
# THIS IS A CUSTOM DATA GENERATION STEP
# 
params = {
    'batch_size': config.BATCH_SIZE,                  # DEFINE BATCH SIZE
    'dim': (config.IMAGE_SIZE, config.IMAGE_SIZE),    # DEFINE IMAGE WIDTH AND HEIGHT
    'n_channels': config.IMAGE_CHANNELS,              # DEFINE THE INPUT CHANNELS TO THE NETWORK
    'n_classes': config.CLASS_NUM,                    # NUMBER OF CLASSES --> (2)
    'shuffle': config.SHUFF,                          # SHUFFLE DATASET EACH EPOCH
    'augmentation':config.AGUMENT_METH,               # DATA AGUMENTATION TYPE
    'imgsize':config.IMAGE_SIZE}                      # IMAGE SIZE
#
#---------------------------------------------------------------
# CREATING DICTIONARY FILES FOR TRAINING AND VALIDATION SET
#---------------------------------------------------------------
#
partition_tr, labels_tr, impath_tr = create_dictionary(config.train_filename, config.BASE_PATH, 'random')
partition_vl, labels_vl, impath_vl = create_dictionary(config.valid_filename, config.BASE_PATH, 'random') 
#
#-------------------------------------------------------------------------------
# DATA GENERATORS FOR TRAINING AND VALIDATION SETS
#-------------------------------------------------------------------------------
training_generator = DataGenerator(partition_tr, labels_tr, impath_tr, **params)
validation_generator = DataGenerator(partition_vl, labels_vl, impath_vl, **params)
#
#

In [None]:
#
#-----------------------------------------------------------------------------
#                          STEP 3A: DISPLAY DATA GENERATION FILE
#-------------------------------------------------------------------------------
# Visualize the train/valid datagenrator images 
#
plt.figure(figsize=(15, 15));
idx=0
for X, Y in validation_generator:
    for i in range(9): # range(9) only work since batch size is 10
        plt.subplot(3,3,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow((X[i]), cmap='gray')
        #
    #
    idx=idx+1
    #
    if idx > 9:
        break
    #
#

In [None]:
#
#******************************************************************
#          STEP 4: CREATE A CLASSIFICATION CNN MODEL
#******************************************************************
#
base_model = tf.keras.applications.vgg16.VGG16(weights = "imagenet", include_top=False, 
                                               input_shape = (config.IMAGE_SIZE, config.IMAGE_SIZE, 
                                                              config.IMAGE_CHANNELS), classes = config.CLASS_NUM)
# 
# FROZEN LAYER NOT TRAINABLE
for layer in base_model.layers[:config.FROZEN_LAYERS]:
    layer.trainable = False
#
# Flatten the output
x = base_model.output
flatten = tf.keras.layers.Flatten()(x)
# 
softmaxHead = tf.keras.layers.Dense(512, activation="relu", activity_regularizer=tf.keras.regularizers.l1_l2(l1=1e-5, l2=1e-5))(flatten)
softmaxHead = tf.keras.layers.Dropout(0.5)(softmaxHead)
softmaxHead = tf.keras.layers.Dense(512, activation="relu", activity_regularizer=tf.keras.regularizers.l1_l2(l1=1e-5, l2=1e-5))(softmaxHead)
softmaxHead = tf.keras.layers.Dropout(0.5)(softmaxHead)
# 
output_layer = tf.keras.layers.Dense(2, activation="softmax", activity_regularizer=tf.keras.regularizers.l1_l2(l1=1e-5, l2=1e-5))(softmaxHead)
#
# Construct the model
model = tf.keras.Model(inputs=[base_model.input], outputs =[output_layer])
#

In [None]:
#
# Displays the model
#
model.summary()

In [None]:
#***********************************************************************
#           STEP 5: COMPILE THE MODEL AND SHOW SUMMARY
#***********************************************************************
#
opt = tf.keras.optimizers.Adam(learning_rate=config.INIT_LR)        # Adam optimizer
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy", tf.keras.metrics.AUC()])
#

In [None]:
#
#----------------------------------------------------------------------------------------
#                         STEP 6: TRAINING
#----------------------------------------------------------------------------------------
# USER DEFINED CALL BACK AND CHECK POINT FUNCTIONS
# 1. CALL BACK: STOPS IF VAL_LOSS DOES NOT MINIMIZES FOR 7 CONSECUTIVE EPOOCHS
# 2. CHECKPOINT: SAVES THE BEST MODEL WITH MAXIMUM VALIDATION ACCURACY
# 3. CALL BACK OPTIONS: REDUCE THE LEARNING RATE IF VALID_LOSS DID NOT REDUCE 7 CONSEQUTIVE EPOOCHS
checkpoint_filepath = 'vgg_models/fold1/'+ 'model_' + config.h5file_name + '_{epoch:02d}_acc_{val_accuracy:.3f}_loss_{val_loss:.3f}_AUC_{val_auc:.3f}.h5'
#
my_callbacks = [
    tfa.callbacks.TQDMProgressBar(),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=7, verbose=1, restore_best_weights=False),
    tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(config.BASE_OUTPUT, checkpoint_filepath), 
                                       monitor='val_loss', verbose=1, mode='min', 
                                       save_best_only=True, save_freq="epoch")]
#
start = time.time()
#
train_history = model.fit(training_generator, validation_data=validation_generator,
                          epochs = config.NUM_EPOCHS, callbacks=[my_callbacks],
                          verbose = 0, initial_epoch=0, workers=4, use_multiprocessing=True)
#
stop = time.time()
#
print(f"Training time: {stop - start}s")
#

In [None]:
#-------------------------------------------------------------------------------------------
#
#                  STEP 7: STORING THE TRAINING DATA
#-------------------------------------------------------------------------------------------
# STORES THE DETAILS IN A FILE
my_acc_train = train_history.history['accuracy']  # Stores the training accuracy
my_loss_train = train_history.history['loss']     # Stores the training loss
my_acc_valid = train_history.history['val_accuracy']   # Stores the validation accuracy
my_loss_valid = train_history.history['val_loss']      # Stores the validation loss
#

In [None]:
#------------------------------------------------------------------------------------
#                  STEP 8: PLOT THE ACCURACY AND LOSS WITH EPOCH
#-----------------------------------------------------------------------------------
#
if config.plot_ACC_LOSS=='Y':
    #------------------------------------------------------------------------------------------
    #                         STEP 7: PLOTTING STEPS
    #                         (THIS STEP IS OPTIONAL)
    fig_name='PLOT_ACC_LOSS_' + config.h5file_name + '.png' 
    #
    def plot_acc_loss(train_history, epochs):
        acc = train_history.history['accuracy']
        loss = train_history.history['loss']
        val_acc = train_history.history['val_accuracy']
        val_loss = train_history.history['val_loss']
        plt.figure(figsize=(15, 5))
        plt.subplot(121)
        plt.plot(range(1,epochs), acc[1:], label='Train_acc')
        plt.plot(range(1,epochs), val_acc[1:], label='Val_acc')
        plt.title('Accuracy over ' + str(epochs) + ' Epochs', size=15)
        plt.xlabel('Epochs', size = 14)
        plt.ylabel('Accuracy', size = 14)
        plt.legend()
        plt.grid(True)
        plt.subplot(122)
        plt.plot(range(1,epochs), loss[1:], label='Train_loss')
        plt.plot(range(1,epochs), val_loss[1:], label='Val_loss')
        plt.title('Loss over ' + str(epochs) + ' Epochs', size=15)
        plt.xlabel('Epochs', size = 14)
        plt.ylabel('Loss', size = 14)
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(config.BASE_OUTPUT, fig_name))
        return()
    # CALL THE PLOT FUNCTION
    plot_acc_loss(train_history, config.NUM_EPOCHS)  # Specify the number of epoch (since training stopped earlier)
    #
elif config.plot_ACC_LOSS=='N':
    pass
#
#------------------------
# END OF PROGRAM
#-----------------------