In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

In [2]:
class Conv(nn.Module):
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int, 
                 kernel_size: int, 
                 stride: int, 
                 padding: int=0, 
                 activation: bool=True, **kwargs) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, padding , bias=False, **kwargs)
        self.bnorm = nn.BatchNorm2d(output_dim)
        self.act = nn.ReLU() if activation else None

    def forward(self, x):
        x = self.conv(x)
        x = self.bnorm(x)
        if self.act:
            x = self.act(x)

        return x

In [3]:
summary(Conv(3, 64, 3, 2, 1, False), (3, 112,112), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 56, 56]           1,728
       BatchNorm2d-2           [-1, 64, 56, 56]             128
Total params: 1,856
Trainable params: 1,856
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.14
Forward/backward pass size (MB): 3.06
Params size (MB): 0.01
Estimated Total Size (MB): 3.21
----------------------------------------------------------------


In [12]:
class ResNetBottleNeck(nn.Module):
    expansion = 4
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int, 
                 stride: int = 1,
                 **kwargs) -> None:
        super().__init__()
        
        self.shortcut = nn.Identity()
        if stride != 1 or input_dim != output_dim * self.expansion:
            self.shortcut = Conv(input_dim, output_dim * self.expansion, 1, stride,0, False)
        
        self.c1 = Conv(input_dim, output_dim, 1, 1, 0)
        self.c2 = Conv(output_dim, output_dim, 3, stride, 1)
        self.c3 = Conv(output_dim, output_dim * self.expansion, 1, 1, 0, False)

        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = x
        x = self.c1(x)
        x = self.c2(x)
        x = self.c3(x)

        if self.shortcut:
            res = self.shortcut(res)

        x = x + res
        x = self.act(x)
        
        return x

In [14]:
## downsample and stride = 2 at 3x3 layer from pytorch documentation
summary(ResNetBottleNeck(64, 256, stride=2), (64, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 256, 224, 224]          16,384
       BatchNorm2d-2        [-1, 256, 224, 224]             512
              ReLU-3        [-1, 256, 224, 224]               0
              Conv-4        [-1, 256, 224, 224]               0
            Conv2d-5        [-1, 256, 112, 112]         589,824
       BatchNorm2d-6        [-1, 256, 112, 112]             512
              ReLU-7        [-1, 256, 112, 112]               0
              Conv-8        [-1, 256, 112, 112]               0
            Conv2d-9       [-1, 1024, 112, 112]         262,144
      BatchNorm2d-10       [-1, 1024, 112, 112]           2,048
             Conv-11       [-1, 1024, 112, 112]               0
           Conv2d-12       [-1, 1024, 112, 112]          65,536
      BatchNorm2d-13       [-1, 1024, 112, 112]           2,048
             Conv-14       [-1, 1024, 1

In [15]:
class ResNetLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_blocks, downsample=False):
        super().__init__()

        blocks = []
        stride = 2 if downsample else 1

        # First block possibly downsamples
        blocks.append(ResNetBottleNeck(input_dim, output_dim, stride=stride))
        in_dim = output_dim * ResNetBottleNeck.expansion

        # Rest keep in_dim == out_dim * expansion
        for _ in range(num_blocks - 1):
            blocks.append(ResNetBottleNeck(in_dim, output_dim))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)

In [16]:
# summary(ResNetLayer(64, 3, Conv(64, 256, 1, 2, activation=False)), (64, 224, 224), device='cpu')
summary(ResNetLayer(256, 128, 3, False), (256, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 128, 224, 224]          32,768
       BatchNorm2d-2        [-1, 128, 224, 224]             256
              ReLU-3        [-1, 128, 224, 224]               0
              Conv-4        [-1, 128, 224, 224]               0
            Conv2d-5        [-1, 128, 224, 224]         147,456
       BatchNorm2d-6        [-1, 128, 224, 224]             256
              ReLU-7        [-1, 128, 224, 224]               0
              Conv-8        [-1, 128, 224, 224]               0
            Conv2d-9        [-1, 512, 224, 224]          65,536
      BatchNorm2d-10        [-1, 512, 224, 224]           1,024
             Conv-11        [-1, 512, 224, 224]               0
           Conv2d-12        [-1, 512, 224, 224]         131,072
      BatchNorm2d-13        [-1, 512, 224, 224]           1,024
             Conv-14        [-1, 512, 2

In [17]:
class ResNet(nn.Module):
    def __init__(self, input_dim, layer_sizes=[64, 128, 256, 512], num_blocks=[3, 4, 6, 3]):
        super().__init__()

        self.c1 = Conv(input_dim, layer_sizes[0], 7, 2, 3)
        self.pool = nn.MaxPool2d(3, 2, 1)

        self.layers = nn.ModuleList()
        in_dim = layer_sizes[0]
        for i in range(len(layer_sizes)):
            out_dim = layer_sizes[i]
            print(in_dim, out_dim)
            downsample = i != 0
            self.layers.append(ResNetLayer(in_dim, num_blocks[i], downsample))
            in_dim = out_dim * ResNetBottleNeck.expansion
        

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.c1(x)
        x = self.pool(x)
        for layer in self.layers:
            x = layer(x)
        x = self.gap(x)
        x = self.flatten(x)
        return x


In [18]:
summary(ResNet(3), (3, 224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
              Conv-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6           [-1, 64, 56, 56]           4,096
       BatchNorm2d-7           [-1, 64, 56, 56]             128
              ReLU-8           [-1, 64, 56, 56]               0
              Conv-9           [-1, 64, 56, 56]               0
           Conv2d-10           [-1, 64, 56, 56]          36,864
      BatchNorm2d-11           [-1, 64, 56, 56]             128
             ReLU-12           [-1, 64, 56, 56]               0
             Conv-13           [-1, 64, 56, 56]               0
           Conv2d-14          [-1, 256,