# Deeper Model

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

%load_ext autoreload
%autoreload 2
import utils

## Residual Block

In [28]:
class ResidualBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3, stride = stride, padding = 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3, stride = 1, padding = 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size = 1, stride = stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, X):
        out = F.relu(self.bn1(self.conv1(X)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(X)
        out = F.relu(out)
        return out


class PreResidualBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3, stride = stride, padding = 1, bias=False) # BatchNormでバイアスを学習しているのでこちらでは不要
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3, stride = 1, padding = 1, bias=False) # BatchNormでバイアスを学習しているのでこちらでは不要
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size = 1, stride = stride, bias=False) # ここではなるべく恒等関数にしたい(入力時のものを後ろのレイヤーに持っていきたい)のでバイアスは不要
            )

    def forward(self, X):
        out = self.conv1(F.relu(self.bn1(X)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = F.relu(out)
        out += self.shortcut(X)
        return out
        

In [11]:
resblock = ResidualBlock(3, 64, stride = 2)

In [12]:
resblock

ResidualBlock(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (shortcut): Sequential(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [13]:
X = torch.randn(1, 3, 32, 32)
output = resblock(X)

In [14]:
output.shape

torch.Size([1, 64, 16, 16])

In [24]:
preresblock = PreResidualBlock(3, 64, stride = 2)

In [25]:
preresblock

PreResidualBlock(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (shortcut): Sequential(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
  )
)

In [26]:
output = preresblock(X)
output.shape

torch.Size([1, 64, 16, 16])

### Point-Wise-ConvolutionとBottleneck構造

In [43]:
class Bottleneck_ResidualBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride = 1, expansion_factor = 4):
        super().__init__()
        self.pw1 = nn.Conv2d(in_ch, out_ch, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv1 = nn.Conv2d(out_ch, out_ch, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.pw2 = nn.Conv2d(out_ch, out_ch*expansion_factor, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn3 = nn.BatchNorm2d(out_ch*expansion_factor)
        self.shortcut = nn.Sequential()

        if in_ch != out_ch*expansion_factor or stride != 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch*expansion_factor, kernel_size = 1, stride = stride, padding = 0, bias = False)
                ,nn.BatchNorm2d(out_ch*expansion_factor)
            )

    def forward(self, X):
        out = F.relu(self.bn1(self.pw1(X)))
        out = F.relu(self.bn2(self.conv1(out)))
        out = self.pw2(out)
        out += self.shortcut(X)
        out = F.relu(out)
        return out

In [44]:
X = torch.randn(1, 256, 28, 28)
bottleneck_resblock = Bottleneck_ResidualBlock(in_ch = 256, out_ch = 64, stride = 2, expansion_factor = 4)
output = bottleneck_resblock(X)
output.shape

torch.Size([1, 256, 14, 14])

### Inceptionモジュール

In [73]:
class Inception(nn.Module):

    def __init__(self, in_ch, out_ch1, out_ch3, out_ch5, out_ch_pool):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch1, kernel_size = 1, stride = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_ch1)

        self.pw1 = nn.Conv2d(in_ch, out_ch3, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.conv2 = nn.Conv2d(out_ch3, out_ch3, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_ch3)

        self.pw2 = nn.Conv2d(in_ch, out_ch5, kernel_size = 5, stride = 1, padding = 2, bias = False)
        self.conv3 = nn.Conv2d(out_ch5, out_ch5, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(out_ch5)

        self.max_pooling = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)
        self.pw3 = nn.Conv2d(in_ch, out_ch_pool, kernel_size = 1, stride = 1, bias = False)
        self.bn4 = nn.BatchNorm2d(out_ch5)

    def forward(self, X):
        out1 = F.relu(self.bn1(self.conv1(X)))
        out2 = F.relu(self.bn2(self.conv2(self.pw1(X))))
        out3 = F.relu(self.bn3(self.conv3(self.pw2(X))))
        out4 = F.relu(self.bn4(self.pw3(self.max_pooling(X))))
        out = torch.concat((out1,out2,out3,out4), dim = 1)
        return out

In [74]:
X = torch.randn(1, 192, 28, 28)

In [75]:
inception = Inception(in_ch = 192, out_ch1 = 64, out_ch3 =  128, out_ch5 =  32, out_ch_pool = 32)

In [76]:
output = inception(X)
output.shape

torch.Size([1, 256, 28, 28])

In [77]:
class InceptionModule(nn.Module):

    def __init__(self, in_ch, out_ch1, out_ch3, out_ch5, out_ch_pool):
        super().__init__()

        # 1x1
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch1, kernel_size=1),
            nn.BatchNorm2d(out_ch1),
            nn.ReLU(),
        )

        # 3x3
        self.branch2 = nn.Sequential(
            # point-wise
            nn.Conv2d(in_ch, out_ch3, kernel_size=1),
            nn.BatchNorm2d(out_ch3),
            nn.ReLU(),
            # 3x3 conv
            nn.Conv2d(out_ch3, out_ch3, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch3),
            nn.ReLU(),
        )

        # 5x5
        self.branch3 = nn.Sequential(
            # point-wise
            nn.Conv2d(in_ch, out_ch5, kernel_size=1),
            nn.BatchNorm2d(out_ch5),
            nn.ReLU(),
            # 5x5 conv
            nn.Conv2d(out_ch5, out_ch5, kernel_size=5, padding=2),
            nn.BatchNorm2d(out_ch5),
            nn.ReLU(),
        )

        # pooling
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_ch, out_ch_pool, kernel_size=1),
            nn.BatchNorm2d(out_ch_pool),
            nn.ReLU(),
        )
        
    def forward(self, X):
        branch1 = self.branch1(X)
        branch2 = self.branch2(X)
        branch3 = self.branch3(X)
        branch4 = self.branch4(X)

        return torch.cat([branch1, branch2, branch3, branch4], dim=1)

In [78]:
X = torch.randn(16, 192, 28, 28)
module = InceptionModule(192, 64, 128, 32, 32)
out = module(X)
print(out.shape)

torch.Size([16, 256, 28, 28])
