In [3]:
from __future__ import print_function

import os
import math
import argparse

import torch
import torch.nn as nn
import torch.cuda as cuda
import torch.optim as optim
import torch.utils as utils
import torch.backends.cudnn as cudnn

from model.vgg import VGG

from torch.autograd import Variable
from torchvision import models, datasets, transforms

In [4]:
# Args
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--learning-rate', type=float, default=0.1, 
                    help='learning rate (default: 0.1)')
parser.add_argument('--train-batch-size', type=int, default=64,
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=100,
                    help='input batch size for testing (default: 100)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--num-workers', type=int, default=4,
                    help='number of workers (default: 4)')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10,
                    help='how many batches to wait before logging training status (default: 10)')
parser.add_argument('-r', '--resume', action='store_true', default=False,
                    help='resume from checkpoint')
args = parser.parse_args([])

In [5]:
# Init variables
print('==> Init variables..')
use_cuda = cuda.is_available()
best_accuracy = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

torch.manual_seed(args.seed)
if use_cuda:
    cuda.manual_seed(args.seed)

==> Init variables..


In [6]:
def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    mean = torch.zeros(3)
    std = torch.zeros(3)
    dataloader = utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args.num_workers)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

In [7]:
# Data
print('==> Download data..')
dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
data_mean, data_std = get_mean_and_std(dataset)

print(data_mean)
print(data_std)

transform_train = transforms.Compose([
    transforms.Scale(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(data_mean, data_std),
])

transform_test = transforms.Compose([
    transforms.Scale(224),
    transforms.ToTensor(),
    transforms.Normalize(data_mean, data_std),
])

print('==> Init dataloader..')
trainset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
trainloader = utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers)

testset = datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
testloader = utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers)

==> Download data..
Files already downloaded and verified
==> Computing mean and std..

 0.4914
 0.4822
 0.4465
[torch.FloatTensor of size 3]


 0.2023
 0.1994
 0.2010
[torch.FloatTensor of size 3]

==> Init dataloader..
Files already downloaded and verified
Files already downloaded and verified


In [8]:
# Model
if args.resume:
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7')
    net = checkpoint['net']
    best_accuracy = checkpoint['accuracy']
    start_epoch = checkpoint['epoch']
else:
    print('==> Building model..')
    net = VGG('vgg16', num_classes=10)
    print(net)

if use_cuda:
    net.cuda()
    net = nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

==> Building model..
VGG (
  (features): Sequential (
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU (inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU (inplace)
    (6): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (9): ReLU (inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (12): ReLU (inplace)
    (13): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (16): ReLU (i

In [9]:
# Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=5e-4)

In [None]:
def train(epoch):
    '''Training.'''
    print('Training Epoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.data
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

In [None]:
def test(epoch):
    '''Testing.'''
    global best_accuracy
    print('Testing Epoch: %d' % epoch)
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    # Save checkpoint.
    accuracy = 100.*correct/total
    if accuracy > best_accuracy:
        print('Saving..')
        state = {
            'net': net.module if use_cuda else net,
            'accuracy': accuracy,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.t7')
        best_accuracy = accuracy

In [None]:
for epoch in range(start_epoch, start_epoch + args.epochs):
    train(epoch)
    test(epoch)