In [5]:
import numpy as np
import torch 
import torch.nn as nn
import  torchvision.models as models

In [6]:
class scSEBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.sse = nn.Sequential(
            nn.Conv2d(in_channels, 1, kernel_size=1),
            nn.Sigmoid()
        )

        self.cse = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels//2, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels//2, in_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        sse_out = self.sse(x)*x
        cse_out = self.cse(x)*x
        return torch.max(sse_out, cse_out)
        

In [7]:
class ResNet18_UNet_scSE(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        resnet = models.resnet18(weights='IMAGENET1K_V1')

        # Encoder path
        self.encoder = nn.ModuleDict({
            "conv1":nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu),
            "maxpool": resnet.maxpool,
            "layer1": resnet.layer1,
            "layer2": resnet.layer2,
            "layer3": resnet.layer3,
            "layer4": resnet.layer4          })
        
        #scSE blocks for skip connection
        self.scSE_skip1 = scSEBlock(64) # for conv1 output
        self.scSE_skip2 = scSEBlock(64) # for layer1 output
        self.scSE_skip3 = scSEBlock(128) # for layer2 output
        self.scSE_skip4 = scSEBlock(256) # for layer3 output

        # Decoder path
        self.upconv1 = self.upconv(512, 256)
        self.dec_conv1 = nn.Conv2d(512, 256, kernel_size=1) # 256 from encoder and 256 from decoder

        self.upconv2 = self.upconv(256, 128)
        self.dec_conv2 = nn.Conv2d(256, 128, kernel_size=1) # 128 from encoder and 128 from decoder

        self.upconv3 = self.upconv(128, 64)
        self.dec_conv3 = nn.Conv2d(128, 64, kernel_size=1) # 64 from encoder and 64 from decoder

        self.upconv4 = self.upconv(64, 32)
        self.dec_conv4 = nn.Conv2d(96, 32, kernel_size=1) # 32 from decoder and 64 from encoder

        self.final_upsample = nn.ConvTranspose2d(32, 32, kernel_size=1)
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)


    def upconv(self, in_channels, out_channels):

        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x1 = self.encoder['conv1'](x) # 64 channel
        print('x1 shape: ', x1.shape)

        x2 = self.encoder['maxpool'](x1)
        print('x2 shape:', x2.shape)

        x3 = self.encoder['layer1'](x2) # 128 channel
        print('x3 shape:', x3.shape)

        x4 = self.encoder['layer2'](x3) # 256 channel
        print('x4 shape:', x4.shape)

        x5 = self.encoder['layer3'](x4)
        print ('x5 shape:', x5.shape)

        x6 = self.encoder['layer4'](x5)
        print('x6 shape:', x6.shape)

        # Decoder path with scSE blocks applied to skip connection
        # First Upsampling blocks
        d1 = self.upconv1(x6)
        x5_scse = self.scSE_skip4(x5)
        d1 = torch.cat([d1, x5_scse], dim=1) # 256+256 = 512 channels
        d1 = self.dec_conv1(d1) # 512 -> 256 channels

        d2 = self.upconv2(d1)  # 128 channles
        x4_scse = self.scSE_skip3(x4)
        d2 = torch.cat([d2, x4_scse], dim=1) # 128 + 128 = 256 channels
        d2 = self.dec_conv2(d2) # 256 -> 128 channels


        d3 = self.upconv3(d2)         # 64 channels
        x3_scse = self.scSE_skip2(x3)
        d3 = torch.cat([d3, x3_scse], dim=1) # 64 +64 = 128 channels
        d3 = self.dec_conv3(d3)                 # 128 _> 64 channels

        d4 = self.upconv4(d3)            # 32 channels
        x1_scse = self.scSE_skip1(x1) # 32 + 64 = 96 channels
        d4 = torch.cat([d4, x1_scse], dim=1)
        d4 = self.dec_conv4(d4) # 96 -> 32 channels

        # Final upsampling and convolution
        d5 = self.final_upsample(d4)
        out = self.final_conv(d5)

        return out


In [8]:
model = ResNet18_UNet_scSE(num_classes=1)
x = torch.randn(1, 3, 256, 256)
output = model(x)
print(output.shape)

x1 shape:  torch.Size([1, 64, 128, 128])
x2 shape: torch.Size([1, 64, 64, 64])
x3 shape: torch.Size([1, 64, 64, 64])
x4 shape: torch.Size([1, 128, 32, 32])
x5 shape: torch.Size([1, 256, 16, 16])
x6 shape: torch.Size([1, 512, 8, 8])
torch.Size([1, 1, 128, 128])
