# MobileNetV3 in PyTorch.

See the following papers for details:

[Searching for MobileNetV3](https://arxiv.org/abs/1905.02244)

[MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381)

In [1]:
# python: 3.8.10 64-bit
# os: ubuntu 20.04.3

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import init
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset

import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from sklearn.model_selection import ShuffleSplit

import numpy as np
import matplotlib.pyplot as plt

In [3]:
print('pytorch:', torch.__version__)

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('cuda device:', torch.cuda.get_device_properties('cuda:0'))
else:
    device = torch.device('cpu')
    print('cuda device unavailable, using cpu')

pytorch: 1.10.2
cuda device: _CudaDeviceProperties(name='NVIDIA GeForce RTX 3080 Ti', major=8, minor=6, total_memory=12050MB, multi_processor_count=80)


In [4]:
num_epochs = 15
train_batch_size = 256
test_batch_size = 8

### Network

In [5]:

class hswish(nn.Module):
    
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out


class hsigmoid(nn.Module):
    
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out


class SeModule(nn.Module):
    
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size),
            hsigmoid()
        )

    def forward(self, x):
        return x * self.se(x)


class Block(nn.Module):
    
    '''expand + depthwise + pointwise'''

    def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
        super(Block, self).__init__()
        self.stride = stride
        self.se = semodule

        self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.nolinear1 = nolinear
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.nolinear2 = nolinear
        self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_size)

        self.shortcut = nn.Sequential()
        
        if stride == 1 and in_size != out_size:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_size),
            )

    def forward(self, x):
        out = self.nolinear1(self.bn1(self.conv1(x)))
        out = self.nolinear2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.se != None:
            out = self.se(out)
        out = out + self.shortcut(x) if self.stride == 1 else out
        return out


class MobileNetV3_Large(nn.Module):
    
    def __init__(self, num_classes=256):
        super(MobileNetV3_Large, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()

        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
            Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
            Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
            Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
            Block(3, 40, 240, 80, hswish(), None, 2),
            Block(3, 80, 200, 80, hswish(), None, 1),
            Block(3, 80, 184, 80, hswish(), None, 1),
            Block(3, 80, 184, 80, hswish(), None, 1),
            Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
            Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
            Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
            Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
            Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
        )

        self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(960)
        self.hs2 = hswish()
        self.linear3 = nn.Linear(960, 1280)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = hswish()
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.hs3(self.bn3(self.linear3(out)))
        out = self.linear4(out)
        return out


class MobileNetV3_Small(nn.Module):
    
    def __init__(self, num_classes=256):
        super(MobileNetV3_Small, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()

        self.bneck = nn.Sequential(
            Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
            Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
            Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
            Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
            Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
            Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
            Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
        )

        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = hswish()
        self.linear3 = nn.Linear(576, 1280)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = hswish()
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.bneck(out)
        out = self.hs2(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.hs3(self.bn3(self.linear3(out)))
        out = self.linear4(out)
        return out


### Load dataset

In [6]:
datasetRoot = "./dataset"
transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = ImageFolder(datasetRoot, transform=transform)
classes = dataset.classes
print('labels:', dataset.classes)
print('label index mapping:', dataset.class_to_idx)

labels: ['0', '1', '2', '3', '4', 'B', 'C', 'E', 'M', 'P', 'R', 'S', 'T', '_']
label index mapping: {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, 'B': 5, 'C': 6, 'E': 7, 'M': 8, 'P': 9, 'R': 10, 'S': 11, 'T': 12, '_': 13}


### Train and test

In [7]:
train_ids, test_ids = next(ShuffleSplit().split(dataset))

train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

trainloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, num_workers=8, sampler=train_subsampler)
testloader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size, num_workers=8, sampler=test_subsampler)

net = MobileNetV3_Small(len(classes))
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print('Started Training')

for epoch in range(num_epochs):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        statistics = 10
        if i % statistics == statistics - 1:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / statistics))
            running_loss = 0.0

    correct = 0
    total = 0
    
    with torch.no_grad():
        net.eval()
        for i, data in enumerate(testloader, 0):
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        net.train()
        
    print('Accuracy of the network on the %d test images: %d %%' % (total, 100 * correct / total))
    
print('Finished Training')

Started Training
[1,    10] loss: 2.610
[1,    20] loss: 2.512
[1,    30] loss: 2.352
[1,    40] loss: 2.123
[1,    50] loss: 1.826
[1,    60] loss: 1.474
[1,    70] loss: 1.225
[1,    80] loss: 1.151
Accuracy of the network on the 2457 test images: 52 %
[2,    10] loss: 0.989
[2,    20] loss: 0.955
[2,    30] loss: 0.894
[2,    40] loss: 0.833
[2,    50] loss: 0.776
[2,    60] loss: 0.708
[2,    70] loss: 0.707
[2,    80] loss: 0.718
Accuracy of the network on the 2457 test images: 59 %
[3,    10] loss: 0.627
[3,    20] loss: 0.588
[3,    30] loss: 0.576
[3,    40] loss: 0.541
[3,    50] loss: 0.498
[3,    60] loss: 0.499
[3,    70] loss: 0.470
[3,    80] loss: 0.445
Accuracy of the network on the 2457 test images: 82 %
[4,    10] loss: 0.425
[4,    20] loss: 0.385
[4,    30] loss: 0.347
[4,    40] loss: 0.351
[4,    50] loss: 0.330
[4,    60] loss: 0.346
[4,    70] loss: 0.326
[4,    80] loss: 0.304
Accuracy of the network on the 2457 test images: 89 %
[5,    10] loss: 0.276
[5,    2

### Save model

In [8]:
PATH = './mobilenet.pth'
torch.save(net, PATH)

### Load model

In [9]:
PATH = './mobilenet.pth'
net = torch.load(PATH, map_location=device)
net.eval()

MobileNetV3_Small(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (hs1): hswish()
  (bneck): Sequential(
    (0): Block(
      (se): SeModule(
        (se): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Conv2d(16, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): ReLU(inplace=True)
          (4): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): hsigmoid()
        )
      )
      (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (nolinear1): ReLU(inplace=True)
      (conv2): Conv2d(16

### Result visualization

In [10]:
dataiter = iter(testloader)

for i in range(20):
    
    images, labels = dataiter.next()
    print('input_size:', images.size())
    outputs = net(images.to(device))
    _, predicted = torch.max(outputs.data, 1)
    
    # show image
    # images[0] = images[0] / 2 + 0.5     # unnormalize
    # npimg = images[0].numpy()
    # plt.imshow(np.transpose(npimg, (1, 2, 0)))
    # plt.show()
    
    print('predicted:', classes[predicted[0]], ',  label_id: ', labels[0], ',  label:', classes[labels[0]])


input_size: torch.Size([8, 3, 224, 224])
predicted: P ,  label_id:  tensor(9) ,  label: P
input_size: torch.Size([8, 3, 224, 224])
predicted: C ,  label_id:  tensor(6) ,  label: C
input_size: torch.Size([8, 3, 224, 224])
predicted: 4 ,  label_id:  tensor(4) ,  label: 4
input_size: torch.Size([8, 3, 224, 224])
predicted: M ,  label_id:  tensor(8) ,  label: M
input_size: torch.Size([8, 3, 224, 224])
predicted: C ,  label_id:  tensor(6) ,  label: C
input_size: torch.Size([8, 3, 224, 224])
predicted: 2 ,  label_id:  tensor(2) ,  label: 2
input_size: torch.Size([8, 3, 224, 224])
predicted: S ,  label_id:  tensor(11) ,  label: S
input_size: torch.Size([8, 3, 224, 224])
predicted: T ,  label_id:  tensor(12) ,  label: T
input_size: torch.Size([8, 3, 224, 224])
predicted: C ,  label_id:  tensor(6) ,  label: C
input_size: torch.Size([8, 3, 224, 224])
predicted: _ ,  label_id:  tensor(13) ,  label: _
input_size: torch.Size([8, 3, 224, 224])
predicted: 3 ,  label_id:  tensor(3) ,  label: 3
input_s