# 在如何使CIFAR-10测试集的分类准确率从40%提升到90%
本文将依次在两层全连接网络，三层卷积网络，自定义多层卷积网络，经典网络（ResNet）上训练和验证cifar-10，见证测试集准确率是如何从40%提升到90%的。通过本文可以对Pytorch下如何构建网络，训练网络和调优有较清晰的认识。<br>
注意，Pytorch model zoo中的resnet模型直接使用会报错，因此需要自己实现一个resnet，本文引用了CSDN上 [以梦为马_Sun](https://blog.csdn.net/sunqiande88/article/details/80100891)的实现。

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import copy
from resnet import ResNet18

## 准备数据集

In [40]:
# 从训练集的50000个样本中，取49000个作为训练集，剩余1000个作为验证集
NUM_TRAIN = 49000

# 数据预处理，减去cifar-10数据均值
transform_normal = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
])
# 数据增强
transform_aug = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 加载训练集
cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_normal)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

# 加载验证集
cifar10_val = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_normal)
loader_val = DataLoader(cifar10_val, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

# 加载测试集
cifar10_test = dset.CIFAR10('./dataset', train=False, download=True, transform=transform_normal)
loader_test = DataLoader(cifar10_test, batch_size=64)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## 指定运行环境

In [3]:
USE_GPU = True
dtype = torch.float32
print_every = 100

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


## Pipeline
定义一个pipeline针对不同的模型和优化器进行训练。

In [9]:
# 验证模型在验证集或者测试集上的准确率
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')
    num_correct = 0
    num_samples = 0
    model.eval()   # set model to evaluation mode
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _,preds = scores.max(1)
            num_correct += (preds==y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 *acc ))
        return acc

In [50]:
def train_model(model, optimizer, epochs=1, scheduler=None):
    '''
    Parameters:
    - model: A Pytorch Module giving the model to train.
    - optimizer: An optimizer object we will use to train the model
    - epochs: A Python integer giving the number of epochs to train
    Returns: best model
    '''
    best_model_wts = None
    best_acc = 0.0
    model = model.to(device=device) # move the model parameters to CPU/GPU
    for e in range(epochs):
        if scheduler:
            scheduler.step()
        for t,(x,y) in enumerate(loader_train):
            model.train()   # set model to training mode
            x = x.to(device, dtype=dtype)
            y = y.to(device, dtype=torch.long)
            
            scores = model(x)
            loss = F.cross_entropy(scores, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print('Epoch %d, loss=%.4f' % (e, loss.item()))
        acc = check_accuracy(loader_val, model)
        if acc > best_acc:
            best_model_wts = copy.deepcopy(model.state_dict())
            best_acc = acc
    print('best_acc:',best_acc)
    model.load_state_dict(best_model_wts)
    return model

## 模型构建
pytorch构建网络的方式有很多，这里采用最简便的Sequential方式。

### 模型一：两层全连接网络

In [6]:
# 为了进行全连接，需要将三维的图像数据转换为向量，因此需要自定义flatten函数，并进一步封装为Module
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image
# 两层全连接网络
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)
def test_flatten():
    x = torch.arange(12).view(2, 1, 3, 2)
    print('Before flattening: ', x)
    print('After flattening: ', flatten(x))

test_flatten()

Before flattening:  tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]]],


        [[[ 6,  7],
          [ 8,  9],
          [10, 11]]]])
After flattening:  tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])


In [15]:
hidden_layer_size = 4000
learning_rate = 1e-2
two_fc_model = nn.Sequential(
    Flatten(),
    nn.Linear(3*32*32, hidden_layer_size),
    nn.ReLU(),
    nn.Linear(hidden_layer_size, 10)
)
optimizer_two_fc = optim.SGD(two_fc_model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
best_two_fc = train_model(two_fc_model, optimizer_two_fc, 10)
check_accuracy(loader_test, best_two_fc)

Epoch 0, loss=1.4780
Checking accuracy on validation set
Got 416 / 1000 correct (41.60)
Epoch 1, loss=1.5226
Checking accuracy on validation set
Got 446 / 1000 correct (44.60)
Epoch 2, loss=1.4634
Checking accuracy on validation set
Got 487 / 1000 correct (48.70)
Epoch 3, loss=1.8989
Checking accuracy on validation set
Got 432 / 1000 correct (43.20)
Epoch 4, loss=1.1913
Checking accuracy on validation set
Got 463 / 1000 correct (46.30)
Epoch 5, loss=1.6002
Checking accuracy on validation set
Got 461 / 1000 correct (46.10)
Epoch 6, loss=2.2539
Checking accuracy on validation set
Got 419 / 1000 correct (41.90)
Epoch 7, loss=1.7929
Checking accuracy on validation set
Got 474 / 1000 correct (47.40)
Epoch 8, loss=1.7705
Checking accuracy on validation set
Got 465 / 1000 correct (46.50)
Epoch 9, loss=2.1885
Checking accuracy on validation set
Got 453 / 1000 correct (45.30)
best_acc: 0.487
Checking accuracy on test set
Got 4685 / 10000 correct (46.85)


0.4685

### 在两层全连接网络上训练10代，最好的验证集表现为48.7%，而在测试集上的表现为46.85%。

## 模型二：三层卷积神经网络

In [17]:
learning_rate = 1e-2
three_conv_model = nn.Sequential(
    nn.Conv2d(3, 32, 5, padding=2),
    nn.ReLU(),
    nn.Conv2d(32, 16, 3, padding=1),
    nn.ReLU(),
    Flatten(),
    nn.Linear(16*32*32, 10)
)
optimizer_three_conv = optim.SGD(three_conv_model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
best_three_conv = train_model(three_conv_model, optimizer_three_conv, 10)
check_accuracy(loader_test, best_three_conv)

Epoch 0, loss=1.2659
Checking accuracy on validation set
Got 581 / 1000 correct (58.10)
Epoch 1, loss=1.0249
Checking accuracy on validation set
Got 564 / 1000 correct (56.40)
Epoch 2, loss=1.0673
Checking accuracy on validation set
Got 598 / 1000 correct (59.80)
Epoch 3, loss=1.1520
Checking accuracy on validation set
Got 608 / 1000 correct (60.80)
Epoch 4, loss=1.0127
Checking accuracy on validation set
Got 600 / 1000 correct (60.00)
Epoch 5, loss=0.7969
Checking accuracy on validation set
Got 598 / 1000 correct (59.80)
Epoch 6, loss=0.3082
Checking accuracy on validation set
Got 584 / 1000 correct (58.40)
Epoch 7, loss=0.5596
Checking accuracy on validation set
Got 581 / 1000 correct (58.10)
Epoch 8, loss=0.4902
Checking accuracy on validation set
Got 574 / 1000 correct (57.40)
Epoch 9, loss=0.2449
Checking accuracy on validation set
Got 581 / 1000 correct (58.10)
best_acc: 0.608
Checking accuracy on test set
Got 5991 / 10000 correct (59.91)


0.5991

### 在包含两个卷积层的网络上训练10代，最好的验证集表现为60.8%，而在测试集上的表现为59.91%。

### 模型三： 自定义多层卷积神经网络

In [18]:
learning_rate = 1e-2
model_customize = nn.Sequential(
    nn.Conv2d(3,16,3,padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2,stride=2),
    nn.Conv2d(16,32,3,padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2,stride=2),
    nn.Conv2d(32,32,3,padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2,stride=2),
    Flatten(),
    nn.Linear(32*4*4,32*4*4),
    nn.Linear(32*4*4,32*2*2),
    nn.Linear(32*2*2,10)
)
optimizer_customize = optim.SGD(model_customize.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
best_customize = train_model(model_customize, optimizer_customize, 10)
check_accuracy(loader_test, best_customize)

Epoch 0, loss=0.9281
Checking accuracy on validation set
Got 578 / 1000 correct (57.80)
Epoch 1, loss=1.0054
Checking accuracy on validation set
Got 625 / 1000 correct (62.50)
Epoch 2, loss=0.8731
Checking accuracy on validation set
Got 690 / 1000 correct (69.00)
Epoch 3, loss=0.5354
Checking accuracy on validation set
Got 672 / 1000 correct (67.20)
Epoch 4, loss=0.9772
Checking accuracy on validation set
Got 718 / 1000 correct (71.80)
Epoch 5, loss=0.6140
Checking accuracy on validation set
Got 708 / 1000 correct (70.80)
Epoch 6, loss=0.5491
Checking accuracy on validation set
Got 707 / 1000 correct (70.70)
Epoch 7, loss=1.1112
Checking accuracy on validation set
Got 729 / 1000 correct (72.90)
Epoch 8, loss=0.4567
Checking accuracy on validation set
Got 723 / 1000 correct (72.30)
Epoch 9, loss=0.7924
Checking accuracy on validation set
Got 721 / 1000 correct (72.10)
best_acc: 0.729
Checking accuracy on test set
Got 7223 / 10000 correct (72.23)


0.7223

### 在包含更多卷积层和全连接层的网络上训练10代，最好的验证集表现为72.9%，而在测试集上的表现为72.23%。

### 模型四：预训练网络ResNet

In [41]:
learning_rate = 1e-2
resnet = ResNet18()
optimizer_resnet = optim.SGD(resnet.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
best_resnet = train_model(resnet, optimizer_resnet,10)
check_accuracy(loader_test, best_resnet)

Epoch 0, loss=0.7911
Checking accuracy on validation set
Got 629 / 1000 correct (62.90)
Epoch 1, loss=0.8354
Checking accuracy on validation set
Got 738 / 1000 correct (73.80)
Epoch 2, loss=0.7350
Checking accuracy on validation set
Got 777 / 1000 correct (77.70)
Epoch 3, loss=0.2774
Checking accuracy on validation set
Got 791 / 1000 correct (79.10)
Epoch 4, loss=0.2839
Checking accuracy on validation set
Got 816 / 1000 correct (81.60)
Epoch 5, loss=0.2602
Checking accuracy on validation set
Got 841 / 1000 correct (84.10)
Epoch 6, loss=0.1178
Checking accuracy on validation set
Got 813 / 1000 correct (81.30)
Epoch 7, loss=0.1170
Checking accuracy on validation set
Got 802 / 1000 correct (80.20)
Epoch 8, loss=0.1597
Checking accuracy on validation set
Got 829 / 1000 correct (82.90)
Epoch 9, loss=0.2146
Checking accuracy on validation set
Got 827 / 1000 correct (82.70)
best_acc: 0.841
Checking accuracy on test set
Got 8189 / 10000 correct (81.89)


0.8189

### 在Resnet18上训练10代，最好的验证集表现为84.1%，而在测试集上的表现为81.89%。

In [43]:
# 应用数据增强
cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_aug)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))
learning_rate = 1e-2
resnet = ResNet18()
optimizer_resnet = optim.SGD(resnet.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
best_resnet = train_model(resnet, optimizer_resnet,10)
check_accuracy(loader_test, best_resnet)

Files already downloaded and verified
Epoch 0, loss=1.3632
Checking accuracy on validation set
Got 609 / 1000 correct (60.90)
Epoch 1, loss=0.6613
Checking accuracy on validation set
Got 715 / 1000 correct (71.50)
Epoch 2, loss=0.7622
Checking accuracy on validation set
Got 801 / 1000 correct (80.10)
Epoch 3, loss=0.5876
Checking accuracy on validation set
Got 814 / 1000 correct (81.40)
Epoch 4, loss=0.6013
Checking accuracy on validation set
Got 826 / 1000 correct (82.60)
Epoch 5, loss=0.5516
Checking accuracy on validation set
Got 824 / 1000 correct (82.40)
Epoch 6, loss=0.2881
Checking accuracy on validation set
Got 862 / 1000 correct (86.20)
Epoch 7, loss=0.2404
Checking accuracy on validation set
Got 847 / 1000 correct (84.70)
Epoch 8, loss=0.4172
Checking accuracy on validation set
Got 879 / 1000 correct (87.90)
Epoch 9, loss=0.1719
Checking accuracy on validation set
Got 883 / 1000 correct (88.30)
best_acc: 0.883
Checking accuracy on test set
Got 8627 / 10000 correct (86.27)


0.8627

### 加入数据增强，在Resnet18上训练10代，最好的验证集表现为88.3%，而在测试集上的表现为86.27%。

In [46]:
# 加大学习率
cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_aug)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))
learning_rate = 1e-1
resnet = ResNet18()
optimizer_resnet = optim.SGD(resnet.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
best_resnet = train_model(resnet, optimizer_resnet,10)
check_accuracy(loader_test, best_resnet)

Files already downloaded and verified
Epoch 0, loss=1.0704
Checking accuracy on validation set
Got 520 / 1000 correct (52.00)
Epoch 1, loss=0.7277
Checking accuracy on validation set
Got 696 / 1000 correct (69.60)
Epoch 2, loss=0.9168
Checking accuracy on validation set
Got 746 / 1000 correct (74.60)
Epoch 3, loss=0.5203
Checking accuracy on validation set
Got 814 / 1000 correct (81.40)
Epoch 4, loss=0.5422
Checking accuracy on validation set
Got 823 / 1000 correct (82.30)
Epoch 5, loss=0.5257
Checking accuracy on validation set
Got 823 / 1000 correct (82.30)
Epoch 6, loss=0.5322
Checking accuracy on validation set
Got 869 / 1000 correct (86.90)
Epoch 7, loss=0.2560
Checking accuracy on validation set
Got 847 / 1000 correct (84.70)
Epoch 8, loss=0.2628
Checking accuracy on validation set
Got 868 / 1000 correct (86.80)
Epoch 9, loss=0.1678
Checking accuracy on validation set
Got 865 / 1000 correct (86.50)
best_acc: 0.869
Checking accuracy on test set
Got 8418 / 10000 correct (84.18)


0.8418

### 加大学习率在10代训练后的表现并没有更好，最佳验证集准确率86.9%，测试集84.18%。

In [51]:
# 训练更多代数，并应用学习率衰减
cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_aug)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))
learning_rate = 1e-2
resnet = ResNet18()
optimizer_resnet = optim.SGD(resnet.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
scheduler = lr_scheduler.StepLR(optimizer_resnet, step_size=15,gamma=0.1)
best_resnet = train_model(resnet, optimizer_resnet,50, scheduler)
check_accuracy(loader_test, best_resnet)

Files already downloaded and verified
Epoch 0, loss=0.8779
Checking accuracy on validation set
Got 637 / 1000 correct (63.70)
Epoch 1, loss=0.8793
Checking accuracy on validation set
Got 778 / 1000 correct (77.80)
Epoch 2, loss=0.5103
Checking accuracy on validation set
Got 764 / 1000 correct (76.40)
Epoch 3, loss=0.5943
Checking accuracy on validation set
Got 811 / 1000 correct (81.10)
Epoch 4, loss=0.2430
Checking accuracy on validation set
Got 833 / 1000 correct (83.30)
Epoch 5, loss=0.4615
Checking accuracy on validation set
Got 855 / 1000 correct (85.50)
Epoch 6, loss=0.5494
Checking accuracy on validation set
Got 861 / 1000 correct (86.10)
Epoch 7, loss=0.4116
Checking accuracy on validation set
Got 875 / 1000 correct (87.50)
Epoch 8, loss=0.2203
Checking accuracy on validation set
Got 867 / 1000 correct (86.70)
Epoch 9, loss=0.3387
Checking accuracy on validation set
Got 868 / 1000 correct (86.80)
Epoch 10, loss=0.2616
Checking accuracy on validation set
Got 865 / 1000 correct (

0.916

### 增加训练代数并且应用学习率衰减后，最佳验证集准确率93.4%，测试集91.6%。

# 结论
通过构建不同复杂度的网络，使得在cifar-10测试集上的表现从46.85%提升到了91.6%。并穿插了数据增强，学习率调整等内容。如果希望继续提升模型表现，可以考虑加入正则化，训练更多的代数，并且尝试更细粒度的学习率调整，也可以尝试其他经典网络结构。