In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.up(x)

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, n_coefficients):
        super().__init__()
        self.W_g = nn.Conv2d(F_g, n_coefficients, kernel_size=1)
        self.W_x = nn.Conv2d(F_l, n_coefficients, kernel_size=1)
        self.psi = nn.Conv2d(n_coefficients, 1, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, gate, skip):
        g1 = self.W_g(gate)
        x1 = self.W_x(skip)
        psi = self.relu(g1 + x1)
        psi = self.sigmoid(self.psi(psi))
        return skip * psi

class ResNet18AttentionUNet(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        
        # ResNet-18 Encoder
        resnet = torchvision.models.resnet18(pretrained=True)
        self.initial = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )
        
        self.enc1 = resnet.layer1  # 64
        self.enc2 = resnet.layer2  # 128
        self.enc3 = resnet.layer3  # 256
        self.enc4 = resnet.layer4  # 512

        # Decoder with Attention
        self.up4 = UpConv(512, 256)
        self.att4 = AttentionBlock(256, 256, 128)
        self.dec4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.up3 = UpConv(256, 128)
        self.att3 = AttentionBlock(128, 128, 64)
        self.dec3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.up2 = UpConv(128, 64)
        self.att2 = AttentionBlock(64, 64, 32)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.up1 = UpConv(64, 64)
        self.att1 = AttentionBlock(64, 64, 32)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.final = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        e0 = self.initial(x)  # 64, H/4
        e1 = self.enc1(e0)    # 64, H/4
        e2 = self.enc2(e1)    # 128, H/8
        e3 = self.enc3(e2)    # 256, H/16
        e4 = self.enc4(e3)    # 512, H/32

        # Decoder
        d4 = self.up4(e4)                     # 256, H/16
        a4 = self.att4(d4, e3)                 # 256, H/16
        d4 = torch.cat([a4, d4], dim=1)        # 512, H/16
        d4 = self.dec4(d4)                     # 256, H/16

        d3 = self.up3(d4)                      # 128, H/8
        a3 = self.att3(d3, e2)                 # 128, H/8
        d3 = torch.cat([a3, d3], dim=1)        # 256, H/8
        d3 = self.dec3(d3)                     # 128, H/8

        d2 = self.up2(d3)                      # 64, H/4
        a2 = self.att2(d2, e1)                 # 64, H/4
        d2 = torch.cat([a2, d2], dim=1)        # 128, H/4
        d2 = self.dec2(d2)                     # 64, H/4

        d1 = self.up1(d2)                      # 64, H/2
        e0_up = F.interpolate(e0, scale_factor=2, mode='bilinear', align_corners=True)
        a1 = self.att1(d1, e0_up)              # 64, H/2
        d1 = torch.cat([a1, d1], dim=1)        # 128, H/2
        d1 = self.dec1(d1)                     # 64, H/2

        out = F.interpolate(d1, scale_factor=2, mode='bilinear', align_corners=True)
        return self.final(out)  # H, W

In [2]:
model = ResNet18AttentionUNet(num_classes=1)
x = torch.randn(1, 3, 256, 256)
output = model(x)
print(output.shape)



torch.Size([1, 1, 256, 256])
