In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms, datasets

In [None]:
class MobileNet(nn.Module):

    def __init__(self) -> None:
        super(MobileNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU())

        self.conv2 = self.conv_dw(32, 32, 1)
        self.conv3 = self.conv_dw(32, 64, 2)
        self.conv4 = self.conv_dw(64, 64, 1)
        self.conv5 = self.conv_dw(64, 128, 2)
        self.conv6 = self.conv_dw(128, 128, 1)
        self.conv7 = self.conv_dw(128, 256, 2)
        self.conv8 = self.conv_dw(256, 256, 1)
        self.conv9 = self.conv_dw(256, 512, 2)

    @staticmethod
    def conv_dw(in_channel, out_channel, stride):
        return nn.Sequential(
            nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=stride, padding=1, groups=in_channel, bias=False),
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),
            nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channel),
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        out = self.conv8(out)
        out = self.conv9(out)
        out = F.avg_pool2d(out, 2)
        return out


In [None]:
# the inception block
class Inception(nn.Module):
    
        def __init__(self, in_channel, out_channel):
            super(Inception, self).__init__()
            self.branch1 = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
            )
            self.branch2 = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
                nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
            )
            self.branch3 = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
                nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
                nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
            )
            self.branch4 = nn.Sequential(
                nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
                nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
            )
    
        def forward(self, x):
            out1 = self.branch1(x)
            out2 = self.branch2(x)
            out3 = self.branch3(x)
            out4 = self.branch4(x)
            out = torch.cat((out1, out2, out3, out4), 1)
            return out
