In [2]:
import torch
import torch.nn as nn

In [3]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)
    

In [4]:
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p


In [22]:
class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c)
        )
        self.Wv = nn.Sequential(
            nn.Conv2d(in_c[0]+in_c[1], out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c)
        )
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.pool = nn.MaxPool2d((2, 2))
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        q = self.Wg(g)
        k = self.Ws(s)
        
       
        
        print(q.shape)
        print(k.shape)
        out = self.relu(k*q)
        out = self.output(out)
        
        
        v = torch.cat([g,s], axis=1)
        v1=  self.Wv(v)
        
        result=v1*out+v1
        
        
        return result


In [23]:
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(out_c, out_c)

    def forward(self, x, s):
        x = self.up(x)
        s = self.ag(x, s)
        x = self.c1(s)
        return x


In [24]:
class attention_unet(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder_block(3,32)
        self.e2 = encoder_block(32,64)
        
        self.b1 = conv_block(64,128)

        self.d1 = decoder_block([128,64],64)
        self.d2 = decoder_block([64,32],32)
       
         
        self.output1 = nn.Conv2d(32, 3, kernel_size=3, padding=1)
        self.finaloutput = nn.Conv2d(3, 1, kernel_size=1, padding=0)
        
       
        

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
       

        b1 = self.b1(p2)

        d1 = self.d1(b1, s2)
        d2 = self.d2(d1, s1)
        

        output = self.output1(d2)
        output = self.finaloutput(output)
        
        return output


In [25]:

if __name__ == "__main__":
    x = torch.randn((8, 3, 256, 256)) 
    model = attention_unet()
    output = model(x)
    print(output.shape)

torch.Size([8, 64, 128, 128])
torch.Size([8, 64, 128, 128])
torch.Size([8, 32, 256, 256])
torch.Size([8, 32, 256, 256])
torch.Size([8, 1, 256, 256])
