- [PyTorch Segmentation (different models)](https://github.com/nyoki-mtl/pytorch-segmentation)
- [TensorFlow DeepLabV3+](https://github.com/rishizek/tensorflow-deeplab-v3-plus/blob/master/deeplab_model.py)

- [PyTorch Good Implementation](https://github.com/jfzhang95/pytorch-deeplab-xception)

[PyTorch Copy Weights Only](https://discuss.pytorch.org/t/copy-weights-only-from-a-networks-parameters/5841/2)

[Dilated Convolutions Blog Post](https://towardsdatascience.com/understanding-2d-dilated-convolution-operation-with-examples-in-numpy-and-tensorflow-with-d376b3972b25)

# MobileNetV2

In [1]:
"""
Creates a MobileNetV2 model as defined in the paper: M. Sandler, 
A. Howard, M. Zhu, A. Zhmoginov, L.-C. Chen. "MobileNetV2: Inverted 
Residuals and Linear Bottlenecks.", arXiv:1801.04381, 2018."

Code reference: https://github.com/tonylins/pytorch-mobilenet-v2
ImageNet pretrained weights: https://drive.google.com/file/d/1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR
"""
import math
import torch
import torch.nn as nn



def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2Original(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2Original, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
def MobileNetOriginal(pretrained=True, **kwargs):
    """
    Constructs a MobileNet V2 model.
    
    Parameters
    ----------
    pretrained: bool, use ImageNet pretrained model or not.
    n_class: int, 1000 classes in ImageNet data.
    weight_file: str, path to pretrained weights
    """
    weight_file = kwargs.pop('weight_file', '')
    model = MobileNetV2Original(**kwargs)
    if pretrained:
        state_dict = torch.load(weight_file)
        model.load_state_dict(state_dict)
    return model

In [2]:
# Load weights pretrained on ImageNet data using function
model = MobileNetOriginal(pretrained=True, n_class=1000, weight_file='./MobileNetV2.pth.tar')

In [3]:
list(model.children())[1]
keys = list(model.state_dict().keys())
values = list(model.state_dict().values())

### Output stride is the ratio of input image spatial resolution to final output resolution

In [4]:
x = torch.randn([1, 3, 512, 512])
for i in range(19):
    x = model.features[i](x)
    print('x size at {}: {}'.format(i, x.shape))
    
print(512/32.)

x size at 0: torch.Size([1, 32, 256, 256])
x size at 1: torch.Size([1, 16, 256, 256])
x size at 2: torch.Size([1, 24, 128, 128])
x size at 3: torch.Size([1, 24, 128, 128])
x size at 4: torch.Size([1, 32, 64, 64])
x size at 5: torch.Size([1, 32, 64, 64])
x size at 6: torch.Size([1, 32, 64, 64])
x size at 7: torch.Size([1, 64, 32, 32])
x size at 8: torch.Size([1, 64, 32, 32])
x size at 9: torch.Size([1, 64, 32, 32])
x size at 10: torch.Size([1, 64, 32, 32])
x size at 11: torch.Size([1, 96, 32, 32])
x size at 12: torch.Size([1, 96, 32, 32])
x size at 13: torch.Size([1, 96, 32, 32])
x size at 14: torch.Size([1, 160, 16, 16])
x size at 15: torch.Size([1, 160, 16, 16])
x size at 16: torch.Size([1, 160, 16, 16])
x size at 17: torch.Size([1, 320, 16, 16])
x size at 18: torch.Size([1, 1280, 16, 16])
16.0


### Modify MobileNetV2 for DeepLabV3

In [5]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function makes sure that number of channels number is divisible by 8.
    Source: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBnReLU(nn.Module):
    """
    [CONV]-[BN]-[ReLU6]
    """

    def __init__(self, inCh, outCh, stride):
        super(ConvBnReLU, self).__init__()
        self.inCh = inCh  # Number of input channels
        self.outCh = outCh  # Number of output channels
        self.stride = stride  # Stride
        self.conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, 3, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class InvertedResidual(nn.Module):
    """
    [EXP:CONV_1x1-BN-ReLU6]-[DW:CONV_3x3-BN-ReLU6]-[PW:CONV_1x1-BN] with identity shortcut 
    and dilation.
    """

    def __init__(self, inCh, outCh, t, s, r):
        super(InvertedResidual, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.t = t  # t: expansion factor
        self.r = r  # r: dilation
        if self.r > 1:
            self.s = 1  # s: Stride
            self.padding = self.r  # Atrous Conv padding same as dilation rate
        else:
            self.s = s  # s: Stride
            self.padding = 1
        self.identity_shortcut = (self.inCh == self.outCh) and (self.s == 1)  # L:506 Keras official code

        # Bottleneck block
        self.block = nn.Sequential(
            # Expansition Conv
            nn.Conv2d(self.inCh, self.t * self.inCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Depthwise Conv
            nn.Conv2d(self.t * self.inCh, self.t * self.inCh, kernel_size=3, stride=self.s, padding=self.padding, 
                      dilation=self.r, groups=self.t * self.inCh, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Pointwise Linear Conv (Projection): i.e. No non-linearity
            nn.Conv2d(self.t * self.inCh, self.outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.outCh),
        )

    def forward(self, x):
        if self.identity_shortcut:
            return x + self.block(x)
        else:
            return self.block(x)


class PointwiseConv(nn.Module):
    def __init__(self, inCh, outCh):
        super(PointwiseConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inCh, outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


# MobileNetV2
class MobileNetV2(nn.Module):
    """
    MobileNetV2 feature extractor modified to include dilation for DeepLabV3+
    """

    def __init__(self, params):
        super(MobileNetV2, self).__init__()
        self.params = params
        self.first_inCh = 3

        self.c = [_make_divisible(c * self.params.alpha, 8) for c in self.params.c]
        
        # Layer-0
        self.layer0 = nn.Sequential(ConvBnReLU(self.first_inCh, self.c[0], self.params.s[0]))

        # Layer-1
        self.layer1 = self._make_layer(self.c[0], self.c[1], self.params.t[1], self.params.s[1], 
                                       self.params.n[1], self.params.r[1])

        # Layer-2: Image size: 512 -> [IRB-2] -> Output size: 128 (low level feature: 128 * 4 = 512)
        self.layer2 = self._make_layer(self.c[1], self.c[2], self.params.t[2], self.params.s[2], 
                                       self.params.n[2], self.params.r[2])

        # Layer-3
        self.layer3 = self._make_layer(self.c[2], self.c[3], self.params.t[3], self.params.s[3], 
                                       self.params.n[3], self.params.r[3])

        # Layer-4
        self.layer4 = self._make_layer(self.c[3], self.c[4], self.params.t[4], self.params.s[4], 
                                       self.params.n[4], self.params.r[4])

        # Layer-5: Image size: 512 -> [IRB-5] -> Output size: 32, so output stride = 16 achieved
        self.layer5 = self._make_layer(self.c[4], self.c[5], self.params.t[5], self.params.s[5], 
                                       self.params.n[5], self.params.r[5])

        # Layer-6: Apply dilation rate = 2
        self.layer6 = self._make_layer(self.c[5], self.c[6], self.params.t[6], self.params.s[6], 
                                       self.params.n[6], self.params.r[6])

        # Layer-7: Apply dilation rate = 2
        self.layer7 = self._make_layer(self.c[6], self.c[7], self.params.t[7], self.params.s[7], 
                                       self.params.n[7], self.params.r[7])
        
        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, inCh, outCh, t, s, n, r):
        layers = []
        for i in range(n):
            # First layer of each sequence has a stride s and all others use stride 1
            if i == 0:
                layers.append(InvertedResidual(inCh, outCh, t, s, r))
            else:
                layers.append(InvertedResidual(inCh, outCh, t, 1, r))

            # Update input channel for next IRB layer in the block
            inCh = outCh
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        print('0 shape: ', x.shape)
        x = self.layer1(x)
        print('1 shape: ', x.shape)
        low_level_features = self.layer2(x)
        print('low_level_features shape: ', low_level_features.shape)
        x = self.layer3(low_level_features)
        print('3 shape: ', x.shape)
        x = self.layer4(x)
        print('4 shape: ', x.shape)
        x = self.layer5(x)
        print('5 shape: ', x.shape)
        x = self.layer6(x)
        print('6 shape: ', x.shape)
        x = self.layer7(x)
        print('7 shape: ', x.shape)
        return x, low_level_features
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [6]:
class Config():
    """
    Configuration for training MobileNetV2 model
    """
    def __init__(self):
        # MobileNetV2 parameters
        # ----------------------
        # Conv and Inverted Residual Parameters: Table-2 (https://arxiv.org/pdf/1801.04381.pdf)
        self.t = [1, 1, 6, 6, 6, 6, 6, 6]  # t: expansion factor
        self.c = [32, 16, 24, 32, 64, 96, 160, 320]  # c: Output channels
        self.n = [1, 1, 2, 3, 4, 3, 3, 1]  # n: Number of times layer is repeated
        self.s = [2, 1, 2, 2, 2, 1, 2, 1]  # s: Stride
        self.r = [1, 1, 1, 1, 1, 1, 2, 2]  # r: Dilation
        # Width multiplier: Controls the width of the network
        self.alpha = 1
        
config = Config()

In [7]:
# Test base
net = MobileNetV2(config)

x = torch.randn([1, 3, 512, 512])
y, _ = net(x)

0 shape:  torch.Size([1, 32, 256, 256])
1 shape:  torch.Size([1, 16, 256, 256])
low_level_features shape:  torch.Size([1, 24, 128, 128])
3 shape:  torch.Size([1, 32, 64, 64])
4 shape:  torch.Size([1, 64, 32, 32])
5 shape:  torch.Size([1, 96, 32, 32])
6 shape:  torch.Size([1, 160, 32, 32])
7 shape:  torch.Size([1, 320, 32, 32])


In [8]:
# Inverted Block (Conv only with dialation)
t = 6
inCh = 96 # 160 
outCh = 160 # 320 
conv_ex = nn.Conv2d(inCh, t*inCh, kernel_size=1, stride=1, padding=0, bias=False)
conv_dw = nn.Conv2d(t*inCh, t*inCh, kernel_size=3, stride=1, padding=2, dilation=2, 
                    groups=t*inCh, bias=False)
conv_pw = nn.Conv2d(t*inCh, outCh, kernel_size=1, stride=1, padding=0, bias=False)

w = torch.randn([1, inCh, 32, 32])
print('input shape: ', w.shape)
x = conv_ex(w)
print(x.shape)
y = conv_dw(x)
print(y.shape)
z = conv_pw(y)
print('output shape: ', z.shape)

input shape:  torch.Size([1, 96, 32, 32])
torch.Size([1, 576, 32, 32])
torch.Size([1, 576, 32, 32])
output shape:  torch.Size([1, 160, 32, 32])


# Atrous Spatial Pyramid Pooling

In [9]:
class AtrousConvBnRelu(nn.Module):
    """
    [Atrous CONV]-[BN]-[ReLU]
    """
    def __init__(self, inCh, outCh, dilation=1):
        super(AtrousConvBnRelu, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.dilation = dilation
        self.kernel = 1 if self.dilation == 1 else 3
        self.padding = 0 if self.dilation == 1 else self.dilation
        self.atrous_conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, self.kernel, stride=1, 
                      padding=self.padding, dilation=self.dilation, bias=False), 
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.atrous_conv(x)

In [10]:
# Check input and output shapes
atrous = AtrousConvBnRelu(3, 32, dilation=18)

x = torch.randn([1, 3, 224, 224])
y = atrous(x)

assert x.shape[-2:] == y.shape[-2:]  # Input [W, H] matches output [W, H]
print('y shape: ', y.shape)

y shape:  torch.Size([1, 32, 224, 224])


In [11]:
print(list(atrous.children()))
print(atrous.state_dict().keys())

[Sequential(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
)]
odict_keys(['atrous_conv.0.weight', 'atrous_conv.1.weight', 'atrous_conv.1.bias', 'atrous_conv.1.running_mean', 'atrous_conv.1.running_var', 'atrous_conv.1.num_batches_tracked'])


### ASPP Module

In [12]:
import torch.nn.functional as F
class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling
    
    Ref(s): https://github.com/rishizek/tensorflow-deeplab-v3-plus/blob/master/deeplab_model.py
    and https://github.com/chenxi116/DeepLabv3.pytorch/blob/master/deeplab.py
    """
    def __init__(self, inCh, outCh):
        super(ASPP, self).__init__()
        self.rates = [1, 6, 12, 18] # for output stride 16
        self.inCh = inCh
        self.outCh = outCh
        
        # ASPP layers
        # (a) One 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18)
        self.conv_1x1_0 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[0])
        self.conv_3x3_1 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[1])
        self.conv_3x3_2 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[2])
        self.conv_3x3_3 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[3])
        
        # (b) The image-level features
        # Global Average Pooling
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        
        # CONV-BN-ReLU after Global Average Pooling
        self.conv_bn_relu_4 = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
        # CONV-BN-ReLU after Concatenation. NOTE: 5 Layers are concatenated
        self.conv_bn_relu_5 = nn.Sequential(
            nn.Conv2d(self.outCh * 5, self.outCh, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x0 = self.conv_1x1_0(x)  # size: [1, outCh, fs, fs]
        print('aspp-0 shape: ', x0.shape)
        x1 = self.conv_3x3_1(x)  # size: [1, outCh, fs, fs]
        print('aspp-1 shape: ', x1.shape)
        x2 = self.conv_3x3_2(x)  # size: [1, outCh, fs, fs]
        print('aspp-2 shape: ', x2.shape)
        x3 = self.conv_3x3_3(x)  # size: [1, outCh, fs, fs]
        print('aspp-3 shape: ', x3.shape)
        
        # Global Average Pooling, CONV-BN-ReLU and upsample
        global_avg_pool = self.global_avg_pooling(x)
        
        x4 = self.conv_bn_relu_4(global_avg_pool)
        print('aspp x4 shape: ', x4.shape)
        
        upsample = F.interpolate(x4, size=(x.size(2), x.size(3)), mode='bilinear', 
                                 align_corners=True)
        
        print('aspp upsample shape: ', upsample.shape)
        
        # Concatinate
        x_concat = torch.cat([x0, x1, x2, x3, upsample], dim=1) # size: [1, 5 * outCh, fs, fs]
        print('aspp concat shape: ', x_concat.shape)
        
        # CONV-BN-ReLU after concatination
        out = self.conv_bn_relu_5(x_concat)
        print('aspp out shape: ', out.shape)
        
        return out

In [13]:
# Check input and output shapes. ASPP seems slow, may be remove ASPP 18?
outCh = 256 # Use 128 or 64?
aspp = ASPP(320, outCh)
aspp.eval()

x = torch.randn([1, 320, 32, 32]) # Batch Size > 1 for training!
%time y = aspp(x)  # Reduce inCh size by using 1x1 conv on MobileNetV2 output?

aspp-0 shape:  torch.Size([1, 256, 32, 32])
aspp-1 shape:  torch.Size([1, 256, 32, 32])
aspp-2 shape:  torch.Size([1, 256, 32, 32])
aspp-3 shape:  torch.Size([1, 256, 32, 32])
aspp x4 shape:  torch.Size([1, 256, 1, 1])
aspp upsample shape:  torch.Size([1, 256, 32, 32])
aspp concat shape:  torch.Size([1, 1280, 32, 32])
aspp out shape:  torch.Size([1, 256, 32, 32])
CPU times: user 372 ms, sys: 0 ns, total: 372 ms
Wall time: 53.9 ms


In [14]:
y.shape

torch.Size([1, 256, 32, 32])

# Decoder

In [15]:
class Decoder(nn.Module):
    """
    Decoder for DeepLabV3+
    """
    def __init__(self, low_level_inch, low_level_outch, inCh, outCh, n_classes):
        super(Decoder, self).__init__()
        self.low_level_inch = low_level_inch
        self.low_level_outch = low_level_outch # 48 (or lower for speed)
        self.inCh = inCh
        self.outCh = outCh
        self.n_classes = n_classes
        
        # 1x1 Conv with BN and ReLU for low level features
        self.conv_1x1_bn_relu = nn.Sequential(
            nn.Conv2d(self.low_level_inch, self.low_level_outch, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.low_level_outch),
            nn.ReLU(inplace=True)
        )
        
        # Conv block with BN and ReLU (paper suggests to use a few 3x3 Convs, but using only 1
        # for speed improvement) and final Conv 1x1 
        self.conv_block = nn.Sequential(
            nn.Conv2d(self.inCh + self.low_level_outch, self.outCh, kernel_size=3, stride=1, padding=1, 
                      bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True),
            
            # For reducing number of channels
            nn.Conv2d(self.outCh, self.n_classes, kernel_size=1, stride=1, bias=False)
        )
    
    def forward(self, x, low_level_features):
        
        # Low level features from MobileNetV2
        low_level_features = self.conv_1x1_bn_relu(low_level_features)
        print('decoder low level feat shape: ', low_level_features.shape)
        
        # Upsample features from ASPP by 4
        x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
        print('decoder upsample shape: ', x.shape)
        
        # Concatinate
        x_concat = torch.cat([x, low_level_features], dim=1)
        print('decoder concat shape: ', x_concat.shape)
        
        # Final Convolution
        out = self.conv_block(x_concat)
        print('decoder out shape: ', out.shape)
        
        return out

In [16]:
decoder = Decoder(24, 48, 256, 256, 21)
decoder.eval()

x = torch.randn([1, 24, 128, 128]) # MobileNetV2 low level features
z = decoder.conv_1x1_bn_relu(x)
print('z: ', z.shape)

y = torch.randn([1, 256, 32, 32]) # ASPP output
u = F.interpolate(y, scale_factor=4, mode='bilinear', align_corners=True)
print('u: ', u.shape)

assert z.shape[2:] == u.shape[2:]

out = decoder(y, x)
print('out: ', out.shape)

z:  torch.Size([1, 48, 128, 128])
u:  torch.Size([1, 256, 128, 128])
decoder low level feat shape:  torch.Size([1, 48, 128, 128])
decoder upsample shape:  torch.Size([1, 256, 128, 128])
decoder concat shape:  torch.Size([1, 304, 128, 128])
decoder out shape:  torch.Size([1, 21, 128, 128])
out:  torch.Size([1, 21, 128, 128])


# DeepLabV3+

In [17]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, config):
        super(DeepLabV3Plus, self).__init__()
        self.config = config
        
        self.base = MobileNetV2(self.config)
        self.aspp = ASPP(320, 256)
        self.decoder = Decoder(24, 48, 256, 256, 21)
        
    def forward(self, x):
        # Extract features from base network
        base_out, low_level_features = self.base(x)
        print('base_out shape: ', base_out.shape)
        print('low_level_features shape: ', low_level_features.shape)
        
        # Pool base network output using Atrous Spatial Pyramid Pooling
        aspp_out = self.aspp(base_out)
        print('dlv3 aspp out shape: ', aspp_out.shape)
        
        # Use decoder to obtain object boundaries
        decoder_out = self.decoder(aspp_out, low_level_features)
        print('dlv3 decoder out shape: ', decoder_out.shape)
        
        # Upsample features from decoder by 4
        out = F.interpolate(decoder_out, scale_factor=4, mode='bilinear', align_corners=True)
        print('dlv3 out shape: ', out.shape)
        
        return out

```python
RuntimeError: Given groups=1, weight of size [256, 320, 1, 1], expected input[4, 160, 32, 32] to have 320 channels, but got 160 channels instead
```

**The above error is caused because in the low level feature channels from MobileNet does not match decoders 1x1 conv input channels - Solution `_make_divisible(inCh, 8)`**

In [18]:
dl = DeepLabV3Plus(config)
dl.eval()

x = torch.randn([4, 3, 512, 512])
y = dl(x)
print(y.shape)

0 shape:  torch.Size([4, 32, 256, 256])
1 shape:  torch.Size([4, 16, 256, 256])
low_level_features shape:  torch.Size([4, 24, 128, 128])
3 shape:  torch.Size([4, 32, 64, 64])
4 shape:  torch.Size([4, 64, 32, 32])
5 shape:  torch.Size([4, 96, 32, 32])
6 shape:  torch.Size([4, 160, 32, 32])
7 shape:  torch.Size([4, 320, 32, 32])
base_out shape:  torch.Size([4, 320, 32, 32])
low_level_features shape:  torch.Size([4, 24, 128, 128])
aspp-0 shape:  torch.Size([4, 256, 32, 32])
aspp-1 shape:  torch.Size([4, 256, 32, 32])
aspp-2 shape:  torch.Size([4, 256, 32, 32])
aspp-3 shape:  torch.Size([4, 256, 32, 32])
aspp x4 shape:  torch.Size([4, 256, 1, 1])
aspp upsample shape:  torch.Size([4, 256, 32, 32])
aspp concat shape:  torch.Size([4, 1280, 32, 32])
aspp out shape:  torch.Size([4, 256, 32, 32])
dlv3 aspp out shape:  torch.Size([4, 256, 32, 32])
decoder low level feat shape:  torch.Size([4, 48, 128, 128])
decoder upsample shape:  torch.Size([4, 256, 128, 128])
decoder concat shape:  torch.Size(

# Accuracy

In [19]:
# Predictions
preds = torch.argmax(y, dim=1)
print('preds shape: ', preds.shape)

# Ground Truth
masks = torch.randint(0, 21, (4, 512, 512)).long()
print('masks shape: ', masks.shape)

# Check equality
correct = torch.eq(preds, masks)

tot_correct = correct.sum()
num_elements = correct.numel()

print(tot_correct.float().item() * 100.0 / num_elements)

preds shape:  torch.Size([4, 512, 512])
masks shape:  torch.Size([4, 512, 512])
4.718875885009766


### Adaptive Average Pooling

[What is Adaptive Average Pooling](https://discuss.pytorch.org/t/what-is-adaptiveavgpool2d/26897/2)

What happens is that the pooling stencil size (aka kernel size) is determined to be `(input_size+target_size-1) // target_size`, i.e. rounded up. With this Then the positions of where to apply the stencil are computed as rounded equidistant points between `0` and `input_size - stencil_size`.

> Let’s have a 1d example: Say you have an input size of 14 and a target size of 4. Then the stencil size is 4.
The four equidistant points would be 0, 3.3333, 6.6666, 10 and get rounded to 0, 3, 7, 10. And so the four items would be the mean of the slices 0:4, 3:7, 7:11, 10:14 (in Python manner, so including lower bound, excluding upper bound). You see that the first two and last two slices overlap by one. Something like - occasional overlaps of 1 - this will generally be the case when the input size is not divisible by the target size.

In [20]:
a = torch.arange(0.0, 14.0, requires_grad=True)
print(a)
print(a.shape, a[None, None].shape)

AAP = nn.AdaptiveAvgPool1d(output_size=4)

b = AAP(a[None, None])
print(b)
print()
print(torch.sum(a[0:4])/4), print(torch.sum(a[3:7])/4), print(torch.sum(a[7:11])/4), print(torch.sum(a[10:14])/4)
print()

b.backward(torch.arange(1.0, 1 + b.size(-1))[None, None])

print(a.grad)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.],
       requires_grad=True)
torch.Size([14]) torch.Size([1, 1, 14])
tensor([[[ 1.5000,  4.5000,  8.5000, 11.5000]]], grad_fn=<SqueezeBackward1>)

tensor(1.5000, grad_fn=<DivBackward0>)
tensor(4.5000, grad_fn=<DivBackward0>)
tensor(8.5000, grad_fn=<DivBackward0>)
tensor(11.5000, grad_fn=<DivBackward0>)

tensor([0.2500, 0.2500, 0.2500, 0.7500, 0.5000, 0.5000, 0.5000, 0.7500, 0.7500,
        0.7500, 1.7500, 1.0000, 1.0000, 1.0000])
