In [1]:
# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
# Codeblock 2
BATCH_SIZE  = 1
IMAGE_SIZE  = 224
IN_CHANNELS = 3
NUM_CLASSES = 1000
ALPHA       = 1

In [3]:
# Codeblock 3
class FirstConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, 
                              out_channels=int(32*ALPHA),    #(1)
                              kernel_size=3,    #(2)
                              stride=2,         #(3)
                              padding=1,        #(4)
                              bias=False)       #(5)
        self.bn = nn.BatchNorm2d(num_features=int(32*ALPHA))
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        return x

In [4]:
# Codeblock 4
first_conv = FirstConv()
x = torch.randn((1, 3, 224, 224))

out = first_conv(x)
out.shape

torch.Size([1, 32, 112, 112])

In [8]:
# Codeblock 5
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):  #(1)
        super().__init__()
        
        in_channels  = int(in_channels*ALPHA)    #(2)
        out_channels = int(out_channels*ALPHA)   #(3)       
        
        if downsample:    #(4)
            stride = 2
        else:
            stride = 1
        
        self.dwconv = nn.Conv2d(in_channels=in_channels,
                                out_channels=in_channels,     #(5)
                                kernel_size=3,                #(6)
                                stride=stride,                #(7)
                                padding=1,
                                groups=in_channels,           #(8)
                                bias=False)
        self.bn0 = nn.BatchNorm2d(num_features=in_channels)   #(9)
        
        self.pwconv = nn.Conv2d(in_channels=in_channels,   
                                out_channels=out_channels,    #(10)
                                kernel_size=1,                #(11)
                                stride=1,                     #(12)
                                padding=0,                    #(13)
                                groups=1,                     #(14)
                                bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)  #(15)
        
        self.relu = nn.ReLU()    #(16)

    def forward(self, x):
        #print(f'original\t: {x.size()}')
        
        x = self.relu(self.bn0(self.dwconv(x)))
        #print(f'after dw conv\t: {x.size()}')
        
        x = self.relu(self.bn1(self.pwconv(x)))
        #print(f'after pw conv\t: {x.size()}')
        
        return x

In [6]:
# Codeblock 6
depthwise_sep_conv = DepthwiseSeparableConv(in_channels=32,     #(1)
                                            out_channels=64,    #(2)
                                            downsample=False)   #(3)
x = torch.randn((1, int(32*ALPHA), 112, 112))                   #(4)

x = depthwise_sep_conv(x)

original	: torch.Size([1, 32, 112, 112])
after dw conv	: torch.Size([1, 32, 112, 112])
after pw conv	: torch.Size([1, 64, 112, 112])


In [7]:
# Codeblock 7
depthwise_sep_conv = DepthwiseSeparableConv(in_channels=64, 
                                            out_channels=128,
                                            downsample=True)

x = depthwise_sep_conv(x)

original	: torch.Size([1, 64, 112, 112])
after dw conv	: torch.Size([1, 64, 56, 56])
after pw conv	: torch.Size([1, 128, 56, 56])


In [11]:
# Codeblock 8a
class MobileNetV1(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.first_conv = FirstConv()    #(1)
        
        self.depthwise_sep_conv0 = DepthwiseSeparableConv(in_channels=32, 
                                                          out_channels=64)
        
        self.depthwise_sep_conv1 = DepthwiseSeparableConv(in_channels=64, 
                                                          out_channels=128, 
                                                          downsample=True)
        
        self.depthwise_sep_conv2 = DepthwiseSeparableConv(in_channels=128, 
                                                          out_channels=128)
        
        self.depthwise_sep_conv3 = DepthwiseSeparableConv(in_channels=128, 
                                                          out_channels=256, 
                                                          downsample=True)
        
        self.depthwise_sep_conv4 = DepthwiseSeparableConv(in_channels=256, 
                                                          out_channels=256)
        
        self.depthwise_sep_conv5 = DepthwiseSeparableConv(in_channels=256, 
                                                          out_channels=512, 
                                                          downsample=True)
        
        self.depthwise_sep_conv6 = nn.ModuleList(
            [DepthwiseSeparableConv(in_channels=512, out_channels=512) for _ in range(5)]
        )
        
        self.depthwise_sep_conv7 = DepthwiseSeparableConv(in_channels=512, 
                                                          out_channels=1024, 
                                                          downsample=True)
        
        self.depthwise_sep_conv8 = DepthwiseSeparableConv(in_channels=1024,  #(2)
                                                          out_channels=1024)
        
        num_out_channels = self.depthwise_sep_conv8.pwconv.out_channels      #(3)
        
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))      #(4)
        self.fc = nn.Linear(in_features=num_out_channels,           #(5)
                            out_features=NUM_CLASSES)
        self.softmax = nn.Softmax(dim=1)                            #(6)
        
        
# Codeblock 8b
    def forward(self, x):
        x = self.first_conv(x)
        #print(f"after first_conv\t\t: {x.shape}")
        
        x = self.depthwise_sep_conv0(x)
        #print(f"after depthwise_sep_conv0\t: {x.shape}")
        
        x = self.depthwise_sep_conv1(x)
        #print(f"after depthwise_sep_conv1\t: {x.shape}")
        
        x = self.depthwise_sep_conv2(x)
        #print(f"after depthwise_sep_conv2\t: {x.shape}")
        
        x = self.depthwise_sep_conv3(x)
        #print(f"after depthwise_sep_conv3\t: {x.shape}")
        
        x = self.depthwise_sep_conv4(x)
        #print(f"after depthwise_sep_conv4\t: {x.shape}")
        
        x = self.depthwise_sep_conv5(x)
        #print(f"after depthwise_sep_conv5\t: {x.shape}")
        
        for i, layer in enumerate(self.depthwise_sep_conv6):
            x = layer(x)
            #print(f"after depthwise_sep_conv6 #{i}\t: {x.shape}")
        
        x = self.depthwise_sep_conv7(x)
        #print(f"after depthwise_sep_conv7\t: {x.shape}")
        
        x = self.depthwise_sep_conv8(x)
        #print(f"after depthwise_sep_conv8\t: {x.shape}")
        
        x = self.avgpool(x)
        #print(f"after avgpool\t\t\t: {x.shape}")
        
        x = torch.flatten(x, start_dim=1)
        #print(f"after flatten\t\t\t: {x.shape}")
        
        x = self.fc(x)
        #print(f"after fc\t\t\t: {x.shape}")
        
        x = self.softmax(x)
        #print(f"after softmax\t\t\t: {x.shape}")
        
        return x

In [10]:
# Codeblock 9
mobilenetv1 = MobileNetV1()
x = torch.randn((BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

out = mobilenetv1(x)

after first_conv		: torch.Size([1, 32, 112, 112])
after depthwise_sep_conv0	: torch.Size([1, 64, 112, 112])
after depthwise_sep_conv1	: torch.Size([1, 128, 56, 56])
after depthwise_sep_conv2	: torch.Size([1, 128, 56, 56])
after depthwise_sep_conv3	: torch.Size([1, 256, 28, 28])
after depthwise_sep_conv4	: torch.Size([1, 256, 28, 28])
after depthwise_sep_conv5	: torch.Size([1, 512, 14, 14])
after depthwise_sep_conv6 #0	: torch.Size([1, 512, 14, 14])
after depthwise_sep_conv6 #1	: torch.Size([1, 512, 14, 14])
after depthwise_sep_conv6 #2	: torch.Size([1, 512, 14, 14])
after depthwise_sep_conv6 #3	: torch.Size([1, 512, 14, 14])
after depthwise_sep_conv6 #4	: torch.Size([1, 512, 14, 14])
after depthwise_sep_conv7	: torch.Size([1, 1024, 7, 7])
after depthwise_sep_conv8	: torch.Size([1, 1024, 7, 7])
after avgpool			: torch.Size([1, 1024, 1, 1])
after flatten			: torch.Size([1, 1024])
after fc			: torch.Size([1, 1000])
after softmax			: torch.Size([1, 1000])


In [12]:
# Codeblock 10
mobilenetv1 = MobileNetV1()
summary(mobilenetv1, input_size=(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

Layer (type:depth-idx)                   Output Shape              Param #
MobileNetV1                              [1, 1000]                 --
├─FirstConv: 1-1                         [1, 32, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 32, 112, 112]         864
│    └─BatchNorm2d: 2-2                  [1, 32, 112, 112]         64
│    └─ReLU: 2-3                         [1, 32, 112, 112]         --
├─DepthwiseSeparableConv: 1-2            [1, 64, 112, 112]         --
│    └─Conv2d: 2-4                       [1, 32, 112, 112]         288
│    └─BatchNorm2d: 2-5                  [1, 32, 112, 112]         64
│    └─ReLU: 2-6                         [1, 32, 112, 112]         --
│    └─Conv2d: 2-7                       [1, 64, 112, 112]         2,048
│    └─BatchNorm2d: 2-8                  [1, 64, 112, 112]         128
│    └─ReLU: 2-9                         [1, 64, 112, 112]         --
├─DepthwiseSeparableConv: 1-3            [1, 128, 56, 56]          --
│    └─Co