# DeepLabV3 与 V3+


语义分割中，应用DCNN有两个挑战：
1. 连续池化或卷积下采样时，带来的分辨率下降而导致的位置信息丢失问题。
2. 实例对象多尺度的问题

## DeepLabV3

- DeepLabV3 中引入空间金字塔池化 `ASPP` 解决了多尺度问题。其主要贡献如下：
    - 回顾了空洞卷积，在级联模块和金字塔池化框架下也能扩大感受野提取多尺度信息
    - 改进了 `ASPP`：由 **不同的采样率** 的空洞卷积 和 BN层组成，以级联或并连的方式布局（论文中最终选取了级联）
    - 大采样率的 $3 \times 3$ 空洞卷积由于图像边界效应，无法捕获长城信息，将退化为 $1 \times 1$ 卷积，建议将图像融入 ASPP 中
    
    <center>
    
    ![](https://bbsmax.ikafan.com/static/L3Byb3h5L2h0dHBzL2ltZzIwMTguY25ibG9ncy5jb20vYmxvZy8xNTE5NTc4LzIwMTkwNS8xNTE5NTc4LTIwMTkwNTE5MTUzODU3ODc3LTExNDk4NjkxNjgucG5n.jpg)
    
    </center>

## DeepLabV3+

DeepLabV3+ 在 V3 的基础上，主要针对下载样导致的位置信息丢失问题，该问题有两种解决方案：
- 使用空洞卷积替代更多的 `pooling` 层来获取更高分辨率的高级特征，然而这意味着极大的运算量。
- 编解码结构（类似于 SegNet、U-Net等）
       
DeepLabV3+ **在encoder-decoder结构上采用SPP模块**。encoder提取丰富的语义信息，decoder恢复精细的物体边缘。encoder允许在任意分辨率下采用空洞卷积。  

因此，最终DeepLabV3+ 的主要贡献（创新）是：

- 提出一个 `Encoder-Decoder` 结构，其中包含 DeepLabV3 作为 Encoder 和高效的 Decoder 模块；
- encoder-decoder 结构中可以通过空洞卷积来平衡精度和运行时间，现有的encoder-decoder结构是不可行的。
- 在语义分割任务中采用 Xception 模型并采用 depthwise separable convolution，从而更快更有效。
- deeplabV3 中 16x 上采样，V3+ 中先 4X 上采样，然后和尺度相同的低级特征拼接。（低级特征用 $1 \times 1$ 卷积降维，因为高级特征中有更加丰富的信息）

&emsp;&emsp;举个例子，如果采用resnet conv2 输出的feature，则这里要* 4上采样。将两种feature连接后，再进行一次3 * 3的卷积（细化作用），然后再次上采样就得到了像素级的预测。后面的实验结果表明这种结构在 stride=16 时既有很高的精度速度又很快。stride=8相对来说只获得了一点点精度的提升，但增加了很多的计算量。

<center>

![](https://bbsmax.ikafan.com/static/L3Byb3h5L2h0dHBzL2ltZzIwMTguY25ibG9ncy5jb20vYmxvZy8xNTE5NTc4LzIwMTkwNS8xNTE5NTc4LTIwMTkwNTE5MTUzODU3NDk2LTEzMjYzOTcyNjMucG5n.jpg)

![](https://bbsmax.ikafan.com/static/L3Byb3h5L2h0dHBzL2ltZzIwMTguY25ibG9ncy5jb20vYmxvZy8xNTE5NTc4LzIwMTkwNS8xNTE5NTc4LTIwMTkwNTE5MTUzODU2OTQ5LTExMTM0NzgwODgucG5n.jpg)

图：deepLabV3+
</center>

- Xception模型用于图像分类任务，Aligned Xception用于物体检测任务，我们对Xception做了一些变化使其可用于语义分割任务。
    1. 更多的层，为了计算量和内存，不对Entry flow网络结构进行修改。
    2. 所有池化层替换为\(depthwise\ separable\ conv\)，以便采用 \(atrous\ separable\ conv\)提取任意分辨率的特征。
    3. 类似于MobileNet，在每个\(3\times 3\)后添加额外的BN和ReLU。

<center>

![](https://bbsmax.ikafan.com/static/L3Byb3h5L2h0dHBzL2ltZzIwMTguY25ibG9ncy5jb20vYmxvZy8xNTE5NTc4LzIwMTkwNS8xNTE5NTc4LTIwMTkwNTE5MTUzODU2MzgyLTE4MzA5MTEzNjgucG5n.jpg)

改进的 Aligned Xception

</center>
    

# DeepLabV3+ 的实现

In [1]:
import sys,os
os.chdir('../../../ImageSegmentation_Review/')

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
from backbone import build_backbone

## ASPP模块

In [11]:
class _ASPPModule(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, padding, dilation, BatchNorm):
        super(_ASPPModule, self).__init__()
        
        self.atrous_conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
        self._init_weight()
    
    def forward(self, x):
        x = self.atrous_conv(x)
        out = self.relu(self.bn(x))
        return out
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, backbone, output_stride, BatchNorm):
        super(ASPP, self).__init__()
        if backbone == 'drn':
            inplanes = 512
        elif backbone == 'mobilenet':
            inplanes = 320
        else:
            inplanes = 2048
        
        if output_stride == 16:   # 高性能
            dilations = [1, 6, 12, 18]
        elif output_stride == 8: # 高精度
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError
    
        self.aspp1 = _ASPPModule(in_ch=inplanes, out_ch=256, kernel_size=1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
        
        self.global_img_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
                                                nn.Conv2d(inplanes, 256, kernel_size=1, stride=1, padding=0, bias = False),
                                                nn.BatchNorm2d(256),
                                                nn.ReLU())
        
        self.conv1_1 = nn.Conv2d(256*5, 256, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = BatchNorm(256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(.5)
        self._init_weight()
    
    def forward(self, x):
        o1 = self.aspp1(x)
        o2 = self.aspp2(x)
        o3 = self.aspp3(x)
        o4 = self.aspp4(x)
        o5 = self.global_img_avg_pool(x)
        o5 = F.interpolate(o5, size = o4.size()[2:], mode = 'bilinear', align_corners=True) # 下采样
        
        x = torch.cat((o1, o2, o3, o4, o5), dim=1)
        
        out = self.relu(self.bn1(self.conv1_1(x)))
        
        return self.dropout(out)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

## 封装模块
def build_aspp(backbone, output_stride, BatchNorm):
    return ASPP(backbone, output_stride, BatchNorm)

## 解码模块

In [12]:
class Decoder(nn.Module):
    def __init__(self, num_classes, backbone, BatchNorm):
        super(Decoder, self).__init__()
        
        if backbone == 'resnet' or backbone == 'drn':
            low_level_inplanes = 256
        elif backbone == 'xception':
            low_level_inplanes = 128
        elif backbone == 'mobilenet':
            low_level_inplanes = 24
        else:
            raise NotImplementedError
            
        self.conv1 = nn.Conv2d(low_level_inplanes, 48, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU()
        
        self.last_conv = nn.Sequential(
            nn.Conv2d(48+256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
        )
        
        self._init_weight()
    
    def forward(self, x, low_level_features):
        low_level_features = self.bn1(self.relu(self.conv1(low_level_features)))
        
        x = F.interpolate(x, low_level_features.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_features), dim=1)
        out = self.last_conv(x)
        
        return out
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

def build_decoder(num_classes, backbone, BatchNorm):
    return Decoder(num_classes, backbone, BatchNorm)

## DeepLabV3+ 组合实现

In [13]:
class DeepLabV3_Plus(nn.Module):
    def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
                 sync_bn=True, freeze_bn=False):
        super(DeepLabV3_Plus, self).__init__()
        if backbone == 'drn':
            output_stride = 8

        if sync_bn == True:
            BatchNorm = SynchronizedBatchNorm2d
        else:
            BatchNorm = nn.BatchNorm2d

        self.backbone = build_backbone(backbone, output_stride, BatchNorm)
        self.aspp = build_aspp(backbone, output_stride, BatchNorm)
        self.decoder = build_decoder(num_classes, backbone, BatchNorm)

        self.freeze_bn = freeze_bn

    def forward(self, input):
        x, low_level_feat = self.backbone(input)
        x = self.aspp(x)
        x = self.decoder(x, low_level_feat)
        x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p

    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if self.freeze_bn:
                    if isinstance(m[1], nn.Conv2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p
                else:
                    if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                            or isinstance(m[1], nn.BatchNorm2d):
                        for p in m[1].parameters():
                            if p.requires_grad:
                                yield p


In [14]:
model = DeepLabV3_Plus(backbone='mobilenet', output_stride=16)
model.eval()
input = torch.rand(1, 3, 513, 513)
output = model(input)
print(output.size())

torch.Size([1, 21, 513, 513])
