In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# linear cut the x into (0, 6) then devide 6 to (0, 1)
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

# then x * x / x * 1 / x * 0
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


In [43]:
# the hswish will make all the params into range like [0, 1] * x
a = torch.randn((10,), requires_grad=True) * 5
print(a)
print('============================')
print(h_swish()(a))

tensor([-6.8560,  5.1842, -2.7119,  0.4446,  2.0827, 13.4886,  0.0136, -4.9741,
        -0.5176,  3.2783], grad_fn=<MulBackward0>)
tensor([-0.0000e+00,  5.1842e+00, -1.3021e-01,  2.5521e-01,  1.7643e+00,
         1.3489e+01,  6.8293e-03, -0.0000e+00, -2.1414e-01,  3.2783e+00],
       grad_fn=<MulBackward0>)


In [48]:
# the coordinate attention
class CoordAtt(nn.Module):
    
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        print(f"horizontal: {x_h.size()}, vertical: {x_w.size()}")

        y = torch.cat([x_h, x_w], dim=2)
        print(y.size())
        y = self.conv1(y)
        print(y.size())
        y = self.bn1(y)
        print(y.size())
        y = self.act(y) 
        print(y.size())
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        print(f"horizontal: {x_h.size()}, vertical: {x_w.size()}")

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        print(f"horizontal a_h : {a_h.size()}, vertical a_w: {a_w.size()}")

        out = identity * a_w * a_h

        return out

In [51]:
"""
    1. pool_h + pool_w pool the feature maps in two different ways
    2. integrate the horizontal and veritical semantic info
"""
s = torch.randn(1, 512, 1024, 1024)
coordatt = CoordAtt(inp=512, oup=512, reduction=32)
coordatt(s)

horizontal: torch.Size([1, 512, 1024, 1]), vertical: torch.Size([1, 512, 1024, 1])
torch.Size([1, 512, 2048, 1])
torch.Size([1, 16, 2048, 1])
torch.Size([1, 16, 2048, 1])
torch.Size([1, 16, 2048, 1])
horizontal: torch.Size([1, 16, 1024, 1]), vertical: torch.Size([1, 16, 1, 1024])
horizontal a_h : torch.Size([1, 512, 1024, 1]), vertical a_w: torch.Size([1, 512, 1, 1024])


tensor([[[[ 3.5031e-01,  5.0381e-01, -6.8972e-02,  ...,  2.4235e-01,
           -3.2165e-02, -2.3709e-01],
          [-4.9914e-01, -9.5855e-02,  3.5843e-01,  ...,  2.0877e-01,
            1.7718e-01,  2.3058e-01],
          [-2.8197e-01, -3.1710e-01,  1.4525e-01,  ...,  4.1460e-02,
            1.0780e+00, -7.5117e-02],
          ...,
          [ 1.7511e-02, -9.0027e-01,  3.4201e-02,  ...,  2.5167e-01,
            9.5652e-02, -6.0967e-02],
          [ 5.9386e-01, -3.6653e-01,  1.0572e-02,  ..., -7.2872e-01,
           -5.0225e-01,  3.0172e-01],
          [-7.2008e-01, -3.5814e-01,  1.8359e-01,  ..., -4.8479e-01,
           -1.3253e-02, -2.6786e-01]],

         [[-2.8347e-01,  2.3336e-01, -2.2689e-01,  ..., -1.2091e-01,
            1.8717e-01,  1.0198e-01],
          [-1.4968e-01,  4.5618e-02,  7.2103e-01,  ...,  1.2494e-01,
           -9.4589e-02, -2.3135e-02],
          [-1.7523e-01, -2.0264e-01, -1.4485e-01,  ..., -1.8179e-02,
            2.9364e-01, -2.8330e-01],
          ...,
     

In [52]:
# dot-mul + metric-mul
tmp1 = torch.randn(1, 128, 1, 1024)
tmp2 = torch.randn(1, 128, 1024, 1)
print((torch.randn(1, 128, 1024, 1024) * tmp1 * tmp2).size())

torch.Size([1, 128, 1024, 1024])


In [53]:
tmp3 = torch.randn(1, 1, 5, 5)
bn = nn.BatchNorm2d(1)
print(tmp3)
bn(tmp3)

tensor([[[[ 0.2637, -0.1923, -1.0151,  0.1154,  0.7724],
          [-1.7421, -0.2755, -0.1362, -1.1723,  0.5523],
          [ 0.8475,  0.5016, -0.2030,  0.0673,  0.9439],
          [-1.5750,  1.3467, -0.4171,  0.7304,  0.0941],
          [-1.3239,  0.1416,  0.2633, -0.5850,  0.5797]]]])


tensor([[[[ 0.4016, -0.1699, -1.2012,  0.2157,  1.0392],
          [-2.1125, -0.2742, -0.0996, -1.3983,  0.7633],
          [ 1.1333,  0.6998, -0.1834,  0.1554,  1.2542],
          [-1.9030,  1.7590, -0.4518,  0.9866,  0.1890],
          [-1.5884,  0.2485,  0.4011, -0.6622,  0.7977]]]],
       grad_fn=<NativeBatchNormBackward0>)