In [1]:
# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
# Codeblock 2
BATCH_SIZE        = 1
IMAGE_SIZE        = 224
IN_CHANNELS       = 3
NUM_CLASSES       = 1000
WIDTH_MULTIPLIER  = 1.0

In [3]:
# Codeblock 3
class Conv(nn.Module):
    def __init__(self, first=False):      #(1)
        super().__init__()
        
        if first:
            in_channels = 3               #(2)
            out_channels = int(32*WIDTH_MULTIPLIER)          #(3)
            kernel_size = 3               #(4)
            stride = 2                    #(5)
            padding = 1                   #(6)
        else:
            in_channels  = int(320*WIDTH_MULTIPLIER)         #(7)
            out_channels = int(1280*WIDTH_MULTIPLIER)        #(8)
            kernel_size = 1               #(9)
            stride = 1                    #(10)
            padding = 0                   #(11)
        
        self.conv = nn.Conv2d(in_channels=in_channels,       #(12)
                              out_channels=out_channels, 
                              kernel_size=kernel_size,
                              stride=stride, 
                              padding=padding, 
                              bias=False)
        self.bn = nn.BatchNorm2d(num_features=out_channels)  #(13)
        self.relu6 = nn.ReLU6()           #(14)
    
    def forward(self, x):
        x = self.relu6(self.bn(self.conv(x)))                #(15)
        return x

In [4]:
# Codeblock 4
conv = Conv(first=True)
x = torch.randn(1, 3, 224, 224)

out = conv(x)
out.shape

torch.Size([1, 32, 112, 112])

In [5]:
# Codeblock 5
conv = Conv(first=False)
x = torch.randn(1, int(320*WIDTH_MULTIPLIER), 7, 7)

out = conv(x)
out.shape

torch.Size([1, 1280, 7, 7])

In [8]:
# Codeblock 6
class InvResidualS2(nn.Module):
    def __init__(self, in_channels, out_channels, t):         #(1)
        super().__init__()
        
        in_channels  = int(in_channels*WIDTH_MULTIPLIER)      #(2)
        out_channels = int(out_channels*WIDTH_MULTIPLIER)     #(3)
        
        self.pwconv0 = nn.Conv2d(in_channels=in_channels,     #(4)
                                 out_channels=in_channels*t,
                                 kernel_size=1, 
                                 stride=1, 
                                 bias=False)
        
        self.bn_pwconv0 = nn.BatchNorm2d(num_features=in_channels*t)
        
        self.dwconv = nn.Conv2d(in_channels=in_channels*t,    #(5)
                                out_channels=in_channels*t, 
                                kernel_size=3,                #(6)
                                stride=2, 
                                padding=1,
                                groups=in_channels*t,         #(7)
                                bias=False)
        
        self.bn_dwconv = nn.BatchNorm2d(num_features=in_channels*t)
        
        self.pwconv1 = nn.Conv2d(in_channels=in_channels*t,   #(8)
                                 out_channels=out_channels, 
                                 kernel_size=1, 
                                 stride=1, 
                                 bias=False)
        
        self.bn_pwconv1 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu6 = nn.ReLU6()
    
    def forward(self, x):
        #print('original\t\t:', x.shape)
        
        x = self.pwconv0(x)
        #print('after pwconv0\t\t:', x.shape)
        x = self.bn_pwconv0(x)
        #print('after bn0_pwconv0\t:', x.shape)
        x = self.relu6(x)
        #print('after relu\t\t:', x.shape)
        
        x = self.dwconv(x)
        #print('after dwconv\t\t:', x.shape)
        x = self.bn_dwconv(x)
        #print('after bn_dwconv\t\t:', x.shape)
        x = self.relu6(x)
        #print('after relu\t\t:', x.shape)
        
        x = self.pwconv1(x)
        #print('after pwconv1\t\t:', x.shape)
        x = self.bn_pwconv1(x)
        #print('after bn_pwconv1\t:', x.shape)
        
        return x

In [7]:
# Codeblock 7
inv_residual_s2 = InvResidualS2(in_channels=16, out_channels=24, t=6)
x = torch.randn(1, int(16*WIDTH_MULTIPLIER), 112, 112)

out = inv_residual_s2(x)

original		: torch.Size([1, 16, 112, 112])
after pwconv0		: torch.Size([1, 96, 112, 112])
after bn0_pwconv0	: torch.Size([1, 96, 112, 112])
after relu		: torch.Size([1, 96, 112, 112])
after dwconv		: torch.Size([1, 96, 56, 56])
after bn_dwconv		: torch.Size([1, 96, 56, 56])
after relu		: torch.Size([1, 96, 56, 56])
after pwconv1		: torch.Size([1, 24, 56, 56])
after bn_pwconv1	: torch.Size([1, 24, 56, 56])


In [11]:
# Codeblock 8
class InvResidualS1(nn.Module):
    def __init__(self, in_channels, out_channels, t):
        super().__init__()
        
        in_channels  = int(in_channels*WIDTH_MULTIPLIER)    #(1)
        out_channels = int(out_channels*WIDTH_MULTIPLIER)   #(2)
        
        self.in_channels  = in_channels
        self.out_channels = out_channels
        
        self.pwconv0 = nn.Conv2d(in_channels=in_channels, 
                                 out_channels=in_channels*t, 
                                 kernel_size=1, 
                                 stride=1, 
                                 bias=False)
        
        self.bn_pwconv0 = nn.BatchNorm2d(num_features=in_channels*t)
        
        self.dwconv = nn.Conv2d(in_channels=in_channels*t, 
                                out_channels=in_channels*t, 
                                kernel_size=3, 
                                stride=1,            #(3)
                                padding=1,
                                groups=in_channels*t, 
                                bias=False)
        
        self.bn_dwconv = nn.BatchNorm2d(num_features=in_channels*t)
        
        self.pwconv1 = nn.Conv2d(in_channels=in_channels*t, 
                                 out_channels=out_channels, 
                                 kernel_size=1, 
                                 stride=1, 
                                 bias=False)
        
        self.bn_pwconv1 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu6 = nn.ReLU6()
        
    def forward(self, x):
        
        if self.in_channels == self.out_channels:    #(4)
            residual = x          #(5)
            #print(f'residual\t\t: {residual.size()}')
        
        x = self.pwconv0(x)
        #print('after pwconv0\t\t:', x.shape)
        x = self.bn_pwconv0(x)
        #print('after bn_pwconv0\t:', x.shape)
        x = self.relu6(x)
        #print('after relu\t\t:', x.shape)
        
        x = self.dwconv(x)
        #print('after dwconv\t\t:', x.shape)
        x = self.bn_dwconv(x)
        #print('after bn_dwconv\t\t:', x.shape)
        x = self.relu6(x)
        #print('after relu\t\t:', x.shape)
        
        x = self.pwconv1(x)
        #print('after pwconv1\t\t:', x.shape)
        x = self.bn_pwconv1(x)
        #print('after bn_pwconv1\t:', x.shape)
        
        if self.in_channels == self.out_channels:
            x = x + residual      #(6)
            #print('after summation\t\t:', x.shape)
        
        return x

In [10]:
# Codeblock 9
inv_residual_s1 = InvResidualS1(in_channels=24, out_channels=24, t=6)
x = torch.randn(1, int(24*WIDTH_MULTIPLIER), 56, 56)

out = inv_residual_s1(x)

residual		: torch.Size([1, 24, 56, 56])
after pwconv0		: torch.Size([1, 144, 56, 56])
after bn_pwconv0	: torch.Size([1, 144, 56, 56])
after relu		: torch.Size([1, 144, 56, 56])
after dwconv		: torch.Size([1, 144, 56, 56])
after bn_dwconv		: torch.Size([1, 144, 56, 56])
after relu		: torch.Size([1, 144, 56, 56])
after pwconv1		: torch.Size([1, 24, 56, 56])
after bn_pwconv1	: torch.Size([1, 24, 56, 56])
after summation		: torch.Size([1, 24, 56, 56])


In [14]:
# Codeblock 10
class MobileNetV2(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Input shape: 3x224x224
        self.first_conv = Conv(first=True)
        
        # Input shape: 32x112x112
        self.inv_residual0 = InvResidualS1(in_channels=32, 
                                           out_channels=16, 
                                           t=1)
        
        # Input shape: 16x112x112
        self.inv_residual1 = nn.ModuleList([InvResidualS2(in_channels=16, 
                                                          out_channels=24, 
                                                          t=6)])
        
        self.inv_residual1.append(InvResidualS1(in_channels=24, 
                                                out_channels=24, 
                                                t=6))
        
        # Input shape: 24x56x56
        self.inv_residual2 = nn.ModuleList([InvResidualS2(in_channels=24, 
                                                          out_channels=32, 
                                                          t=6)])
        
        for _ in range(2):
            self.inv_residual2.append(InvResidualS1(in_channels=32, 
                                                    out_channels=32, 
                                                    t=6))
        
        # Input shape: 32x28x28
        self.inv_residual3 = nn.ModuleList([InvResidualS2(in_channels=32, 
                                                          out_channels=64, 
                                                          t=6)])
        
        for _ in range(3):
            self.inv_residual3.append(InvResidualS1(in_channels=64, 
                                                    out_channels=64, 
                                                    t=6))
            
        # Input shape: 64x14x14
        self.inv_residual4 = nn.ModuleList([InvResidualS1(in_channels=64, 
                                                          out_channels=96, 
                                                          t=6)])
        
        for _ in range(2):
            self.inv_residual4.append(InvResidualS1(in_channels=96, 
                                                    out_channels=96, 
                                                    t=6))
        
        
        # Input shape: 96x14x14
        self.inv_residual5 = nn.ModuleList([InvResidualS2(in_channels=96, 
                                                          out_channels=160, 
                                                          t=6)])
        
        for _ in range(2):
            self.inv_residual5.append(InvResidualS1(in_channels=160, 
                                                    out_channels=160, 
                                                    t=6))
        
        # Input shape: 160x7x7
        self.inv_residual6 = InvResidualS1(in_channels=160, 
                                           out_channels=320, 
                                           t=6)
        
        # Input shape: 320x7x7
        self.last_conv = Conv(first=False)
        
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))        #(1)
        self.dropout = nn.Dropout(p=0.2)                              #(2)
        self.fc = nn.Linear(in_features=int(1280*WIDTH_MULTIPLIER),   #(3)
                            out_features=1000)

    def forward(self, x):
        x = self.first_conv(x)
        #print(f"after first_conv\t: {x.shape}")
        
        x = self.inv_residual0(x)
        #print(f"after inv_residual0\t: {x.shape}")
            
        for i, layer in enumerate(self.inv_residual1):
            x = layer(x)
            #print(f"after inv_residual1 #{i}\t: {x.shape}")
            
        for i, layer in enumerate(self.inv_residual2):
            x = layer(x)
            #print(f"after inv_residual2 #{i}\t: {x.shape}")
            
        for i, layer in enumerate(self.inv_residual3):
            x = layer(x)
            #print(f"after inv_residual3 #{i}\t: {x.shape}")
            
        for i, layer in enumerate(self.inv_residual4):
            x = layer(x)
            #print(f"after inv_residual4 #{i}\t: {x.shape}")
            
        for i, layer in enumerate(self.inv_residual5):
            x = layer(x)
            #print(f"after inv_residual5 #{i}\t: {x.shape}")
        
        x = self.inv_residual6(x)
        #print(f"after inv_residual6\t: {x.shape}")
        
        x = self.last_conv(x)
        #print(f"after last_conv\t\t: {x.shape}")
        
        x = self.avgpool(x)
        #print(f"after avgpool\t\t: {x.shape}")
        
        x = torch.flatten(x, start_dim=1)
        #print(f"after flatten\t\t: {x.shape}")
        
        x = self.dropout(x)
        #print(f"after dropout\t\t: {x.shape}")
        
        x = self.fc(x)
        #print(f"after fc\t\t: {x.shape}")
                
        return x

In [13]:
# Codeblock 11
mobilenetv2 = MobileNetV2()
x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

out = mobilenetv2(x)

after first_conv	: torch.Size([1, 32, 112, 112])
after inv_residual0	: torch.Size([1, 16, 112, 112])
after inv_residual1 #0	: torch.Size([1, 24, 56, 56])
after inv_residual1 #1	: torch.Size([1, 24, 56, 56])
after inv_residual2 #0	: torch.Size([1, 32, 28, 28])
after inv_residual2 #1	: torch.Size([1, 32, 28, 28])
after inv_residual2 #2	: torch.Size([1, 32, 28, 28])
after inv_residual3 #0	: torch.Size([1, 64, 14, 14])
after inv_residual3 #1	: torch.Size([1, 64, 14, 14])
after inv_residual3 #2	: torch.Size([1, 64, 14, 14])
after inv_residual3 #3	: torch.Size([1, 64, 14, 14])
after inv_residual4 #0	: torch.Size([1, 96, 14, 14])
after inv_residual4 #1	: torch.Size([1, 96, 14, 14])
after inv_residual4 #2	: torch.Size([1, 96, 14, 14])
after inv_residual5 #0	: torch.Size([1, 160, 7, 7])
after inv_residual5 #1	: torch.Size([1, 160, 7, 7])
after inv_residual5 #2	: torch.Size([1, 160, 7, 7])
after inv_residual6	: torch.Size([1, 320, 7, 7])
after last_conv		: torch.Size([1, 1280, 7, 7])
after avgpo

In [15]:
# Codeblock 12
mobilenetv2 = MobileNetV2()
summary(mobilenetv2, input_size=(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

Layer (type:depth-idx)                   Output Shape              Param #
MobileNetV2                              [1, 1000]                 --
├─Conv: 1-1                              [1, 32, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 32, 112, 112]         864
│    └─BatchNorm2d: 2-2                  [1, 32, 112, 112]         64
│    └─ReLU6: 2-3                        [1, 32, 112, 112]         --
├─InvResidualS1: 1-2                     [1, 16, 112, 112]         --
│    └─Conv2d: 2-4                       [1, 32, 112, 112]         1,024
│    └─BatchNorm2d: 2-5                  [1, 32, 112, 112]         64
│    └─ReLU6: 2-6                        [1, 32, 112, 112]         --
│    └─Conv2d: 2-7                       [1, 32, 112, 112]         288
│    └─BatchNorm2d: 2-8                  [1, 32, 112, 112]         64
│    └─ReLU6: 2-9                        [1, 32, 112, 112]         --
│    └─Conv2d: 2-10                      [1, 16, 112, 112]         512
│    └─Ba