In [1]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model 
# concatenate: to concatonate features
# skip connection: allows for use of features of early layers in the encoder 

## Convolution Segment:
- Two 3 by 3 convolutions 
- Batchnormalisation + ReLU 

In [2]:
# down convolutions 
def ConvSegment(inputs, filterNum): # accepts image or feature map
    
    # 1st conv layer
    x = Conv2D(filterNum, 3, padding="same")(inputs) # convolution
        # 3: 3, 3 kernel 
        # same padding ensures that x and skipFeatures is always the same
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    # 2nd conv layer
    x = Conv2D(filterNum, 3, padding="same")(x) 
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    return x # input for next block

## Encoder Segment:
- Conv segment 
- 2 by 2 Max pooling layer 

In [3]:
# up-sampling
def EncoderSegment(inputs, filterNum):
    x = ConvSegment(inputs, filterNum)
    p = MaxPool2D((2, 2))(x) 
    
    return x, p
        # x = skip connection
        # p = output of encoder block

## Decoder Segment:
- 2, 2 Transpose conv  
- Concatenation with the respective skip connection

In [4]:
def DecoderSegment(inputs, skipFeatures, filterNum): # accepts x-outup from EncoderSegment
    x = Conv2DTranspose(filterNum, (2, 2), strides=2, padding="same")(inputs) 
        # feature map size=256, 256, 128 
    x = Concatenate()([x, skipFeatures]) # makes list, shape of both must be same
    x = ConvSegment(x, filterNum)
    
    return x

## Full U-NET-like Architecture:

In [8]:
def FullUNET(inputShape):
    inputs = Input(inputShape)
    
    """" Encoder Segment """
    skipCon1, pool1 = EncoderSegment(inputs, 64)
        # 64: feature channels (filters)
    skipCon2, pool2 = EncoderSegment(pool1, 128)
    skipCon3, pool3 = EncoderSegment(pool2, 256)
    skipCon4, pool4 = EncoderSegment(pool3, 512)
    
    
    """" Bridge between Segments """
    bridge1 = ConvSegment(pool4, 1024)
    
    """" Decoder Segment """
    decoderOut1 = DecoderSegment(bridge1, skipCon4, 512)
    decoderOut2 = DecoderSegment(decoderOut1, skipCon3, 256)
    decoderOut3 = DecoderSegment(decoderOut2, skipCon2, 128)
    decoderOut4 = DecoderSegment(decoderOut3, skipCon1, 64)
        # upsamples the size of the feature map
    
    """" Output """
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(decoderOut4)
        # num of feat. channels = 1 (one class = the area of the lungs)
        
    """" Model """
    model = Model(inputs, outputs, name="U-NET")
    return model
        
# runs model structure (summary) 
if __name__ == "__main__":
    inputShape = (512, 512, 3)
    model = FullUNET(inputShape)
    model.summary()

Model: "U-NET"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 512, 512, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 512, 512, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

                                                                                                  
 conv2d_8 (Conv2D)              (None, 32, 32, 1024  4719616     ['max_pooling2d_3[0][0]']        
                                )                                                                 
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 1024  4096       ['conv2d_8[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_8 (Activation)      (None, 32, 32, 1024  0           ['batch_normalization_8[0][0]']  
                                )                                                                 
                                                                                                  
 conv2d_9 

                                                                                                  
 activation_15 (Activation)     (None, 256, 256, 12  0           ['batch_normalization_15[0][0]'] 
                                8)                                                                
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 512, 512, 64  32832      ['activation_15[0][0]']          
 spose)                         )                                                                 
                                                                                                  
 concatenate_3 (Concatenate)    (None, 512, 512, 12  0           ['conv2d_transpose_3[0][0]',     
                                8)                                'activation_1[0][0]']           
                                                                                                  
 conv2d_16