In [6]:
import os
from pathlib import Path

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense, Conv3D, Dropout, MaxPooling3D, UpSampling3D, Activation, BatchNormalization, PReLU, Conv3DTranspose, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping

from config import *
# force channels-last ordering
tf.keras.backend.set_image_data_format('channels_last')
print(tf.keras.backend.image_data_format())

channels_last


In [25]:
run_eagerly=False
fold=0
class cascade_model:
    def __init__(self, options):
        self.channels = len(options['modalities'])
        self.train_split_perc = options['train_split']
        self.num_epochs = options['max_epochs']
        self.max_epochs_patience = options['patience']
        self.verbose=options['net_verbose']
        # save model to disk to re-use it. Create an experiment folder
        # organize experiment
        self.nets_path=os.path.join(options['weight_paths'], options['experiment'], 'nets')
        self.Path(nets_path).mkdir(parents=True, exist_ok=True)

        self.padding='same'
        self.strides=(1,1,1)
        self.shape=(*options['patch_size'],channels)
        self.pooling_kernel=(2,2,2)
        self.pooling_strides=(2,2,2)
        self.drop_rate=0.5
        self.objective_loss_function=tf.keras.losses.CategoricalCrossentropy
        self.load_weights=options['load_weights']
        
 
    def network_layers(self,n="1"):
        input_layer=Input(shape=self.shape, name='in'+n)

        conv1= Conv3D(32, 3, padding=self.padding, strides=self.strides, name='conv1_'+n)(input_layer)
        batch_norm1=BatchNormalization(axis=-1,name = 'BN1'+n)(conv1)
        pool1= MaxPooling3D(pool_size=self.pooling_kernel,strides=self.pooling_strides, name='avgpool1_'+n)(batch_norm1)

        conv2= Conv3D(64, 3, padding=self.padding, strides=self.strides, name='conv2_'+n)(pool1)
        batch_norm2=BatchNormalization(axis=-1,name = 'BN2_'+n)(conv2)
        pool2= MaxPooling3D(pool_size=self.pooling_kernel,strides=self.pooling_strides, name='avgpool2_'+n)(batch_norm2)

        dr = Dropout(name = 'l2drop', rate=self.drop_rate)(pool2)
        dens1 = Dense( name='d1_'+n, units = 256)(dr)
        dens2 = Dense( name = 'out', units = 2)(dens1)
        act=Activation('softmax')(dens2)

        model = Model(inputs=input_layer, outputs=act)
        return model
    
    
    def get_callbacks(self,weights_file_path, fold, initial_learning_rate=0.0001, learning_rate_drop=0.5,
                      learning_rate_patience=50, verbosity=1, early_stopping_patience=None):

        check_point = ModelCheckpoint(weights_file_path+'fold_' + fold + '_weights-{epoch:02d}-{val_loss:.2f}.hdf5', save_best_only=True)
        csv_log = CSVLogger(weights_file_path+'training-log.csv', append=True)

        # potential problem of recude learning rate: https://github.com/keras-team/keras/issues/10924
        reduce = ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience, verbose=verbosity)
        if early_stopping_patience:
            early_stop = EarlyStopping(verbose=verbosity, patience=early_stopping_patience)
            return [check_point, csv_log, reduce, early_stop]
        else:
            return [check_point, csv_log, reduce]
    
    def get_cascade_model(self):
        
        # --------------------------------------------------
        # first model
        # --------------------------------------------------
        net_model = 'model_1'
        net1=self.network_layers(n="1")
        
        net1.compile(optimizer=Adam(lr=0.001), 
                      loss=self.objective_loss_function, metrics='mse',run_eagerly=run_eagerly)
        
        path_w1=os.path.join(self.nets_path,  net_model)
        callbacks = model.get_callbacks(path_w1, str(fold),
                                initial_learning_rate=0.001,
                                learning_rate_drop=0.1,
                                learning_rate_patience=self.max_epochs_patience,
                                early_stopping_patience=self.max_epochs_patience)

        
        #batch_iterator_train=Rotate_batch_Iterator(batch_size=128)
        history=model.fit(x=train_generator,
                  batch_size=config.batch_size,
                  epochs=self.num_epochs,
                  verbose=self.verbose,
                  callbacks=callbacks,
                  validation_data=valid_generator,
                  validation_batch_size=128,
                  workers=4)
        

        net_model = 'model_2'
        net2=self.network_layers(n="2")
        net2.compile(optimizer=Adam(lr=0.001), 
              loss=self.objective_loss_function, metrics='mse',run_eagerly=run_eagerly)
        
        path_w2=os.path.join(self.nets_path,  net_model)
        callbacks = model.get_callbacks(path_w2, str(fold),
                                initial_learning_rate=0.001,
                                learning_rate_drop=0.1,
                                learning_rate_patience=self.max_epochs_patience,
                                early_stopping_patience=self.max_epochs_patience)

            # upload weights if set
        if self.load_weights == 'True':
            print ("    --> CNN, loading weights from", options['experiment'], 'configuration')
            net1.build((None ,*self.shape))
            net1.load_weights(path_w1)
        
            net2.build((None ,*self.shape))
            net2.load_weights(path_w2)
        return [net1, net2]