In [1]:
# Imports

import torch
import torch.nn as nn

In [2]:
# Utility functions

def depthwise_conv(in_ch, stride):
    """
    Depthwise convolution with BatchNorm and ReLU6. 
    
    https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315
    
    Args:
        in_ch (int): Number of channels in the input feature
        stride (int): Stride of the convolution
    """
    return nn.Sequential(nn.Conv2d(in_ch, in_ch, 3, stride, padding=1, groups=in_ch, bias=False),
                         nn.BatchNorm2d(in_ch),
                         nn.ReLU6(inplace=True))

def pointwise_conv(in_ch, out_ch):
    """
    Pointwise convolution with BatchNorm and ReLU6.
    
    Args:
        in_ch (int): Number of channels in the input feature
        out_ch (int): Number of channels in the output feature
    """
    return nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, 1, bias=False),
                         nn.BatchNorm2d(out_ch),
                         nn.ReLU6(inplace=True))

def depthwise_separable_conv(in_ch, out_ch, stride):
    """
    Depthwise separable convolution with BatchNorm and ReLU6.
    
    https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315
    
    Args:
        in_ch (int): Number of channels in the input feature
        out_ch (int): Number of channels in the output feature
        stride (int): Stride of the convolution
    """
    return nn.Sequential(
        # Depthwise
        nn.Conv2d(in_ch, in_ch, 3, stride, padding=1, groups=in_ch, bias=False),
        nn.BatchNorm2d(in_ch), 
        nn.ReLU6(inplace=True),
        
        # Pointwise
        nn.Conv2d(in_ch, out_ch, 1, 1, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.ReLU6(inplace=True)
    )

def conv_bn_relu(in_ch, out_ch, stride):
    """
    Convolution with BatchNorm and ReLU6.
    
    Args:
        in_ch (int): Number of channels in the input feature
        out_ch (int): Number of channels in the output feature
        stride (int): Stride of the convolution
    """
    return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride, padding=1, bias=False),
                         nn.BatchNorm2d(out_ch), 
                         nn.ReLU6(inplace=True))

In [3]:
# Model

# Class structure based on: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py


layer_info = [(32, 64, 1), (64, 128, 2), (128, 128, 1), (128, 256, 2),
              (256, 256, 1), (256, 512, 2), (512, 512, 1), (512, 512, 1),
              (512, 512, 1), (512, 512, 1), (512, 512, 1), (512, 1024, 2),
              (1024, 1024, 2)]

class DepthwiseSeparable(nn.Module):
    """
    Depthwise separable convolution block with BatchNorm and ReLU6.
    """
    def __init__(self, in_planes, out_planes, stride):
        super(DepthwiseSeparable, self).__init__()
        
        # Depthwise conv
        self.dw_conv = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, 
                                 padding=1, groups=in_planes, bias=False)
        self.dw_bn = nn.BatchNorm2d(in_planes)
        self.dw_relu = nn.ReLU6(inplace=True)
        
        # Pointwise conv
        self.pw_conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
                                 bias=False)
        
        self.pw_bn = nn.BatchNorm2d(out_planes)
        self.pw_relu = nn.ReLU6(inplace=True)
        
    def forward(self, x):
        out = self.dw_conv(x)
        out = self.dw_bn(out)
        out = self.dw_relu(out)
        
        out = self.pw_conv(out)
        out = self.pw_bn(out)
        out = self.pw_relu(out)
        
        return out

class MobileNetV1(nn.Module):
    def __init__(self, num_classes=1000, layer_info=None):
        super(MobileNetV1, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU6(inplace=True)
        self.layers = self._make_layers(layer_info)
        self.avg_pool = nn.AvgPool2d(7)
        self.fc = nn.Linear(1024, num_classes)
        
    def _make_layers(self, layer_info):
        layers = []
        for in_ch, out_ch, stride in layer_info:
            layers.append(DepthwiseSeparable(in_ch, out_ch, stride))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.layers(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [4]:
# Test
net = MobileNetV1(layer_info=layer_info)
x = torch.randn(1, 3, 224, 224)
y = net(x)
print(y.size())

torch.Size([1, 1000])
