# <center>MobileNet - Pytorch

# Step 1: Prepare data

In [1]:
# MobileNet-Pytorch
import argparse 
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.metrics import accuracy_score
from mobilenets import mobilenet

use_cuda = torch.cuda.is_available()
use_cudause_cud  = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

In [2]:
# Train, Validate, Test. Heavily inspired by Kevinzakka https://github.com/kevinzakka/DenseNet/blob/master/data_loader.py

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

valid_size=0.1

# define transforms
valid_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
])

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])


# load the dataset
train_dataset = datasets.CIFAR10(root="data", train=True, 
            download=True, transform=train_transform)

valid_dataset = datasets.CIFAR10(root="data", train=True, 
            download=True, transform=valid_transform)

num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train)) #5w张图片的10%用来当做验证集


np.random.seed(42)
np.random.shuffle(indices) # 随机乱序[0,1,...,49999]

train_idx, valid_idx = indices[split:], indices[:split]


train_sampler = SubsetRandomSampler(train_idx) # 这个很有意思
valid_sampler = SubsetRandomSampler(valid_idx)


# ------------------------- 使用不同的批次大小 ------------------------------------
show_step=10  # 批次大，show_step就小点
max_epoch=60  # 训练最大epoch数目
train_loader = torch.utils.data.DataLoader(train_dataset, 
                batch_size=256, sampler=train_sampler)

valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                batch_size=256, sampler=valid_sampler)


test_transform = transforms.Compose([
    transforms.ToTensor(), normalize
])

test_dataset = datasets.CIFAR10(root="data", 
                                train=False, 
                                download=True,transform=test_transform)

test_loader = torch.utils.data.DataLoader(test_dataset, 
                                          batch_size=64, 
                                          shuffle=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


# Step 2: Model Config

In [3]:
# 
model = mobilenet(num_classes=10, large_img=False)
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
criterion = nn.CrossEntropyLoss()

In [4]:
# Implement validation
def train(epoch):
    model.train()
    writer = SummaryWriter()
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        correct = 0
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).sum()
        
        loss = criterion(output, target)
        loss.backward()
        accuracy = 100. * (correct.cpu().numpy()/ len(output))
        optimizer.step()
        if batch_idx % show_step == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), accuracy))
            writer.add_scalar('Loss/Loss', loss.item(), epoch)
            writer.add_scalar('Accuracy/Accuracy', accuracy, epoch)
    scheduler.step()

In [5]:
def validate(epoch):
    model.eval()
    writer = SummaryWriter()
    valid_loss = 0
    correct = 0
    for data, target in valid_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        valid_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).sum()

    valid_loss /= len(valid_idx)
    accuracy = 100. * correct.cpu().numpy() / len(valid_idx)
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        valid_loss, correct, len(valid_idx),
        100. * correct / len(valid_idx)))
    writer.add_scalar('Loss/Validation_Loss', valid_loss, epoch)
    writer.add_scalar('Accuracy/Validation_Accuracy', accuracy, epoch)
    return valid_loss, accuracy

In [9]:
# Fix best model

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct.cpu().numpy() / len(test_loader.dataset)))

In [7]:
def save_best(loss, accuracy, best_loss, best_acc):
    if best_loss == None:
        best_loss = loss
        best_acc = accuracy
        file = 'saved_models/best_save_model.p'
        torch.save(model.state_dict(), file)
        
    elif loss < best_loss and accuracy > best_acc:
        best_loss = loss
        best_acc = accuracy
        file = 'saved_models/best_save_model.p'
        torch.save(model.state_dict(), file)
    return best_loss, best_acc

In [8]:
# Fantastic logger for tensorboard and pytorch, 
# run tensorboard by opening a new terminal and run "tensorboard --logdir runs"
# open tensorboard at http://localhost:6006/
from tensorboardX import SummaryWriter
best_loss = None
best_acc = None

import time 
SINCE=time.time()

for epoch in range(max_epoch):
    train(epoch)
    loss, accuracy = validate(epoch)
    best_loss, best_acc = save_best(loss, accuracy, best_loss, best_acc)
    
    NOW=time.time() 
    DURINGS=NOW-SINCE
    SINCE=NOW
    print("the time of this epoch:[{} s]".format(DURINGS))
    
writer = SummaryWriter()
writer.export_scalars_to_json("./all_scalars.json")

writer.close()


Validation set: Average loss: 1.9797, Accuracy: 1591/5000 (31.00%)

the time of this epoch:[53.92196607589722 s]

Validation set: Average loss: 1.6753, Accuracy: 1960/5000 (39.00%)

the time of this epoch:[53.870646476745605 s]

Validation set: Average loss: 1.4611, Accuracy: 2416/5000 (48.00%)

the time of this epoch:[53.453192472457886 s]

Validation set: Average loss: 2.0505, Accuracy: 1965/5000 (39.00%)

the time of this epoch:[53.80319929122925 s]



Validation set: Average loss: 1.4734, Accuracy: 2666/5000 (53.00%)

the time of this epoch:[53.764643907547 s]

Validation set: Average loss: 1.0665, Accuracy: 3064/5000 (61.00%)

the time of this epoch:[53.74446082115173 s]

Validation set: Average loss: 1.0236, Accuracy: 3210/5000 (64.00%)

the time of this epoch:[53.729058504104614 s]

Validation set: Average loss: 0.9515, Accuracy: 3318/5000 (66.00%)

the time of this epoch:[53.55753993988037 s]



Validation set: Average loss: 0.9672, Accuracy: 3305/5000 (66.00%)

the time of this epoch:[53.70943212509155 s]

Validation set: Average loss: 0.9706, Accuracy: 3364/5000 (67.00%)

the time of this epoch:[53.631659269332886 s]

Validation set: Average loss: 1.4479, Accuracy: 2777/5000 (55.00%)

the time of this epoch:[53.724318981170654 s]

Validation set: Average loss: 0.6227, Accuracy: 3902/5000 (78.00%)

the time of this epoch:[53.439478158950806 s]



Validation set: Average loss: 0.6651, Accuracy: 3831/5000 (76.00%)

the time of this epoch:[53.09663939476013 s]

Validation set: Average loss: 0.7508, Accuracy: 3691/5000 (73.00%)

the time of this epoch:[53.62254762649536 s]

Validation set: Average loss: 0.6433, Accuracy: 3829/5000 (76.00%)

the time of this epoch:[53.71054697036743 s]

Validation set: Average loss: 0.5579, Accuracy: 4017/5000 (80.00%)

the time of this epoch:[53.21430826187134 s]



Validation set: Average loss: 0.6044, Accuracy: 3951/5000 (79.00%)

the time of this epoch:[53.01625442504883 s]

Validation set: Average loss: 0.5631, Accuracy: 4005/5000 (80.00%)

the time of this epoch:[53.586265563964844 s]

Validation set: Average loss: 0.5718, Accuracy: 4010/5000 (80.00%)

the time of this epoch:[53.76464509963989 s]

Validation set: Average loss: 0.5535, Accuracy: 4033/5000 (80.00%)

the time of this epoch:[53.64075517654419 s]



Validation set: Average loss: 0.5671, Accuracy: 4041/5000 (80.00%)

the time of this epoch:[53.67682409286499 s]

Validation set: Average loss: 0.5036, Accuracy: 4120/5000 (82.00%)

the time of this epoch:[53.65208601951599 s]

Validation set: Average loss: 0.4959, Accuracy: 4170/5000 (83.00%)

the time of this epoch:[53.51461577415466 s]

Validation set: Average loss: 0.4777, Accuracy: 4194/5000 (83.00%)

the time of this epoch:[53.14698576927185 s]



Validation set: Average loss: 0.4990, Accuracy: 4151/5000 (83.00%)

the time of this epoch:[53.20006728172302 s]

Validation set: Average loss: 0.4949, Accuracy: 4187/5000 (83.00%)

the time of this epoch:[53.538750410079956 s]

Validation set: Average loss: 0.4889, Accuracy: 4195/5000 (83.00%)

the time of this epoch:[53.4344756603241 s]

Validation set: Average loss: 0.4824, Accuracy: 4205/5000 (84.00%)

the time of this epoch:[53.074204444885254 s]



Validation set: Average loss: 0.5768, Accuracy: 4082/5000 (81.00%)

the time of this epoch:[53.95497751235962 s]

Validation set: Average loss: 0.6305, Accuracy: 3988/5000 (79.00%)

the time of this epoch:[54.08571410179138 s]

Validation set: Average loss: 0.5032, Accuracy: 4141/5000 (82.00%)

the time of this epoch:[54.224751472473145 s]

Validation set: Average loss: 0.4658, Accuracy: 4223/5000 (84.00%)

the time of this epoch:[53.55865025520325 s]



Validation set: Average loss: 0.4685, Accuracy: 4239/5000 (84.00%)

the time of this epoch:[52.94324469566345 s]

Validation set: Average loss: 0.5048, Accuracy: 4178/5000 (83.00%)

the time of this epoch:[53.531903982162476 s]

Validation set: Average loss: 0.4669, Accuracy: 4245/5000 (84.00%)

the time of this epoch:[53.894577741622925 s]

Validation set: Average loss: 0.4820, Accuracy: 4239/5000 (84.00%)

the time of this epoch:[53.58134913444519 s]



Validation set: Average loss: 0.5005, Accuracy: 4207/5000 (84.00%)

the time of this epoch:[53.914472579956055 s]

Validation set: Average loss: 0.4812, Accuracy: 4240/5000 (84.00%)

the time of this epoch:[53.67718148231506 s]

Validation set: Average loss: 0.5025, Accuracy: 4208/5000 (84.00%)

the time of this epoch:[54.43834853172302 s]

Validation set: Average loss: 0.4993, Accuracy: 4215/5000 (84.00%)

the time of this epoch:[54.21666932106018 s]



Validation set: Average loss: 0.4936, Accuracy: 4235/5000 (84.00%)

the time of this epoch:[54.031473875045776 s]

Validation set: Average loss: 0.4689, Accuracy: 4268/5000 (85.00%)

the time of this epoch:[54.023372173309326 s]

Validation set: Average loss: 0.4792, Accuracy: 4251/5000 (85.00%)

the time of this epoch:[53.76105046272278 s]

Validation set: Average loss: 0.4851, Accuracy: 4250/5000 (85.00%)

the time of this epoch:[53.60430550575256 s]



Validation set: Average loss: 0.4724, Accuracy: 4271/5000 (85.00%)

the time of this epoch:[53.59726905822754 s]

Validation set: Average loss: 0.4931, Accuracy: 4248/5000 (84.00%)

the time of this epoch:[53.521843671798706 s]

Validation set: Average loss: 0.4832, Accuracy: 4244/5000 (84.00%)

the time of this epoch:[53.54103755950928 s]

Validation set: Average loss: 0.5078, Accuracy: 4237/5000 (84.00%)

the time of this epoch:[53.45686197280884 s]



Validation set: Average loss: 0.4979, Accuracy: 4249/5000 (84.00%)

the time of this epoch:[53.525235176086426 s]

Validation set: Average loss: 0.4972, Accuracy: 4249/5000 (84.00%)

the time of this epoch:[53.49142551422119 s]

Validation set: Average loss: 0.4845, Accuracy: 4259/5000 (85.00%)

the time of this epoch:[53.51525640487671 s]

Validation set: Average loss: 0.5137, Accuracy: 4242/5000 (84.00%)

the time of this epoch:[52.88517189025879 s]



Validation set: Average loss: 0.4951, Accuracy: 4261/5000 (85.00%)

the time of this epoch:[53.4054811000824 s]

Validation set: Average loss: 0.4983, Accuracy: 4257/5000 (85.00%)

the time of this epoch:[53.4133243560791 s]

Validation set: Average loss: 0.4943, Accuracy: 4261/5000 (85.00%)

the time of this epoch:[53.1641640663147 s]

Validation set: Average loss: 0.4986, Accuracy: 4247/5000 (84.00%)

the time of this epoch:[53.70726704597473 s]



Validation set: Average loss: 0.5064, Accuracy: 4245/5000 (84.00%)

the time of this epoch:[53.45646953582764 s]

Validation set: Average loss: 0.4952, Accuracy: 4267/5000 (85.00%)

the time of this epoch:[54.44821524620056 s]

Validation set: Average loss: 0.4960, Accuracy: 4256/5000 (85.00%)

the time of this epoch:[54.38080549240112 s]

Validation set: Average loss: 0.4997, Accuracy: 4266/5000 (85.00%)

the time of this epoch:[54.11194586753845 s]


# Step 3: Test

In [10]:
test(epoch)


Test set: Average loss: 0.5272, Accuracy: 8457/10000 (84.57%)

