In [21]:
from torchinfo import summary
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import ResNet, Bottleneck
import torchvision

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((1, None))
        self.pool_w = nn.AdaptiveAvgPool2d((None, 1))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        x_cat = torch.cat((x_h, x_w), dim=2)

        x_cat = self.conv1(x_cat)
        x_cat = self.bn1(x_cat)
        x_cat = self.act(x_cat)

        x_h = self.conv_h(x_cat[:, :, 0:1, :])
        x_w = self.conv_w(x_cat[:, :, 1:2, :]).permute(0, 1, 3, 2)

        out = x_h * x_w
        return out

class BottleneckWithCA(Bottleneck):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ca = CoordAtt(self.conv3.out_channels, self.conv3.out_channels)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.ca(out) * out # Apply Coordinate Attention

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

def resnet50_ca(pretrained=False, **kwargs):
    model = ResNet(block=BottleneckWithCA, layers=[3, 4, 6, 3])
    model.fc= nn.Sequential(
        torch.nn.Linear(2048,1024),
        torch.nn.ReLU(),
        torch.nn.Linear(1024,512),
        torch.nn.Dropout(),
        torch.nn.Linear(in_features=512,
                        out_features=2,
                        bias=True)
    )

    weight = torchvision.models.ResNet50_Weights.DEFAULT
    model.load_state_dict(weight.get_state_dict(progress=True), strict=False)
    
    for name, param in model.named_parameters():
        if name in weight.get_state_dict():
            param.requires_grad = False
            
    return model

In [None]:
model = resnet50_ca()
summary(model, input_size=(32, 3, 224, 224), depth=4, col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds", "trainable"], row_settings=["var_names"], verbose=0)

Layer (type (var_name))                            Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds                 Trainable
ResNet (ResNet)                                    [32, 3, 224, 224]         [32, 2]                   --                        --                        --                        Partial
├─Conv2d (conv1)                                   [32, 3, 224, 224]         [32, 64, 112, 112]        (9,408)                   [7, 7]                    3,776,446,464             False
├─BatchNorm2d (bn1)                                [32, 64, 112, 112]        [32, 64, 112, 112]        (128)                     --                        4,096                     False
├─ReLU (relu)                                      [32, 64, 112, 112]        [32, 64, 112, 112]        --                        --                        --                        --
├─MaxPool2d (maxpool)                              [32, 64, 11

: 

In [6]:
summary(resnet50_ca(), depth=4)

Layer (type:depth-idx)                        Param #
ResNet                                        --
├─Conv2d: 1-1                                 9,408
├─BatchNorm2d: 1-2                            128
├─ReLU: 1-3                                   --
├─MaxPool2d: 1-4                              --
├─Sequential: 1-5                             --
│    └─BottleneckWithCA: 2-1                  --
│    │    └─Conv2d: 3-1                       4,096
│    │    └─BatchNorm2d: 3-2                  128
│    │    └─Conv2d: 3-3                       36,864
│    │    └─BatchNorm2d: 3-4                  128
│    │    └─Conv2d: 3-5                       16,384
│    │    └─BatchNorm2d: 3-6                  512
│    │    └─ReLU: 3-7                         --
│    │    └─Sequential: 3-8                   --
│    │    │    └─Conv2d: 4-1                  16,384
│    │    │    └─BatchNorm2d: 4-2             512
│    │    └─CoordAtt: 3-9                     --
│    │    │    └─AdaptiveAvgPool2d: 4-3  