In [3]:
import torch
from torch import nn
from torchinfo import summary

In [40]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, r=16):
        super().__init__()

        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        self.se_layer = nn.Sequential(
            nn.Linear(in_channels, in_channels//r),
            nn.ReLU(),
            nn.Linear(in_channels//r, in_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        SE = self.squeeze(x)
        SE = SE.reshape(x.shape[0], x.shape[1])
        
        SE = self.se_layer(SE)
        SE = SE.unsqueeze(dim=2).unsqueeze(dim=3)

        x = x * SE
        return x

# model = SEBlock(32)
# summary(model, (2, 32, 100, 100))

In [47]:
class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, inner_channels, stride=1, projection=None):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, inner_channels, 1, bias=False),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(),

            nn.Conv2d(inner_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(),

            nn.Conv2d(inner_channels, inner_channels*self.expansion, 1, bias=False),
            nn.BatchNorm2d(inner_channels*self.expansion),
        )

        self.se_block = SEBlock(inner_channels*self.expansion)

        self.projection = projection
        self.relu = nn.ReLU()
    
    def forward(self, x):
        residual = self.residual(x)
        residual = self.se_block(residual)

        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x
        
        output = self.relu(residual + shortcut)
        return output

# model = BottleNeck(256, 64)
# summary(model, (2, 256, 56, 56))

In [54]:
class SE_ResNet(nn.Module):
    def __init__(self, block, num_block_list, num_classes=1000):
        super().__init__()

        self.in_channels = 64

        self.conv_blk = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

        self.layer1 = self.make_layers(block, num_block_list[0], 64, stride=1)
        self.layer2 = self.make_layers(block, num_block_list[1], 128, stride=2)    
        self.layer3 = self.make_layers(block, num_block_list[2], 256, stride=2)    
        self.layer4 = self.make_layers(block, num_block_list[3], 512, stride=2)

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512*block.expansion, num_classes)
    
    def forward(self, x):
        x = self.conv_blk(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

    def make_layers(self, block, num_of_block, inner_channels, stride):
        if stride != 1 or self.in_channels != inner_channels*block.expansion:
            projection = nn.Sequential(
                nn.Conv2d(self.in_channels, inner_channels*block.expansion, 1, stride=stride, bias=False),
                nn.BatchNorm2d(inner_channels*block.expansion)
            )
        else:
            projection = None

        layers = []
        layers += [block(self.in_channels, inner_channels, stride, projection)]

        self.in_channels = inner_channels * block.expansion

        for _ in range(0, num_of_block-1):
            layers += [block(self.in_channels, inner_channels)]

        return nn.Sequential(*layers)


model = SE_ResNet(block=BottleNeck, num_block_list=[3, 4, 6, 3], num_classes=1000)
summary(model, (2, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
SE_ResNet                                     [2, 1000]                 --
├─Sequential: 1-1                             [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                            [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                       [2, 64, 112, 112]         128
│    └─ReLU: 2-3                              [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                              [2, 64, 56, 56]           --
├─Sequential: 1-3                             [2, 256, 56, 56]          --
│    └─BottleNeck: 2-4                        [2, 256, 56, 56]          --
│    │    └─Sequential: 3-1                   [2, 256, 56, 56]          58,112
│    │    └─SEBlock: 3-2                      [2, 256, 56, 56]          8,464
│    │    └─Sequential: 3-3                   [2, 256, 56, 56]          16,896
│    │    └─ReLU: 3-4                         [2, 256, 56, 56]          --
│    