In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

# Modified MobileNetV2 for DeepLabV3+

In [2]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function makes sure that number of channels number is divisible by 8.
    Source: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBnReLU(nn.Module):
    """
    [CONV]-[BN]-[ReLU6]
    """

    def __init__(self, inCh, outCh, stride):
        super(ConvBnReLU, self).__init__()
        self.inCh = inCh  # Number of input channels
        self.outCh = outCh  # Number of output channels
        self.stride = stride  # Stride
        self.conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, 3, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class InvertedResidual(nn.Module):
    """
    [EXP:CONV_1x1-BN-ReLU6]-[DW:CONV_3x3-BN-ReLU6]-[PW:CONV_1x1-BN] with identity shortcut 
    and dilation.
    """

    def __init__(self, inCh, outCh, t, s, r):
        super(InvertedResidual, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.t = t  # t: expansion factor
        self.r = r  # r: dilation
        if self.r > 1:
            self.s = 1  # s: Stride
            self.padding = self.r  # Atrous Conv padding same as dilation rate
        else:
            self.s = s  # s: Stride
            self.padding = 1
        self.identity_shortcut = (self.inCh == self.outCh) and (self.s == 1)  # L:506 Keras official code

        # Bottleneck block
        self.block = nn.Sequential(
            # Expansition Conv
            nn.Conv2d(self.inCh, self.t * self.inCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Depthwise Conv
            nn.Conv2d(self.t * self.inCh, self.t * self.inCh, kernel_size=3, stride=self.s, padding=self.padding, 
                      dilation=self.r, groups=self.t * self.inCh, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Pointwise Linear Conv (Projection): i.e. No non-linearity
            nn.Conv2d(self.t * self.inCh, self.outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.outCh),
        )

    def forward(self, x):
        if self.identity_shortcut:
            return x + self.block(x)
        else:
            return self.block(x)


class PointwiseConv(nn.Module):
    def __init__(self, inCh, outCh):
        super(PointwiseConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inCh, outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


# MobileNetV2
class MobileNetV2(nn.Module):
    """
    MobileNetV2 feature extractor modified to include dilation for DeepLabV3+. 
    NOTE: Last conv Layer and classification layer removed.
    """

    def __init__(self, params):
        super(MobileNetV2, self).__init__()
        self.params = params
        self.first_inCh = 3

        self.c = [_make_divisible(c * self.params.alpha, 8) for c in self.params.c]
        
        # Layer-0
        self.layer0 = nn.Sequential(ConvBnReLU(self.first_inCh, self.c[0], self.params.s[0]))

        # Layer-1
        self.layer1 = self._make_layer(self.c[0], self.c[1], self.params.t[1], self.params.s[1], 
                                       self.params.n[1], self.params.r[1])

        # Layer-2: Image size: 512 -> [IRB-2] -> Output size: 128 (low level feature: 128 * 4 = 512)
        self.layer2 = self._make_layer(self.c[1], self.c[2], self.params.t[2], self.params.s[2], 
                                       self.params.n[2], self.params.r[2])

        # Layer-3
        self.layer3 = self._make_layer(self.c[2], self.c[3], self.params.t[3], self.params.s[3], 
                                       self.params.n[3], self.params.r[3])

        # Layer-4
        self.layer4 = self._make_layer(self.c[3], self.c[4], self.params.t[4], self.params.s[4], 
                                       self.params.n[4], self.params.r[4])

        # Layer-5: Image size: 512 -> [IRB-5] -> Output size: 32, so output stride = 16 achieved
        self.layer5 = self._make_layer(self.c[4], self.c[5], self.params.t[5], self.params.s[5], 
                                       self.params.n[5], self.params.r[5])

        # Layer-6: Apply dilation rate = 2
        self.layer6 = self._make_layer(self.c[5], self.c[6], self.params.t[6], self.params.s[6], 
                                       self.params.n[6], self.params.r[6])

        # Layer-7: Apply dilation rate = 2
        self.layer7 = self._make_layer(self.c[6], self.c[7], self.params.t[7], self.params.s[7], 
                                       self.params.n[7], self.params.r[7])
        
        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, inCh, outCh, t, s, n, r):
        layers = []
        for i in range(n):
            # First layer of each sequence has a stride s and all others use stride 1
            if i == 0:
                layers.append(InvertedResidual(inCh, outCh, t, s, r))
            else:
                layers.append(InvertedResidual(inCh, outCh, t, 1, r))

            # Update input channel for next IRB layer in the block
            inCh = outCh
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        low_level_features = self.layer2(x)  # [512, 512]/4 = [128, 128] 
        x = self.layer3(low_level_features)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        return x, low_level_features
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
                
def MobileNet():
    pass

# Atrous Spatial Pyramid Pooling (ASPP)

In [3]:
class AtrousConvBnRelu(nn.Module):
    """
    [Atrous CONV]-[BN]-[ReLU]
    """
    def __init__(self, inCh, outCh, dilation=1):
        super(AtrousConvBnRelu, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.dilation = dilation
        self.kernel = 1 if self.dilation == 1 else 3
        self.padding = 0 if self.dilation == 1 else self.dilation
        self.atrous_conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, self.kernel, stride=1, 
                      padding=self.padding, dilation=self.dilation, bias=False), 
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.atrous_conv(x)
    

class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling
    
    Ref(s): https://github.com/rishizek/tensorflow-deeplab-v3-plus/blob/master/deeplab_model.py
    and https://github.com/chenxi116/DeepLabv3.pytorch/blob/master/deeplab.py
    """
    def __init__(self, inCh, outCh):
        super(ASPP, self).__init__()
        self.rates = [1, 6, 12, 18] # for output stride 16
        self.inCh = inCh
        self.outCh = outCh
        
        # ASPP layers
        # (a) One 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18)
        self.conv_1x1_0 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[0])
        self.conv_3x3_1 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[1])
        self.conv_3x3_2 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[2])
        self.conv_3x3_3 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[3])
        
        # (b) The image-level features
        # Global Average Pooling
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        
        # CONV-BN-ReLU after Global Average Pooling
        self.conv_bn_relu_4 = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
        # CONV-BN-ReLU after Concatenation. NOTE: 5 Layers are concatenated
        self.conv_bn_relu_5 = nn.Sequential(
            nn.Conv2d(self.outCh * 5, self.outCh, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x0 = self.conv_1x1_0(x)  # size: [1, outCh, fs, fs]
        x1 = self.conv_3x3_1(x)  # size: [1, outCh, fs, fs]
        x2 = self.conv_3x3_2(x)  # size: [1, outCh, fs, fs]
        x3 = self.conv_3x3_3(x)  # size: [1, outCh, fs, fs]
        
        # Global Average Pooling, CONV-BN-ReLU and upsample
        global_avg_pool = self.global_avg_pooling(x)
        
        x4 = self.conv_bn_relu_4(global_avg_pool)
        
        upsample = F.interpolate(x4, size=(x.size(2), x.size(3)), mode='bilinear', 
                                 align_corners=True)
        
        # Concatinate
        x_concat = torch.cat([x0, x1, x2, x3, upsample], dim=1) # size: [1, 5 * outCh, fs, fs]
        
        # CONV-BN-ReLU after concatination
        out = self.conv_bn_relu_5(x_concat)
        
        return out

# Decoder

In [4]:
class Decoder(nn.Module):
    """
    Decoder for DeepLabV3+
    """
    def __init__(self, low_level_inch, low_level_outch, inCh, outCh, n_classes):
        super(Decoder, self).__init__()
        self.low_level_inch = low_level_inch
        self.low_level_outch = low_level_outch # 48 (or lower for speed)
        self.inCh = inCh
        self.outCh = outCh
        self.n_classes = n_classes
        
        # 1x1 Conv with BN and ReLU for low level features
        self.conv_1x1_bn_relu = nn.Sequential(
            nn.Conv2d(self.low_level_inch, self.low_level_outch, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.low_level_outch),
            nn.ReLU(inplace=True)
        )
        
        # Conv block with BN and ReLU (paper suggests to use a few 3x3 Convs, but using only 1
        # for speed improvement) and final Conv 1x1 
        self.conv_block = nn.Sequential(
            nn.Conv2d(self.inCh + self.low_level_outch, self.outCh, kernel_size=3, stride=1, padding=1, 
                      bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True),
            
            # For reducing number of channels
            nn.Conv2d(self.outCh, self.n_classes, kernel_size=1, stride=1, bias=False)
        )
    
    def forward(self, x, low_level_features):
        
        # Low level features from MobileNetV2
        low_level_features = self.conv_1x1_bn_relu(low_level_features)
        
        # Upsample features from ASPP by 4
        x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
        
        # Concatinate
        x_concat = torch.cat([x, low_level_features], dim=1)
        
        # Final Convolution
        out = self.conv_block(x_concat)
        
        return out

# DeepLabV3+

In [5]:
class DeepLabV3plus(nn.Module):
    def __init__(self, config):
        super(DeepLabV3plus, self).__init__()
        self.config = config
        
        # Base Network
        self.base = MobileNetV2(params=self.config)
        
        # ASPP Module
        self.aspp = ASPP(inCh=self.config.aspp_inch, 
                         outCh=self.config.aspp_outch)
        
        # Decoder
        self.decoder = Decoder(low_level_inch=self.config.low_level_inCh, 
                               low_level_outch=self.config.low_level_outCh, 
                               inCh=self.config.in_channels, 
                               outCh=self.config.out_channels,
                               n_classes=self.config.n_classes)
        
    def forward(self, x):
        # Extract features from base network
        base_out, low_level_features = self.base(x)
        
        # Pool base network output using Atrous Spatial Pyramid Pooling
        aspp_out = self.aspp(base_out)
        
        # Use decoder to obtain object boundaries
        decoder_out = self.decoder(aspp_out, low_level_features)
        
        # Upsample features from decoder by 4
        out = F.interpolate(decoder_out, scale_factor=4, mode='bilinear', align_corners=True)
        
        return out

# Config

In [6]:
class Config():
    """
    Configuration for training DeepLabV3+
    """
    def __init__(self):
        # MobileNetV2 parameters
        # ----------------------
        # Conv and Inverted Residual Parameters: Table-2 (https://arxiv.org/pdf/1801.04381.pdf)
        self.t = [1, 1, 6, 6, 6, 6, 6, 6]  # t: expansion factor
        self.c = [32, 16, 24, 32, 64, 96, 160, 320]  # c: Output channels
        self.n = [1, 1, 2, 3, 4, 3, 3, 1]  # n: Number of times layer is repeated
        self.s = [2, 1, 2, 2, 2, 1, 2, 1]  # s: Stride
        self.r = [1, 1, 1, 1, 1, 1, 2, 2]  # r: Dilation (added to take care of dilation)
        # Width multiplier: Controls the width of the network
        self.alpha = 0.5
        
        # ASPP Parameters
        # ---------------
        self.aspp_inch = int(self.alpha * self.c[-1])  # Width multiplier * 320
        self.aspp_outch = int(self.alpha * 256)  # Width multiplier * 256
        
        # Decoder Parameters
        # ------------------
        self.n_classes = 20
        self.low_level_inCh = int(self.alpha * self.c[3])  # Width multiplier * 32 
        self.low_level_outCh = int(2 * self.low_level_inCh)  # 2 * low level features channels
        self.in_channels = int(self.alpha * 256) # Width multiplier * 256
        self.out_channels = int(self.alpha * 256) # Width multiplier * 256
        
config = Config()

# Check DeepLabV3+

In [7]:
deeplab = DeepLabV3plus(config=config)
deeplab.eval()

# Input Image
x = torch.randn([1, 3, 512, 512])

y = deeplab(x)
print('output shape: ', y.shape)

output shape:  torch.Size([1, 20, 512, 512])


### Transform DeepLabV3+ Output to Class Map (for Cityscapes Data) 

- DeepLabV3+ Output size: `[20, 512, 512]`
- Transformed Class Map size: `[1024, 2048]`

#### For App, use the input image size instead of hard-coded numbers

In [8]:
def logits_to_class_map(logits):
    """
    Transform DeepLabV3+ Output to trainId map (class map)
    
    logits: PyTorch tensor, size: [1, 20, h, w]. It is the output 
        of DeepLabV3+ model.
    """
    upsample = F.interpolate(logits, size=(1024, 2048), mode='bilinear', 
                             align_corners=False)
    
    # [1, 20, 1024, 2048] -> [20, 1024, 2048]
    upsample = upsample.squeeze(0)
    
    # Find indices with maximum value
    out = torch.argmax(upsample, dim=0)
    
    return out

### Development

In [9]:
out = torch.randn([1, 20, 512, 512])
u_out = F.interpolate(out, size=(1024, 2048), mode='bilinear', 
                      align_corners=False)

print('Upsampled output shape: ', u_out.shape)

u_out = u_out.squeeze(0)
print('Upsampled output shape: ', u_out.shape)

c_map = torch.argmax(u_out, dim=0)
print('Class map shape: ', c_map.shape)

Upsampled output shape:  torch.Size([1, 20, 1024, 2048])
Upsampled output shape:  torch.Size([20, 1024, 2048])
Class map shape:  torch.Size([1024, 2048])


In [10]:
print(torch.unique(c_map))

tensor([ 2,  6, 18,  4, 11, 13,  3, 10,  7,  8, 17, 14,  5, 19,  0, 12,  1,  9,
        15, 16])
