# Convolutional Attention Block

Propuesta alternativa al CoordinateAttention.

Se utiliza kernels de distintos tamaños para obtener descriptores horizontales y verticales, el propósito es obtener descriptores que definan la información espacial de forma complementaria para luego unificar dicha información mediante la adición (o concatenación) de ambos descriptores.

De todas formas, no se puede ignorar la información de los canales de cada input. Estos podrían afectar drásticamente a la los descriptores espaciales. Por este motivo, se han aplicado DepthWise Separable convolutions para que cada descriptor no se base toda su información teniendo en cuenta todos los canales, sino un subconjunto de ellos. (Es necesario desarrollar esto).

In [1]:
import torch
from torch import nn

In [2]:
class ConvolutionalAttentionBlock(nn.Module):
    def __init__(self, img_size: tuple, in_channels: int, reduction_rate: int, groups=True, bias=True) -> None:
        super(ConvolutionalAttentionBlock, self).__init__()
        out_channels = max(8, in_channels // reduction_rate)
        H, W = img_size

        self.conv_h = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (1, W), bias=bias, groups=out_channels if groups else 1),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

        self.conv_w = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (H, 1), bias=bias, groups=out_channels if groups else 1),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )
        
        self.att = nn.Sequential(
            nn.Conv2d(out_channels, in_channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_h = self.conv_h(x) # Height descriptor
        x_w = self.conv_w(x) # Width descriptor

        # Coordinate attention
        coordAtt = self.att(x_h+x_w)
        # TODO: Concatenate x_h and x_w
        
        return coordAtt  

class CoordAttConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, groups: int, bias: bool) -> None:
        super(CoordAttConv, self).__init__()

        self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        # self.att_block = CoordinateAttentionBlock(out_channels, out_channels, att_reduction)


# RESNet-18

In [3]:
import torchvision.models as models
from torchvision.models.resnet import conv3x3, conv1x1, Bottleneck

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, input_size=None, **kargs):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], **kargs)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0], **kargs)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1], **kargs)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2], **kargs)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, **kargs):
        print(kargs)
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer, **kargs))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)



class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, **kargs):
        super(BasicBlock, self).__init__()

        print(kargs)
        print(inplanes)
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


In [83]:
from torchvision.models.resnet import BasicBlock

class ResAttentionBlock(BasicBlock):
    def __init__(self, img_size: tuple, inplanes:int, planes:int, stride=1, downsample=None, 
                att_reduction=8, att_groups=True, att_bias=True, **kargs):
                 
        super(ResAttentionBlock, self).__init__(inplanes, planes, stride, downsample)

        self.attention = ConvolutionalAttentionBlock(img_size, planes, att_reduction, 
                    groups=att_groups, bias=att_bias)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        att = self.attention(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        res = self.relu((att*out) + identity)
        return res

In [84]:
# att_block = ResAttentionBlock((128,128), 64, 64)
# # att_block(torch.rand(1,64,128,128))


tensor([[[[0.1452, 0.0000, 0.0000,  ..., 0.0000, 0.1423, 1.1244],
          [0.3389, 0.0000, 1.6299,  ..., 0.0966, 1.2071, 0.7163],
          [0.0000, 0.0000, 0.6837,  ..., 1.6847, 0.0000, 0.0856],
          ...,
          [0.3527, 1.2866, 0.0000,  ..., 0.3138, 0.0000, 0.4681],
          [0.1139, 0.0000, 0.0000,  ..., 0.9910, 0.4116, 1.4662],
          [1.3024, 1.5806, 1.0227,  ..., 1.6218, 2.4134, 1.0917]],

         [[0.0000, 0.0921, 0.2776,  ..., 0.6466, 1.1902, 0.3347],
          [0.9279, 0.0774, 0.1865,  ..., 0.0000, 0.7676, 0.0430],
          [0.0739, 0.0591, 0.6241,  ..., 1.4331, 0.9105, 0.5466],
          ...,
          [1.1355, 0.0000, 1.4560,  ..., 0.2645, 1.4166, 1.0860],
          [0.3188, 0.8262, 1.9807,  ..., 0.4303, 1.6467, 0.0555],
          [0.0929, 0.1579, 0.5086,  ..., 0.4295, 0.1069, 0.3421]],

         [[0.5223, 0.0000, 0.9467,  ..., 0.3550, 0.7645, 0.6179],
          [1.0765, 0.7996, 1.9936,  ..., 0.8632, 1.2247, 0.9293],
          [0.0000, 1.4586, 0.4162,  ..., 0

In [59]:
block = BasicBlock(64, 64)
block(torch.rand(1,64,128,128)).shape

torch.Size([1, 64, 128, 128])

In [78]:
from torchvision.models import ResNet
import numpy as np

class ResNet_Attention(ResNet):
    def __init__(self, img_size:tuple, block:nn.Module, layers:list, num_classes=1000, **kargs):
        super(ResNet_Attention, self).__init__(BasicBlock, layers, num_classes, kargs)

        if not isinstance(img_size, np.ndarray):
            img_size = np.array(img_size)

        self.inplanes = 64 # Because in super init it has been set to 512
        self.layer1 = self._make_attention_layer(tuple(img_size // (2**2)), block, 64, layers[0], **kargs)
        self.layer2 = self._make_attention_layer(tuple(img_size // (2**3)), block, 128, layers[1], stride=2,
                                       dilate=False, **kargs)
        self.layer3 = self._make_attention_layer(tuple(img_size // (2**4)), block, 256, layers[2], stride=2,
                                       dilate=False, **kargs)
        self.layer4 = self._make_attention_layer(tuple(img_size // (2**5)), block, 512, layers[3], stride=2,
                                       dilate=False, **kargs)
        
    def _make_attention_layer(self, input_size, block, planes, blocks, stride=1, dilate=False, **kargs):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(input_size, self.inplanes, planes, stride, downsample, 
                    groups=self.groups, base_width=self.base_width, dilation=previous_dilation,
                    norm_layer=norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(input_size, self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)


In [81]:
from torchvision.models.resnet import BasicBlock
test = ResNet_Attention((128,128), ResAttentionBlock, [1,2,2,2], num_classes=1000)

result = test(torch.rand(1,3,128,128))
result.shape

64
64
128
256


torch.Size([1, 1000])

In [80]:
test

ResNet_Attention(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResAttentionBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (attention): ConvolutionalAttentionBlock(
        (conv_h): Sequential(
          (0): Conv2d(64, 8, kernel_size=(1, 32), stride=(1, 1), groups=8)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_runn

In [24]:
import numpy as np
input_size = np.array((128, 128))
input_size

input_size // 2**5

array([4, 4])

In [15]:
resnet_18 = ResNet(BasicBlock, [2, 2, 2, 2])

In [16]:

a = torch.rand((1,3,128,128))
a = resnet_18.conv1(a)
a = resnet_18.bn1(a)
print(a.shape)
a = resnet_18.maxpool(a)
print(a.shape)
a = resnet_18.layer1(a)
print("1", a.shape)
a = resnet_18.layer2(a)
print(a.shape)
a = resnet_18.layer3(a)
print(a.shape)
a = resnet_18.layer4(a)
print(a.shape)
a = resnet_18.avgpool(a)
print(a.shape)


torch.Size([1, 64, 64, 64])
torch.Size([1, 64, 32, 32])
1 torch.Size([1, 64, 32, 32])
torch.Size([1, 128, 16, 16])
torch.Size([1, 256, 8, 8])
torch.Size([1, 512, 4, 4])
torch.Size([1, 512, 1, 1])


In [17]:
128 // (2**2)

32

In [None]:
resnet_18

In [None]:
128 // 16

In [None]:
resnet_18

In [20]:
import numpy as np
tuple(np.array((128,128)) // 2**2)

(32, 32)