In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class ResidualConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ResidualConvBlock, self).__init__()

        self.bn1 = nn.BatchNorm2d(in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)

        self.bn2 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False)

        self.se = SEBlock(out_ch)

        self.shortcut = nn.Identity()
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)

    def forward(self, x):
        identity = self.shortcut(x)

        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))

        out = self.se(out)

        return F.relu(out + identity, inplace=True)

class ASM(nn.Module):
    def __init__(self, in_ch=8, out_ch=1):
        super(ASM, self).__init__()

        self.conv1 = ResidualConvBlock(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = ResidualConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.conv3 = ResidualConvBlock(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.conv4 = ResidualConvBlock(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.conv5 = ResidualConvBlock(512, 1024)

        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv6 = ResidualConvBlock(1024, 512)

        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv7 = ResidualConvBlock(512, 256)

        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv8 = ResidualConvBlock(256, 128)

        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = ResidualConvBlock(128, 64)

        self.out_conv = nn.Conv2d(64, out_ch, kernel_size=1)

    def forward(self, x):
        c1 = self.conv1(x);  p1 = self.pool1(c1)
        c2 = self.conv2(p1); p2 = self.pool2(c2)
        c3 = self.conv3(p2); p3 = self.pool3(c3)
        c4 = self.conv4(p3); p4 = self.pool4(c4)

        c5 = self.conv5(p4)

        u6 = self.up6(c5); u6 = self._pad_or_crop(u6, c4)
        c6 = self.conv6(torch.cat([u6, c4], dim=1))

        u7 = self.up7(c6); u7 = self._pad_or_crop(u7, c3)
        c7 = self.conv7(torch.cat([u7, c3], dim=1))

        u8 = self.up8(c7); u8 = self._pad_or_crop(u8, c2)
        c8 = self.conv8(torch.cat([u8, c2], dim=1))

        u9 = self.up9(c8); u9 = self._pad_or_crop(u9, c1)
        c9 = self.conv9(torch.cat([u9, c1], dim=1))

        return self.out_conv(c9)

    def _pad_or_crop(self, up, bypass):
        diffY = bypass.size(2) - up.size(2)
        diffX = bypass.size(3) - up.size(3)
        if diffX != 0 or diffY != 0:
            up = F.pad(up, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
        return up

if __name__ == "__main__":
    model = ASM(in_ch=8, out_ch=1)
    x = torch.randn(2, 8, 181, 380)
    y = model(x)
    print("Input:", x.shape)
    print("Output:", y.shape)

Input: torch.Size([2, 8, 181, 380])
Output: torch.Size([2, 1, 181, 380])
