In [None]:
from keras.models import Model
from keras.layers import Input, Conv3D, MaxPooling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda, Activation, UpSampling3D
from tensorflow.keras.optimizers import Adam
from keras.metrics import MeanIoU

def custom_model(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):

    kernel_initializer =  'he_normal' #Try others if you want

    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))
    
    s = inputs

    # Contraction path ------------------------------------------- ENCODER -------------------------------------------------



    #                 -------------------------------------------- Level 1 -------------------------------------------------


    conv_1 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(s)
    
    conv_1 = Conv3D(32, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(conv_1)
    BN_1 = BatchNormalization()(conv_1)
    BN_1 = Activation("relu")(BN_1)

    pool_1 = MaxPooling3D((2, 2, 2))(BN_1)

    conv_1 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(pool_1)
    


    #                 -------------------------------------------- Level 2 -------------------------------------------------


    conv_2 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(conv_1)
    
    conv_2 = Conv3D(64, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(conv_2)
    BN_2 = BatchNormalization()(conv_2)
    BN_2 = Activation("relu")(BN_2)

    pool_2 = MaxPooling3D((2, 2, 2))(BN_2)

    conv_2 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(pool_2)

    #<--- concatination (skip connection) -->

    concat_1 = MaxPooling3D(pool_size=(2, 2, 2))(BN_1) # Encoder Level 1 feature
    concat_1 = MaxPooling3D(pool_size=(2, 2, 2))(concat_1)
    
    concat_1 = concatenate([conv_2,concat_1])


    #                 ------------------------------------------- Level 3 -------------------------------------------------


    conv_3 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_1)
    
    conv_3 = Conv3D(128, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(conv_3)
    BN_3 = BatchNormalization()(conv_3)
    BN_3 = Activation("relu")(BN_3)

    pool_3 = MaxPooling3D((2, 2, 2))(BN_3)

    conv_3 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(pool_3)

    #<--- concatination (skip connection) -->

    concat_2 = MaxPooling3D(pool_size=(2, 2, 2))(BN_2) # Encoder Level 2 feature map
    concat_2 = MaxPooling3D(pool_size=(2, 2, 2))(concat_2)
    
    concat_2 = concatenate([conv_3,concat_2])



    #                 ------------------------------------------- Level 4 -------------------------------------------------


    conv_4 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_2)
    
    conv_4 = Conv3D(256, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(conv_4)
    BN_4 = BatchNormalization()(conv_4)
    BN_4 = Activation("relu")(BN_4)

    pool_4 = MaxPooling3D((2, 2, 2))(BN_4)

    conv_4 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(pool_4)

    #<--- concatination (skip connection) -->

    concat_3 = MaxPooling3D(pool_size=(2, 2, 2))(BN_3) # Encoder Level 3 feature map
    concat_3 = MaxPooling3D(pool_size=(2, 2, 2))(concat_3)
    
    concat_3 = concatenate([conv_4,concat_3])


    #                 -------------------------------------------- Bridge -------------------------------------------------
    #                 -------------------------------------------- Level 0 -------------------------------------------------


    conv_5 = Conv3D(512, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_3)
    
    conv_5 = Conv3D(512, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(conv_5)
    BN_5 = BatchNormalization()(conv_5)
    BN_5 = Activation("relu")(BN_5)

    #pool_5 = MaxPooling3D((2, 2, 2))(BN_5)

    conv_5 = Conv3D(512, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(BN_5)

    #<--- concatination (skip connection) -->

    concat_4 = MaxPooling3D(pool_size=(2, 2, 2))(BN_4) # Encoder Level 4 feature map
    #concat_4 = MaxPooling3D(pool_size=(2, 2, 2))(concat_4)
    
    concat_4 = concatenate([conv_5,concat_4])




    # Expansion path ------------------------------------------- DECODER -------------------------------------------------


    #                 -------------------------------------------- Level 1 -------------------------------------------------


    upsamp_1 = UpSampling3D(size=(2, 2, 2))(concat_4) #upsampling increases the size of feature maps 64x64 -> 128x128 

    conv_6 = Conv3D(256, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(upsamp_1)
    BN_6 = BatchNormalization()(conv_6)
    BN_6 = Activation("relu")(BN_6)


    #<--- concatination (skip connection) -->

    #features from encoder section level 4

    level_1_feature_map_1 = BN_4 #BN_4 feature map from encoder level 4

    concat_5 = concatenate([BN_6, level_1_feature_map_1]) 



    conv_6 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_5)
    
    conv_6 = Conv3D(256, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(conv_6)


    #                 -------------------------------------------- Level 2 -------------------------------------------------


    upsamp_2 = UpSampling3D(size=(2, 2, 2))(conv_6) #upsampling increases the size of feature maps 64x64 -> 128x128 

    conv_7 = Conv3D(128, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(upsamp_2)
    BN_7 = BatchNormalization()(conv_7)
    BN_7 = Activation("relu")(BN_7)

    #<--- concatination (skip connection) -->


    #features from encoder section level 3

    level_2_feature_map_1 = BN_3 #BN_3 feature map from encoder level 3

    #features from decoder section level 1

    level_2_feature_map_2 = UpSampling3D(size=(2, 2, 2))(BN_6)


    concat_6 = concatenate([BN_7,level_2_feature_map_1, level_2_feature_map_2]) 



    conv_7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_6)
    
    conv_7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(conv_7)


    #                 -------------------------------------------- Level 3 -------------------------------------------------




    upsamp_3 = UpSampling3D(size=(2, 2, 2))(conv_7) #upsampling increases the size of feature maps 64x64 -> 128x128 

    conv_8 = Conv3D(64, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(upsamp_3)
    BN_8 = BatchNormalization()(conv_8)
    BN_8 = Activation("relu")(BN_8)

    #<--- concatination (skip connection) -->


    #features from encoder section level 2

    level_3_feature_map_1 = BN_2 #BN_2 feature map from encoder level 2

    #features from decoder section level 2

    level_3_feature_map_2 = UpSampling3D(size=(2, 2, 2))(BN_7)


    concat_7 = concatenate([BN_8,level_3_feature_map_1, level_3_feature_map_2]) 



    conv_8 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_7)
    
    conv_8 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(conv_8)



    #                 -------------------------------------------- Level 4 -------------------------------------------------




    upsamp_4 = UpSampling3D(size=(2, 2, 2))(conv_8) #upsampling increases the size of feature maps 64x64 -> 128x128 

    conv_9 = Conv3D(32, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(upsamp_4)
    BN_9 = BatchNormalization()(conv_9)
    BN_9 = Activation("relu")(BN_9)

    #<--- concatination (skip connection) -->


    #features from encoder section level 1

    level_4_feature_map_1 = BN_1 #BN_1 feature map from encoder level 1

    #features from decoder section level 3

    level_4_feature_map_2 = UpSampling3D(size=(2, 2, 2))(BN_8)


    concat_8 = concatenate([BN_9,level_4_feature_map_1, level_4_feature_map_2]) 



    conv_9 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(concat_8)
    
    conv_9 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kernel_initializer, padding='same')(conv_9)




    #                 -------------------------------------------- output -------------------------------------------------

    outputs = Conv3D(num_classes, (1, 1, 1), activation='softmax')(conv_9)


    model = Model(inputs=[inputs], outputs=[outputs])
    

    return model

model = custom_model(240, 240, 144, 4, 6)
print(model.summary())

Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_8 (InputLayer)           [(None, 240, 240, 1  0           []                               
                                44, 4)]                                                           
                                                                                                  
 conv3d_131 (Conv3D)            (None, 240, 240, 14  3488        ['input_8[0][0]']                
                                4, 32)                                                            
                                                                                                  
 conv3d_132 (Conv3D)            (None, 240, 240, 14  27680       ['conv3d_131[0][0]']             
                                4, 32)                                                      