In [35]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Input, UpSampling2D, concatenate, Conv2DTranspose, Dropout
from tensorflow.keras.optimizers import Adam

In [48]:
def unet_model(img_height, img_width, img_channels, num_classes):
    inputs = Input(shape = (img_height, img_width, num_classes))
    l1 = Conv2D(32, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(inputs)
    l1 = Dropout(0.2)(l1)
    l1 = Conv2D(32, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(l1)
    p1 = MaxPooling2D(2,2)(l1)
    
    l2 = Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(p1)
    l2 = Dropout(0.2)(l2)
    l2 = Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(l2)
    p2 = MaxPooling2D(2,2)(l2)
    
    l3 = Conv2D(128, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(p2)
    l3 = Dropout(0.2)(l3)
    l3 = Conv2D(128, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(l3)
    p3 = MaxPooling2D(2,2)(l3)
    
    l4 = Conv2D(256, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(p3)
    l4 = Dropout(0.2)(l4)
    l4 = Conv2D(256, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(l4)
    p4 = MaxPooling2D(2,2)(l4)
    
    l5 = Conv2D(512, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(p4)
    l5 = Dropout(0.2)(l5)
    l5 = Conv2D(512, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(l5)
    
    u6 = Conv2DTranspose(256, (2,2), strides = (2,2), padding = 'same')(l5)
    u6 = concatenate([u6, l4])
    u6 = Conv2D(256, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u6)
    u6 = Dropout(0.2)(u6)
    u6 = Conv2D(256, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u6)
    
    u7 = Conv2DTranspose(128, (2,2), strides = (2,2), padding = 'same')(u6)
    u7 = concatenate([u7, l3])
    u7 = Conv2D(128, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u7)
    u7 = Dropout(0.2)(u7)
    u7 = Conv2D(128, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u7)
    
    u8 = Conv2DTranspose(64, (2,2), strides = (2,2), padding = 'same')(u7)
    u8 = concatenate([u8, l2])
    u8 = Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u8)
    u8 = Dropout(0.2)(u8)
    u8 = Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u8)
    
    u9 = Conv2DTranspose(32, (2,2), strides = (2,2), padding = 'same')(u8)
    u9 = concatenate([u9, l1])
    u9 = Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u9)
    u9 = Dropout(0.2)(u9)
    u9 = Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_normal', padding='same')(u9)
    
    outputs = Conv2D(num_classes, (1,1), activation = 'softmax')(u9)
    
    model = Model(inputs = [inputs], outputs = [outputs])
    
    model.summary()
    return model

In [49]:
model = unet_model(128, 128, 3, 4)

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_20 (InputLayer)       [(None, 128, 128, 4)]        0         []                            
                                                                                                  
 conv2d_170 (Conv2D)         (None, 128, 128, 32)         1184      ['input_20[0][0]']            
                                                                                                  
 dropout_81 (Dropout)        (None, 128, 128, 32)         0         ['conv2d_170[0][0]']          
                                                                                                  
 conv2d_171 (Conv2D)         (None, 128, 128, 32)         9248      ['dropout_81[0][0]']          
                                                                                              