In [1]:
import os
import json
import time
import math
import random
import numpy as np
from PIL import Image, ImageFilter, ImageOps

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Modified MobileNetV2 for DeepLabV3+

In [2]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function makes sure that number of channels number is divisible by 8.
    Source: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBnReLU(nn.Module):
    """
    [CONV]-[BN]-[ReLU6]
    """

    def __init__(self, inCh, outCh, stride):
        super(ConvBnReLU, self).__init__()
        self.inCh = inCh  # Number of input channels
        self.outCh = outCh  # Number of output channels
        self.stride = stride  # Stride
        self.conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, 3, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class InvertedResidual(nn.Module):
    """
    [EXP:CONV_1x1-BN-ReLU6]-[DW:CONV_3x3-BN-ReLU6]-[PW:CONV_1x1-BN] with identity shortcut 
    and dilation.
    """

    def __init__(self, inCh, outCh, t, s, r):
        super(InvertedResidual, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.t = t  # t: expansion factor
        self.r = r  # r: dilation
        if self.r > 1:
            self.s = 1  # s: Stride
            self.padding = self.r  # Atrous Conv padding same as dilation rate
        else:
            self.s = s  # s: Stride
            self.padding = 1
        self.identity_shortcut = (self.inCh == self.outCh) and (self.s == 1)  # L:506 Keras official code

        # Bottleneck block
        self.block = nn.Sequential(
            # Expansition Conv
            nn.Conv2d(self.inCh, self.t * self.inCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Depthwise Conv
            nn.Conv2d(self.t * self.inCh, self.t * self.inCh, kernel_size=3, stride=self.s, padding=self.padding, 
                      dilation=self.r, groups=self.t * self.inCh, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Pointwise Linear Conv (Projection): i.e. No non-linearity
            nn.Conv2d(self.t * self.inCh, self.outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.outCh),
        )

    def forward(self, x):
        if self.identity_shortcut:
            return x + self.block(x)
        else:
            return self.block(x)


class PointwiseConv(nn.Module):
    def __init__(self, inCh, outCh):
        super(PointwiseConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inCh, outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


# MobileNetV2
class MobileNetV2(nn.Module):
    """
    MobileNetV2 feature extractor modified to include dilation for DeepLabV3+. 
    NOTE: Last conv Layer and classification layer removed.
    """

    def __init__(self, params):
        super(MobileNetV2, self).__init__()
        self.params = params
        self.first_inCh = 3

        self.c = [_make_divisible(c * self.params.alpha, 8) for c in self.params.c]
        
        # Layer-0
        self.layer0 = nn.Sequential(ConvBnReLU(self.first_inCh, self.c[0], self.params.s[0]))

        # Layer-1
        self.layer1 = self._make_layer(self.c[0], self.c[1], self.params.t[1], self.params.s[1], 
                                       self.params.n[1], self.params.r[1])

        # Layer-2: Image size: 512 -> [IRB-2] -> Output size: 128 (low level feature: 128 * 4 = 512)
        self.layer2 = self._make_layer(self.c[1], self.c[2], self.params.t[2], self.params.s[2], 
                                       self.params.n[2], self.params.r[2])

        # Layer-3
        self.layer3 = self._make_layer(self.c[2], self.c[3], self.params.t[3], self.params.s[3], 
                                       self.params.n[3], self.params.r[3])

        # Layer-4
        self.layer4 = self._make_layer(self.c[3], self.c[4], self.params.t[4], self.params.s[4], 
                                       self.params.n[4], self.params.r[4])

        # Layer-5: Image size: 512 -> [IRB-5] -> Output size: 32, so output stride = 16 achieved
        self.layer5 = self._make_layer(self.c[4], self.c[5], self.params.t[5], self.params.s[5], 
                                       self.params.n[5], self.params.r[5])

        # Layer-6: Apply dilation rate = 2
        self.layer6 = self._make_layer(self.c[5], self.c[6], self.params.t[6], self.params.s[6], 
                                       self.params.n[6], self.params.r[6])

        # Layer-7: Apply dilation rate = 2
        self.layer7 = self._make_layer(self.c[6], self.c[7], self.params.t[7], self.params.s[7], 
                                       self.params.n[7], self.params.r[7])
        
        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, inCh, outCh, t, s, n, r):
        layers = []
        for i in range(n):
            # First layer of each sequence has a stride s and all others use stride 1
            if i == 0:
                layers.append(InvertedResidual(inCh, outCh, t, s, r))
            else:
                layers.append(InvertedResidual(inCh, outCh, t, 1, r))

            # Update input channel for next IRB layer in the block
            inCh = outCh
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        low_level_features = self.layer2(x)  # [512, 512]/4 = [128, 128] 
        x = self.layer3(low_level_features)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        return x, low_level_features
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                

def MobileNet(pretrained=True, **kwargs):
    """
    Constructs a MobileNet V2 model.
    
    Parameters
    ----------
    pretrained: bool, use ImageNet pretrained model or not.
    n_class: int, 1000 classes in ImageNet data.
    weight_file: str, path to pretrained weights
    """
    weight_file = kwargs.pop('weight_file', '')
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = torch.load(weight_file)
        model.load_state_dict(state_dict)
    return model

# Config

In [3]:
class Config():
    """
    Configuration for training DeepLabV3+
    """
    def __init__(self):
        # MobileNetV2 parameters
        # ----------------------
        self.pretrained = False
        # Conv and Inverted Residual Parameters: Table-2 (https://arxiv.org/pdf/1801.04381.pdf)
        self.t = [1, 1, 6, 6, 6, 6, 6, 6]  # t: expansion factor
        self.c = [32, 16, 24, 32, 64, 96, 160, 320]  # c: Output channels
        self.n = [1, 1, 2, 3, 4, 3, 3, 1]  # n: Number of times layer is repeated
        self.s = [2, 1, 2, 2, 2, 1, 2, 1]  # s: Stride
        self.r = [1, 1, 1, 1, 1, 1, 2, 2]  # r: Dilation (added to take care of dilation)
        # Width multiplier: Controls the width of the network
        self.alpha = 1 # Use multiples of 0.25, min=0.25, max=1.0
        
config = Config()

# MobileNet Pre-trained on ImageNet - Weight Transfer


[Example](https://github.com/lizhengwei1992/mobilenetv2_deeplabv3_pytorch/blob/master/utils.py)

In [4]:
# Model pre-trained on ImageNet weights
trained_state_dict = torch.load('./MobileNetV2.pth.tar')
trained_keys = list(trained_state_dict['state_dict'].keys())
print('Number of keys (pre-trained): ', len(trained_keys))
trained_weights = list(trained_state_dict['state_dict'].values())

Number of keys (pre-trained):  267


In [5]:
# Modified MobileNet
model = MobileNetV2(params=config)
model_keys = list(model.state_dict().keys())
print('Number of keys (modified): ', len(model_keys))
model_weights = list(model.state_dict().values())

Number of keys (modified):  312


In [6]:
# Tensors with size []
print(312 - sum([k.endswith('num_batches_tracked') for k in model_keys]))

260


In [7]:
ignore_indices = [i for i, k in enumerate(model_keys) if k.endswith('num_batches_tracked')]
model_keys[ignore_indices[9]]

'layer2.1.block.7.num_batches_tracked'

In [8]:
_model_keys = []
_model_key_ids = []
for i, w in enumerate(model_weights):
    if i not in ignore_indices:
        _model_keys.append(model_keys[i])
        _model_key_ids.append(i)
        print(i, w.size(), '---> ', model_keys[i])

print()
print('Number of weights: ', len(_model_keys))

0 torch.Size([32, 3, 3, 3]) --->  layer0.0.conv.0.weight
1 torch.Size([32]) --->  layer0.0.conv.1.weight
2 torch.Size([32]) --->  layer0.0.conv.1.bias
3 torch.Size([32]) --->  layer0.0.conv.1.running_mean
4 torch.Size([32]) --->  layer0.0.conv.1.running_var
6 torch.Size([32, 32, 1, 1]) --->  layer1.0.block.0.weight
7 torch.Size([32]) --->  layer1.0.block.1.weight
8 torch.Size([32]) --->  layer1.0.block.1.bias
9 torch.Size([32]) --->  layer1.0.block.1.running_mean
10 torch.Size([32]) --->  layer1.0.block.1.running_var
12 torch.Size([32, 1, 3, 3]) --->  layer1.0.block.3.weight
13 torch.Size([32]) --->  layer1.0.block.4.weight
14 torch.Size([32]) --->  layer1.0.block.4.bias
15 torch.Size([32]) --->  layer1.0.block.4.running_mean
16 torch.Size([32]) --->  layer1.0.block.4.running_var
18 torch.Size([16, 32, 1, 1]) --->  layer1.0.block.6.weight
19 torch.Size([16]) --->  layer1.0.block.7.weight
20 torch.Size([16]) --->  layer1.0.block.7.bias
21 torch.Size([16]) --->  layer1.0.block.7.running_

In [9]:
_trained_keys = []
_trained_key_ids = []
for i, w in enumerate(trained_weights[:-7]):
    _trained_keys.append(trained_keys[i])
    _trained_key_ids.append(i)
    print(i, w.size(), '---> ', trained_keys[i])
    
print()
print('Number of weights: ', len(_trained_keys))

0 torch.Size([32, 3, 3, 3]) --->  module.conv1.weight
1 torch.Size([32]) --->  module.bn1.weight
2 torch.Size([32]) --->  module.bn1.bias
3 torch.Size([32]) --->  module.bn1.running_mean
4 torch.Size([32]) --->  module.bn1.running_var
5 torch.Size([32, 32, 1, 1]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.conv1.weight
6 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.weight
7 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.bias
8 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.running_mean
9 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.running_var
10 torch.Size([32, 1, 3, 3]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.conv2.weight
11 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn2.weight
12 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn2.bias
13 torch.Size([32]) 

96 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.weight
97 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.bias
98 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.running_mean
99 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.running_var
100 torch.Size([192, 1, 3, 3]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.conv2.weight
101 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn2.weight
102 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn2.bias
103 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn2.running_mean
104 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn2.running_var
105 torch.Size([64, 192, 1, 1]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.conv3.weight
106 torch.Size([64]) --->  module.bottlenec

221 torch.Size([960]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn2.weight
222 torch.Size([960]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn2.bias
223 torch.Size([960]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn2.running_mean
224 torch.Size([960]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn2.running_var
225 torch.Size([160, 960, 1, 1]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.conv3.weight
226 torch.Size([160]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn3.weight
227 torch.Size([160]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn3.bias
228 torch.Size([160]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn3.running_mean
229 torch.Size([160]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_1.bn3.running_var
230 torch.Size([960, 160, 1, 1]) --->  module.bottlenecks.Bottlenecks_5.LinearBottleneck5_2.conv1.weight
231 torch.Size([960]) --->  module.b

In [10]:
trained_key_map = dict(zip(_trained_keys, _model_keys))
trained_key_ids_map = dict(zip(_trained_key_ids, _model_key_ids))

for train_key_id, model_key_id in trained_key_ids_map.items():
    assert trained_weights[train_key_id].shape == model_weights[model_key_id].shape
    
for train_key, model_key in trained_key_map.items():
    assert trained_state_dict['state_dict'][train_key].shape == model.state_dict()[model_key].shape

In [11]:
trained_state = trained_state_dict['state_dict']
model_state = model.state_dict()

```python
model_state['layer0.0.conv.0.weight'][0, :, :, :]

tensor([[[ 0.0642,  0.0665, -0.0183],
         [ 0.1289, -0.0172, -0.0765],
         [-0.1314,  0.0287, -0.0152]],

        [[-0.0030,  0.0045, -0.0062],
         [-0.0078,  0.0550,  0.0823],
         [ 0.0269,  0.0041, -0.0504]],

        [[ 0.1453, -0.0798,  0.0570],
         [ 0.0030, -0.0672, -0.0656],
         [ 0.0377, -0.0117,  0.0147]]])
```

In [12]:
for train_key, model_key in trained_key_map.items():
    model_state[model_key] = trained_state[train_key]

In [13]:
trained_state['module.conv1.weight'][0, :, :, :]

tensor([[[-0.0328, -0.0461,  0.0477],
         [-0.0949, -0.0869,  0.1127],
         [ 0.0155, -0.0195,  0.0531]],

        [[ 0.0586,  0.0218,  0.0724],
         [-0.0116, -0.0063,  0.2373],
         [ 0.0562,  0.0030,  0.1027]],

        [[-0.0158, -0.0735,  0.0536],
         [-0.0307, -0.0113,  0.3229],
         [-0.0774, -0.1312,  0.0766]]], device='cuda:1')

In [14]:
model_state['layer0.0.conv.0.weight'][0, :, :, :]

tensor([[[-0.0328, -0.0461,  0.0477],
         [-0.0949, -0.0869,  0.1127],
         [ 0.0155, -0.0195,  0.0531]],

        [[ 0.0586,  0.0218,  0.0724],
         [-0.0116, -0.0063,  0.2373],
         [ 0.0562,  0.0030,  0.1027]],

        [[-0.0158, -0.0735,  0.0536],
         [-0.0307, -0.0113,  0.3229],
         [-0.0774, -0.1312,  0.0766]]], device='cuda:1')

In [15]:
model.load_state_dict(model_state)

In [16]:
model.state_dict()['layer0.0.conv.0.weight'][0, :, :, :]

tensor([[[-0.0328, -0.0461,  0.0477],
         [-0.0949, -0.0869,  0.1127],
         [ 0.0155, -0.0195,  0.0531]],

        [[ 0.0586,  0.0218,  0.0724],
         [-0.0116, -0.0063,  0.2373],
         [ 0.0562,  0.0030,  0.1027]],

        [[-0.0158, -0.0735,  0.0536],
         [-0.0307, -0.0113,  0.3229],
         [-0.0774, -0.1312,  0.0766]]])

# Transfer Weights and Save

In [17]:
# Model pre-trained on ImageNet weights
pre_trained_state = torch.load('./MobileNetV2.pth.tar')['state_dict']

# Modified MobileNet
mobile_net = MobileNetV2(params=config)
mobile_net_state = mobile_net.state_dict()

for train_key, model_key in trained_key_map.items():
    mobile_net_state[model_key] = pre_trained_state[train_key]
    
mobile_net.load_state_dict(mobile_net_state)
torch.save(mobile_net.state_dict(), 'MobileNetV2-Pretrained-Weights.pth.tar')

In [18]:
print(pre_trained_state['module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.conv2.weight'][0,...])
print()
print(mobile_net_state['layer1.0.block.3.weight'][0, ...])
print()
print(mobile_net.state_dict()['layer1.0.block.3.weight'][0, ...])

tensor([[[ 0.0000,  0.0881,  0.0684],
         [ 0.0612, -0.1843, -0.1351],
         [ 0.0600,  0.0133, -0.1209]]], device='cuda:1')

tensor([[[ 0.0000,  0.0881,  0.0684],
         [ 0.0612, -0.1843, -0.1351],
         [ 0.0600,  0.0133, -0.1209]]], device='cuda:1')

tensor([[[ 0.0000,  0.0881,  0.0684],
         [ 0.0612, -0.1843, -0.1351],
         [ 0.0600,  0.0133, -0.1209]]])


In [19]:
children = list(mobile_net.children())
len(children)

8