In [1]:
import numpy as np 
import os
import cv2
import skimage.io as io
import skimage.transform as trans
import numpy as np
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.losses import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras import backend as keras
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras.models import model_from_yaml
import tensorflow as tf 

from __future__ import print_function
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np 
import os
import glob
import cv2
import glob
import itertools
import skimage.io as io
import skimage.transform as trans
from tensorflow.keras.initializers import Constant
from sklearn.model_selection import train_test_split


from matplotlib import pyplot as plt
%matplotlib inline

from skimage.morphology import disk
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc
#from sklearn.metrics import jaccard_similarity_score

## cpu configuration

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

## the proposed network


In [3]:
'''
The implementation details with the explanation can be found on the paper. 
Please, read the paper first then see the source code.
'''
def SIMOCNN(nClasses, input_height, input_width):
    
    '''
    Load VGG16 from keras and initialize with the ImageNet. The output of the VGG16
    which is pool5 is used as input to all the sub-networks. This part of the network
    is responsible for the feature extraction which is so called convolution part 
    of the semantic segmentation CNN. 
    '''
    
    # defining the Input shape where channel 3 means RGB.  
    img_input = Input(shape=(input_height, input_width, 3)) 

    vgg_Base = VGG16(weights = 'imagenet',
                     include_top = False,
                     input_tensor = img_input) 
    
    '''
    To overcome the sub-sampling limitations  and  deconvolution  overlap, we  have
    employed two types of skip connections. First one is between the corresponding
    same dimensional feature map in both encoder and decoder which  has  ladder like
    structure and  it  is  inspired  from  U-Net. Second  one, so called FrG connect
    the very end layer of the decoder with the original image via stack of depthwise
    separable  convolution  without sub-sampling  to  produce  the  fully resolution
    feature map. Second skip connection is the compensatory of losing spacial 
    information dueto sub-sampling by concatenating the full resolution feature map.  
    In the presented SIMO-DCNN network, both  the  segmentation  and  regression  
    sub-networks  follow  our proposedencoder-decoder  networks  to  get  full  
    resolution  features  map.
    '''
    
    
    FrG = SeparableConv2D(filters = 64,
                          kernel_size = (3, 3), 
                          activation = 'relu',
                          kernel_initializer='glorot_uniform',
                          padding="same")(img_input)
    FrG = BatchNormalization()(FrG)
    
    FrG = SeparableConv2D(filters = 256,
                          kernel_size = (3, 3),
                          activation = 'relu', 
                          kernel_initializer='glorot_uniform', 
                          padding="same")(FrG)
    FrG = BatchNormalization()(FrG)
        
    FrG = SeparableConv2D(filters = 64,
                          kernel_size = (3, 3), 
                          activation = 'relu',
                          kernel_initializer='glorot_uniform',
                          padding="same")(FrG)
    FrG = BatchNormalization()(FrG)
    
    FrG = SeparableConv2D(filters = nClasses,
                          kernel_size = (3, 3), 
                          activation = 'relu',
                          kernel_initializer='glorot_uniform',
                          padding="same")(FrG)
    FrG = BatchNormalization()(FrG)
    
    
    '''
    Decoding the encoded output to semantically project the discriminating features 
    of  lower  resolution learnt  by  the  encoder  onto  the  pixel space  of 
    higher  resolution  to  get a dense pixel wise classification.
    '''
    
    
    conv_14 = SeparableConv2D(filters = 1024, 
                            kernel_size = (3, 3), 
                            activation = 'relu', 
                            kernel_initializer='glorot_uniform', 
                            padding="same")(vgg_Base.output)
    conv_14 = BatchNormalization()(conv_14)


    conv_15 = SeparableConv2D(filters = 1024, 
                              kernel_size = (3, 3), 
                              activation = 'relu', 
                              kernel_initializer='glorot_uniform', 
                              padding="same")(conv_14)
    conv_15 = BatchNormalization()(conv_15)

    
    deconv_1 = UpSampling2D(size = (2, 2))(conv_15)
    deconv_1 = concatenate([vgg_Base.get_layer(name="block4_pool").output,
                            deconv_1], axis=-1)
    deconv_1 = SeparableConv2D(filters = 512, 
                               kernel_size = (3, 3), 
                               activation = 'relu', 
                               kernel_initializer='glorot_uniform', 
                               padding = "same")(deconv_1)
    deconv_1 = BatchNormalization()(deconv_1)


    deconv_2 = UpSampling2D(size = (2, 2))(deconv_1)
    deconv_2 = concatenate([vgg_Base.get_layer(name="block3_pool").output,
                            deconv_2], axis=-1)
    deconv_2 = SeparableConv2D(filters = 256,
                               kernel_size = (3, 3),
                               activation = 'relu',
                               kernel_initializer='glorot_uniform',
                               padding = "same")(deconv_2)
    deconv_2 = BatchNormalization()(deconv_2)


    deconv_3 = UpSampling2D( size = (2, 2))(deconv_2)
    deconv_3 = concatenate([vgg_Base.get_layer(name="block2_pool").output,
                            deconv_3], axis=-1)
    deconv_3 = SeparableConv2D(filters = 128,
                               kernel_size = (3, 3),
                               activation = 'relu',
                               kernel_initializer='glorot_uniform',
                               padding = "same")(deconv_3)
    kept = BatchNormalization()(deconv_3)
    
    
    

    '''
    DBRS blocks named as segmentation sub-network for semantic tissue or
    instrument pixels labelling to get semantic segmentation of the surgical tool.
    '''
    
    tool = UpSampling2D(size = (2, 2))(kept)
    tool = concatenate([vgg_Base.get_layer( name="block1_pool").output, 
                        tool], axis=-1)
    
    tool = SeparableConv2D(filters = 64,
                           kernel_size = (3, 3), 
                           activation = 'relu',
                           kernel_initializer='glorot_uniform',
                           padding = "same")(tool)
    tool = BatchNormalization()(tool)

    tool = UpSampling2D(size = (2, 2))(tool)
    tool = SeparableConv2D(filters = 64, 
                           kernel_size = (3, 3), 
                           activation = 'relu',
                           kernel_initializer='glorot_uniform', 
                           padding = "same")(tool)
    tool = BatchNormalization()(tool)

    tool = SeparableConv2D(filters = nClasses,
                           kernel_size = (1, 1),
                           activation = 'relu',
                           kernel_initializer='glorot_uniform',
                           padding = "same")(tool)
    tool = BatchNormalization()(tool)

    tool = concatenate([tool, FrG], axis=-1)

    tool = Conv2D(filters = 1,
                  kernel_size = 1,
                  activation = 'sigmoid',
                  name='tool')(tool)
 
    modeltool = Model(inputs = img_input, outputs = tool)    
    modeltool.load_weights('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/pretrainedRobotic_editedFinal2.h5')
    
    
    
    '''
    DBR blocks named as regression sub-network for mid-line feature of the surgical
    tool for pose estimation. 
    '''

    midline = UpSampling2D(size = (2, 2))(kept)
    midline = concatenate([vgg_Base.get_layer( name="block1_pool").output, 
                           midline], axis=-1)
    
    midline = SeparableConv2D(filters = 64,
                              kernel_size = (3, 3),
                              activation = 'relu',
                              kernel_initializer='glorot_uniform',
                              padding = "same")(midline)
    midline = BatchNormalization()(midline)

    midline = UpSampling2D(size = (2, 2))(midline)
    midline = SeparableConv2D(filters = 64, 
                              kernel_size = (3, 3),activation = 'relu',
                              kernel_initializer='glorot_uniform',
                              padding = "same")(midline)
    midline = BatchNormalization()(midline)

    midline = SeparableConv2D(filters = nClasses, 
                              kernel_size = (1, 1),
                              activation = 'relu',
                              kernel_initializer='glorot_uniform',
                              padding = "same")(midline)
    midline = BatchNormalization()(midline)

    midline = concatenate([midline, FrG], axis=-1)

    midline = Conv2D(filters = 1,
                     kernel_size = 1,
                     activation = 'sigmoid',
                     name='midline')(midline)
 
    modelmidline = Model(inputs = img_input, outputs = midline)    
    modelmidline.load_weights('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/pretrainedRobotic_editedFinal2.h5')
 


    '''
    DBR blocks named as regression sub-network for tool-tip feature of the surgical
    tool for pose estimation. 
    '''

    tooltip = UpSampling2D(size = (2, 2))(kept)
    tooltip = concatenate([vgg_Base.get_layer( name="block1_pool").output,
                           tooltip], axis=-1)
    
    tooltip = SeparableConv2D(filters = 64,
                              kernel_size = (3, 3),
                              activation = 'relu',
                              kernel_initializer='glorot_uniform',
                              padding = "same")(tooltip)
    tooltip = BatchNormalization()(tooltip)

    tooltip = UpSampling2D(size = (2, 2))(tooltip)
    tooltip = SeparableConv2D(filters = 64,
                              kernel_size = (3, 3),
                              activation = 'relu',
                              kernel_initializer='glorot_uniform',
                              padding = "same")(tooltip)
    tooltip = BatchNormalization()(tooltip)

    tooltip = SeparableConv2D(filters = nClasses,
                              kernel_size = (1, 1),
                              activation = 'relu', 
                              kernel_initializer='glorot_uniform', 
                              padding = "same")(tooltip)
    tooltip = BatchNormalization()(tooltip)

    tooltip = concatenate([tooltip, FrG], axis=-1)

    tooltip = Conv2D(filters = 1,
                     kernel_size = 1,
                     activation = 'sigmoid',
                     name='tooltip')(tooltip)
 
    modeltooltip = Model(inputs = img_input, outputs = tooltip)    
    modeltooltip.load_weights('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/pretrainedRobotic_editedFinal2.h5')
    
    
    '''
    DBR blocks named as regression sub-network for Edge-line feature of the surgical
    tool for pose estimation. 
    '''

    edgeline = UpSampling2D(size = (2, 2))(kept)
    edgeline = concatenate([vgg_Base.get_layer( name="block1_pool").output,
                            edgeline], axis=-1)
    
    edgeline = SeparableConv2D(filters = 64,
                               kernel_size = (3, 3),
                               activation = 'relu',
                               kernel_initializer='glorot_uniform',
                               padding = "same")(edgeline)
    edgeline = BatchNormalization()(edgeline)

    edgeline = UpSampling2D(size = (2, 2))(edgeline)
    edgeline = SeparableConv2D(filters = 64,
                               kernel_size = (3, 3),
                               activation = 'relu',
                               kernel_initializer='glorot_uniform',
                               padding = "same")(edgeline)
    edgeline = BatchNormalization()(edgeline)

    edgeline = SeparableConv2D(filters = nClasses,
                               kernel_size = (1, 1),
                               activation = 'relu',
                               kernel_initializer='glorot_uniform',
                               padding = "same")(edgeline)
    edgeline = BatchNormalization()(edgeline)

    edgeline = concatenate([edgeline, FrG], axis=-1)

    edgeline = Conv2D(filters = 1,
                      kernel_size = 1,
                      activation = 'sigmoid',name='edgeline')(edgeline)
 
    modeledgeline = Model(inputs = img_input, outputs = edgeline)    
    modeledgeline.load_weights('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/pretrainedRobotic_editedFinal2.h5')


    '''
    Detection sub-network for getting the tool flag that will indicate 
    either pose will estimate or not? 
    '''
    x = modeltool.get_layer('block5_conv3').output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(2, activation='softmax',name='detection')(x)
    
    modeldetection = Model(inputs=img_input, outputs=x)
    

    '''
    SIMO model is build by returning the multiple output as a list variable. 
    In the output prediction the sequence of the outputs are as follows:
    
    Output[0]= Predicted Probabilty map fo the surgical tool segmentation.
    Output[1]= Predicted Regression map for the mid-line of the surgical tool.
    Output[2]= Predicted Regression map for the tool-tip of the surgical tool.
    Output[3]= Predicted Regression map for the edge-line of the surgical tool.
    Output[4]= Predicted softmax probability of the tool detection.
    '''
      
    SIMO = Model(inputs = img_input, outputs = [modeltool.output,
                                              modelmidline.output,
                                              modeltooltip.output,
                                              modeledgeline.output,
                                              modeldetection.output])
    
#     for layer in SIMO.layers[:40]:
#         layer.trainable = False    

    return SIMO

## Data Generator 

In [4]:
def DataGenerator(data, batch_size): 
    '''
    This function is for the data generator for fit_generator training of the SIMO
    model. 
    Input Argument: 
          data = Is the numpy array having 6 column and N row. N is the number of 
          images as training/ testing sample. It can be pressented as below-
          
          orgImage | tool mask | edge-line | mid-line | tool-tip | label |
          ---------|-----------|-----------|----------|----------|-------|
          
          ---------|-----------|-----------|----------|----------|-------|
          batch_size = the number of samples that will perform forward-backward pass
                       in a single shot.
    Output Argument:
        Tuple of the true images and corresponding mask/ label for each types
        of sub-network. 
    '''
    
    img = np.array([i[0] for i in data]).reshape(-1,192,256,3)
    mask = np.array([i[1] for i in data]).reshape(-1,192,256,1)
    edge = np.array([i[2] for i in data]).reshape(-1,192,256,1)
    mid  = np.array([i[3] for i in data]).reshape(-1,192,256,1)
    tip = np.array([i[4] for i in data]).reshape(-1,192,256,1)
    label = np.array([i[5] for i in data])
    
    label=to_categorical(label, num_classes=2, dtype='float32')

    zipped = itertools.cycle( zip(img, mask, mid, edge, tip, label))
    

    while True:
        X = [] 
        Y = []
        Z = []
        A = []
        C = []   
        D = []
        for _ in range( batch_size):
            im , sg, sg_mid, sg_edge, sg_tip, lab = next(zipped)
            X.append(im)
            Y.append(sg)
            Z.append(sg_mid)
            A.append(sg_edge)
            C.append(sg_tip)
            D.append(lab)
            
        yield (np.array(X) , {'tool':np.array(Y),
                              'midline':np.array(Z),
                              'tooltip':np.array(C),
                              'edgeline':np.array(A),
                              'detection':np.array(D)})

## Loss Functions 

In [5]:
from tensorflow.keras import backend as K

def IoU(y_true, y_pred):
        
    ''' 
    The Intersection over Union (IoU) also referred to as the Jaccard index (JI),
    is essentially a method to quantify the percent overlap between the GT mask
    and prediction output. The IoU metric measures the number of pixels common 
    between the target and prediction masks divided by the total number of pixels
    present across both masks.
  
    Input Arguments: 
        y_true: True Labels of the 2D images so called ground truth (GT).
        y_pred: Predicted Labels of the 2D images so called Predicted/ segmented Mask.
        
    Output Arguments: 

        iou: The IoU between y_true and y_pred

    Author: Md. Kamrul Hasan, 
            Erasmus Scholar on Medical Imaging and Application (MAIA)
            E-mail: kamruleeekuet@gmail.com

    '''
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection) / (K.sum(y_true_f) + K.sum(y_pred_f)-intersection)


def IoU_loss(y_true, y_pred):
    return 1-IoU(y_true, y_pred)


def bce_IoU_loss(y_true, y_pred):
    return (binary_crossentropy(y_true, y_pred) + IoU_loss(y_true, y_pred))


## Train model 

In [None]:
CurrentDirectory=os.getcwd()


datatrain = np.load('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/DataFinal/training_data.npy',allow_pickle = True)
x_train, x_valid = train_test_split(datatrain,test_size=0.15, shuffle= True)
print(x_train.shape, x_valid.shape)

TrainGen = DataGenerator(data= x_train, batch_size=5)
TestGen = DataGenerator(data= x_valid, batch_size=5)


#datatest = np.load('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/DataFinal/testing_data.npy',allow_pickle = True)


model = SIMOCNN(2, 192, 256)
#modelSavePath = '/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/ModelLR_1/FinalModel.hdf5'


plot_model(model, show_shapes=True, to_file='Graph of ART-Net.png')
model.summary()

optim = Adadelta(learning_rate=1)

model.compile(optimizer = optim, 
              loss = {'tool':bce_IoU_loss,
                      'midline':'mean_squared_error',
                      'tooltip':'mean_squared_error',
                      'edgeline':'mean_squared_error',
                      'detection':'categorical_crossentropy',}, 
              metrics = {'tool': IoU,
                         'midline':'mae',
                         'tooltip':'mae',
                         'edgeline':'mae',
                         'detection': 'acc'})

checkpoint_path = '/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/FullTraining/ModelLR1/Checkpoint'
modelSavePath = '/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/FullTraining/ModelLR1/fullmodel26-5-2021.h5'
    
model_checkpoint = [
tf.keras.callbacks.ModelCheckpoint(checkpoint_path,save_best_only=True,save_weights_only=True, save_freq = 'epoch' ),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, verbose=1, patience=20, mode='min'), ## new_lr = lr * factor # 5
tf.keras.callbacks.EarlyStopping(monitor='val_loss', verbose=1, patience=40, mode='min', restore_best_weights=True),    
tf.keras.callbacks.CSVLogger('/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/FullTraining/ModelLR1/training26-5-2021.csv'),
tf.keras.callbacks.TensorBoard(log_dir='/home/mahmoud/Desktop/laparoscopic-Tools-Segmentation/Models/FullTraining/ModelLR1/logs',write_graph=True),
tf.keras.callbacks.TerminateOnNaN()]
        

#model_checkpoint = ModelCheckpoint('FineTunedmodel.hdf5', 
#                                   monitor='val_loss', 
#                                   verbose=1, 
#                                   save_best_only=True)





model_yaml = model.to_yaml()
with open('modelSaved.yaml', "w") as yaml_file:
    yaml_file.write(model_yaml)

history=model.fit_generator(TrainGen, 
                             steps_per_epoch=133, 
                             epochs=150,                  
                             verbose=1, 
                             validation_data= TestGen , 
                             validation_steps=31,
                             callbacks=[model_checkpoint])

model.save(modelSavePath)

(1125, 6) (199, 6)
('Failed to import pydot. You must `pip install pydot` and install graphviz (https://graphviz.gitlab.io/download/), ', 'for `pydotprint` to work.')
Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 192, 256, 3) 0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 192, 256, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 192, 256, 64) 36928       block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_pool (MaxPooling2



Epoch 1/150
Epoch 2/150

In [None]:
# # Plot training & validation loss values
plt.figure()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss Total')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train_total','val_total'], loc='upper right')
plt.grid('on')
plt.show()


plt.figure()
plt.plot(history.history['tool_loss'])
plt.plot(history.history['val_tool_loss'])
plt.title('Model loss for the Tool Segmentation')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train_tool_loss', 'val_tool_loss'], loc='upper right')
plt.grid('on')
plt.show()


plt.figure()
plt.plot(history.history['midline_loss'])
plt.plot(history.history['val_midline_loss'])
plt.title('Model loss for the Midline line prediction')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train_midline_loss','val_midline_loss'], loc='upper right')
plt.grid('on')
plt.show()


plt.figure()
plt.plot(history.history['tooltip_loss'])
plt.plot(history.history['val_tooltip_loss'])
plt.title('Model loss for the tool tip point prediction')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train_tooltip_loss','val_tooltip_loss'], loc='upper right')
plt.grid('on')
plt.show()


plt.figure()
plt.plot(history.history['edgeline_loss'])
plt.plot(history.history['val_edgeline_loss'])
plt.title('Model loss for the edge line prediction')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train_edgeLine_loss','val_edgeLine_loss'], loc='upper right')
plt.grid('on')
plt.show()


plt.figure()
plt.plot(history.history['detection_loss'])
plt.plot(history.history['val_detection_loss'])
plt.title('Model loss for the detection ')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train_detection_loss','val_detection_loss'], loc='upper right')
plt.grid('on')
plt.show()