In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
# Turn on the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
batch_size = 16
EPOCHS = 50
drop_out_rate = 0.3
NUM_CLASS = 10
schedule_step = 20
learning_rate_initial = 0.001
gamma_val = 0.1

In [None]:
# data augmentation
train_transform = transforms.Compose([transforms.RandomRotation(5),
                                      transforms.RandomHorizontalFlip(0.3),
                                      transforms.RandomVerticalFlip(0.05),
                                      transforms.Resize(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                           std = [0.229, 0.224, 0.225])
                                     ])

test_transform = transforms.Compose([transforms.Resize(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                          std = [0.229, 0.224, 0.225])
                                     ])

# load data
train_set = torchvision.datasets.CIFAR10(root = './data', train = True,
                                         download = True, transform = train_transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, 
                                           shuffle = True, num_workers = 4)

test_set = torchvision.datasets.CIFAR10(root = './data', train = False, 
                                        download = True, transform = test_transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = batch_size, 
                                          shuffle = False, num_workers = 4)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# Build up MobileNet 
class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, stride = 2, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(num_features = 32)
        self.relu1 = nn.ReLU(inplace = True)

        self.conv_dw1 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1, padding = 1, groups = 32, bias = False)
        self.bn_dw1 = nn.BatchNorm2d(num_features = 32)
        self.relu_dw1 = nn.ReLU(inplace = True)

        self.conv_pw1 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw1 = nn.BatchNorm2d(num_features = 64)
        self.relu_pw1 = nn.ReLU(inplace = True)

        self.conv_dw2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1, groups = 64, bias = False)
        self.bn_dw2 = nn.BatchNorm2d(num_features = 64)
        self.relu_dw2 = nn.ReLU(inplace = True)

        self.conv_pw2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw2 = nn.BatchNorm2d(num_features = 128)
        self.relu_pw2 = nn.ReLU(inplace = True)

        self.conv_dw3 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, groups = 128, bias = False)
        self.bn_dw3 = nn.BatchNorm2d(num_features = 128)
        self.relu_dw3 = nn.ReLU(inplace = True)

        self.conv_pw3 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw3 = nn.BatchNorm2d(num_features = 128)
        self.relu_pw3 = nn.ReLU(inplace = True)

        self.conv_dw4 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2, padding = 1, groups = 128, bias = False)
        self.bn_dw4 = nn.BatchNorm2d(num_features = 128)
        self.relu_dw4 = nn.ReLU(inplace = True)

        self.conv_pw4 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw4 = nn.BatchNorm2d(num_features = 256)
        self.relu_pw4 = nn.ReLU(inplace = True)

        self.conv_dw5 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1, groups = 256, bias = False)
        self.bn_dw5 = nn.BatchNorm2d(num_features = 256)
        self.relu_dw5 = nn.ReLU(inplace = True)

        self.conv_pw5 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw5 = nn.BatchNorm2d(num_features = 256)
        self.relu_pw5 = nn.ReLU(inplace = True)

        self.conv_dw6 = nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 2, padding = 1, groups = 256, bias = False)
        self.bn_dw6 = nn.BatchNorm2d(num_features = 256)
        self.relu_dw6 = nn.ReLU(inplace = True)

        self.conv_pw6 = nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw6 = nn.BatchNorm2d(num_features = 512)
        self.relu_pw6 = nn.ReLU(inplace = True)

        # 5x
        # 7
        self.conv_dw7 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, groups = 512, bias = False)
        self.bn_dw7 = nn.BatchNorm2d(num_features = 512)
        self.relu_dw7 = nn.ReLU(inplace = True)

        self.conv_pw7 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw7 = nn.BatchNorm2d(num_features = 512)
        self.relu_pw7 = nn.ReLU(inplace = True)

        # 8
        self.conv_dw8 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, groups = 512, bias = False)
        self.bn_dw8 = nn.BatchNorm2d(num_features = 512)
        self.relu_dw8 = nn.ReLU(inplace = True)

        self.conv_pw8 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw8 = nn.BatchNorm2d(num_features = 512)
        self.relu_pw8 = nn.ReLU(inplace = True)

        # 9
        self.conv_dw9 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, groups = 512, bias = False)
        self.bn_dw9 = nn.BatchNorm2d(num_features = 512)
        self.relu_dw9 = nn.ReLU(inplace = True)

        self.conv_pw9 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw9 = nn.BatchNorm2d(num_features = 512)
        self.relu_pw9 = nn.ReLU(inplace = True)

        # 10
        self.conv_dw10 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, groups = 512, bias = False)
        self.bn_dw10 = nn.BatchNorm2d(num_features = 512)
        self.relu_dw10 = nn.ReLU(inplace = True)

        self.conv_pw10 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw10 = nn.BatchNorm2d(num_features = 512)
        self.relu_pw10 = nn.ReLU(inplace = True)

        # 11
        self.conv_dw11 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, groups = 512, bias = False)
        self.bn_dw11 = nn.BatchNorm2d(num_features = 512)
        self.relu_dw11 = nn.ReLU(inplace = True)

        self.conv_pw11 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw11 = nn.BatchNorm2d(num_features = 512)
        self.relu_pw11 = nn.ReLU(inplace = True)

        self.conv_dw12 = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 2, padding = 1, groups = 512, bias = False)
        self.bn_dw12 = nn.BatchNorm2d(num_features = 512)
        self.relu_dw12 = nn.ReLU(inplace = True)

        self.conv_pw12 = nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw12 = nn.BatchNorm2d(num_features = 1024)
        self.relu_pw12 = nn.ReLU(inplace = True)

        # paper is wrong here
        self.conv_dw13 = nn.Conv2d(in_channels = 1024, out_channels = 1024, kernel_size = 3, stride = 1, padding = 1, groups = 1024, bias = False)
        self.bn_dw13 = nn.BatchNorm2d(num_features = 1024)
        self.relu_dw13 = nn.ReLU(inplace = True)

        self.conv_pw13 = nn.Conv2d(in_channels = 1024, out_channels = 1024, kernel_size = 1, stride = 1, padding = 0, bias = False)
        self.bn_pw13 = nn.BatchNorm2d(num_features = 1024)
        self.relu_pw13 = nn.ReLU(inplace = True)

        self.avg_pool = nn.AvgPool2d(kernel_size = 7, padding = 0)

        self.linear = nn.Linear(1024, NUM_CLASS)
            
    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu_pw1(self.bn_pw1(self.conv_pw1(self.relu_dw1(self.bn_dw1(self.conv_dw1(x))))))
        x = self.relu_pw2(self.bn_pw2(self.conv_pw2(self.relu_dw2(self.bn_dw2(self.conv_dw2(x))))))
        x = self.relu_pw3(self.bn_pw3(self.conv_pw3(self.relu_dw3(self.bn_dw3(self.conv_dw3(x))))))
        x = self.relu_pw4(self.bn_pw4(self.conv_pw4(self.relu_dw4(self.bn_dw4(self.conv_dw4(x))))))
        x = self.relu_pw5(self.bn_pw5(self.conv_pw5(self.relu_dw5(self.bn_dw5(self.conv_dw5(x))))))
        x = self.relu_pw6(self.bn_pw6(self.conv_pw6(self.relu_dw6(self.bn_dw6(self.conv_dw6(x))))))
        x = self.relu_pw7(self.bn_pw7(self.conv_pw7(self.relu_dw7(self.bn_dw7(self.conv_dw7(x))))))
        x = self.relu_pw8(self.bn_pw8(self.conv_pw8(self.relu_dw8(self.bn_dw8(self.conv_dw8(x))))))
        x = self.relu_pw9(self.bn_pw9(self.conv_pw9(self.relu_dw9(self.bn_dw9(self.conv_dw9(x))))))
        x = self.relu_pw10(self.bn_pw10(self.conv_pw10(self.relu_dw10(self.bn_dw10(self.conv_dw10(x))))))
        x = self.relu_pw11(self.bn_pw11(self.conv_pw11(self.relu_dw11(self.bn_dw11(self.conv_dw11(x))))))
        x = self.relu_pw12(self.bn_pw12(self.conv_pw12(self.relu_dw12(self.bn_dw12(self.conv_dw12(x))))))
        x = self.relu_pw13(self.bn_pw13(self.conv_pw13(self.relu_dw13(self.bn_dw13(self.conv_dw13(x))))))

        x = self.avg_pool(x)

        x = x.view(x.shape[0], -1)
        
        x = self.linear(x)
        return x
        
        

In [None]:
# Create the neural network and send to GPU
mobilenet = MobileNet()
mobilenet.cuda()

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mobilenet.parameters(), lr = learning_rate_initial)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = schedule_step, gamma = gamma_val)

In [None]:
train_loss_list = []
train_acc_list = []
test_acc_list = []

# Train the data
for epoch in range(EPOCHS):
    print("Epoch: ", epoch + 1)
    mobilenet.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        inputs, labels = Variable(inputs), Variable(labels)
    
        optimizer.zero_grad()
    
        outputs = mobilenet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item()
    
    running_loss /= len(train_loader)
    train_loss_list.append(running_loss)
    print('[%d] loss: %.3f' % (epoch + 1, running_loss))
    
    mobilenet.eval()
    train_total = 0
    train_correct = 0
    
    total = 0
    correct = 0
    with torch.no_grad():
        for data in train_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
    
            outputs = mobilenet(Variable(images))
            i, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
    
            outputs = mobilenet(Variable(images))
            i, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    train_acc_list.append(100 * train_correct / train_total)  
    test_acc_list.append(100 * correct / total)
    print('Training Accuracy of current epoch: %.3f %%' % (100 * train_correct / train_total))
    print('Testing Accuracy of current epoch: %.3f %%' % (100 * correct / total))
    
    scheduler.step()

In [None]:
plt.plot(train_loss_list)
plt.title('Training Loss Plot of Each Epoch')
plt.xlabel('Epoch number')
plt.ylabel('Training Loss')
plt.show()

plt.plot(train_acc_list)
plt.title('Training Accuracy Plot of Each Epoch')
plt.xlabel('Epoch number')
plt.ylabel('Training Accuracy')
plt.show()

plt.plot(test_acc_list)
plt.title('Test Accuracy Plot of Each Epoch')
plt.xlabel('Epoch number')
plt.ylabel('Test Accuracy')
plt.show()