<a href="https://colab.research.google.com/github/OUCTheoryGroup/colab_demo/blob/master/11_MixUp_ICLR2018.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable

In [2]:
# LeNet 模型
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [3]:
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [4]:
transform_train = transforms.Compose([ transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test  = transforms.Compose([ transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = datasets.CIFAR10(root='./data', train=True, download=True,   transform=transform_train)
testset  = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True,  num_workers=2)
testloader  = torch.utils.data.DataLoader(testset,  batch_size=64, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data


In [5]:
# MixUp的重要参数 alpha
alpha = 0.5

net = LeNet().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

In [7]:
# 网络训练
net.train()

for epoch in range(30):
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha)
        outputs = net(inputs)
        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        train_loss += loss.data
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)' 
        % (epoch+1, train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    


# 网络测试

net.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
    inputs, targets = inputs.cuda(), targets.cuda()
    outputs = net(inputs)
    loss = criterion(outputs, targets)

    test_loss += loss.data
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

print(' Test Accuracy: %.3f %%' % (100.*correct/total))

Epoch: 1 | Loss: 1.645 | Acc: 43.547% (21773/50000)
Epoch: 2 | Loss: 1.627 | Acc: 44.433% (22216/50000)
Epoch: 3 | Loss: 1.609 | Acc: 45.128% (22563/50000)
Epoch: 4 | Loss: 1.590 | Acc: 45.716% (22858/50000)
Epoch: 5 | Loss: 1.591 | Acc: 46.050% (23025/50000)
Epoch: 6 | Loss: 1.572 | Acc: 46.512% (23255/50000)
Epoch: 7 | Loss: 1.568 | Acc: 47.038% (23519/50000)
Epoch: 8 | Loss: 1.537 | Acc: 48.195% (24097/50000)
Epoch: 9 | Loss: 1.541 | Acc: 48.427% (24213/50000)
Epoch: 10 | Loss: 1.531 | Acc: 48.837% (24418/50000)
Epoch: 11 | Loss: 1.527 | Acc: 48.919% (24459/50000)
Epoch: 12 | Loss: 1.517 | Acc: 49.181% (24590/50000)
Epoch: 13 | Loss: 1.519 | Acc: 49.474% (24737/50000)
Epoch: 14 | Loss: 1.502 | Acc: 49.934% (24967/50000)
Epoch: 15 | Loss: 1.472 | Acc: 51.131% (25565/50000)
Epoch: 16 | Loss: 1.474 | Acc: 51.206% (25602/50000)
Epoch: 17 | Loss: 1.464 | Acc: 51.350% (25674/50000)
Epoch: 18 | Loss: 1.461 | Acc: 51.635% (25817/50000)
Epoch: 19 | Loss: 1.450 | Acc: 52.072% (26036/50000)
Ep