In [1]:
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.transforms.transforms as T

USE_GPU = True

TRAIN_SET_NUM = 49000
BATCH_SIZE = 64
EPOCH_NUM = 15

# 数据预处理
transform_normal = T.Compose([
    T.ToTensor(), 
    T.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
])

# 数据增强
transform_aug = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])


# 加载训练集
train_dataset = torchvision.datasets.CIFAR10(root='./', train=True, transform=transform_aug, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler.SubsetRandomSampler(range(TRAIN_SET_NUM)))

# 加载验证集
val_dataset = torchvision.datasets.CIFAR10(root='./', train=True, transform=transform_normal, download=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=sampler.SubsetRandomSampler(range(TRAIN_SET_NUM, 50000)))

# 加载测试集
test_dataset = torchvision.datasets.CIFAR10(root='./', train=False, transform=transform_normal, download=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [2]:
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


In [4]:
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, ic, oc, stride=1):
        super(ResidualBlock, self).__init__()
        # 残差
        self.left = nn.Sequential(
            nn.Conv2d(ic, oc, kernel_size = 3, stride=stride, padding=1),
            nn.BatchNorm2d(oc),
            nn.ReLU(),
            nn.Conv2d(oc, oc, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(oc)
        )
        self.shortcut = nn.Sequential()
        # residual block 经过一次降采样 通道数翻倍
        # 若通道未翻倍，则为普通卷积，不降采样，也不翻倍通道数
        # stride == 1 || ic == oc未降采样
        if stride != 1 or ic != oc:
            self.shortcut = nn.Sequential(
                nn.Conv2d(ic, oc, kernel_size=1, stride=stride),
                nn.BatchNorm2d(oc)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet18, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        """
        每个 layer 由多个 residual block 组成
        """
        layers = []
        for i in range(num_blocks):
            if i == 0:
                layers.append(block(self.inchannel, channels, stride))
            else:
                layers.append(block(channels, channels, 1))
            self.inchannel = channels
            
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out




net = ResNet18(ResidualBlock)
print(net)

ResNet18(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 

In [5]:
import copy

train_loss_hist = []
test_loss_hist = []

# 验证模型在验证集或者测试集上的准确率
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')
    num_correct = 0
    num_samples = 0
    model.eval()   # set model to evaluation mode
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            scores = model(x)
            _,preds = scores.max(1)
            num_correct += (preds==y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 *acc ))
        return acc

def train_model(model, optimizer, epochs=1, scheduler=None):
    '''
    Parameters:
    - model: A Pytorch Module giving the model to train.
    - optimizer: An optimizer object we will use to train the model
    - epochs: A Python integer giving the number of epochs to train
    Returns: best model
    '''
    best_model_wts = None
    best_acc = 0.0
    model = model.to(device=device) # move the model parameters to CPU/GPU
    for e in range(epochs):
        if scheduler:
            scheduler.step()
        for t,(x,y) in enumerate(train_dataloader):
            model.train()   # set model to training mode
            x = x.to(device)
            y = y.to(device)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch %d, loss=%.4f' % (e, loss.item()))
        acc = check_accuracy(val_dataloader, model)
        if acc > best_acc:
            best_model_wts = copy.deepcopy(model.state_dict())
            best_acc = acc
    print('best_acc:',best_acc)
    model.load_state_dict(best_model_wts)
    return model

In [9]:
from torch.optim import lr_scheduler

optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=15,gamma=0.1)
best_model = train_model(net, optimizer, epochs=EPOCH_NUM, scheduler=scheduler)

Epoch 0, loss=0.8930
Checking accuracy on validation set
Got 547 / 1000 correct (54.70)
Epoch 1, loss=0.7152
Checking accuracy on validation set
Got 726 / 1000 correct (72.60)
Epoch 2, loss=0.4884
Checking accuracy on validation set
Got 781 / 1000 correct (78.10)
Epoch 3, loss=0.4800
Checking accuracy on validation set
Got 792 / 1000 correct (79.20)
Epoch 4, loss=0.6150
Checking accuracy on validation set
Got 834 / 1000 correct (83.40)
Epoch 5, loss=0.4651
Checking accuracy on validation set
Got 837 / 1000 correct (83.70)
Epoch 6, loss=0.2091
Checking accuracy on validation set
Got 852 / 1000 correct (85.20)
Epoch 7, loss=0.4268
Checking accuracy on validation set
Got 861 / 1000 correct (86.10)
Epoch 8, loss=0.1468
Checking accuracy on validation set
Got 866 / 1000 correct (86.60)
Epoch 9, loss=0.1507
Checking accuracy on validation set
Got 882 / 1000 correct (88.20)
Epoch 10, loss=0.4157
Checking accuracy on validation set
Got 868 / 1000 correct (86.80)
Epoch 11, loss=0.1216
Checking 

In [10]:
check_accuracy(test_dataloader, best_model)

Checking accuracy on test set
Got 9103 / 10000 correct (91.03)


0.9103

In [None]:
# 91.03%