![Deeplab_overall](../images/Deeplab_overall.png)

https://github.com/f1recracker/pytorch-deeplab-v3-plus/tree/07141599f1a7466f484a8aa47593da59aad9bb4f/model

## Backbone(Exception)

![Exception_Architecture](../images/Exception_Architecture.png)

In [None]:
class Xception(BackboneModule):
    ''' Xception feature extractor backbone '''

    class EntryFlowBlock(nn_ext.SkipBlock):
        ''' Xception entry flow block '''

        def __init__(self, in_channels, out_channels, dilation=1, atrous=False):
            # If atrous mode increase dilation, padding and reset stride
            out_dilation = 2 * dilation if atrous else 1 * dilation
            last_padding = 2 * dilation if atrous else 1 * dilation
            stride = 1 if atrous else 2
            super().__init__(
                main_path=nn.Sequential(
                    nn_ext.SeparableConv2dLayer(in_channels, out_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation),
                    nn_ext.SeparableConv2dLayer(out_channels, out_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation),
                    nn_ext.SeparableConv2dLayer(out_channels, out_channels, kernel_size=3,
                                                dilation=out_dilation, stride=stride,
                                                padding=last_padding)),
                skip_path=nn_ext.Conv2dLayer(in_channels, out_channels, kernel_size=1,
                                             dilation=out_dilation, stride=stride),
                aggregator=torch.add)

    class MiddleFlowBlock(nn_ext.SkipBlock):
        ''' Xception middle flow block '''

        def __init__(self, in_channels=728, out_channels=728, dilation=1):
            super().__init__(
                main_path=nn.Sequential(
                    nn_ext.SeparableConv2dLayer(in_channels, out_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation),
                    nn_ext.SeparableConv2dLayer(out_channels, out_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation),
                    nn_ext.SeparableConv2dLayer(out_channels, out_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation)),
                aggregator=torch.add)

    class ExitFlowBlock(nn_ext.SkipBlock):
        ''' Xception exit flow block '''

        def __init__(self, in_channels=728, out_channels=1024, dilation=1, atrous=False):
            # If atrous mode increase dilation, padding and reset stride
            out_dilation = 2 * dilation if atrous else 1 * dilation
            last_padding = 2 * dilation if atrous else 1 * dilation
            stride = 1 if atrous else 2
            super().__init__(
                main_path=nn.Sequential(
                    nn_ext.SeparableConv2dLayer(in_channels, in_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation),
                    nn_ext.SeparableConv2dLayer(in_channels, out_channels, kernel_size=3,
                                                dilation=dilation, padding=1 * dilation),
                    nn_ext.SeparableConv2dLayer(out_channels, out_channels, kernel_size=3,
                                                dilation=out_dilation, stride=stride,
                                                padding=last_padding)),
                skip_path=nn_ext.Conv2dLayer(in_channels, out_channels, kernel_size=1,
                                             dilation=out_dilation, stride=stride),
                aggregator=torch.add)

    def __init__(self, output_stride=16):

        if output_stride not in {8, 16}:
            raise ValueError('Invalid output_stride; Supported values: {8, 16}')
        low_out_channels = 128 if output_stride == 16 else 64
        super().__init__(output_stride=16, out_channels=2048, low_out_channels=low_out_channels)

        # Adjust dilation rates to control output_stride
        opts = defaultdict(lambda: {})
        if output_stride == 16:
            opts['exit_flow_block_0'] = {'atrous': True}
            opts['exit_flow_block_1'] = {'dilation':2, 'padding': 2}
        elif output_stride == 8:
            opts['entry_flow_block_3'] = {'atrous': True}
            opts['middle_flow_block'] = {'dilation': 2}
            opts['exit_flow_block_0'] = {'dilation': 2, 'atrous': True}
            opts['exit_flow_block_1'] = {'dilation': 4, 'padding': 4}
        else:
            raise ValueError('Invalid output_stride; Supported values: {8, 16}')

        entry_flow = nn.Sequential(OrderedDict([
            ('block_0', nn.Sequential(
                nn_ext.Conv2dLayer(3, 32, kernel_size=3, stride=2, padding=1),
                nn_ext.Conv2dLayer(32, 64, kernel_size=3, padding=1))),
            ('block_1', Xception.EntryFlowBlock(64, 128, **opts['entry_flow_block_1'])),
            ('block_2', Xception.EntryFlowBlock(128, 256, **opts['entry_flow_block_2'])),
            ('block_3', Xception.EntryFlowBlock(256, 728, **opts['entry_flow_block_3'])),
        ]))

        # [DeepLabV3+ specific] Split-entry flow sequence to extract
        # low-level features needed by decoder (at output_stride // 4)
        if output_stride == 16:
            self.entry_flow = nn.ModuleList([entry_flow[0:2], entry_flow[2:4]])
        elif output_stride == 8:
            self.entry_flow = nn.ModuleList([entry_flow[0:1], entry_flow[1:4]])
        else:
            raise ValueError('Invalid output_stride; Supported values: {8, 16}')

        self.middle_flow = nn.Sequential(OrderedDict([
            ('block_0', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_1', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_2', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_3', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_4', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_5', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_6', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_7', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_8', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_9', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_10', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_11', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_12', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_13', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_14', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
            ('block_15', Xception.MiddleFlowBlock(**opts['middle_flow_block'])),
        ]))

        self.exit_flow = nn.Sequential(OrderedDict([
            ('block_0', Xception.ExitFlowBlock(**opts['exit_flow_block_0'])),
            ('block_1', nn.Sequential(
                nn_ext.SeparableConv2dLayer(1024, 1536, kernel_size=3, **opts['exit_flow_block_1']),
                nn_ext.SeparableConv2dLayer(1536, 1536, kernel_size=3, **opts['exit_flow_block_1']),
                nn_ext.SeparableConv2dLayer(1536, 2048, kernel_size=3, **opts['exit_flow_block_1'])))
        ]))

    def forward(self, x):
        # Conv 32, 3x3, stride 2 -> Conv 64, 3x3
        x_low = self.entry_flow[0](x)
        # SeparableConv2dLayer 128 -> 256 -> 728(with Skip Connection)
        x = self.entry_flow[1](x_low)
        # SeparableConv2dLayer Repeat 16 times(with Skip Connection)
        x = self.middle_flow(x)
        # SeparableConv2dLayer 1024(with Skip Connection) -> SeparableConv2dLayer 1536
        # -> SeparableConv2dLayer 1536 -> SeparableConv2dLayer 2048
        x = self.exit_flow(x)
        return x, x_low

## Deeplab V3+

![ASPP](../images/ASPP.png)

In [None]:
class DeepLab(nn.Module):
    ''' DeepLab V3+ module '''

    class ASPP(nn.Module):
        ''' Atrous spatial pyramid pooling module '''

        def __init__(self, in_channels, output_stride=16):
            super().__init__()

            if output_stride not in {8, 16}:
                raise ValueError('Invalid output_stride; Supported values: {8, 16}')
            dilation_factor = 1 if output_stride == 16 else 2

            self.aspp = nn.ModuleList([
                nn_ext.Conv2dLayer(in_channels, 256, kernel_size=1, dilation=1),
                nn_ext.Conv2dLayer(in_channels, 256, kernel_size=3,
                                   dilation=6 * dilation_factor, padding=6 * dilation_factor),
                nn_ext.Conv2dLayer(in_channels, 256, kernel_size=3,
                                   dilation=12 * dilation_factor, padding=12 * dilation_factor),
                nn_ext.Conv2dLayer(in_channels, 256, kernel_size=3,
                                   dilation=18 * dilation_factor, padding=18 * dilation_factor)])

            self.global_avg_pool = nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                nn.Conv2d(in_channels, 256, kernel_size=1),
                nn.ReLU(inplace=True))

            self.output_conv = nn_ext.Conv2dLayer(256 * 4 + 256, 256, kernel_size=1)

        def forward(self, x):
            x_aspp = (aspp(x) for aspp in self.aspp)
            x_pool = self.global_avg_pool(x)
            x_pool = nn_func.interpolate(x_pool, size=x.shape[2:4])
            feats = torch.cat((*x_aspp, x_pool), dim=1)
            feats = self.output_conv(feats)
            return feats

    class Decoder(nn.Module):
        ''' DeepLab V3+ decoder module '''

        def __init__(self, low_in_channels, num_classes):
            super().__init__()

            self.conv_low = nn_ext.Conv2dLayer(low_in_channels, 48, kernel_size=1)
            self.conv_logit = nn.Conv2d(48 + 256, num_classes, kernel_size=3, padding=1)

        def forward(self, feats, low_feats):
            low_feats = self.conv_low(low_feats)
            feats = nn_func.interpolate(feats, size=low_feats.shape[2:4],
                                        mode='bilinear', align_corners=True)
            feats = torch.cat((feats, low_feats), dim=1)
            logits = self.conv_logit(feats)
            return logits

    # -- Deeplab part
    def __init__(self, backbone, num_classes):
        super().__init__()
        if not isinstance(backbone, BackboneModule):
            raise RuntimeError('Backbone must extend model.backbone.BackboneModue')

        self.backbone = backbone
        self.aspp = DeepLab.ASPP(in_channels=backbone.out_channels,
                                 output_stride=backbone.output_stride)
        self.decoder = DeepLab.Decoder(low_in_channels=backbone.low_out_channels,
                                       num_classes=num_classes)

        self._init_weights()

    def forward(self, x_in):
        x, x_low = self.backbone(x_in)
        x = self.aspp(x)
        logits = self.decoder(x, x_low)
        logits = nn_func.interpolate(logits, size=x_in.shape[2:4],
                                     mode='bilinear', align_corners=True)
        return logits

    def _init_weights(self):
        ''' Initializes weights of the model.
            - Conv2d parameters initialized using Kaiming normal
            - Batchnorm affine parameters initialized as Identity
        '''
        for module in self.modules():
            if isinstance(module, torch.nn.modules.Conv2d):
                torch.nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
            elif isinstance(module, torch.nn.modules.BatchNorm2d):
                torch.nn.init.constant_(module.weight, 1.0)
                torch.nn.init.constant_(module.bias, 0.0)

## Ext.

In [None]:
class SeparableConv2d(nn.Module):
    ''' Depthwise separable 2D convolution '''

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, bias=True,
                 _depthwise_conv=nn.Conv2d, _pointwise_conv=nn.Conv2d):
        super().__init__()
        self.depthwise_conv = _depthwise_conv(
            in_channels, in_channels, kernel_size=kernel_size,
            groups=in_channels, stride=stride, padding=padding,
            dilation=dilation, bias=bias)
        self.pointwise_conv = _pointwise_conv(
            in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        return x


class Conv2dLayer(nn.Module):
    ''' 2D convolution with batch norm and ReLU '''

    def __init__(self, in_channels, out_channels, kernel_size,
                 groups=1, stride=1, padding=0, dilation=1, bias=True,
                 batchnorm_opts={'eps': 1e-3, 'momentum': 3e-4}):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                      groups=groups, stride=stride, padding=padding,
                      dilation=dilation, bias=bias),
            nn.BatchNorm2d(out_channels, **batchnorm_opts),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)


class SeparableConv2dLayer(SeparableConv2d):
    ''' Depthwise separable 2D convolution with batchnorm and ReLU '''

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size=kernel_size,
                         stride=stride, padding=padding, dilation=dilation, bias=bias,
                         _depthwise_conv=Conv2dLayer, _pointwise_conv=Conv2dLayer)


class SkipBlock(nn.Module):
    ''' Container for modules with a skip connection, followed by an aggregator
        (default aggregator: torch.cat) '''

    def __init__(self, main_path, skip_path=Identity(), aggregator=torch.cat):
        super().__init__()
        self.main_path = main_path
        self.skip_path = skip_path
        self.aggregator = aggregator

    def forward(self, x):
        return self.aggregator(self.main_path(x), self.skip_path(x))