In [10]:
import torch
import torch.nn as nn
from torchvision import models
from functools import partial


In [11]:
def softplus_feature_map(x):
    return torch.nn.functional.softplus(x)

In [12]:
def conv3otherRelu(in_channel, out_channel, kernel_size=None, stride=None, padding=None):
    if kernel_size is None:
        kernel_size = 3
    assert isinstance (kernel_size, (int, tuple)), "kernel size is not in (int, tuple)!"

    if stride is None:
        stride = 1
    assert isinstance (stride,  (int, tuple)), "Stride is not (int, tuple)!"

    if padding is None:
        padding=1 
    assert isinstance(padding, (int, tuple)), 'padding is not in (int, tuple)!'

    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding),
        nn.ReLU(inplace=True)
    )

In [13]:
class PAM_Module (nn.Module):
    def __init__(self, in_channel, scale=8, eps=1e6):
        super(PAM_Module, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.in_channel = in_channel
        self.softplus_feature = softplus_feature_map
        self.eps = eps

        self.query_conv = nn.Conv2d(in_channel, in_channel//scale, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channel, in_channel//scale, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channel, in_channel, kernel_size=1)

    def forward(self, x):

        batch_size, channel, width, height = x.shape
        Q = self.query_conv(x).view(batch_size, -1, width*height)
        K = self.key_conv(x).view(batch_size, -1, width*height)
        V = self.value_conv(x).view(batch_size, -1, width*height)

        Q = self.softplus_feature(Q).permute(-3, -1, -2)
        K = self.softplus_feature(K)

        KV = torch.einsum('bmn, bcn->bmc', K, V)
        norm = 1 / torch.einsum('bnc, bc->bn', Q, torch.sum(K, dim=-1) + self.eps)
        weight_value = torch.einsum('bnm, bmc, bn->bcn', Q, KV, norm)
        weight_value = weight_value.view(batch_size, channel, height, width)

        return (x+self.gamma*weight_value).contiguous()

In [14]:
class CAM_Module(nn.Module):
    
    def __init__(self):
        super(CAM_Module, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        proj_query = x.view(batch_size, channels, -1)
        proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(batch_size, channels, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(batch_size, channels, height, width)

        out = self.gamma*out + x
        return out


In [15]:

class PAM_CAM_Layer(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.PAM = PAM_Module(in_channel)
        self.CAM = CAM_Module()
    def forward(self, x):
        return self.PAM(x) + self.CAM(x)

In [16]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channel, n_filters):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channel, in_channel//4, 1)
        self.norm1 = nn.BatchNorm2d(in_channel//4)
        self.relu1 = nn.ReLU(inplace=True)

        self.deconv2 = nn.ConvTranspose2d(in_channel//4, in_channel//4, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm2 = nn.BatchNorm2d(in_channel//4)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channel//4, n_filters, 1)
        self.norm3 = nn.BatchNorm2d(n_filters)
        self.relu3 = nn.ReLU(inplace=True)

    def forward(self, x):

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        x = self.deconv2(x)
        x = self.norm2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)

        return x





In [17]:
class MANet(nn.Module):
    def __init__(self, in_channel=3, num_classes=1, backbone='resnext50'):

        super().__init__()
        self.name = 'MANet'
        filters = [256, 512, 1024, 2048]
        if backbone == 'resnext50':
            resnext = models.resnext50_32x4d(pretrained=True)

        elif backbone == 'resnext101':
            resnext = models.resNeXt101_32X8D_Weights(pretrained=True)

        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
        
        # Encoder module
        self.firstconv = resnext.conv1
        self.firstbn = resnext.bn1
        self.firstrelu = resnext.relu
        self.firstmaxpool = resnext.maxpool

        self.encoder1 = resnext.layer1
        self.encoder2 = resnext.layer2
        self.encoder3 = resnext.layer3
        self.encoder4 = resnext.layer4

        # Attention Module
        self.attention4 = PAM_CAM_Layer(filters[3])
        self.attention3 = PAM_CAM_Layer(filters[2])
        self.attention2 = PAM_CAM_Layer(filters[1])
        self.attention1 = PAM_CAM_Layer(filters[0])

        # Decoder Module
        self.decoder4 = DecoderBlock(filters[3], filters[2])
        self.decoder3 = DecoderBlock(filters[2], filters[1])
        self.decoder2 = DecoderBlock(filters[1], filters[0])
        self.decoder1 = DecoderBlock(filters[0], filters[0])

        # Final Layers
        self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
        self.finalrelu1 = nn.ReLU(inplace=True)
        self.finalconv2 =  nn.Conv2d(32, 32, 3, padding=1)
        self.finalrelu2 =  nn.ReLU(inplace=True)
        self.finalconv3 =  nn.Conv2d(32, num_classes, 3, padding=1)

    def forward(self, x):
        #Encoder
        x1 = self.firstconv(x)
        x1 = self.firstbn(x1)
        x1 = self.firstrelu(x1)
        x1 = self.firstmaxpool(x1)

        e1 = self.encoder1(x1)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        e4 = self.attention4(e4)

        #Decoder
        d4 = self.decoder4(e4) + self.attention3(e3)
        d3 = self.decoder3(d4) + self.attention2(e2)
        d2 = self.decoder2(d3) + self.attention1(e1)
        d1 = self.decoder1(d2)

        out = self.finaldeconv1(d1)
        out = self.finalrelu1(out)
        out = self.finalconv2(out)
        out = self.finalrelu2(out)
        out = self.finalconv3(out)

        return out


In [18]:
if __name__ == '__main__':
    num_classes = 1
    in_batch, inchannel, in_h, in_w = 4, 3, 256, 256
    x = torch.randn(in_batch, inchannel, in_h, in_w)
    
    # Test both backbones
    net_resnext50 = MANet(3, num_classes, backbone='resnext50')
    out_resnext50 = net_resnext50(x)
    print(f"ResNeXt-50 output shape: {out_resnext50.shape}")
    
    net_resnext101 = MANet(3, num_classes, backbone='resnext101')
    out_resnext101 = net_resnext101(x)
    print(f"ResNeXt-101 output shape: {out_resnext101.shape}")

ResNeXt-50 output shape: torch.Size([4, 1, 256, 256])
