In [1]:
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
class Residual_Block(nn.Module):
    def __init__(self, num_channels):
        super().__init__()

        self.conv_residual1 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, 
                                        kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
        self.bn_residual1 = nn.BatchNorm2d(num_features=num_channels)
        self.conv_residual2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, 
                                        kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
        self.bn_residual2 = nn.BatchNorm2d(num_features=num_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):

        identity = x

        x = self.conv_residual1(x)
        x = self.bn_residual1(x)
        x = self.relu(x)

        x = self.conv_residual2(x)
        x = self.bn_residual2(x)
        x = self.relu(x)

        x += identity
        x = self.relu(x)

        return x

In [3]:
residual_block_test = Residual_Block(num_channels=64)
random_input = torch.randn(16, 64, 100, 100)
output = residual_block_test(random_input)
output.size()

torch.Size([16, 64, 100, 100])

In [4]:
summary(residual_block_test, input_size=(16, 64, 100, 100))

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Output Shape              Param #
Residual_Block                           [16, 64, 100, 100]        --
├─Conv2d: 1-1                            [16, 64, 100, 100]        36,864
├─BatchNorm2d: 1-2                       [16, 64, 100, 100]        128
├─ReLU: 1-3                              [16, 64, 100, 100]        --
├─Conv2d: 1-4                            [16, 64, 100, 100]        36,864
├─BatchNorm2d: 1-5                       [16, 64, 100, 100]        128
├─ReLU: 1-6                              [16, 64, 100, 100]        --
├─ReLU: 1-7                              [16, 64, 100, 100]        --
Total params: 73,984
Trainable params: 73,984
Non-trainable params: 0
Total mult-adds (G): 11.80
Input size (MB): 40.96
Forward/backward pass size (MB): 327.68
Params size (MB): 0.30
Estimated Total Size (MB): 368.94

In [5]:
class Residual_Block_Trans(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_identity_adapt = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                             kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=False)
        self.bn_identity_adapt = nn.BatchNorm2d(num_features=out_channels)
        
        self.conv_transition1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                          kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=False)
        self.bn_transition1 = nn.BatchNorm2d(num_features=out_channels)
        
        self.conv_transition2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 
                                          kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
        self.bn_transition2 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        identity = self.bn_identity_adapt(self.conv_identity_adapt(x))
        x = self.relu(self.bn_transition1(self.conv_transition1(x)))
        x = self.relu(self.bn_transition2(self.conv_transition2(x)))
        x += identity
        x = self.relu(x)
        return x

In [6]:
block_transition_test = Residual_Block_Trans(in_channels=128, out_channels=256)
random_input = torch.randn(16, 128, 100, 100)
output = block_transition_test(random_input)
output.size()

torch.Size([16, 256, 50, 50])

In [7]:
summary(block_transition_test, input_size=(16, 128, 100, 100))

Layer (type:depth-idx)                   Output Shape              Param #
Residual_Block_Trans                     [16, 256, 50, 50]         --
├─Conv2d: 1-1                            [16, 256, 50, 50]         32,768
├─BatchNorm2d: 1-2                       [16, 256, 50, 50]         512
├─Conv2d: 1-3                            [16, 256, 50, 50]         294,912
├─BatchNorm2d: 1-4                       [16, 256, 50, 50]         512
├─ReLU: 1-5                              [16, 256, 50, 50]         --
├─Conv2d: 1-6                            [16, 256, 50, 50]         589,824
├─BatchNorm2d: 1-7                       [16, 256, 50, 50]         512
├─ReLU: 1-8                              [16, 256, 50, 50]         --
├─ReLU: 1-9                              [16, 256, 50, 50]         --
Total params: 919,040
Trainable params: 919,040
Non-trainable params: 0
Total mult-adds (G): 36.70
Input size (MB): 81.92
Forward/backward pass size (MB): 491.52
Params size (MB): 3.68
Estimated Total Size (M

In [8]:
class ResNet(nn.Module):
    def __init__(self, num_repeats):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, 
                               kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.relu = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        
        self.residual_blocks_conv2_x = nn.ModuleList()
        for _ in range(num_repeats[0]):
            self.residual_blocks_conv2_x.append(Residual_Block(num_channels=64))
            
        self.residual_block_trans3 = Residual_Block_Trans(in_channels=64, 
                                                          out_channels=128)
        
        self.residual_blocks_conv3_x = nn.ModuleList()
        for _ in range(num_repeats[1]-1):
            self.residual_blocks_conv3_x.append(Residual_Block(num_channels=128))
        
        self.residual_block_trans4 = Residual_Block_Trans(in_channels=128, 
                                                          out_channels=256)
        
        self.residual_blocks_conv4_x = nn.ModuleList()
        for _ in range(num_repeats[2]-1):
            self.residual_blocks_conv4_x.append(Residual_Block(num_channels=256))
            
        self.residual_block_trans5 = Residual_Block_Trans(in_channels=256, 
                                                          out_channels=512)
        
        self.residual_blocks_conv5_x = nn.ModuleList()
        for _ in range(num_repeats[3]-1):
            self.residual_blocks_conv5_x.append(Residual_Block(num_channels=512))
            
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
        self.fc = nn.Linear(in_features=512, out_features=1000)
        
    def forward(self, x):
        
        # conv1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        
        # conv2_x
        for layer in self.residual_blocks_conv2_x:
            x = layer(x)
            
        # conv3_x
        x = self.residual_block_trans3(x)
        for layer in self.residual_blocks_conv3_x:
            x = layer(x)
                                         
        # conv4_x
        x = self.residual_block_trans4(x)
        for layer in self.residual_blocks_conv4_x:
            x = layer(x)
            
        # conv5_x
        x = self.residual_block_trans5(x)
        for layer in self.residual_blocks_conv5_x:
            x = layer(x)
        
        # output
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
                                                
        return x

In [9]:
resnet_test = ResNet(num_repeats=[3,4,6,3])
random_input = torch.randn(16, 3, 224, 224)
output = resnet_test(random_input)
output.size()

torch.Size([16, 1000])

In [10]:
summary(resnet_test, input_size=(4, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [4, 1000]                 --
├─Conv2d: 1-1                            [4, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [4, 64, 112, 112]         128
├─ReLU: 1-3                              [4, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [4, 64, 56, 56]           --
├─ModuleList: 1-5                        --                        --
│    └─Residual_Block: 2-1               [4, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [4, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [4, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [4, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [4, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [4, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [4, 64, 56, 56]           --
│