In [1]:
import tensorflow as tf
from tensorflow.keras import layers as tfl
from tensorflow.keras.models import Model

In [61]:
class CNNBlock(tfl.Layer):
    def __init__(self,kernels,**kwargs):
        super().__init__()
        self.conv=tfl.Conv2D(kernels,3,padding='same')
        self.batchN=tfl.BatchNormalization()
        self.act=tfl.Activation("relu")
        self.conv2=tfl.Conv2D(kernels,3,padding='same')
        self.batchN2=tfl.BatchNormalization()
        self.act2=tfl.Activation("relu")

    def call(self,x):
        x = self.conv(x)
        x = self.batchN(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.batchN2(x)
        x = self.act2(x)
        return x
    

In [62]:
class encoder(tf.keras.Model):
    def __init__(self,kernels,**kwargs):
        super().__init__()
        self.cnnBlock = CNNBlock(kernels)
        self.maxPool = tfl.MaxPool2D((2,2))
    def call(self,x):
        x = self.cnnBlock(x)
        p = self.maxPool(x)
        return x,p

In [34]:
def attention_gate(g, s, kernels):
    Wg = tfl.Conv2D(kernels,1,padding='same')(g)
    Wg = tfl.BatchNormalization()(Wg)

    Ws = tfl.Conv2D(kernels,1,padding='same')(s)
    Ws = tfl.BatchNormalization()(Ws)

    out = tfl.Activation("relu")(Ws+Wg)
    out = tfl.Conv2D(kernels,1,padding= 'same')(out)
    out = tfl.Activation("sigmoid")(out)
    return out*s

In [63]:
class decoder(tf.keras.Model):
    def __init__(self,kernels,**kwargs):
        super().__init__()
        self.upsample = tfl.UpSampling2D(interpolation="bilinear")
        self.kernels=kernels
        self.cnnBlock = CNNBlock(kernels)

    def call(self,x,s):
        x = self.upsample(x)
        s = attention_gate(x,s,self.kernels)
        x = tfl.Concatenate()([x,s])
        x = self.cnnBlock(x)
        return x

In [68]:
class Attention_unet(tf.keras.Model):
    def __init__(self,classes):
        super().__init__()
        self.e1 = encoder(64)
        self.e2 = encoder(128)
        self.e3 = encoder(256)
        self.b = CNNBlock(512)
        self.d1 = decoder(256)
        self.d2 = decoder(128)
        self.d3 = decoder(64)
        self.output_con = tfl.Conv2D(classes,1,padding="same",activation="sigmoid")
    def call(self,x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        b1 = self.b(p3)
        x1 = self.d1(b1,s3)
        x2 = self.d2(x1,s2)
        x3 = self.d3(x2,s1)
        o = self.output_con(x3)
        
        return o


In [58]:

model = Attention_unet(classes=1)
dummy_input = tf.random.normal([1, 256, 256, 3])
model(dummy_input)  
model.summary()

In [67]:
def call_m(x):
    inputs = tfl.Input(x)
    e1 = encoder(64)
    e2 = encoder(128)
    e3 = encoder(256)
    b  = CNNBlock(512)

    s1, p1 = e1(inputs)
    s2, p2 = e2(p1)
    s3, p3 = e3(p2)
    
    b1 = b(p3)
    #print(s1.shape, s2.shape,s3.shape,b1.shape)
    d1 = decoder(256)
    d2 = decoder(128)
    d3 = decoder(64)
    x1 = d1(b1,s3)
    x2 = d2(x1,s2)
    x3 = d3(x2,s1)
    #print(x1.shape, x2.shape,x3.shape)
    outputs = tfl.Conv2D(1,1,padding="same",activation="sigmoid")(x3)
    model =Model(inputs,outputs, name="a_u")
    return model

model= call_m((256,256,3))
model.summary()