In [None]:
################################################################################
#
# FILE
#
#   CIFAR_MobileNetV2_CNN.ipynb
#
# DESCRIPTION
#
#   Creates a convolutional neural network model designed for the CIFAR-10 data
#   set based on MobileNetV2 using PyTorch.
#
#   Two python classes are defined in this file:
#      1. MobileNetV2Bottleneck: A building block for the model based on the
#         MobileNetV2 bottleneck block.
#      2. Model: Creates the convolutional neural network model using the
#         building block class MobileNetV2Bottleneck.
#
#   There are five distinct layers in the model: the encoder tail, encoder
#   level 0, encoder level 1, encoder level 2, and the decoder. The layer
#   details and number of MACs per block are detailed below:
#      T.  Conv 1x1 stride=2  | 21600     MACs
#      ---------------------------------------
#      0.1 Bottleneck S=1 t=1 | 3225600   MACs
#      0.2 Bottleneck S=2 t=6 | 368148480 MACs
#      0.3 Bottleneck S=2 t=6 | 92798976  MACs
#      ---------------------------------------
#      1.1 Bottleneck S=1 t=6 | 89653248  MACs
#      1.2 Bottleneck S=1 t=6 | 92798976  MACs
#      1.3 Bottleneck S=1 t=6 | 23199744  MACs
#      ---------------------------------------
#      2.1 Bottleneck S=2 t=6 | 22413312  MACs
#      2.2 Bottleneck S=2 t=6 | 88080384  MACs
#      2.3 Bottleneck S=2 t=6 | 88080384  MACs
#      2.4 Bottleneck S=1 t=6 | 88080384  MACs
#      ---------------------------------------
#      D.1 Average Pool       | 512       MACs
#      D.2 Conv 1x1 stride=1  | 262144    MACs
#      D.3 Flatten            | 0         MACs
#      D.4 Linear             | 5120      MACs
#
#   After being trained for 60 epochs with Adam as the optimizer and a learning
#   rate schedule of linear warmup followed by cosine decay, the final accuracy
#   achieved on CIFAR-10 is 86.10%. The results are shown below:
#   (Note: training code is not provided in this file).
#      Epoch  0 lr = 0.000010 avg loss = 0.063959 accuracy = 32.32
#      Epoch  1 lr = 0.000208 avg loss = 0.049835 accuracy = 52.11
#      Epoch  2 lr = 0.000406 avg loss = 0.042698 accuracy = 57.54
#      Epoch  3 lr = 0.000604 avg loss = 0.038423 accuracy = 60.94
#      Epoch  4 lr = 0.000802 avg loss = 0.035498 accuracy = 64.32
#      Epoch  5 lr = 0.001000 avg loss = 0.033053 accuracy = 67.69
#      Epoch  6 lr = 0.001000 avg loss = 0.030139 accuracy = 70.57
#      Epoch  7 lr = 0.000998 avg loss = 0.027921 accuracy = 72.55
#      Epoch  8 lr = 0.000996 avg loss = 0.026018 accuracy = 73.01
#      Epoch  9 lr = 0.000993 avg loss = 0.024296 accuracy = 74.49
#      Epoch 10 lr = 0.000990 avg loss = 0.023007 accuracy = 75.05
#      Epoch 11 lr = 0.000985 avg loss = 0.021833 accuracy = 77.13
#      Epoch 12 lr = 0.000980 avg loss = 0.020620 accuracy = 77.21
#      Epoch 13 lr = 0.000973 avg loss = 0.019535 accuracy = 78.58
#      Epoch 14 lr = 0.000966 avg loss = 0.018739 accuracy = 78.76
#      Epoch 15 lr = 0.000958 avg loss = 0.017912 accuracy = 79.28
#      Epoch 16 lr = 0.000950 avg loss = 0.017341 accuracy = 80.45
#      Epoch 17 lr = 0.000940 avg loss = 0.016679 accuracy = 81.21
#      Epoch 18 lr = 0.000930 avg loss = 0.015876 accuracy = 80.71
#      Epoch 19 lr = 0.000919 avg loss = 0.015329 accuracy = 81.05
#      Epoch 20 lr = 0.000907 avg loss = 0.014723 accuracy = 80.93
#      Epoch 21 lr = 0.000895 avg loss = 0.014145 accuracy = 80.89
#      Epoch 22 lr = 0.000881 avg loss = 0.013655 accuracy = 82.03
#      Epoch 23 lr = 0.000867 avg loss = 0.013055 accuracy = 81.71
#      Epoch 24 lr = 0.000853 avg loss = 0.012706 accuracy = 81.84
#      Epoch 25 lr = 0.000837 avg loss = 0.012210 accuracy = 82.34
#      Epoch 26 lr = 0.000821 avg loss = 0.011783 accuracy = 82.92
#      Epoch 27 lr = 0.000804 avg loss = 0.011280 accuracy = 82.51
#      Epoch 28 lr = 0.000787 avg loss = 0.010752 accuracy = 83.07
#      Epoch 29 lr = 0.000768 avg loss = 0.010456 accuracy = 82.33
#      Epoch 30 lr = 0.000750 avg loss = 0.009876 accuracy = 83.29
#      Epoch 31 lr = 0.000730 avg loss = 0.009717 accuracy = 83.57
#      Epoch 32 lr = 0.000710 avg loss = 0.009173 accuracy = 83.59
#      Epoch 33 lr = 0.000689 avg loss = 0.008730 accuracy = 83.83
#      Epoch 34 lr = 0.000668 avg loss = 0.008331 accuracy = 83.81
#      Epoch 35 lr = 0.000646 avg loss = 0.008141 accuracy = 83.51
#      Epoch 36 lr = 0.000624 avg loss = 0.007546 accuracy = 84.13
#      Epoch 37 lr = 0.000601 avg loss = 0.007341 accuracy = 83.79
#      Epoch 38 lr = 0.000578 avg loss = 0.006916 accuracy = 83.77
#      Epoch 39 lr = 0.000554 avg loss = 0.006608 accuracy = 83.32
#      Epoch 40 lr = 0.000530 avg loss = 0.006328 accuracy = 83.73
#      Epoch 41 lr = 0.000505 avg loss = 0.005971 accuracy = 84.50
#      Epoch 42 lr = 0.000480 avg loss = 0.005560 accuracy = 84.61
#      Epoch 43 lr = 0.000454 avg loss = 0.005245 accuracy = 84.25
#      Epoch 44 lr = 0.000428 avg loss = 0.004837 accuracy = 84.26
#      Epoch 45 lr = 0.000402 avg loss = 0.004562 accuracy = 84.81
#      Epoch 46 lr = 0.000376 avg loss = 0.004474 accuracy = 84.59
#      Epoch 47 lr = 0.000349 avg loss = 0.004081 accuracy = 84.68
#      Epoch 48 lr = 0.000321 avg loss = 0.003792 accuracy = 85.29
#      Epoch 49 lr = 0.000294 avg loss = 0.003449 accuracy = 85.08
#      Epoch 50 lr = 0.000266 avg loss = 0.003152 accuracy = 85.23
#      Epoch 51 lr = 0.000238 avg loss = 0.002954 accuracy = 84.97
#      Epoch 52 lr = 0.000210 avg loss = 0.002787 accuracy = 85.38
#      Epoch 53 lr = 0.000182 avg loss = 0.002564 accuracy = 85.23
#      Epoch 54 lr = 0.000153 avg loss = 0.002220 accuracy = 85.74
#      Epoch 55 lr = 0.000125 avg loss = 0.002080 accuracy = 85.73
#      Epoch 56 lr = 0.000096 avg loss = 0.001954 accuracy = 85.83
#      Epoch 57 lr = 0.000068 avg loss = 0.001744 accuracy = 85.95
#      Epoch 58 lr = 0.000039 avg loss = 0.001576 accuracy = 86.00
#      Epoch 59 lr = 0.000010 avg loss = 0.001587 accuracy = 86.10
#
################################################################################

In [None]:
################################################################################
#
# IMPORT
#
################################################################################

# PyTorch
import torch
import torch.nn    as nn

# version check
print(torch.__version__)

In [None]:
################################################################################
#
# PARAMETERS
#
################################################################################

# data (general)
DATA_NUM_CHANNELS = 3
DATA_NUM_CLASSES  = 10

# model
MODEL_LEVEL_0_BLOCKS            = 3
MODEL_LEVEL_1_BLOCKS            = 3
MODEL_LEVEL_2_BLOCKS            = 4
MODEL_TAIL_END_CHANNELS         = 32
MODEL_LEVEL_0_IDENTITY_CHANNELS = 128
MODEL_LEVEL_1_IDENTITY_CHANNELS = 256
MODEL_LEVEL_2_IDENTITY_CHANNELS = 512

In [None]:
################################################################################
#
# NETWORK BUILDING BLOCKS
#
################################################################################

# mobile net v2 bottleneck
class MobileNetV2Bottleneck(nn.Module):

    # initialization
    def __init__(self, C_in, C_out, S, t):

        # parent initialization
        super(MobileNetV2Bottleneck, self).__init__()

        # identity
        if ((C_in != C_out) or (S > 1)):
            self.conv0_present = True
            self.conv0         = nn.Conv2d(C_in, C_out, (1, 1), stride=(S, S), padding=(0, 0), dilation=(1, 1), groups=1, bias=False, padding_mode='zeros')
        else:
            self.conv0_present = False

        # set C_res
        C_res = C_in * t

        # residual
        self.bn1   = nn.BatchNorm2d(C_in, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu1 = nn.ReLU6()
        self.conv1 = nn.Conv2d(C_in, C_res, (1, 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, bias=False, padding_mode='zeros')
        self.bn2   = nn.BatchNorm2d(C_res, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu2 = nn.ReLU6()
        self.conv2 = nn.Conv2d(C_res, C_res, (3, 3), stride=(S, S), padding=(1, 1), dilation=(1, 1), groups=C_res, bias=False, padding_mode='zeros')
        self.bn3   = nn.BatchNorm2d(C_res, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.relu3 = nn.ReLU6()
        self.conv3 = nn.Conv2d(C_res, C_out, (1, 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, bias=False, padding_mode='zeros')

    # forward path
    def forward(self, x):

        # residual
        res = self.bn1(x)
        res = self.relu1(res)
        res = self.conv1(res)
        res = self.bn2(res)
        res = self.relu2(res)
        res = self.conv2(res)
        res = self.bn3(res)
        res = self.relu3(res)
        res = self.conv3(res)

        # identity
        if (self.conv0_present == True):
            x = self.conv0(x)

        # summation
        x = x + res

        # return
        return x

In [None]:
################################################################################
#
# NETWORK
#
################################################################################

# define
class Model(nn.Module):

    # initialization
    def __init__(self, data_num_channels, data_num_classes, model_level_0_blocks, model_level_1_blocks, model_level_2_blocks, model_tail_end_channels, model_level_0_identity_channels, model_level_1_identity_channels, model_level_2_identity_channels):

        # parent initialization
        super(Model, self).__init__()

        # encoder tail
        self.enc_tail = nn.ModuleList()
        self.enc_tail.append(nn.Conv2d(data_num_channels, model_tail_end_channels, (1, 1), stride=(2, 2), padding=(1, 1), dilation=(1, 1), groups=1, bias=False, padding_mode='zeros'))

        # encoder level 0
        self.enc_0 = nn.ModuleList()
        self.enc_0.append(MobileNetV2Bottleneck(model_tail_end_channels, model_level_0_identity_channels, 1, 1))
        for n in range(model_level_0_blocks - 1):
            self.enc_0.append(MobileNetV2Bottleneck(model_level_0_identity_channels, model_level_0_identity_channels, 2, 6))

        # encoder level 1
        self.enc_1 = nn.ModuleList()
        self.enc_1.append(MobileNetV2Bottleneck(model_level_0_identity_channels, model_level_1_identity_channels, 1, 6))
        for n in range(model_level_1_blocks - 1):
            self.enc_1.append(MobileNetV2Bottleneck(model_level_1_identity_channels, model_level_1_identity_channels, 1, 6))

        # encoder level 2
        self.enc_2 = nn.ModuleList()
        self.enc_2.append(MobileNetV2Bottleneck(model_level_1_identity_channels, model_level_2_identity_channels, 2, 6))
        for n in range(model_level_2_blocks - 2):
            self.enc_2.append(MobileNetV2Bottleneck(model_level_2_identity_channels, model_level_2_identity_channels, 2, 6))
        self.enc_2.append(MobileNetV2Bottleneck(model_level_2_identity_channels, model_level_2_identity_channels, 1, 6))

        # encoder level 2 complete the bn - relu6 pattern
        self.enc_2.append(nn.BatchNorm2d(model_level_2_identity_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
        self.enc_2.append(nn.ReLU6())

        # decoder
        self.dec = nn.ModuleList()
        self.dec.append(nn.AdaptiveAvgPool2d((1, 1)))
        self.dec.append(nn.Conv2d(model_level_2_identity_channels, model_level_2_identity_channels, (1, 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, bias=True, padding_mode='zeros'))
        self.dec.append(nn.Flatten())
        self.dec.append(nn.Linear(model_level_2_identity_channels, data_num_classes, bias=True))

    # forward path
    def forward(self, x):

        # encoder tail
        for layer in self.enc_tail:
            x = layer(x)

        # encoder level 0
        for layer in self.enc_0:
            x = layer(x)

        # encoder level 1
        for layer in self.enc_1:
            x = layer(x)

        # encoder level 2
        for layer in self.enc_2:
            x = layer(x)

        # decoder
        for layer in self.dec:
            x = layer(x)

        # return
        return x

# create
model = Model(DATA_NUM_CHANNELS, DATA_NUM_CLASSES, MODEL_LEVEL_0_BLOCKS, MODEL_LEVEL_1_BLOCKS, MODEL_LEVEL_2_BLOCKS, MODEL_TAIL_END_CHANNELS, MODEL_LEVEL_0_IDENTITY_CHANNELS, MODEL_LEVEL_1_IDENTITY_CHANNELS, MODEL_LEVEL_2_IDENTITY_CHANNELS)

# visualization
print(model)