In [1]:

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

In [2]:
input_tensor = torch.rand((128, 4, 1000))
input_tensor.shape

torch.Size([128, 4, 1000])

In [8]:
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm1d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

In [9]:
m = BasicConv(4, 64, 7)
m

BasicConv(
  (conv): Conv1d(4, 64, kernel_size=(7,), stride=(1,), bias=False)
  (bn): BatchNorm1d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU()
)

In [10]:
output_BasicConv = m(input_tensor)
output_BasicConv.shape

torch.Size([128, 64, 994])

In [20]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [21]:
m = Flatten()
m

Flatten()

In [22]:
output_Flatten = m(output_BasicConv)
output_Flatten.shape

torch.Size([128, 63616])

In [25]:
class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types

    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool1d( x, x.size(2), stride=x.size(2))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool1d( x, x.size(2), stride=x.size(2))
                channel_att_raw = self.mlp( max_pool )
            # elif pool_type=='lp':
            #     lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
            #     channel_att_raw = self.mlp( lp_pool )
            # elif pool_type=='lse':
            #     # LSE pool only
            #     lse_pool = logsumexp_2d(x)
            #     channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).expand_as(x)
        return x * scale


In [26]:
m = ChannelGate(64)
m

ChannelGate(
  (mlp): Sequential(
    (0): Flatten()
    (1): Linear(in_features=64, out_features=4, bias=True)
    (2): ReLU()
    (3): Linear(in_features=4, out_features=64, bias=True)
  )
)

In [27]:
output_ChannelGate = m(output_BasicConv)
output_ChannelGate.shape

torch.Size([128, 64, 994])

In [31]:
output_BasicConv[0, :5, :10], output_ChannelGate[0, :5, :10]

(tensor([[0.0000, 0.0000, 0.0000, 0.0546, 0.3367, 0.0000, 0.0000, 0.4305, 0.0049,
          0.0992],
         [0.0000, 0.0000, 0.1492, 0.0000, 0.0000, 0.8467, 0.0000, 0.0000, 0.5111,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.1184, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.3481, 0.1092, 0.0000, 0.0210, 0.1440, 0.0000, 0.2298, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0519, 0.6489, 0.0000, 0.1636, 0.3114, 0.0579, 0.0000,
          0.2451]], grad_fn=<SliceBackward>),
 tensor([[0.0000, 0.0000, 0.0000, 0.0386, 0.2379, 0.0000, 0.0000, 0.3042, 0.0034,
          0.0701],
         [0.0000, 0.0000, 0.0276, 0.0000, 0.0000, 0.1565, 0.0000, 0.0000, 0.0945,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.1274, 0.0400, 0.0000, 0.0077, 0.0527, 0.0000, 0.0841, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0290, 0.3619, 0.0000, 0.0912, 0.1737

In [32]:
def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

In [33]:
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

In [34]:
m = ChannelPool()
m

ChannelPool()

In [36]:
output_ChannelPool = m(output_BasicConv)
output_ChannelPool.shape

torch.Size([128, 2, 994])

In [39]:
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

In [40]:
m = SpatialGate()
m

SpatialGate(
  (compress): ChannelPool()
  (spatial): BasicConv(
    (conv): Conv1d(2, 1, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn): BatchNorm1d(1, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
)

In [41]:
output_SpatialGate = m(output_BasicConv)
output_SpatialGate.shape

torch.Size([128, 64, 994])

In [45]:
output_BasicConv[0, :5, :10], output_SpatialGate[0, :5, :10]

(tensor([[0.0000, 0.0000, 0.0000, 0.0546, 0.3367, 0.0000, 0.0000, 0.4305, 0.0049,
          0.0992],
         [0.0000, 0.0000, 0.1492, 0.0000, 0.0000, 0.8467, 0.0000, 0.0000, 0.5111,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.1184, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.3481, 0.1092, 0.0000, 0.0210, 0.1440, 0.0000, 0.2298, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0519, 0.6489, 0.0000, 0.1636, 0.3114, 0.0579, 0.0000,
          0.2451]], grad_fn=<SliceBackward>),
 tensor([[0.0000, 0.0000, 0.0000, 0.0165, 0.0856, 0.0000, 0.0000, 0.3674, 0.0018,
          0.0254],
         [0.0000, 0.0000, 0.1324, 0.0000, 0.0000, 0.3282, 0.0000, 0.0000, 0.1894,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0301, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.3399, 0.0969, 0.0000, 0.0053, 0.0558, 0.0000, 0.1961, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0461, 0.1958, 0.0000, 0.0634, 0.2303

In [42]:
class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

In [43]:
m = CBAM(64)
m

CBAM(
  (ChannelGate): ChannelGate(
    (mlp): Sequential(
      (0): Flatten()
      (1): Linear(in_features=64, out_features=4, bias=True)
      (2): ReLU()
      (3): Linear(in_features=4, out_features=64, bias=True)
    )
  )
  (SpatialGate): SpatialGate(
    (compress): ChannelPool()
    (spatial): BasicConv(
      (conv): Conv1d(2, 1, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
      (bn): BatchNorm1d(1, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
  )
)

In [44]:
output_CBAM = m(output_BasicConv)
output_CBAM.shape

torch.Size([128, 64, 994])

In [46]:
output_BasicConv[0, :5, :10], output_CBAM[0, :5, :10]

(tensor([[0.0000, 0.0000, 0.0000, 0.0546, 0.3367, 0.0000, 0.0000, 0.4305, 0.0049,
          0.0992],
         [0.0000, 0.0000, 0.1492, 0.0000, 0.0000, 0.8467, 0.0000, 0.0000, 0.5111,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.1184, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.3481, 0.1092, 0.0000, 0.0210, 0.1440, 0.0000, 0.2298, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0519, 0.6489, 0.0000, 0.1636, 0.3114, 0.0579, 0.0000,
          0.2451]], grad_fn=<SliceBackward>),
 tensor([[0.0000, 0.0000, 0.0000, 0.0143, 0.0853, 0.0000, 0.0000, 0.0977, 0.0011,
          0.0257],
         [0.0000, 0.0000, 0.0277, 0.0000, 0.0000, 0.2059, 0.0000, 0.0000, 0.1083,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0302, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.0851, 0.0206, 0.0000, 0.0049, 0.0357, 0.0000, 0.0476, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0124, 0.1961, 0.0000, 0.0513, 0.0948

In [3]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [None]:
class ChannelGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
        super(ChannelGate, self).__init__()
        self.gate_activation = gate_activation
        self.gate_c = nn.Sequential()
        self.gate_c.add_module( 'flatten', Flatten() )
        gate_channels = [gate_channel]
        gate_channels += [gate_channel // reduction_ratio] * num_layers
        gate_channels += [gate_channel]
        for i in range( len(gate_channels) - 2 ):
            self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )
            self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
            self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
        self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )
    def forward(self, in_tensor):
        avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )
        return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)

In [4]:
class SpatialGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
        super(SpatialGate, self).__init__()
        self.gate_s = nn.Sequential()
        self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv1d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
        self.gate_s.add_module( 'gate_s_bn_reduce0',	nn.BatchNorm1d(gate_channel//reduction_ratio) )
        self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
        for i in range( dilation_conv_num ):
            self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv1d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, padding=dilation_val, dilation=dilation_val) )
            self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm1d(gate_channel//reduction_ratio) )
            self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
        self.gate_s.add_module( 'gate_s_conv_final', nn.Conv1d(gate_channel//reduction_ratio, 1, kernel_size=1) )
    def forward(self, in_tensor):
        return self.gate_s( in_tensor ).expand_as(in_tensor)

In [None]:
class BAM(nn.Module):
    def __init__(self, gate_channel):
        super(BAM, self).__init__()
        self.channel_att = ChannelGate(gate_channel)
        self.spatial_att = SpatialGate(gate_channel)
    def forward(self,in_tensor):
        att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )
        return att * in_tensor