In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        # Calculate valid groups for mid_channels
        if mid_channels % 32 == 0: groups_mid = 32
        elif mid_channels % 16 == 0: groups_mid = 16
        elif mid_channels % 8 == 0: groups_mid = 8
        else: groups_mid = 1
        
        # Calculate valid groups for out_channels
        if out_channels % 32 == 0: groups_out = 32
        elif out_channels % 16 == 0: groups_out = 16
        elif out_channels % 8 == 0: groups_out = 8
        else: groups_out = 1

        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(groups_mid, mid_channels), 
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(groups_out, out_channels), 
        )

        if in_channels != out_channels:
            # Fix shortcut groups as well
            if out_channels % 32 == 0: groups_sc = 32
            elif out_channels % 16 == 0: groups_sc = 16
            else: groups_sc = 1
            
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.GroupNorm(groups_sc, out_channels) # Fixed
            )
        else:
            self.shortcut = nn.Identity()

        self.final_act = nn.GELU()

    def forward(self, x):
        return self.final_act(self.double_conv(x) + self.shortcut(x))
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


In [None]:
class PixelShuffleUp(nn.Module):
    def __init__(self, deep_ch, skip_ch, out_ch):
        super().__init__()
        
        # Upsampling block using PixelShuffle (efficient sub-pixel convolution)
        self.project = nn.Sequential(
            nn.Conv2d(deep_ch, deep_ch * 4, kernel_size=1, bias=False),
            nn.PixelShuffle(upscale_factor=2),
            nn.GroupNorm(1, deep_ch), 
            nn.GELU()
        )
        
        # Convolution block for feature fusion after concatenation
        self.conv = DoubleConv(deep_ch + skip_ch, out_ch)

    def forward(self, x_deep, x_skip):
        x_up = self.project(x_deep)
        
        # Handle spatial dimension mismatch via padding (crucial for U-Net)
        diffY = x_skip.size()[2] - x_up.size()[2]
        diffX = x_skip.size()[3] - x_up.size()[3]
        
        x_up = F.pad(x_up, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
        
        # Concatenate skip connection and upsampled features
        x_cat = torch.cat([x_skip, x_up], dim=1)
        
        return self.conv(x_cat)

In [None]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class PAM_Module(nn.Module):
    """ Position Attention Module (GSA) """
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

In [None]:
class TSA_Block(nn.Module):
    """
    Replaces original Transformer Self Attention with
    Large Kernel Attention (LKA) - 2024/2025 SOTA method
    """
    def __init__(self, dim):
        super().__init__()
        # Local Context (5x5)
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        # Long-range Context (7x7 dilated) -> Simulates Global Attention linearly
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        # Channel Mixing
        self.conv1 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)
        return u * attn

class SAA_Module(nn.Module):
    """ Self-Aware Attention Module (Improved) """
    def __init__(self, dim):
        super(SAA_Module, self).__init__()
        self.norm1 = nn.BatchNorm2d(dim)
        self.tsa = TSA_Block(dim) # This is now the LKA-based block
        self.norm2 = nn.BatchNorm2d(dim)

        # Feed Forward Network (Standard in modern blocks)
        self.ffn = nn.Sequential(
            nn.Conv2d(dim, dim*4, 1),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Conv2d(dim*4, dim, 1),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        # Residual connection 1 (Attention)
        x = x + self.tsa(self.norm1(x))
        # Residual connection 2 (FFN)
        x = x + self.ffn(self.norm2(x))
        return x



In [None]:
class CAM_Module(nn.Module):
    """ Channel Attention Module """
    def __init__(self, in_dim):
        super(CAM_Module, self).__init__()
        self.chanel_in = in_dim
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)
        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out


class DANetBlock(nn.Module):
    """ Dual Attention: Parallel PAM + CAM """
    def __init__(self, in_channels):
        super(DANetBlock, self).__init__()
        self.pam = PAM_Module(in_channels)  # spatial
        self.cam = CAM_Module(in_channels)  # channel

    def forward(self, x):
        x_pam = self.pam(x)
        x_cam = self.cam(x)
        return x_pam + x_cam  # fusion
    
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.GroupNorm(32, out_channels), # GroupNorm
            nn.GELU()
        )
        
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False),
            nn.GroupNorm(32, out_channels), # GroupNorm
            nn.GELU()
        )
        
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),
            nn.GroupNorm(32, out_channels), # GroupNorm
            nn.GELU()
        )
        
        self.branch4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False),
            nn.GroupNorm(32, out_channels), # GroupNorm
            nn.GELU()
        )
        
        self.branch5 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.GroupNorm(32, out_channels), # GroupNorm
            nn.GELU()
        )
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
            nn.GroupNorm(32, out_channels), # GroupNorm
            nn.GELU(),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        h, w = x.shape[2], x.shape[3]
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        b5 = F.interpolate(self.branch5(x), size=(h, w), mode='bilinear', align_corners=True)
        out = torch.cat([b1, b2, b3, b4, b5], dim=1)
        return self.bottleneck(out)




In [None]:

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        
        if F_int % 32 == 0: groups = 32
        elif F_int % 16 == 0: groups = 16
        elif F_int % 8 == 0: groups = 8
        else: groups = 1

        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(groups, F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.GroupNorm(groups, F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        if g1.shape[2:] != x1.shape[2:]:
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=True)

        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

In [None]:
class PretrainedTransAttUnet(nn.Module):
    def __init__(self, n_classes=1, backbone_name='convnextv2_tiny.fcmae_ft_in22k_in1k', pretrained=True):
        super(PretrainedTransAttUnet, self).__init__()
        
        self.encoder = timm.create_model(
            backbone_name, 
            features_only=True, 
            pretrained=pretrained,
            out_indices=(0, 1, 2, 3)  
        )
        
        feature_channels = self.encoder.feature_info.channels()
        c1, c2, c3, c4 = feature_channels 
        print(f"Adaptive Channels: {feature_channels}")

        self.aspp = ASPP(in_channels=c4, out_channels=c4)
        self.danet = DANetBlock(c4)
        self.saa_bridge = SAA_Module(c4) 

        self.ag1 = AttentionGate(F_g=c4, F_l=c3, F_int=c3 // 2)
        self.ag2 = AttentionGate(F_g=c3, F_l=c2, F_int=c2 // 2)
        self.ag3 = AttentionGate(F_g=c2, F_l=c1, F_int=c1 // 2)

        self.up1 = PixelShuffleUp(deep_ch=c4, skip_ch=c3, out_ch=c3) 
        
        self.up2 = PixelShuffleUp(deep_ch=c3, skip_ch=c2, out_ch=c2)
        
        self.up3 = PixelShuffleUp(deep_ch=c2, skip_ch=c1, out_ch=c1)

        self.outc_final = OutConv(c1, n_classes)
        self.outc_up1 = OutConv(c3, n_classes) 
        self.outc_up2 = OutConv(c2, n_classes)

    def forward(self, x):
        # Encoder
        features = self.encoder(x)
        x1, x2, x3, x4 = features 
        
        # Bridge
        x_center = self.saa_bridge(self.danet(self.aspp(x4)))
        
        # Decoder Path
        x3_ag = self.ag1(g=x_center, x=x3)
        d1 = self.up1(x_center, x3_ag)
        
        x2_ag = self.ag2(g=d1, x=x2)
        d2 = self.up2(d1, x2_ag)
        
        x1_ag = self.ag3(g=d2, x=x1)
        d3 = self.up3(d2, x1_ag) 

        input_size = x.shape[2:]
        
        out_final = F.interpolate(self.outc_final(d3), size=input_size, mode='bilinear', align_corners=True)
        out_d1 = F.interpolate(self.outc_up1(d1), size=input_size, mode='bilinear', align_corners=True)
        out_d2 = F.interpolate(self.outc_up2(d2), size=input_size, mode='bilinear', align_corners=True)

        return [out_final, out_d2, out_d1] 