# 1.模型架构

![img](./pic/deeplabv3plus.png)

- Encoder-Decoder架构
- FCN结构，输入和输入尺寸相同
- Encoder部分包含Backbone、ASPP两个模块
- Decoder部分结合Low-level Feature (Backbone中间输出) 与 High-Level Feature (ASPP后输出)

In [52]:
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.models import resnet
from torchvision.models.resnet import BasicBlock, Bottleneck

## 1.1 Backbone
- Backbone部分常用的为MobileNet、ResNet、Xception
- 这里仅实现ResNet50作为示例
- 便于看出区别，这里直接继承torch库中的标准ResNet类，重载前向函数,多输出一个low-level feature
- 原始Resnet下采样为32倍，通过使用replace_stride_with_dilation,将下采样替换为碰撞卷积，进而下采样倍数变为8或16

In [53]:
class Resnet_astrous(resnet.ResNet):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual = False, groups= 1, width_per_group = 64, replace_stride_with_dilation= None, norm_layer= None) -> None:
        super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
    def forward(self, x):
        x = self.conv1(x)  # 1/2
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)  # 1/4
        x = self.layer1(x)
        low_level_features = x

        x = self.layer2(x)  # 1/8
        x = self.layer3(x)  # 1/16
        x = self.layer4(x)  # 1/16

        return low_level_features,x

def load_resnet50(downsamperFactor=16):
    if downsamperFactor==16:
        return Resnet_astrous(block=resnet.Bottleneck,layers=[3, 4, 6, 3],replace_stride_with_dilation=[False,False,True])
    elif downsamperFactor==8:
        return Resnet_astrous(block=resnet.Bottleneck,layers=[3, 4, 6, 3],replace_stride_with_dilation=[False,True,True])
    

## 1.2 ASPP模块
![img](./pic/ASPP.png)
- 利用不同碰撞率的卷积进行特征提取，实现多尺度特征提取

In [54]:
class ASPP(nn.Module):
	def __init__(self, dim_in, dim_out, rate=1):
		super(ASPP, self).__init__()
		self.branch1 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
				nn.BatchNorm2d(dim_out),
				nn.ReLU(inplace=True),
		)
		self.branch2 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
				nn.BatchNorm2d(dim_out),
				nn.ReLU(inplace=True),	
		)
		self.branch3 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
				nn.BatchNorm2d(dim_out),
				nn.ReLU(inplace=True),	
		)
		self.branch4 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
				nn.BatchNorm2d(dim_out),
				nn.ReLU(inplace=True),	
		)
		self.branch5=nn.Sequential(
				nn.AdaptiveAvgPool2d(1),
				nn.Conv2d(dim_in, dim_out, 1, bias=False),
				nn.BatchNorm2d(dim_out),
				nn.ReLU()
		)
		

		self.conv_cat = nn.Sequential(
				nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
				nn.BatchNorm2d(dim_out),
				nn.ReLU(inplace=True),		
		)

	def forward(self, x):
		[b, c, row, col] = x.size()
        #-----------------------------------------#
        #   一共五个分支
        #-----------------------------------------#
		conv1x1 = self.branch1(x)
		conv3x3_1 = self.branch2(x)
		conv3x3_2 = self.branch3(x)
		conv3x3_3 = self.branch4(x)
        #-----------------------------------------#
        #   第五个分支，全局平均池化+卷积
        #-----------------------------------------#
		global_feature = self.branch5(x)
		global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
		
        #-----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        #-----------------------------------------#
		feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
		result = self.conv_cat(feature_cat)
		return result

## 1.3 Encoder-Decoder
- high-level Feture通过双线性插值上采样与low-level Feature形状对齐
- 经过卷积层与分类头获得类别概率
- 最后双线性插值实现上采样
- 双线性插值可替换为转置卷积实现更高精度的边界分割

In [55]:
class DeepLab(nn.Module):
    def __init__(self,num_classes, downsample_factor=16):
        super(DeepLab, self).__init__()
        self.num_classes = num_classes
        #-----------------------------------------#
        #   BackBone模块
        #-----------------------------------------#
        self.backbone = load_resnet50(downsamperFactor=downsample_factor)
        in_channels = 2048
        low_level_channels = 256
        #-----------------------------------------#
        #   ASPP特征提取模块
        #-----------------------------------------#
        self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
        #----------------------------------#
        #   浅层特征边
        #----------------------------------#
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )		

        self.cat_conv = nn.Sequential(
            nn.Conv2d(48+256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1)
        )
        self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

    def forward(self, x):
        H, W = x.size(2), x.size(3)
        #-----------------------------------------#
        #   获得两个特征层
        #   low_level_features: 浅层特征-进行卷积处理
        #   x : 主干部分-利用ASPP结构进行加强特征提取
        #-----------------------------------------#
        low_level_features, x = self.backbone(x)
        x = self.aspp(x)
        low_level_features = self.shortcut_conv(low_level_features)
        #-----------------------------------------#
        #   将加强特征边上采样
        #   与浅层特征堆叠后利用卷积进行特征提取
        #-----------------------------------------#
        x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
        x = self.cls_conv(x)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x

In [56]:
# 类别数为20
model=DeepLab(num_classes=20)
model.eval()
tensor=torch.randn((1,3,512,512))
print(model(tensor).shape)
# 输出形状为各像素的类别概率，num_classes*H*W

torch.Size([1, 20, 512, 512])
