In [1]:
import time

import torch
from torch.optim import lr_scheduler
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision.datasets as dsets
import torchvision.transforms as transforms

BATCH_SIZE = 100
criterion = nn.CrossEntropyLoss()

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [2]:
mean = [x/255 for x in [125.3, 123.0, 113.9]]
std = [x/255 for x in [63.0, 62.1, 66.7]]

# 不做数据增强
train_set = dsets.CIFAR10(root='./data', train=True, 
                          download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize(mean, std)
                          ]))

tain_dl = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)

# 数据增强
train_set_ag = dsets.CIFAR10(root='../data/cifar10',
                             train=True,
                             transform=transforms.Compose([
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean, std)
                                ]))

train_dl_ag = DataLoader(train_set_ag, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)

# 测试集
test_set = dsets.CIFAR10(root='./data', train=False,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize(mean, std)
                         ]))
test_dl = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified


In [3]:
def eval(model,criterion,dataloader):
    model.eval()
    loss = 0
    accuracy = 0
    for batch_x, batch_y in dataloader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)

        logits = model(batch_x)
        error = criterion(logits, batch_y)
        loss += error.item()
        
        probs, pred_y = logits.data.max(dim=1)
        accuracy += (pred_y == batch_y).sum().float() / batch_y.shape[0]
        
    loss /= len(dataloader)
    accuracy = accuracy*100.0 / len(dataloader)
    return loss, accuracy

def train_epoch(net,criterion,optimizer,dataloader):
    net.train()
    for batch_x, batch_y in dataloader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        logits = net(batch_x)
        error = criterion(logits, batch_y)
        error.backward()
        optimizer.step()

In [4]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3, 50, kernel_size=5, padding=2),
            nn.AvgPool2d(4),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.AvgPool2d(2),
            nn.ReLU())
        self.classifier = nn.Sequential(
            nn.Linear(50*4*4, 500),
            nn.ReLU(),
            nn.Linear(500, 10))
   
    def forward(self, x):
        x = self.feature(x)
        x = x.view(-1, 50*4*4)
        x = self.classifier(x)
        return x

In [5]:
nepochs = 60

net = LeNet().to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[45], gamma=0.1)
learn_hist = []

net_ag = LeNet().to(device)

optimizer_ag = torch.optim.Adam(net_ag.parameters(), lr=0.001)
scheduler_ag = lr_scheduler.MultiStepLR(optimizer_ag, milestones=[40], gamma=0.1)
learn_hist_ag = []

In [6]:
print('Start training without data augmentation')
for epoch in range(nepochs):
    start = time.time()
    train_epoch(net, criterion, optimizer, tain_dl)
    tr_loss, tr_acc = eval(net, criterion, tain_dl)
    te_loss, te_acc = eval(net, criterion, test_dl)
    learn_hist.append((tr_loss, tr_acc, te_loss, te_acc))
    now = time.time()
    if (epoch+1) % 5 == 0:
            print('[%2d/%d], train error: %.1e, train acc: %.2f\t test error: %.1e, test acc: %.2f'%(epoch+1,nepoches,tr_loss,tr_acc,te_loss,te_acc))

print('Start training with data augmentation')
for epoch in range(nepochs):
    since = time.time()
    scheduler_ag.step()
    train_epoch(net_ag, criterion, optimizer_ag, train_dl_ag)
    tr_loss, tr_acc = eval(net_ag, criterion, train_dl_ag)
    te_loss, te_acc = eval(net_ag, criterion, test_dl)
    learn_hist_ag.append((tr_loss, tr_acc, te_loss, te_acc))
    now = time.time()
    
    if (epoch+1) % 5 == 0:
        print('[%2d/%d], train error: %.1e, train acc: %.2f\t test error: %.1e, test acc: %.2f'%(epoch+1,nepochs,tr_loss,tr_acc,te_loss,te_acc))
    

Start training without data augmentation


KeyboardInterrupt: 

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.plot([t[1].to('cpu') for t in learn_hist],'r',label='Train without data augmentation')
plt.plot([t[3].to('cpu') for t in learn_hist],'b',label='Test  without data augmentation')

plt.plot([t[1].to('cpu') for t in learn_hist_ag],'r--',label='Train with data augmentation')
plt.plot([t[3].to('cpu') for t in learn_hist_ag],'b--',label='Test with data augmentation')

plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.show()
