In [1]:
from torch import nn
from torch.nn import functional as F
import numpy as np

import torch


In [2]:
# I had to make Residual_block a subclass so i can manage the forward function to do skip connection

class Residual_block(nn.Module):
    def __init__(self, channels, kernel_sizes, stride, padding, bias, first_residual_block = False):
        super().__init__()
        # First conv layers.
        self.not_last_conv = nn.Sequential(
            *[
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size, stride=2 if (i == 0 and first_residual_block) else stride,
                              padding = padding if kernel_size != 1 else 0,  # padding = 0 if kernel_size is 1
                              bias=bias),              
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                ) for i, (in_channels, out_channels, kernel_size) in enumerate(zip(channels[:-2], channels[1:-1], kernel_sizes[:-1]))
            ]
        )

        # Last conv layer in the block, it's output is used for skip connection.
        self.last_conv = nn.Sequential(
            nn.Conv2d(channels[-2], channels[-1], kernel_sizes[-1], stride,
                      padding if kernel_sizes[-1] != 1 else 0, bias=bias),
            nn.BatchNorm2d(channels[-1])
            )
        self.shortcut = nn.Conv2d(channels[0], channels[-1], kernel_size=1, stride= stride if (first_residual_block == False) else 2, padding=0, bias=False) # The shortcut

    def forward(self, x):
        out = self.not_last_conv(x)
        out = self.last_conv(out)
        x = self.shortcut(x)
        return F.relu(x + out)
        
def make_layer(num_blocks, first_block_in_channels ,channels, kernel_sizes, stride, padding, bias):
    first_block_channels = [first_block_in_channels]+ channels
    later_channels = [channels[-1]]+ channels # The out_channels of the last residual_block is the number of in_channels for the next block
    print(first_block_channels)
    layer = [
        Residual_block(first_block_channels, kernel_sizes, stride, padding, bias, first_residual_block=True)
    ] + [
        nn.Sequential(
            Residual_block(later_channels, kernel_sizes, stride, padding, bias, first_residual_block=False)
        ) for _ in range(num_blocks - 1)
    ]
    return nn.Sequential(*layer)

class ResNet(nn.Module):
    def __init__(self, configs):
        super().__init__()

        #self.conv1 = nn.Conv2d(3, 64, 7, 2, bias= False)
        #self.max_pool1 = nn.MaxPool2d(3, 2)
        conv_layers = [
            make_layer(
                config['num_blocks'], config['first_block_in_channels'] , config['channels'], config['kernel_sizes'], stride=1, padding=1, bias = False
                ) for config in configs
        ]
        last_conv_out_channels = configs[-1]['channels'][-1]

        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 7, 1, 3, bias= False),
            nn.MaxPool2d(3, 2),
            *conv_layers,
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(last_conv_out_channels ,1000),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        out = self.model(x)

        return out
        
def test_resnet():
    configs_R18 = [
        {
            'num_blocks': 2,
            'first_block_in_channels': 64,
            'channels': [64, 64],
            'kernel_sizes': [3, 3]
        },
        {
            'num_blocks': 2,
            'first_block_in_channels': 64,
            'channels': [128, 128],
            'kernel_sizes': [3, 3]
        },
        {
            'num_blocks': 2,
            'first_block_in_channels': 128,
            'channels': [256, 256],
            'kernel_sizes': [3, 3]
        },
        {
            'num_blocks': 2,
            'first_block_in_channels': 256,
            'channels': [512, 512],
            'kernel_sizes': [3, 3]
        },
    ]
    model = ResNet(configs= configs_R18)
    x = torch.randn(4, 3, 224, 224)
    out = model(x)
    print(out.shape)
    print(model)

test_resnet()

[64, 64, 64]
[64, 128, 128]
[128, 256, 256]
[256, 512, 512]
torch.Size([4, 1000])
ResNet(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Sequential(
      (0): Residual_block(
        (not_last_conv): Sequential(
          (0): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (last_conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (shortcut): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): Sequential(
        (0)

In [3]:
configs_R50 = [
        {
            'num_blocks': 3,
            'first_block_in_channels': 64,
            'channels': [64, 64, 256],
            'kernel_sizes': [1,3,1]
        },
        {
            'num_blocks': 4,
            'first_block_in_channels': 256,
            'channels': [128, 128, 512],
            'kernel_sizes': [1,3,1]
        },
        {
            'num_blocks': 6,
            'first_block_in_channels': 512,
            'channels': [256, 256, 1024],
            'kernel_sizes': [1,3,1]
        },
        {
            'num_blocks': 3,
            'first_block_in_channels': 1024,
            'channels': [512, 512, 2048],
            'kernel_sizes': [1,3,1]
        },
    ]

for config in configs_R50:
    print(config['num_blocks'], config['first_block_in_channels'] , config['channels'], config['kernel_sizes'])

configs_R50[1]['num_blocks']
print([1] + configs_R50[0]['channels'])

3 64 [64, 64, 256] [1, 3, 1]
4 256 [128, 128, 512] [1, 3, 1]
6 512 [256, 256, 1024] [1, 3, 1]
3 1024 [512, 512, 2048] [1, 3, 1]
[1, 64, 64, 256]


In [4]:
def test_resnet50():
    configs_R50 = [
        {
            'num_blocks': 3,
            'first_block_in_channels': 64,
            'channels': [64, 64, 256],
            'kernel_sizes': [1,3,1]
        },
        {
            'num_blocks': 4,
            'first_block_in_channels': 256,
            'channels': [128, 128, 512],
            'kernel_sizes': [1,3,1]
        },
        {
            'num_blocks': 6,
            'first_block_in_channels': 512,
            'channels': [256, 256, 1024],
            'kernel_sizes': [1,3,1]
        },
        {
            'num_blocks': 3,
            'first_block_in_channels': 1024,
            'channels': [512, 512, 2048],
            'kernel_sizes': [1,3,1]
        },
    ]
    model = ResNet(configs= configs_R50)
    x = torch.randn(4, 3, 224, 224)
    out = model(x)
    print(out.shape)
    print(model)

test_resnet50()

[64, 64, 64, 256]
[256, 128, 128, 512]
[512, 256, 256, 1024]
[1024, 512, 512, 2048]
torch.Size([4, 1000])
ResNet(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Sequential(
      (0): Residual_block(
        (not_last_conv): Sequential(
          (0): Sequential(
            (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (1): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
        )
        (last_conv): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 