In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torchsummaryM import summary

In [3]:
def activation_fun(activ):
    return nn.ModuleDict([
        ['relu', nn.ReLU(True)],
        ['leaky_relu', nn.LeakyReLU(0.02, True)],
        ['selu', nn.SELU(True)],
        ['none', nn.Identity()]
    ])[activ]

In [4]:
class BasicBlock(nn.Module):
    def __init__(self, in_f, out_f, a='relu'):
        super(BasicBlock, self).__init__()
        if in_f != out_f:
            self.conv1 = nn.Conv2d(in_f, out_f, kernel_size=3, stride=2, padding=1, bias=False)
        else:
            self.conv1 = nn.Conv2d(in_f, out_f, kernel_size=3, stride=1, padding=1, bias=False)
            
        self.conv2 = nn.Conv2d(out_f, out_f, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_f)
        self.bn2   = nn.BatchNorm2d(out_f)
        self.activ = activation_fun(a)
        
        if in_f != out_f:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_f, out_f, kernel_size=1, stride=2, padding=0, bias=False),
                nn.BatchNorm2d(out_f)
            )
        else:
            self.shortcut = False
            
    def forward(self, x):
        residual = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activ(x)
        x = self.conv2(x)
        x = self.bn2(x)
        
        if self.shortcut:
            residual = self.shortcut(residual)
            
        x += residual
        x = self.activ(x)
            
        return x

In [5]:
class BottleBlock(nn.Module):
    def __init__(self, in_f, out_f, a='relu'):
        super(BottleBlock, self).__init__()
        
        if in_f == out_f: last = out_f; middle = in_f//4
        if in_f > out_f: last = 4*out_f; middle = out_f
        if in_f < out_f: last = out_f; middle = in_f
            
        if in_f > out_f: self.conv1 = nn.Conv2d(in_f, middle, kernel_size=1, stride=2, bias=False)
        else: self.conv1 = nn.Conv2d(in_f, middle, kernel_size=1, stride=1, bias=False)
        self.conv2 = nn.Conv2d(middle, middle, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(middle, last, kernel_size=1, stride=1, bias=False)
        
        self.bn1   = nn.BatchNorm2d(middle)
        self.bn2   = nn.BatchNorm2d(middle)
        self.bn3   = nn.BatchNorm2d(last)
        self.activ = activation_fun(a)
        
        if in_f > out_f:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_f, last, kernel_size=1, stride=2, padding=0, bias=False),
                nn.BatchNorm2d(last)
            )
        elif in_f < out_f:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_f, last, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(last)
            )
        else:
            self.shortcut = False
            
    def forward(self, x):
        residual = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activ(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activ(x)
        x = self.conv3(x)
        x = self.bn3(x)
        
        if self.shortcut:
            residual = self.shortcut(residual)    
        x += residual
        x = self.activ(x)
        return x

In [6]:
class ResNet(nn.Module):
    def __init__(self, in_c, n_cls, depths:list, block = BottleBlock):
        super(ResNet, self).__init__()
        self.in_c, self.n_cls = in_c, n_cls

        blocks = []
        channels = [64, 128, 256, 512]

        if block == BasicBlock:
            c = 1
            for _ in range(depths[0]): blocks.append(block(channels[0], channels[0]))
            for depth in depths[1:]:
                for n in range(depth):
                    if n == 0: blocks.append(block(channels[c-1], channels[c]))
                    else: blocks.append(block(channels[c], channels[c]))
                c += 1
        else:
            c = 0
            in_c  = 64
            out_c = 256
            for depth in depths:
                for n in range(depth):
                    blocks.append(block(in_c, out_c))
                    if n==0 and c !=0: out_c *=4
                    if n==0: in_c = out_c
                out_c //= 2
                c += 1
        
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.conv2 = nn.Sequential(
            *blocks[:depths[0]]
        )
        self.conv3 = nn.Sequential(
            *blocks[depths[0]:sum(depths[:1])]
        )
        self.conv4 = nn.Sequential(
            *blocks[sum(depths[:1]):sum(depths[:2])]
        )
        self.conv5 = nn.Sequential(
            *blocks[sum(depths[:2]):sum(depths)]
        )
        
        self.last = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(512*4, 1000) if block == BottleBlock else nn.Linear(512, 1000)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.last(x)
        return x

In [7]:
def resnet18(in_c, n_cls):
    return ResNet(in_c, n_cls, [2, 2, 2, 2], BasicBlock)

def resnet34(in_c, n_cls):
    return ResNet(in_c, n_cls, [3, 4, 6, 3], BasicBlock)

def resnet50(in_c, n_cls):
    return ResNet(in_c, n_cls, [3, 4, 6, 3])

def resnet101(in_c, n_cls):
    return ResNet(in_c, n_cls, [3, 4, 23, 3])

def resnet152(in_c, n_cls):
    return ResNet(in_c, n_cls, [3, 8, 36, 3])

In [8]:
model18 = resnet18(3, 10)
model34 = resnet34(3, 10)
model50 = resnet50(3, 10)
model101 = resnet101(3, 10)
model152 = resnet152(3, 10)

In [9]:
summary18 = summary(model18, torch.ones(1, 3, 224, 224))

----------------------------------------------------------------------------------------------------
Layer(type)                             ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                           ||                   -     [1, 3, 224, 224]               -
                                        ||                                                         
01> ResNet-Conv1-Conv2d                 ||       [3, 64, 7, 7]    [1, 64, 111, 111]           9,408
02> ResNet-Conv1-BatchNorm2d            ||                [64]    [1, 64, 111, 111]             128
03> ResNet-Conv1-ReLU                   ||                   -    [1, 64, 111, 111]               0
04> ResNet-Conv1-MaxPool2d              ||                   -      [1, 64, 56, 56]               0
05> ResNet-Conv2-1-Conv2d               ||      [64, 64, 3, 3]      [1, 64, 56, 56]          36,864
06> ResNet-Conv2-1-BatchNorm2d          ||                [64]      [1, 64, 56, 56]             128

In [10]:
summary34 = summary(model34, torch.ones(1, 3, 224, 224))

----------------------------------------------------------------------------------------------------
Layer(type)                             ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                           ||                   -     [1, 3, 224, 224]               -
                                        ||                                                         
001> ResNet-Conv1-Conv2d                ||       [3, 64, 7, 7]    [1, 64, 111, 111]           9,408
002> ResNet-Conv1-BatchNorm2d           ||                [64]    [1, 64, 111, 111]             128
003> ResNet-Conv1-ReLU                  ||                   -    [1, 64, 111, 111]               0
004> ResNet-Conv1-MaxPool2d             ||                   -      [1, 64, 56, 56]               0
005> ResNet-Conv2-1-Conv2d              ||      [64, 64, 3, 3]      [1, 64, 56, 56]          36,864
006> ResNet-Conv2-1-BatchNorm2d         ||                [64]      [1, 64, 56, 56]             128

In [11]:
summary50 = summary(model50, torch.ones(1, 3, 224, 224))

----------------------------------------------------------------------------------------------------
Layer(type)                             ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                           ||                   -     [1, 3, 224, 224]               -
                                        ||                                                         
001> ResNet-Conv1-Conv2d                ||       [3, 64, 7, 7]    [1, 64, 111, 111]           9,408
002> ResNet-Conv1-BatchNorm2d           ||                [64]    [1, 64, 111, 111]             128
003> ResNet-Conv1-ReLU                  ||                   -    [1, 64, 111, 111]               0
004> ResNet-Conv1-MaxPool2d             ||                   -      [1, 64, 56, 56]               0
005> ResNet-Conv2-1-Conv2d              ||      [64, 64, 1, 1]      [1, 64, 56, 56]           4,096
006> ResNet-Conv2-1-BatchNorm2d         ||                [64]      [1, 64, 56, 56]             128

In [12]:
summary101 = summary(model101, torch.ones(1, 3, 224, 224))

-----------------------------------------------------------------------------------------------------
Layer(type)                              ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                            ||                   -     [1, 3, 224, 224]               -
                                         ||                                                         
001> ResNet-Conv1-Conv2d                 ||       [3, 64, 7, 7]    [1, 64, 111, 111]           9,408
002> ResNet-Conv1-BatchNorm2d            ||                [64]    [1, 64, 111, 111]             128
003> ResNet-Conv1-ReLU                   ||                   -    [1, 64, 111, 111]               0
004> ResNet-Conv1-MaxPool2d              ||                   -      [1, 64, 56, 56]               0
005> ResNet-Conv2-1-Conv2d               ||      [64, 64, 1, 1]      [1, 64, 56, 56]           4,096
006> ResNet-Conv2-1-BatchNorm2d          ||                [64]      [1, 64, 56, 56]      

In [13]:
summary152 = summary(model152, torch.ones(1, 3, 224, 224))

-----------------------------------------------------------------------------------------------------
Layer(type)                              ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                            ||                   -     [1, 3, 224, 224]               -
                                         ||                                                         
001> ResNet-Conv1-Conv2d                 ||       [3, 64, 7, 7]    [1, 64, 111, 111]           9,408
002> ResNet-Conv1-BatchNorm2d            ||                [64]    [1, 64, 111, 111]             128
003> ResNet-Conv1-ReLU                   ||                   -    [1, 64, 111, 111]               0
004> ResNet-Conv1-MaxPool2d              ||                   -      [1, 64, 56, 56]               0
005> ResNet-Conv2-1-Conv2d               ||      [64, 64, 1, 1]      [1, 64, 56, 56]           4,096
006> ResNet-Conv2-1-BatchNorm2d          ||                [64]      [1, 64, 56, 56]      

In [15]:
from torchvision import models

official152 = models.resnet152()
summary152 = summary(official152, torch.ones(1, 3, 224, 224))

-------------------------------------------------------------------------------------------------------
Layer(type)                                ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                              ||                   -     [1, 3, 224, 224]               -
                                           ||                                                         
001> ResNet-Conv2d                         ||       [3, 64, 7, 7]    [1, 64, 112, 112]           9,408
002> ResNet-BatchNorm2d                    ||                [64]    [1, 64, 112, 112]             128
003> ResNet-ReLU                           ||                   -    [1, 64, 112, 112]               0
004> ResNet-MaxPool2d                      ||                   -      [1, 64, 56, 56]               0
005> ResNet-Layer1-1-Conv2d                ||      [64, 64, 1, 1]      [1, 64, 56, 56]           4,096
006> ResNet-Layer1-1-BatchNorm2d           ||                [64]      [

In [16]:
official101 = models.resnet101()
summary101 = summary(official101, torch.ones(1, 3, 224, 224))

-------------------------------------------------------------------------------------------------------
Layer(type)                                ||        Kernel Shape         Output Shape         Param #
ResNet Inputs                              ||                   -     [1, 3, 224, 224]               -
                                           ||                                                         
001> ResNet-Conv2d                         ||       [3, 64, 7, 7]    [1, 64, 112, 112]           9,408
002> ResNet-BatchNorm2d                    ||                [64]    [1, 64, 112, 112]             128
003> ResNet-ReLU                           ||                   -    [1, 64, 112, 112]               0
004> ResNet-MaxPool2d                      ||                   -      [1, 64, 56, 56]               0
005> ResNet-Layer1-1-Conv2d                ||      [64, 64, 1, 1]      [1, 64, 56, 56]           4,096
006> ResNet-Layer1-1-BatchNorm2d           ||                [64]      [