In [1]:
import torch
from torch import nn
from torchinfo import summary

In [2]:
class DepSepConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU()
        )

        self.seperate = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.seperate(x)
        return x

In [7]:
class MobileNetV1(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.conv2 = DepSepConv(32, 64, stride=1)

        self.conv3 = nn.Sequential(
            DepSepConv(64, 128, stride=2),
            DepSepConv(128, 128, stride=1)
        )

        self.conv4 = nn.Sequential(
            DepSepConv(128, 256, stride=2),
            DepSepConv(256, 256, stride=1)
        )

        self.conv5 = nn.Sequential(
            DepSepConv(256, 512, stride=2),
            *[DepSepConv(512, 512, stride=1) for _ in range(5)]
        )

        self.conv6 = nn.Sequential(
            DepSepConv(512, 1024, stride=2),
            DepSepConv(1024, 1024, stride=1)
        )

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)

        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

model = MobileNetV1()
summary(model, (2, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
MobileNetV1                              [2, 1000]                 --
├─Sequential: 1-1                        [2, 32, 112, 112]         --
│    └─Conv2d: 2-1                       [2, 32, 112, 112]         864
│    └─BatchNorm2d: 2-2                  [2, 32, 112, 112]         64
│    └─ReLU: 2-3                         [2, 32, 112, 112]         --
├─DepSepConv: 1-2                        [2, 64, 112, 112]         --
│    └─Sequential: 2-4                   [2, 32, 112, 112]         --
│    │    └─Conv2d: 3-1                  [2, 32, 112, 112]         288
│    │    └─BatchNorm2d: 3-2             [2, 32, 112, 112]         64
│    │    └─ReLU: 3-3                    [2, 32, 112, 112]         --
│    └─Sequential: 2-5                   [2, 64, 112, 112]         --
│    │    └─Conv2d: 3-4                  [2, 64, 112, 112]         2,048
│    │    └─BatchNorm2d: 3-5             [2, 64, 112, 112]         128
│    │   