In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [3]:
# Depth Wise Separable Convolution
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/6
class conv_dw(nn.Module):
  def __init__(self, in_channels, out_channels, stride, kernel_size=3, bias=True):
    super(conv_dw, self).__init__()
    self.depthwise = nn.Conv2d(
        in_channels, 
        out_channels, 
        stride=stride,
        kernel_size=kernel_size, 
        padding=1, 
        groups=in_channels,
        bias=bias
        )
    self.pointwise = nn.Conv2d(
        in_channels, 
        out_channels, 
        kernel_size=1,
        bias=bias
        )
    self.ReLU = nn.ReLU()
    self.bn1 = nn.BatchNorm2d(in_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)

  def forward(self, x):
    out = self.depthwise(x)
    out = self.bn1(out)
    out = self.ReLU(out)
    out = self.pointwise(out)
    out = self.bn2(out)
    out = self.ReLU(out)
    return out

In [6]:
class MobileNet(nn.Module):
    def __init__(self, in_channel, out_classes):
        super(MobileNet, self).__init__()
        self.conv_model = nn.Sequential(
            nn.Conv2d(in_channel, 32, 3, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            conv_dw(32, 64, 1),
            conv_dw(64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7)
        )

        self.classifer = nn.Sequential(
            nn.Linear(1024, out_classes),
            nn.Softmax()
        )

    def forward(self, x):
        out = nn.conv_model(x)
        out = nn.classifer(out)
        return out
