In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

In [28]:
class Conv(nn.Module):
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int, 
                 kernel_size: int, 
                 stride: int, 
                 padding: int=0, 
                 activation: bool=True, **kwargs) -> None:
        super().__init__()
        
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, padding , bias=False, **kwargs)
        self.bnorm = nn.BatchNorm2d(output_dim)
        self.act = nn.ReLU() if activation else None

    def forward(self, x):
        x = self.conv(x)
        x = self.bnorm(x)
        if self.act:
            x = self.act(x)

        return x

In [3]:
summary(Conv(3, 64, 3, 2, 1, False), (3, 112,112), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 56, 56]           1,728
       BatchNorm2d-2           [-1, 64, 56, 56]             128
Total params: 1,856
Trainable params: 1,856
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.14
Forward/backward pass size (MB): 3.06
Params size (MB): 0.01
Estimated Total Size (MB): 3.21
----------------------------------------------------------------


In [47]:
class ResNetBottleNeck(nn.Module):
    expansion = 4
    def __init__(self, 
                 input_dim: int, 
                 output_dim: int, 
                 stride: int = 1,
                 downsample: nn.Module =None, **kwargs) -> None:
        super().__init__()
        
        self.downsample = downsample
        
        self.c1 = Conv(input_dim, input_dim, 1, 1, 0)
        self.c2 = Conv(input_dim, input_dim, 3, stride, 1)
        self.c3 = Conv(input_dim, output_dim * self.expansion, 1, 1, 0, False)

        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = x
        x = self.c1(x)
        x = self.c2(x)
        x = self.c3(x)

        if self.downsample:
            res = self.downsample(res)

        x = x + res
        x = self.act(x)
        
        return x

In [60]:
## downsample and stride = 2 at 3x3 layer from pytorch documentation
summary(ResNetBottleNeck(64, 64, stride=2, downsample=Conv(64, 256, 1, 2), expansion=4), (64, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           4,096
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
              Conv-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
              Conv-8         [-1, 64, 112, 112]               0
            Conv2d-9        [-1, 256, 112, 112]          16,384
      BatchNorm2d-10        [-1, 256, 112, 112]             512
             Conv-11        [-1, 256, 112, 112]               0
           Conv2d-12        [-1, 256, 112, 112]          16,384
      BatchNorm2d-13        [-1, 256, 112, 112]             512
             ReLU-14        [-1, 256, 1

In [72]:
class ResNetLayer(nn.Module):
    def __init__(self, input_dim, num_blocks, downsample=None, **kwargs):
        super().__init__()
        
        blocks = []
        if downsample:
            blocks.append(ResNetBottleNeck(input_dim, input_dim, stride=2, downsample=downsample))
            num_blocks = num_blocks - 1
        for _ in range(num_blocks):
            blocks.append(ResNetBottleNeck(input_dim * 4, input_dim))

        self.blocks = nn.ModuleList(blocks)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [73]:
summary(ResNetLayer(64, 4, Conv(64, 256, 1, 2, activation=False)), (64, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           4,096
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
              Conv-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
              Conv-8         [-1, 64, 112, 112]               0
            Conv2d-9        [-1, 256, 112, 112]          16,384
      BatchNorm2d-10        [-1, 256, 112, 112]             512
             Conv-11        [-1, 256, 112, 112]               0
           Conv2d-12        [-1, 256, 112, 112]          16,384
      BatchNorm2d-13        [-1, 256, 112, 112]             512
             Conv-14        [-1, 256, 1

In [47]:
class ResNet(nn.Module):
    def __init__(self, input_dim, layer_sizes=[64,128,256,512], num_blocks=[3,4,6,3], **kwargs):
        super().__init__()
        self.c1 = Conv(input_dim, layer_sizes[0], 7, 2, 3)
        self.pool = nn.MaxPool2d(3, 2, 1)

        # self.layers = nn.ModuleList([
        #     for
        # ])

        self.gap = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.c1(x)
        x = self.pool(x)

        # x = self.first_layer(x)
        # x = self.rest_layers(x)

        for block in self.first_layer:
            x = block(x)

        for block in self.rest_layers:
            x = block(x)

        x = self.gap(x)

        return x

In [48]:
summary(ResNet(3), (3, 224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
              Conv-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6           [-1, 64, 56, 56]           4,096
       BatchNorm2d-7           [-1, 64, 56, 56]             128
              ReLU-8           [-1, 64, 56, 56]               0
              Conv-9           [-1, 64, 56, 56]               0
           Conv2d-10           [-1, 64, 56, 56]          36,864
      BatchNorm2d-11           [-1, 64, 56, 56]             128
             ReLU-12           [-1, 64, 56, 56]               0
             Conv-13           [-1, 64, 56, 56]               0
           Conv2d-14           [-1, 64,