In [1]:
import math
import torch
import torch.nn as nn

# MobileNetV2

In [2]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function makes sure that number of channels number is divisible by 8.
    Source: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBnReLU(nn.Module):
    """
    [CONV]-[BN]-[ReLU6]
    """

    def __init__(self, inCh, outCh, stride):
        super(ConvBnReLU, self).__init__()
        self.inCh = inCh  # Number of input channels
        self.outCh = outCh  # Number of output channels
        self.stride = stride  # Stride
        self.conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, 3, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class InvertedResidual(nn.Module):
    """
    [CONV_1x1-BN-ReLU6]-[CONV_3x3-BN-ReLU6]-[CONV_1x1-BN] with identity shortcut.
    """

    def __init__(self, inCh, outCh, t, s):
        super(InvertedResidual, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.t = t  # t: expansion factor
        self.s = s  # s: Stride
        self.identity_shortcut = (self.inCh == self.outCh) and (self.s == 1)  # L:506 Keras official code

        # Bottleneck block
        self.block = nn.Sequential(
            # Expansition Conv
            nn.Conv2d(self.inCh, self.t * self.inCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Depthwise Conv
            nn.Conv2d(self.t * self.inCh, self.t * self.inCh, kernel_size=3, stride=self.s, padding=1, 
                      groups=self.t * self.inCh, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Pointwise Linear Conv (Projection): i.e. No non-linearity
            nn.Conv2d(self.t * self.inCh, self.outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.outCh),
        )

    def forward(self, x):
        if self.identity_shortcut:
            return x + self.block(x)
        else:
            return self.block(x)


class PointwiseConv(nn.Module):
    def __init__(self, inCh, outCh):
        super(PointwiseConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inCh, outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


# MobileNetV2
class MobileNetV2(nn.Module):
    """
    MobileNetV2 feature extractor for YOLOv3. NOTE: YOLOv3 uses convolutional layers only!

    Input: 416 x 416 x 3
    Last layer Pointwise conv output:13 x 13 x 1024 -> Large object detection
    5th layer Pointwise conv output: :26 x 26 x 512 -> Medium object detection
    3rd layer Pointwise conv output: 52 x 52 x 256 -> Small object detection
    """

    def __init__(self, params):
        super(MobileNetV2, self).__init__()
        self.params = params
        self.first_inCh = 3
        last_outCh = 1280

        self.c = [_make_divisible(c * self.params.alpha, 8) for c in self.params.c]
        # Last convolution has 1280 output channels for alpha <= 1
        self.last_outCh = _make_divisible(int(last_outCh * self.params.alpha),
                                          8) if self.params.alpha > 1.0 else last_outCh

        # NOTE: YOLOv3 makes predictions at 3 different scales: (1) In the last feature map layer: 13 x 13
        # (2) The feature map from 2 layers previous and upsample it by 2x: 26 x 26
        # (3) The feature map from 2 layers previous and upsample it by 2x: 52 x 52

        # Layer-0
        self.layer0 = nn.Sequential(ConvBnReLU(self.first_inCh, self.c[0], self.params.s[0]))

        # Layer-1
        self.layer1 = self._make_layer(self.c[0], self.c[1], self.params.t[1], self.params.s[1], self.params.n[1])

        # Layer-2
        self.layer2 = self._make_layer(self.c[1], self.c[2], self.params.t[2], self.params.s[2], self.params.n[2])

        # Layer-3
        self.layer3 = self._make_layer(self.c[2], self.c[3], self.params.t[3], self.params.s[3], self.params.n[3])
        self.layer3_out = nn.Sequential(PointwiseConv(self.c[3], 256))

        # Layer-4
        self.layer4 = self._make_layer(self.c[3], self.c[4], self.params.t[4], self.params.s[4], self.params.n[4])

        # Layer-5
        self.layer5 = self._make_layer(self.c[4], self.c[5], self.params.t[5], self.params.s[5], self.params.n[5])
        self.layer5_out = nn.Sequential(PointwiseConv(self.c[5], 512))

        # Layer-6
        self.layer6 = self._make_layer(self.c[5], self.c[6], self.params.t[6], self.params.s[6], self.params.n[6])

        # Layer-7
        self.layer7 = self._make_layer(self.c[6], self.c[7], self.params.t[7], self.params.s[7], self.params.n[7])

        # Layer-8
        self.layer8 = nn.Sequential(PointwiseConv(self.c[7], self.last_outCh))

        self.out_channels = [256, 512, 1280]
        
        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, inCh, outCh, t, s, n):
        layers = []
        for i in range(n):
            # First layer of each sequence has a stride s and all others use stride 1
            if i == 0:
                layers.append(InvertedResidual(inCh, outCh, t, s))
            else:
                layers.append(InvertedResidual(inCh, outCh, t, 1))

            # Update input channel for next IRB layer in the block
            inCh = outCh
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        out52 = self.layer3_out(x)
        x = self.layer4(x)
        x = self.layer5(x)
        out26 = self.layer5_out(x)
        x = self.layer6(x)
        x = self.layer7(x)
        out13 = self.layer8(x)
        return out52, out26, out13
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
                
def MobileNet(pretrained=True, **kwargs):
    """
    Constructs a MobileNet V2 model.
    
    Parameters
    ----------
    pretrained: bool, use ImageNet pretrained model or not.
    n_class: int, 1000 classes in ImageNet data.
    weight_file: str, path to pretrained weights
    """
    weight_file = kwargs.pop('weight_file', '')
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = torch.load(weight_file)
        model.load_state_dict(state_dict)
    return model

# Config

In [3]:
class Config():
    """
    Configuration for training MobileNetV2-YOLOv3 model
    """
    def __init__(self):
        # MobileNetV2 parameters
        # ----------------------
        # Conv and Inverted Residual Parameters: Table-2 (https://arxiv.org/pdf/1801.04381.pdf)
        self.t = [1, 1, 6, 6, 6, 6, 6, 6]  # t: expansion factor
        self.c = [32, 16, 24, 32, 64, 96, 160, 320]  # c: Output channels
        self.n = [1, 1, 2, 3, 4, 3, 3, 1]  # n: Number of times layer is repeated
        self.s = [2, 1, 2, 2, 2, 1, 2, 1]  # s: Stride
        # Width multiplier: Controls the width of the network
        self.alpha = 1.0
        
config = Config()

# MobileNet Pre-trained on ImageNet - Weight Transfer


[Example](https://github.com/lizhengwei1992/mobilenetv2_deeplabv3_pytorch/blob/master/utils.py)

In [4]:
# Model pre-trained on ImageNet weights
trained_state_dict = torch.load('./MobileNetV2.pth.tar')
trained_keys = list(trained_state_dict['state_dict'].keys())
print('Number of keys (pre-trained): ', len(trained_keys))
trained_weights = list(trained_state_dict['state_dict'].values())

Number of keys (pre-trained):  267


In [5]:
# Modified MobileNet
model = MobileNetV2(params=config)
model_keys = list(model.state_dict().keys())
print('Number of keys (modified): ', len(model_keys))
model_weights = list(model.state_dict().values())

Number of keys (modified):  330


In [6]:
ignore_indices = [i for i, k in enumerate(model_keys) if k.endswith('num_batches_tracked')]
ignore_indices += [114, 115, 116, 117, 118]  # layer3_out weights not in MobileNetV2 pretrained on ImageNet
ignore_indices += [246, 247, 248, 249, 250]  # layer5_out weights not in MobileNetV2 pretrained on ImageNet

print(330 - len(ignore_indices))
model_keys[ignore_indices[0]]

265


'layer0.0.conv.1.num_batches_tracked'

In [7]:
# for i, w in enumerate(model_weights):
#     if i not in ignore_indices:
#         if len(list(w.size())) > 1:
#             print(i, w.size(), '---> ', model_keys[i])  # 114 to 118? 246 to 250

In [8]:
# for i, w in enumerate(trained_weights[:-2]):
#     if len(list(w.size())) > 1:
#         print(i, w.size(), '---> ', trained_keys[i])

In [9]:
_model_keys = []
_model_key_ids = []
for i, w in enumerate(model_weights):
    if i not in ignore_indices:
        _model_keys.append(model_keys[i])
        _model_key_ids.append(i)
        print(i, w.size(), '---> ', model_keys[i])

print()
print('Number of weights: ', len(_model_keys))

0 torch.Size([32, 3, 3, 3]) --->  layer0.0.conv.0.weight
1 torch.Size([32]) --->  layer0.0.conv.1.weight
2 torch.Size([32]) --->  layer0.0.conv.1.bias
3 torch.Size([32]) --->  layer0.0.conv.1.running_mean
4 torch.Size([32]) --->  layer0.0.conv.1.running_var
6 torch.Size([32, 32, 1, 1]) --->  layer1.0.block.0.weight
7 torch.Size([32]) --->  layer1.0.block.1.weight
8 torch.Size([32]) --->  layer1.0.block.1.bias
9 torch.Size([32]) --->  layer1.0.block.1.running_mean
10 torch.Size([32]) --->  layer1.0.block.1.running_var
12 torch.Size([32, 1, 3, 3]) --->  layer1.0.block.3.weight
13 torch.Size([32]) --->  layer1.0.block.4.weight
14 torch.Size([32]) --->  layer1.0.block.4.bias
15 torch.Size([32]) --->  layer1.0.block.4.running_mean
16 torch.Size([32]) --->  layer1.0.block.4.running_var
18 torch.Size([16, 32, 1, 1]) --->  layer1.0.block.6.weight
19 torch.Size([16]) --->  layer1.0.block.7.weight
20 torch.Size([16]) --->  layer1.0.block.7.bias
21 torch.Size([16]) --->  layer1.0.block.7.running_

291 torch.Size([960]) --->  layer6.2.block.1.running_mean
292 torch.Size([960]) --->  layer6.2.block.1.running_var
294 torch.Size([960, 1, 3, 3]) --->  layer6.2.block.3.weight
295 torch.Size([960]) --->  layer6.2.block.4.weight
296 torch.Size([960]) --->  layer6.2.block.4.bias
297 torch.Size([960]) --->  layer6.2.block.4.running_mean
298 torch.Size([960]) --->  layer6.2.block.4.running_var
300 torch.Size([160, 960, 1, 1]) --->  layer6.2.block.6.weight
301 torch.Size([160]) --->  layer6.2.block.7.weight
302 torch.Size([160]) --->  layer6.2.block.7.bias
303 torch.Size([160]) --->  layer6.2.block.7.running_mean
304 torch.Size([160]) --->  layer6.2.block.7.running_var
306 torch.Size([960, 160, 1, 1]) --->  layer7.0.block.0.weight
307 torch.Size([960]) --->  layer7.0.block.1.weight
308 torch.Size([960]) --->  layer7.0.block.1.bias
309 torch.Size([960]) --->  layer7.0.block.1.running_mean
310 torch.Size([960]) --->  layer7.0.block.1.running_var
312 torch.Size([960, 1, 3, 3]) --->  layer7.0.b

In [10]:
_trained_keys = []
_trained_key_ids = []
for i, w in enumerate(trained_weights[:-2]):
    _trained_keys.append(trained_keys[i])
    _trained_key_ids.append(i)
    print(i, w.size(), '---> ', trained_keys[i])
    
print()
print('Number of weights: ', len(_trained_keys))

0 torch.Size([32, 3, 3, 3]) --->  module.conv1.weight
1 torch.Size([32]) --->  module.bn1.weight
2 torch.Size([32]) --->  module.bn1.bias
3 torch.Size([32]) --->  module.bn1.running_mean
4 torch.Size([32]) --->  module.bn1.running_var
5 torch.Size([32, 32, 1, 1]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.conv1.weight
6 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.weight
7 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.bias
8 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.running_mean
9 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn1.running_var
10 torch.Size([32, 1, 3, 3]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.conv2.weight
11 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn2.weight
12 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.bn2.bias
13 torch.Size([32]) 

90 torch.Size([32, 192, 1, 1]) --->  module.bottlenecks.Bottlenecks_2.LinearBottleneck2_2.conv3.weight
91 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_2.LinearBottleneck2_2.bn3.weight
92 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_2.LinearBottleneck2_2.bn3.bias
93 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_2.LinearBottleneck2_2.bn3.running_mean
94 torch.Size([32]) --->  module.bottlenecks.Bottlenecks_2.LinearBottleneck2_2.bn3.running_var
95 torch.Size([192, 32, 1, 1]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.conv1.weight
96 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.weight
97 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.bias
98 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.running_mean
99 torch.Size([192]) --->  module.bottlenecks.Bottlenecks_3.LinearBottleneck3_0.bn1.running_var
100 torch.Size([192, 1, 3, 3]) --->  module.bottlene

In [11]:
trained_key_map = dict(zip(_trained_keys, _model_keys))
trained_key_ids_map = dict(zip(_trained_key_ids, _model_key_ids))

for train_key_id, model_key_id in trained_key_ids_map.items():
    assert trained_weights[train_key_id].shape == model_weights[model_key_id].shape
    
for train_key, model_key in trained_key_map.items():
    assert trained_state_dict['state_dict'][train_key].shape == model.state_dict()[model_key].shape

# Transfer Weights and Save

In [12]:
# Model pre-trained on ImageNet weights
pre_trained_state = torch.load('./MobileNetV2.pth.tar')['state_dict']

# Modified MobileNet
mobile_net = MobileNetV2(params=config)
mobile_net_state = mobile_net.state_dict()

for train_key, model_key in trained_key_map.items():
    mobile_net_state[model_key] = pre_trained_state[train_key]
    
mobile_net.load_state_dict(mobile_net_state)
torch.save(mobile_net.state_dict(), 'MobileNetV2-Pretrained-Weights.pth.tar')

In [13]:
print(pre_trained_state['module.bottlenecks.Bottlenecks_0.LinearBottleneck0_0.conv2.weight'][1,...])
print()
print(mobile_net_state['layer1.0.block.3.weight'][1, ...])
print()
print(mobile_net.state_dict()['layer1.0.block.3.weight'][1, ...])

tensor([[[ 0.0176,  0.0243, -0.0063],
         [ 0.0567, -0.8243, -0.0080],
         [-0.1055,  0.6830,  0.0980]]], device='cuda:1')

tensor([[[ 0.0176,  0.0243, -0.0063],
         [ 0.0567, -0.8243, -0.0080],
         [-0.1055,  0.6830,  0.0980]]], device='cuda:1')

tensor([[[ 0.0176,  0.0243, -0.0063],
         [ 0.0567, -0.8243, -0.0080],
         [-0.1055,  0.6830,  0.0980]]])


# Check if Weights load

In [14]:
net = MobileNetV2(params=config)
net_state = torch.load('./MobileNetV2-Pretrained-Weights.pth.tar')
print(net_state['layer1.0.block.3.weight'][1, ...])
net.load_state_dict(net_state)

tensor([[[ 0.0176,  0.0243, -0.0063],
         [ 0.0567, -0.8243, -0.0080],
         [-0.1055,  0.6830,  0.0980]]])
