In [2]:
import torch
from torch import nn
from collections import OrderedDict
import torch.nn.functional as F

In [3]:
class ChannelAttention(nn.Module):
    def __init__(self,channel,reduction=16):
        super().__init__()
        self.maxpool=nn.AdaptiveMaxPool2d(1)
        self.avgpool=nn.AdaptiveAvgPool2d(1)
        self.se=nn.Sequential(
            nn.Conv2d(channel,channel//reduction,1,bias=False),
            nn.ReLU(),
            nn.Conv2d(channel//reduction,channel,1,bias=False)
        )
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        max_result=self.maxpool(x)
        avg_result=self.avgpool(x)
        max_out=self.se(max_result)
        avg_out=self.se(avg_result)
        output=self.sigmoid(max_out+avg_out)
        return output

class SpatialAttention(nn.Module):
    def __init__(self,kernel_size=3):
        super().__init__()
        self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        max_result,_=torch.max(x,dim=1,keepdim=True)
        avg_result=torch.mean(x,dim=1,keepdim=True)
        result=torch.cat([max_result,avg_result],1)
        output=self.conv(result)
        output=self.sigmoid(output)
        return output

class CBAMBlock(nn.Module):

    def __init__(self, channel=512,reduction=16,kernel_size=3):
        super().__init__()

        self.ca=ChannelAttention(channel=channel,reduction=reduction)
        self.sa=SpatialAttention(kernel_size=kernel_size)


    def forward(self, x):
        b, c, _, _ = x.size()
        residual=x

        out_ca = x*self.ca(x)
        out_sa = out_ca*self.sa(out_ca)

        out = out_sa + residual

        return out

In [4]:
cbam = CBAMBlock(channel=64)
x  = torch.randn(2,64,7,7)
output = cbam(x)
print(output.shape)


torch.Size([2, 64, 7, 7])


In [22]:
class SKUnit(nn.Module):
    def __init__(self, CBAMBlock, kernels=[3,5,7,11], channels=None, L = 32, reduction=16):
        super().__init__()

        self.d=max(L,channels//reduction)
        
        self.path_0 = self._make_pathways(kernel_size=(kernels[0], kernels[0]),
                                        stride=2, channels=channels, CBAMBlock=CBAMBlock)
        self.path_1 = self._make_pathways(kernel_size=(kernels[1], kernels[1]),
                                        stride=2, channels=channels, CBAMBlock=CBAMBlock)
        self.path_2 = self._make_pathways(kernel_size=(kernels[2], kernels[2]),
                                        stride=2, channels=channels, CBAMBlock=CBAMBlock)
        self.path_3 = self._make_pathways(kernel_size=(kernels[3], kernels[3]),
                                        stride=2, channels=channels, CBAMBlock=CBAMBlock)
        
        self.up_0 = nn.ConvTranspose2d(in_channels=channels, out_channels=channels,kernel_size=(kernels[0], kernels[0]),
                                        stride=2)
        self.up_1 = nn.ConvTranspose2d(in_channels=channels, out_channels=channels,kernel_size=(kernels[1], kernels[1]),
                                        stride=2)
        self.up_2 = nn.ConvTranspose2d(in_channels=channels, out_channels=channels,kernel_size=(kernels[2], kernels[2]),
                                        stride=2)
        self.up_3 = nn.ConvTranspose2d(in_channels=channels, out_channels=channels,kernel_size=(kernels[3], kernels[3]),
                                        stride=2)

        self.fc=nn.Linear(channels,self.d)

        self.fcs=nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d,channels))

        self.softmax=nn.Softmax(dim=0)
        
    def forward(self, input):
        identity = input
        bs, c, _, _ = input.size()

        p_0 = self.path_0(input)
        p_1 = self.path_1(input)
        p_2 = self.path_2(input)
        p_3 = self.path_3(input)

        x_0 = self.up_0(p_0, output_size=identity.size())
        x_1 = self.up_1(p_1, output_size=identity.size())
        x_2 = self.up_2(p_2, output_size=identity.size())
        x_3 = self.up_3(p_3, output_size=identity.size())

        print(f"p_shape{p_0.shape}, x_0.shape{x_0.shape}")

        feats = torch.stack([x_0, x_1, x_2, x_3], dim=0)
        print(f"feats.shape{feats.shape}")
        select = torch.sum(feats, dim=0)
        print(f"select.shape{select.shape}")
        S=select.mean(-1).mean(-1) 
        print(f"S.shape:{S.shape}")
        Z=self.fc(S) 

        weights=[]
        for fc in self.fcs:
            weight=fc(Z)
            weights.append(weight.view(bs,c,1,1)) 
        attention_weughts=torch.stack(weights,0)
        attention_weughts=self.softmax(attention_weughts)

        V=(attention_weughts*feats).sum(0)
        return V

    def _make_pathways(self, kernel_size, stride, channels, CBAMBlock):

        return nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=stride),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            CBAMBlock(channels)
        )

In [23]:
sk = SKUnit(CBAMBlock, channels=64)
x  = torch.randn(2,64,112,112)
output = sk(x)
print(output.shape)

p_shapetorch.Size([2, 64, 55, 55]), x_0.shapetorch.Size([2, 64, 112, 112])
feats.shapetorch.Size([4, 2, 64, 112, 112])
select.shapetorch.Size([2, 64, 112, 112])
S.shape:torch.Size([2, 64])
torch.Size([2, 64, 112, 112])


In [17]:
class block(nn.Module):
    def __init__(self, SKUnit, CBAMBlock, in_channels, out_channels, identity_downsample=None, stride = 1):
        super(block, self).__init__()
        self.expansion = 2

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = SKUnit(CBAMBlock, channels=out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)

        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample

    def forward(self, x):
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        
        x += identity
        x = self.relu(x)

        return x

In [18]:
class ResNet(nn.Module):
    def __init__(self, block, SKUnit, CBAMBlock, layers, num_channels, num_classes):
        super(ResNet, self).__init__()

        self.in_channels = 64
        self.expansion = 2

        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=self.in_channels,
                               kernel_size=3, stride=2, padding=0)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)
        self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=1)
        self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=1)
        self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=1)

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.in_channels, num_classes)

    def _make_layer(self, block, no_residual_blocks, out_channels, stride):

        layers = []
        identity_downsample = None

        if stride != 1 or self.in_channels != out_channels*self.expansion:
            identity_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels=out_channels*self.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(self.expansion*out_channels)
            )
        
        layers.append(block(SKUnit, CBAMBlock, self.in_channels, out_channels, identity_downsample, stride))
        self.in_channels = out_channels*self.expansion

        for i in range(no_residual_blocks-1):
            layers.append(block(SKUnit, CBAMBlock, self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x)

        x = self.avg_pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

In [20]:
layers = [2,2,2,2]
model = ResNet(block,SKUnit, CBAMBlock, layers, num_channels=1, num_classes=2)