In [30]:
import torch
import torch.nn as nn

In [31]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.double_conv(x)

In [32]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False),
            nn.Sigmoid()
        )
    

In [33]:
class ChanAttUNet(nn.Module):
    def __init__(self):
        super(ChanAttUNet, self).__init__()
        ### 1st part encoder
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = DoubleConv(3, 64)
        self.down_conv_2 = DoubleConv(64, 128)
        self.down_conv_3 = DoubleConv(128, 256)
        self.down_conv_4 = DoubleConv(256, 512)
        self.down_conv_5 = DoubleConv(512, 1024)

In [34]:
def forward(self, x):
    avg_out = self.fc(self.avg_pool(x))
    print(avg_out.size())
    max_out = self.fc(self.max_pool(x))
    out = avg_out + max_out
    print(out.size())
    return out

In [35]:
def forward(self, x):
        # encoder
        x1 = self.down_conv_1(x) # we will need x1, x3, x5, x7 in the decoder part. We don't need x9 in the decoder part because there is no maxpooling.
        print("Output after 1st convolution:", x1.size())
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2) #
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4) #
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6) #
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        print("Output after 5th convolution:", x9.size())

In [36]:
if __name__ == '__main__':
        image = torch.rand((1, 3, 512, 512))
        model = ChanAttUNet()
        # print(model(image))
        output = model(x)
        print("Final shape", output.shape)

NameError: name 'x' is not defined