In [1]:
import torch
from torch import nn

In [2]:
class ChannelAttention(nn.Module):
    """
    CBAM:Convolutional Blck Attention Module
    
    Args:
        in_channel(int):channel number of input feature
        ratio(int):linear Layer
    Return:
        out (tensor):channel attention ,the shape is (N,1,H,W)
    """
    def __init__(self,in_channel,ratio = 16):
        super(ChannelAttention,self).__init__()
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.linear = nn.Sequential(nn.Conv2d(in_channel,in_channel//ratio,1,bias=False),
                                    nn.ReLU(),
                                    nn.Conv2d(in_channel//ratio,in_channel,1,bias=False)
        )
        self.act = nn.Sigmoid()
    def forward(self,x):
        avg = self.linear(self.avg_pool(x))
        maxs = self.linear(self.max_pool(x))
        out = avg + maxs
        return self.act(out)
        
class SpatialAttention(nn.Module):
    
    def __init__(self,kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2,1,kernel_size,padding = kernel_size//2)
        self.act = nn.Sigmoid()
    def forward(self,x):
        avg = torch.mean(x,dim=1,keepdim=True)
        maxs,_ = torch.max(x,dim=1,keepdim=True)
        att = self.conv(torch.cat([avg,maxs],dim=1))
        return self.act(att)

In [5]:
class CBAM(nn.Module):
    """
    put the channel and spatial block at the end of basic 3*3 block
    """
    def __init__(self,in_channel,
                planes,
                downsample=False):
        super(CBAM,self).__init__()
        self.in_channel= in_channel
        self.planes =planes 
        
        self.conv1 = nn.Conv2d(in_channel,planes,3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.act = nn.ReLU()
        
        self.conv2 =nn.Conv2d(planes,planes,3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(planes)
        
        self.donwsample = downsample
        
        self.channel =ChannelAttention(planes)
        self.spatial=SpatialAttention()
    def forward(self,x):
        residual = x
        
        out =self.conv1(x)
        out = self.bn1(out)
        out =self.act(out)
        
        out = self.conv2(out)
        out =self.bn2(out)
        
        out = self.channel(out)*out
        out =self.spatial(out)*out
        
        if self.donwsample:
            residual =self.donwsample(residual)
        
        out += residual
        
        return self.act(out)

In [8]:
test = CBAM(64,64)
a =torch.randn(1,64,10,10)
torch.onnx.export(test,a,'mo.onnx')