<a href="https://colab.research.google.com/github/ShanmukhaManoj11/DL_experiments/blob/master/MobileNetV2_SSDLite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.nn.init as init

import cv2
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import os
import time
import sys
import itertools
from math import sqrt
import xml.etree.ElementTree as ET

# MobileNetV2-SSDLite

Ref: https://github.com/qfgaohao/pytorch-ssd

## nn Module utils

In [2]:
def conv2d_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1, use_bias=False, use_batch_norm=True):
  '''
  @ combined util to apply conv2d -> batch norm 2d (No activation)
  input:
    1. in_channels: number of input channels
    2. out_channels: number of output channels 
    3. kernel_size: kernel size for conv2d operation                                                                            
    4. stride: stride for conv2d                                                                                                
    5. padding: padding value for conv2d operation                                                                              
    6. groups: groups value for defining connection between input and output channels                                           
               when groups = in_channels, performs depthwise operation                                                         
    7. use_bias: boolean variable to specify if bias is needed in the conv2d operation                                          
    8. use_batch_norm: boolean variable to specify if batch norm is applied after the conv2d operation                          
  '''
  if use_batch_norm:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=use_bias),
        nn.BatchNorm2d(out_channels)
    )
  else:
    nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=use_bias)

def conv2d_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1, use_bias=False, use_batch_norm=True, onnx_compatible=False):
  '''
  @ combined util to apply conv2d -> batch norm 2d -> relu
  input:
    1. in_channels: number of input channels
    2. out_channels: number of output channels 
    3. kernel_size: kernel size for conv2d operation                                                                            
    4. stride: stride for conv2d                                                                                                
    5. padding: padding value for conv2d operation                                                                              
    6. groups: groups value for defining connection between input and output channels                                           
               when groups = in_channels, performs depthwise operation                                                         
    7. use_bias: boolean variable to specify if bias is needed in the conv2d operation                                          
    8. use_batch_norm: boolean variable to specify if batch norm is applied after the conv2d operation   
    9. onnx_compatible: boolean variable that specifies the use of ReLU or ReLU6                       
  '''
  ReLU = nn.ReLU if onnx_compatible else nn.ReLU6

  if use_batch_norm:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=use_bias),
        nn.BatchNorm2d(out_channels),
        ReLU(inplace=True)
    )
  else:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=use_bias),
        ReLU(inplace=True)
    )

### Separable Convolution

In [3]:
def SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, onnx_compatible=False):
  '''
  Conv2d as depthwise and pointwise convolution operations
  '''
  return nn.Sequential(
      # depthwise operation
      conv2d_bn_relu(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, use_bias=True, use_batch_norm=True, onnx_compatible=onnx_compatible),
      # pointwise linear operation
      nn.Conv2d(in_channels, out_channels, 1)
  )

### Inverted residual bottleneck module

In [4]:
class InvertedResidualBottleneck(nn.Module):
  '''
  @ Inverted residual bottleneck module - basic block of the mobilenetV2 architecture
  '''
  def __init__(self, in_channels, out_channels, stride, expansion_ratio, use_batch_norm=True, onnx_compatible=False):
    super(InvertedResidualBottleneck, self).__init__()
    hidden_channels = round(expansion_ratio*in_channels)
    self.residual_connection = (stride==1 and in_channels==out_channels)

    if expansion_ratio==1:
      self.operation=nn.Sequential(
          # depthwise operation: notice groups=in_channels
          conv2d_bn_relu(in_channels, in_channels, 3, stride, 1, groups=in_channels, use_bias=False, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible),
          # pointwise "linear" operation: no ReLU activation
          conv2d_bn(in_channels, out_channels, 1, 1, 0, use_bias=False)
      )
    else:
      self.operation=nn.Sequential(
          # pointwise operation for expansion based on expansion_ratio
          conv2d_bn_relu(in_channels, hidden_channels, 1, 1, 0, use_bias=False, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible),
          # depthwise operation
          conv2d_bn_relu(hidden_channels, hidden_channels, 3, stride, 1, groups=hidden_channels, use_bias=False, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible),
          # pointwise "linear" operation
          conv2d_bn(hidden_channels, out_channels, 1, 1, 0, use_bias=False)
      )
  
  def forward(self, x):
    if self.residual_connection:
      return x + self.operation(x)
    else:
      return self.operation(x)

## MobileNetV2

In [5]:
class MobileNetV2(nn.Module):
  '''
  @ MobileNetV2 architecture
  https://arxiv.org/pdf/1801.04381.pdf
  '''
  def __init__(self, n_classes=1000, width_multiplier=1.0, dropout_ratio=0.2, use_batch_norm=True, onnx_compatible=False):
    super(MobileNetV2, self).__init__()
    self.n_classes = n_classes
    # layout defines the structure of the inverted residual bottleneck blocks
    layout = [
              # t, c, n, s
              [1, 16, 1, 1],
              [6, 24, 2, 2],
              [6, 32, 3, 2],
              [6, 64, 4, 2],
              [6, 96, 3, 1],
              [6, 160, 3, 2],
              [6, 320, 1, 1]
    ]

    # add first standard conv(3x3)+bn+relu module converting input 3 channels to 32*width_multiplier channels
    self.features = [conv2d_bn_relu(3, int(32*width_multiplier), 3, 2, 1, groups=1, use_bias=False, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible)]
    # add inverted residual bottleneck modules
    in_channels = int(32*width_multiplier)
    for t, c, n, s in layout:
      out_channels = int(c*width_multiplier)
      for i in range(n):
        if i==0:
          self.features.append(InvertedResidualBottleneck(in_channels, out_channels, s, t, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible))
        else:
          self.features.append(InvertedResidualBottleneck(in_channels, out_channels, 1, t, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible))
        in_channels = out_channels
    # add standard conv(1x1)+bn+relu module
    out_channels = int(1280*width_multiplier) if width_multiplier>1.0 else 1280
    self.features.append(conv2d_bn_relu(in_channels, out_channels, 1, 1, 0, groups=1, use_bias=False, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible))
    self.features = nn.ModuleList(self.features)

    self.classifier = nn.Sequential(
        nn.Dropout(dropout_ratio),
        nn.Conv2d(out_channels, n_classes, 1, stride=1, padding=0)
    )

    self.initialize_weights()

  def forward(self, x):
    for op in self.features:
      x=op(x)
    x = nn.AvgPool2d(7)(x)
    x = self.classifier(x).view(-1,self.n_classes)
    return x
  
  def initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, sqrt(2.0/n))
        if m.bias is not None:
          m.bias.data.zero_()
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.01)
        m.bias.data.zero_()

## SSDLite with MobileNetV2 base

In [6]:
def mobilenetv2_base(width_multiplier=1.0, use_batch_norm=True, onnx_compatible=False):
  '''
  @ return "features" from MobileNetV2
  '''
  return MobileNetV2(width_multiplier=width_multiplier, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible).features

def auxiliary_layers(in_channels):
  '''
  @ return auxiliary modules to be sequentially added at the end of the base layers
  input: in_channels = out_channels of the last layer of the base layers
  '''
  return nn.ModuleList([
      InvertedResidualBottleneck(in_channels, 512, 2, expansion_ratio=0.2),
      InvertedResidualBottleneck(512, 256, 2, expansion_ratio=0.25),
      InvertedResidualBottleneck(256, 256, 2, expansion_ratio=0.5),
      InvertedResidualBottleneck(256, 64, 2, expansion_ratio=0.25)
  ])

def predictors(base_layers, n_classes=21, onnx_compatible=False):
  '''
  @ return modules responsible for predicting localization offsets and class confidences as tuple of module lists
  With MobileNetV2 as base network for SSD, output 15th (of the base layers) layer's expansion layer and last layer's output are used as feature maps for predictions
    along with the auxiliary layers 
  '''
  in_channels1 = base_layers[14].operation[0][0].out_channels
  in_channels2 = base_layers[-1][0].out_channels
  loc_layers = nn.ModuleList([
      SeparableConv2d(in_channels1, 6*4, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(in_channels2, 6*4, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(512, 6*4, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(256, 6*4, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(256, 6*4, 3, 1, 1, onnx_compatible=onnx_compatible),
      nn.Conv2d(64, 6*4, 1)
  ])

  conf_layers = nn.ModuleList([
      SeparableConv2d(in_channels1, 6*n_classes, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(in_channels2, 6*n_classes, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(512, 6*n_classes, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(256, 6*n_classes, 3, 1, 1, onnx_compatible=onnx_compatible),
      SeparableConv2d(256, 6*n_classes, 3, 1, 1, onnx_compatible=onnx_compatible),
      nn.Conv2d(64, 6*n_classes, 1)
  ])
  return loc_layers, conf_layers

In [7]:
class SSDLite(nn.Module):
  '''
  @ SSD network with "features" from MobileNetV2 as base layers
  '''
  def __init__(self, base_layers, auxiliary_layers, loc_layers, conf_layers, n_classes=21):
    '''
    @ constructor
    inputs:
      1. base_layers: base layers (features) from MobileNetV2 as nn.ModuleList
      2. auxiliary_layers: extra layers sequentially added to last base layer as nn.ModuleList
      3. loc_layers: operations responsible for predicting localization offsets as nn.ModuleList
      4. conf_layers: operations responsible for predicting class confidences as nn.ModuleList
    '''
    super(SSDLite, self).__init__()
    self.base_layers=base_layers
    self.auxiliary_layers=auxiliary_layers
    self.loc_layers=loc_layers
    self.conf_layers=conf_layers
    self.n_classes=n_classes
  
  def forward(self, x):
    '''
    feature maps used for predicitng localization offsets and classes
      1. output from expansion of 15th layer from base layers (NOTE: index starts from 0)
      2. output from last layer of base layers
      3. output from all auxiliary layers
    '''
    pred_feature_maps=[]
    # forward pass on base layers till 15th layer and cache output of layer 15's expansion operation to use for predictions in pred_feature_maps list
    for k in range(14):
      x=self.base_layers[k](x)
    x=self.base_layers[14].operation[0](x)
    pred_feature_maps.append(x)
    # forward pass on remaining layers of base network and cache last layer output to use for predictions in pred_feature_maps list
    for op in self.base_layers[14].operation[1:]:
      x=op(x)
    for k in range(15, len(self.base_layers)):
      x=self.base_layers[k](x)
    pred_feature_maps.append(x)
    # forward pass on auxiliary layers and cache intermediate layers to be used for prediction in pred_feature_maps list
    for aux in self.auxiliary_layers:
      x=aux(x)
      pred_feature_maps.append(x)
    # use cached layers for prediction
    loc, conf=[],[]
    for feature_map, loc_layer, conf_layer in zip(pred_feature_maps, self.loc_layers, self.conf_layers):
      loc.append(loc_layer(feature_map).permute(0,2,3,1).contiguous())
      conf.append(conf_layer(feature_map).permute(0,2,3,1).contiguous())
    loc=torch.cat([o.view(o.size(0), -1) for o in loc], 1)
    conf=torch.cat([o.view(o.size(0), -1) for o in conf], 1)
    return loc.view(loc.size(0),-1,4), conf.view(conf.size(0),-1,self.n_classes)

In [8]:
def build_MobileNetV2_SSDLite(n_classes=21, width_multiplier=1.0, use_batch_norm=True, onnx_compatible=False):
  base_layers = mobilenetv2_base(width_multiplier=width_multiplier, use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible)

  in_channels = base_layers[-1][0].out_channels
  aux_layers = auxiliary_layers(in_channels)

  loc_layers, conf_layers = predictors(base_layers, n_classes=n_classes, onnx_compatible=onnx_compatible)

  return SSDLite(base_layers, aux_layers, loc_layers, conf_layers, n_classes=n_classes)

In [9]:
ssd=build_MobileNetV2_SSDLite()

In [11]:
x=torch.rand((2,3,300,300))
print(x.shape)
loc_pred, conf_pred=ssd(x)
print(loc_pred.shape, conf_pred.shape)

torch.Size([2, 3, 300, 300])
torch.Size([2, 3000, 4]) torch.Size([2, 3000, 21])
