# [03] COVID Image Classification: MobileNet 구조 만들기

본 실습에서는 이미지 분류에서 경량화 모델로 널리 알려진 MobileNet V2를 구현해보겠습니다.

논문의 링크는 이곳입니다. (https://arxiv.org/pdf/1801.04381.pdf)

구조에 대해서 간단하게 짚어보겠습니다. 아래의 그림들은 모두 위의 논문에서 발췌하였습니다.

![MobileNetV2](./imgs/mobileNetV2.png)

경량화 신경망들의 구조를 뽑아본 것입니다. 여기서 가장 눈여겨 볼 것은 MobileNet과 MobileNetV2입니다.

두 신경망의 가장 큰 차이는 `bottleneck` 구조입니다. `Inverted Residual Block`으로 보통 부릅니다.

아래는 `MobileNet`에 대한 설명입니다.

![overview_mobilenet](./imgs/overview_mobilenet.png)

반면, MobileNetV2는 기존의 구조에 bottleneck 구조를 더한 아래와 같은 형태입니다.

- `1x1` => `3x3` => `1x1`을 이어서 block을 형성한 것입니다.

![bottleneck](./imgs/bottleneck.png)

이것을 기반으로 MobileNetV2는 아래처럼 구조를 만들었습니다.

![overview](./imgs/overview_mobilenetv2.png)

아래의 코드는 파이토치의 official 코드를 조금 고친 것입니다.

In [1]:
import torch
from torch import nn
from torchvision.models.utils import load_state_dict_from_url

In [3]:
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, bn_aff = True):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            ?????, # Convolutional Layer
            ?????, # Batch Normalization
            ?????, # ReLU activation function
        )

## InvertedResidual Block Remind 요약
- 기존의 Convolution filter를 두 가지로 분리해서 붙이는 것을 고안한 방법입니다..
    1. Depthwise convolution: 3x3 convolution이나 channel마다 계산을 하는 (평면적으로) 함수.
    2. Pointwise convolution: 1x1의 original convolution의 형태. 서로 다른 channel들의 값을 고려하는 함수.
- 위의 내용에 더해 expansion - convolution - squeeze의 형태로 block이 구성됩니다. (Inverted Residual Block)
- 쉽게 이야기하면 아래와 같습니다.
    1. 처음 들어온 input의 channel을 확장시키는 pointwise convolution을 통과함.
    2. channel의 수가 늘어난 input을 depthwise convolution을 통과함.
    3. 다시 channel 수를 줄이는 (squeeze) pointwise convolution을 통과함.

In [4]:
class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, shortcut, bn_aff):
        super(InvertedResidual, self).__init__()
        self.shortcut = shortcut
        self.bn_aff = bn_aff
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(???????)
        layers.extend([
            # dw
            ?????????,
            # pw-linear
            ???????????,
            ???????????,
        ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            if self.shortcut:
                return ?????????
            else:
                return self.conv(x)
        else:
            return self.conv(x)

- 아래는 inverted residual block으로 구성한 network의 구조

In [None]:
model_urls = {
    'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


class MobileNetV2(nn.Module):
    def __init__(self,
                 num_classes=1000,
                 width_mult=1.0,
                 inverted_residual_setting=None,
                 round_nearest=8,
                 block=None,
                 shortcut=True,
                 bn_aff=True):
        """
        MobileNet V2 main class

        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
            block: Module specifying inverted residual building block for mobilenet

        """
        super(MobileNetV2, self).__init__()

        if block is None:
            block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        self.bn_aff = bn_aff
        self.shortcut = shortcut

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                ?????????,
                ?????????,
                ?????????,
                ?????????,
                ?????????,
                ?????????,
                ?????????,
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(3, input_channel, stride=2, bn_aff = self.bn_aff)]
        
        
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, shortcut = self.shortcut, bn_aff = self.bn_aff))
                input_channel = output_channel
                
                
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, bn_aff = self.bn_aff))
        
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x):
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
        x = self.features(x)
        # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
        x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
        x = self.classifier(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)

In [5]:
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model