In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Depthwise(nn.Module):
  def __init__(self,iin,out,stride):
    super().__init__()
    self.depth=nn.Conv2d(
        iin,iin,
        kernel_size=3,
        stride=stride,
        padding=1,
        groups=iin
    )
    self.point=nn.Conv2d(
        iin,out,
        kernel_size=1,
        stride=1,
        padding=0
    )
  def forward(self,x):
    return self.point(self.depth(x))

In [3]:
class MObileNet(nn.Module):
  def __init__(self,num=1000):
    super().__init__()
    self.conv1=nn.Conv2d(3,32,kernel_size=3,stride=2,padding=1)

    self.dw2=Depthwise(32,64,1)
    self.dw3=Depthwise(64,128,2)
    self.dw4=Depthwise(128,128,1)
    self.dw5=Depthwise(128,256,2)
    self.dw6=Depthwise(256,256,1)
    self.dw7=Depthwise(256,512,2)

    self.dw8=Depthwise(512,512,1)
    self.dw9=Depthwise(512,512,1)
    self.dw10=Depthwise(512,512,1)
    self.dw11=Depthwise(512,512,1)
    self.dw12=Depthwise(512,512,1)

    self.dw13=Depthwise(512,1024,2)
    self.dw14=Depthwise(1024,1024,1)

    self.avgpool=nn.AdaptiveAvgPool2d(1)
    self.fc=nn.Linear(1024,num)

  def forward(self,x):
    x = self.dw_conv2(x)
    x = F.relu(x)
    x = self.dw_conv3(x)
    x = F.relu(x)
    x = self.dw_conv4(x)
    x = F.relu(x)
    x = self.dw_conv5(x)
    x = F.relu(x)
    x = self.dw_conv6(x)
    x = F.relu(x)
    x = self.dw_conv7(x)
    x = F.relu(x)

    x = self.dw_conv8(x)
    x = F.relu(x)
    x = self.dw_conv9(x)
    x = F.relu(x)
    x = self.dw_conv10(x)
    x = F.relu(x)
    x = self.dw_conv11(x)
    x = F.relu(x)
    x = self.dw_conv12(x)
    x = F.relu(x)

    x = self.dw_conv13(x)
    x = F.relu(x)
    x = self.dw_conv14(x)
    x = F.relu(x)

    x = self.avg_pool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

In [4]:
mobilenet=MObileNet(num=1000)

In [5]:
mobilenet

MObileNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (dw2): Depthwise(
    (depth): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
    (point): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
  )
  (dw3): Depthwise(
    (depth): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
    (point): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
  )
  (dw4): Depthwise(
    (depth): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
    (point): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
  )
  (dw5): Depthwise(
    (depth): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
    (point): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (dw6): Depthwise(
    (depth): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
    (point): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (dw7): Dep