In [None]:
def conv_block(inputs,num_filters):
    x=Conv2D(num_filters,3,padding="same")(inputs)
    x=BatchNormalization()(x)
    x=Activation("relu")(x)
    
    x=Conv2D(num_filters,3,padding="same")(x)
    x=BatchNormalization()(x)
    x=Activation("relu")(x)
    
    return x




def encoder_block(inputs,num_filters):
    s=conv_block(inputs,num_filters)
    p=MaxPool2D((2,2))(s)
    return s,p

def decoder_block(inputs,att_skip_features,num_filters):
    x=Conv2DTranspose(num_filters,(2,2),strides=2,padding="same")(inputs)
    x=Activation("relu")(x)                             # # #
    x=Concatenate()([x,att_skip_features])              # # #
#     x=BatchNormalization()(x)                     ################
#     x=Activation("relu")(x)                       ################
#     x=conv_block(x,num_filters)
    return x

def expend_as(self,tensor, rep):
    my_repeat = Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), arguments={'repnum': rep})(tensor)
    return my_repeat

def Attention_Gate(inputs,skip_features,num_filters):
    g=inputs
    x=skip_features
    shape_g1 =K.int_shape(g)
    shape_x1 = K.int_shape(x)
    
    num_filters=num_filters*2
    g=Conv2D(num_filters,(1,1),padding="same")(g)  #1024
    g=BatchNormalization()(g)
    g=Activation("relu")(g)
    shape_g2 =K.int_shape(g)
    
    x=Conv2D(num_filters,(2,2),strides=2,padding="same")(x)  #1024
    shape_x2 = K.int_shape(x)
    
    g_up=Conv2D(num_feature,(1,1),padding='same')(g)
    g_up=Conv2DTranspose(inter_shape, (3, 3),strides=(shape_x2[1] // shape_g2[1], shape_x2[2] // shape_g2[2]),padding='same')(g_up)
    
    add_xg = add([g_up,x])  #element wise add
    relu_xg = Activation("relu")(add_xg)
    
    psi= Conv2D(shape_x1[3],(1,1),padding='same')(relu_xg)
    sig_xg=Activation('sigmoid')(psi)
    shape_sigxg = K.int_shape(sig_xg)
    upsample_psi = UpSampling2D(size=(shape_x1[1] // shape_sigxg[1], shape_x1[2] // shape_sigxg[2]))(sig_xg)  
    
    upsample_psi = expend_as(upsample_psi, shape_x[3])
    
    y=multiply([upsample_psi,skip_features])
    
    result = Conv2D(shape_x1[3], (1, 1), padding='same')(y)
    result_bn = BatchNormalization()(result)
    
    return result_bn
    

In [None]:
def build_unet(input_shape):
    inputs=Input(input_shape)    # input layer
    
    s1,p1 = encoder_block(inputs,64)     #Encoder
    s2,p2 = encoder_block(p1,128)
    s3,p3 = encoder_block(p2,256)
    s4,p4 = encoder_block(p3,512)
    
    b1 = conv_block(p4,1024)       #Bottleneck
    
    att_s4 = Attention_Gate(b1,s4,512)      #Decoder
    d1 = decoder_block(b1,att_s4,512)
    att_s3 = Attention_Gate(d1,s3,256)
    d2 = decoder_block(d1,att_s3,256)
    att_s2 = Attention_Gate(d2,s2,128)
    d3 = decoder_block(d2,att_s2,128)
    att_s1 = Attention_Gate(d3,s1,64)
    d4 = decoder_block(d3,att_s1,64)
    
    outputs = Conv2D(1,1,padding="same",activation="sigmoid")(d4)  #Output layer
    
    model = Model(inputs,outputs,name="UNET")
    return model

model=build_unet((256,256,3))
model.summary()