In [1]:
import torch
import torch.nn as nn

In [2]:
class Encoder(nn.Module):
  def __init__(self, in_channels = None, out_channels = None):
    super(Encoder, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels

    self.model = nn.Sequential(
        nn.Conv2d(self.in_channels, self.out_channels, 3, 1, 1),
        nn.ReLU(inplace = True),
        nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
        nn.BatchNorm2d(self.out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.model(x)

In [None]:
data = torch.randn(64, 3, 256, 256)
encoder = Encoder(in_channels=3, out_channels=64)

encoder(data).shape

torch.Size([64, 64, 256, 256])

In [3]:
class Decoder(nn.Module):
  def __init__(self, in_channels = None, out_channels = None):
    super(Decoder, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels

    self.model = nn.Sequential(
        nn.ConvTranspose2d(self.in_channels, self.out_channels, 2, 2, 0)
    )

  def forward(self, x):
    x = self.model(x)
    return x

In [None]:
data1 = torch.randn(64, 1024, 16, 16)
data2 = torch.randn(64, 512, 32, 32)

decoder = Decoder(in_channels=1024, out_channels=512)

decoder(data1, data2).shape

torch.Size([64, 1024, 32, 32])

In [4]:
class AttentionBlock(nn.Module):
  def __init__(self, in_channels = None, out_channels = None):
    super().__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels

    self.W_gate = nn.Sequential(
        nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0, bias=True),
        nn.BatchNorm2d(self.out_channels)
    )

    self.W_x = nn.Sequential(
        nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0, bias=True),
        nn.BatchNorm2d(self.out_channels)
    )

    self.psi = nn.Sequential(
        nn.Conv2d(self.in_channels*2, 1, 1, 1, 0, bias=True),
        nn.BatchNorm2d(1),
        nn.Sigmoid()
    )

    self.relu = nn.ReLU(inplace = True)

  def forward(self,x, skip_info):
    x = self.W_gate(x)
    skip = self.W_x(skip_info)

    concat = torch.cat((x, skip), dim = 1)
    concat = self.relu(concat)

    out = self.psi(concat)

    resampled = out * skip

    return resampled

In [77]:
data1 = torch.randn(64, 512, 16, 16)      # 64, 512, 8, 8
data2 = torch.randn(64, 512, 16, 16)   # 64, 512, 8, 8 # skip

# '''
# 64, 1024, 8, 8 #

# 64, 512, 16, 16 # skip_info
# '''

attention = AttentionBlock(in_channels=512, out_channels=512)

attention(data1, data2).shape

torch.Size([64, 512, 16, 16])

In [5]:
from math import e
class AttentionUNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.encoder_block1 = Encoder(in_channels=3, out_channels=64)
    self.encoder_block2 = Encoder(in_channels=64, out_channels=128)
    self.encoder_block3 = Encoder(in_channels=128, out_channels=256)
    self.encoder_block4 = Encoder(in_channels=256, out_channels=512)
    self.encoder_block5 = Encoder(in_channels=512, out_channels=1024)

    self.intermiadte_block1 = Encoder(in_channels=1024, out_channels=512)
    self.intermiadte_block2 = Encoder(in_channels=512, out_channels=256)
    self.intermiadte_block3 = Encoder(in_channels=256, out_channels=128)
    self.intermiadte_block4 = Encoder(in_channels=128, out_channels=64)

    self.decoder_block1 = Decoder(in_channels=1024, out_channels=512)
    self.decoder_block2 = Decoder(in_channels=512, out_channels=256)
    self.decoder_block3 = Decoder(in_channels=256, out_channels=128)
    self.decoder_block4 = Decoder(in_channels=128, out_channels=64)

    self.maxpool = nn.MaxPool2d(2, 2)

    self.attention_block1 = AttentionBlock(in_channels=512, out_channels=512)
    self.attention_block2 = AttentionBlock(in_channels=256, out_channels=256)
    self.attention_block3 = AttentionBlock(in_channels=128, out_channels=128)
    self.attention_block4 = AttentionBlock(in_channels=64, out_channels=64)


    self.final = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)


  def forward(self, x):
    e1 = self.encoder_block1(x)
    e1_out = self.maxpool(e1)

    e2 = self.encoder_block2(e1_out)
    e2_out = self.maxpool(e2)

    e3 = self.encoder_block3(e2_out)
    e3_out = self.maxpool(e3)

    e4 = self.encoder_block4(e3_out)    # skip_info
    e4_out = self.maxpool(e4)

    e5 = self.encoder_block5(e4_out)    # out

    up1 = self.decoder_block1(e5)


    att1 = self.attention_block1(up1, e4)
    att1 = torch.cat((att1, e4), dim = 1)


    # de1 = self.decoder_block1(att1)
    de1_out = self.intermiadte_block1(att1)

    up2 = self.decoder_block2(de1_out)
    att2 = self.attention_block2(up2, e3)
    att2 = torch.cat((att2, e3), dim = 1)

    de2_out = self.intermiadte_block2(att2)


    up3 = self.decoder_block3(de2_out)
    att3 = self.attention_block3(up3, e2)
    att3 = torch.cat((att3, e2), dim = 1)

    de3_out = self.intermiadte_block3(att3)

    up4 = self.decoder_block4(de3_out)
    att4 = self.attention_block4(up4, e1)
    att4 = torch.cat((att4, e1), dim = 1)

    de4_out = self.intermiadte_block4(att4)


    out = self.final(de4_out)


    return out

In [6]:
data = torch.randn(64, 3, 128, 128)
model = AttentionUNet()
model(data).shape

torch.Size([64, 1, 128, 128])