# ColonSegNet

<br/>

<span style="font-size: 18px; line-height: 25px;">
Research Paper: <a href="https://github.com/DebeshJha/ColonSegNet/blob/main/access.pdf"> Real-Time Polyp Detection, Localization and Segmentation in Colonoscopy Using Deep Learning </a>

<br/>
<ul>
    <li> ColonSegNet is an encoder-decoder architecture developed for the purpose of polyp segmentation. </li>
    <li> It uses Residual block with Squeeze and Excitation as the main component. </li> 
</ul>
    
</span>

<img src="images/ColonSegNet.png">

## Import

In [1]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Dense
from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.models import Model

## Sequeeze and Excitation
<img src="images/squeeze_and_excitation_detailed_block_diagram.png" style="width: 500px;">

In [2]:
def se_layer(x, num_filters, reduction=16):
    x_init = x
    
    x = GlobalAveragePooling2D()(x)
    x = Dense(num_filters//reduction, use_bias=False, activation="relu")(x)
    x = Dense(num_filters, use_bias=False, activation="sigmoid")(x)
    
    return x_init * x

## Residual Block
<img src="images/ResidualBlock.png">

In [3]:
def residual_block(x, num_filters):
    x_init = x
    
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    
    s = Conv2D(num_filters, 1, padding="same")(x_init)
    s = BatchNormalization()(x)
    s = se_layer(s, num_filters)
    
    x = Activation("relu")(x + s)
    
    return x

## Strided Convolution
<img src="images/Strided_Conv_Block.png">

In [4]:
def strided_conv_block(x, num_filters):
    x = Conv2D(num_filters, 3, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

## Encoder Block

In [5]:
def encoder_block(x, num_filters):
    x1 = residual_block(x, num_filters)
    x2 = strided_conv_block(x1, num_filters)
    x3 = residual_block(x2, num_filters)
    p = MaxPool2D((2, 2))(x3)
    
    return x1, x3, p

## ColonSegNet

In [12]:
def build_colonsegnet(input_shape):
    """ Input """
    inputs = Input(input_shape)
    
    """ Encoder """
    s11, s12, p1 = encoder_block(inputs, 64)
    s21, s22, p2 = encoder_block(p1, 256)
    
    """ Decoder 1 """
    x = Conv2DTranspose(128, 4, strides=4, padding="same")(s22)
    x = Concatenate()([x, s12])
    x = residual_block(x, 128)
    r1 = x
    
    x = Conv2DTranspose(128, 4, strides=2, padding="same")(s21)
    x = Concatenate()([x, r1])
    x = residual_block(x, 128)
    
    """ Decoder 2 """
    x = Conv2DTranspose(64, 4, strides=2, padding="same")(x)
    x = Concatenate()([x, s11])
    x = residual_block(x, 64)
    r2 = x
    
    x = Conv2DTranspose(64, 4, strides=2, padding="same")(s12)
    x = Concatenate()([x, r2])
    x = residual_block(x, 32)
    
    """ Output """
    output = Conv2D(5, 1, padding="same", activation="softmax")(x)
    
    """ Model """
    model = Model(inputs, output, name="ColonSegNet")
    
    return model

## Model

In [13]:
input_shape = (512, 512, 3)
model = build_colonsegnet(input_shape)

In [14]:
model.summary()

Model: "ColonSegNet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_54 (Conv2D)              (None, 512, 512, 64) 1792        input_3[0][0]                    
__________________________________________________________________________________________________
batch_normalization_52 (BatchNo (None, 512, 512, 64) 256         conv2d_54[0][0]                  
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 512, 512, 64) 0           batch_normalization_52[0][0]     
________________________________________________________________________________________