## Model Definition

In [23]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

In [24]:
def activation_function(activation):
    return nn.ModuleDict([
        ['relu', nn.ReLU(inplace=True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
        ['none', nn.Identity()]
    ])[activation]

def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs),
                        nn.BatchNorm2d(out_channels))

In [25]:
class Conv2dAuto(nn.Conv2d):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding = (self.kernel_size[0] // 2,
                        self.kernel_size[1] // 2)

In [27]:
conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)

In [28]:
class BackBone(nn.ModuleList):
    
    def __init__(self, in_channels, out_channels, activation):
        super().__init__()
        
        

In [29]:
class ResidualBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, activation='relu'):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.activation = activation_function(activation)
        self.blocks = nn.Identity()
        self.shortcut = nn.Identity()
        
    @property
    def is_shortcut(self):
        return self.in_channels != self.out_channels
    
    def forward(self, x):
        residual = self.shortcut(x) if self.is_shortcut else x
        out = self.blocks(x)
        out += residual
        out = self.activation(x)
        return out

In [45]:
class ResNetResidualBlock(ResidualBlock):
    
    def __init__(self, in_channels, out_channels, expansion=1, downsample=1, conv=conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.expansion, self.downsample, self.conv = expansion, downsample, conv
        self.shortcut = nn.Sequential(nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
                                               stride=self.downsample, bias=False),
                                     nn.BatchNorm2d(self.expanded_channels)) if self.is_shortcut else None
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion

    @property
    def is_shortcut(self):
        return self.in_channels != self.out_channels

In [34]:
class ResNetBasicBlock(ResNetResidualBlock):
    
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(conv_bn(self.in_channels, self.out_channels, conv=self.conv,
                                           bias=False, stride=self.downsample),
                                   self.activation,
                                   conv_bn(self.out_channels, self.expanded_channels, conv=self.conv,
                                          bias=False))

In [46]:
class ResNetBottleNeckBlock(ResNetResidualBlock):
    
    expansion = 4
    
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, expansion=4, *args, **kwargs)
        nn.Sequential(conv_bn(self.in_channels, self.out_channels, self.conv, kernel_size=1),
                     self.activation,
                     conv_bn(self.out_channels, self.out_channels, self.conv, kernel_size=3, stride=self.downsample),
                     self.activation,
                     conv_bn(self.out_channels, self.expanded_channels, self.conv, kernel_size=1))

In [47]:
dummy = torch.ones((1, 32, 10, 10))
block = ResNetBottleNeckBlock(32, 64)
block(dummy).shape
print(block)

RuntimeError: The size of tensor a (32) must match the size of tensor b (256) at non-singleton dimension 1

In [11]:
class BasicBlock(nn.ModuleList):
    
    def __init__(self, in_channels, out_channels, activation, downsample=1, expansion=1, conv=conv3x3, *args, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        self.activation = activation_function(activation)
        self.expansion = expansion
        self.conv = conv
        self.block = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, self.conv,
                   stride=self.downsample, bias=False),
            self.activation)
        self.shortcut = nn.Sequential(
            nn.Conv2d(self.in_channels, self.expanded_channels,
                     kernel_size=1, stride = self.downsample),
            nn.BatchNorm2d(self.expanded_channels))
        
    def forward(x):
        residual = self.shortcut(x) if self.is_shortcut else x
        out = self.blocks(x)
        out += residual
        out = self.activation(out)
        return out
    
    @property
    def is_shortcut(self):
        return self.in_channels != slef.out_channels
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion

In [12]:
BasicBlock(32, 64, 'relu')

BasicBlock(
  (blocks): Identity()
  (activation): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)