In [90]:
import os, sys
import torch
import torchvision
import torch.nn as nn
import numpy as np

### Setting Residual block

In [98]:
# Reference: https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/deep_residual_network/main.py
class Basic_conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwarg):
        super(Basic_conv2d, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwarg)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        
        return out
class ResidualV1_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualV1_Block, self).__init__()
        
        self.conv_1x1 = Basic_conv2d(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=(1,1),
                                     stride=(1,1),
                                     padding=(0,0))
        self.conv_3x3 = Basic_conv2d(in_channels=out_channels,
                                     out_channels=out_channels,
                                     kernel_size=(3,3), 
                                     stride=(1,1), 
                                     padding=(1,1))
        self.conv_1x1_2 = Basic_conv2d(in_channels=out_channels, 
                                       out_channels=out_channels,
                                       kernel_size=(1,1), 
                                       stride=(1,1),
                                       padding=(0,0))
        self.relu = nn.ReLU(inplace=True)
        self.conv_match = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), stride=(1,1), padding=(0,0)),
            nn.BatchNorm2d(num_features=out_channels))
        
    def forward(self, x):
        x_shortcut = x
        out = self.conv_1x1(x)
        out = self.conv_3x3(out)
        out = self.conv_1x1_2(out)
        # make sure residual channel and out channel is the same
        if out.size()[1] != x_shortcut.size()[1]:
            x_shortcut = self.conv_match(x_shortcut)
        
        out += x_shortcut
        out = self.relu(out)
        
        return out

In [83]:
class ResidualV2_Block(nn.Module):
    def __init__(self, in_channels, com_channels, out_channels, strides=(1,1)):
        super(ResidualV2_Block, self).__init__()
        
        self.strides = strides
        
        if self.strides != (1,1):
            self.downsample_conv = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, (1,1) ,self.strides)
            )
            
        self.conv_1x1_com = Basic_conv2d(in_channels,
                                     com_channels,
                                     kernel_size=(1,1),
                                     stride=(1,1),
                                     padding=(0,0))
        
        self.conv_3x3 = Basic_conv2d(com_channels,
                                     com_channels, 
                                     kernel_size=(3,3),
                                     stride=self.strides,
                                     padding=(1,1))
        
        self.conv_1x1_exp = nn.Conv2d(com_channels,
                                      out_channels,
                                      kernel_size=(1,1),
                                      stride=(1,1),
                                      padding=(0,0))
        
    def forward(self, x):
        x_shortcut = x
        out = self.conv_1x1_com(x)
        out = self.conv_3x3(out)
        out = self.conv_1x1_exp(out)
        # make sure residual channel and out channel is the same
        if out.size()[1] != x_shortcut.size()[1] or self.strides != (1,1): #dealing with size or channel incnsistency
            x_shortcut = self.downsample_conv(x_shortcut)
        
        out += x_shortcut
        
        return out

### Testing on Mnist

In [87]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [96]:
### Residual V1 ###
model = ResidualV1_Block(in_channels=1, out_channels=64)
model

ResidualV1_Block(
  (conv_1x1): Basic_conv2d(
    (conv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (conv_3x3): Basic_conv2d(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (conv_1x1_2): Basic_conv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (relu): ReLU(inplace)
  (conv_match): Sequential(
    (0): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [97]:
out = model(x)
out.size()

tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         ...,

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 

torch.Size([64, 64, 28, 28])

In [84]:
### Residual V2 ###
model_v2 = ResidualV2_Block(in_channels=64, com_channels=32, out_channels=64, strides=(2,2))
model_v2

ResidualV2_Block(
  (downsample_conv): Sequential(
    (0): Conv1d(64, 64, kernel_size=(1, 1), stride=(2, 2))
  )
  (conv_1x1_com): Basic_conv2d(
    (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (conv_3x3): Basic_conv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (conv_1x1_exp): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
)

In [85]:
out_v2 = model_v2(out)
out_v2.size()

torch.Size([64, 64, 14, 14])

### Reference:
1. [Pytorch Tutorial: ResNet](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/deep_residual_network/main.py)

2. [ResNet_V1:Paper](https://arxiv.org/pdf/1512.03385.pdf)

3. [ResNet_V2](https://zhuanlan.zhihu.com/p/28413039)
4. [ResNet_V2:Paper](https://arxiv.org/pdf/1603.05027.pdf)