In [1]:
import torch
from torchvision import transforms, datasets, models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 128
step_size = 0.01
random_seed = 0
epochs = 50
L2_decay = 1e-4
alpha = 1.

torch.manual_seed(random_seed)

<torch._C.Generator at 0x178832ef4d0>

In [3]:
"""
Data
"""
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = True
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def mixup_MNIST(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_MNIST(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        epoch_loss += mixup_loss.item()
        mixup_loss.backward()
        optimizer.step()
    print('{}: {}'.format(epoch, epoch_loss))

0: 473.769829839468
1: 361.91803466528654
2: 338.3119442462921
3: 320.7842643484473
4: 307.9343439415097
5: 301.80080972984433
6: 293.51977694779634
7: 294.8360937945545
8: 293.0151604488492
9: 288.97041629999876
10: 285.5374068170786
11: 290.3960692919791
12: 280.1688536927104
13: 279.02523909136653
14: 277.4751128666103
15: 277.6376750841737
16: 267.88091595843434
17: 271.8669027425349
18: 271.67082326859236
19: 268.0029690042138
20: 265.29177652671933
21: 270.3418270908296
22: 255.613528188318
23: 258.6215535029769
24: 265.19374030455947
25: 258.6382108181715
26: 257.89798752963543
27: 259.51313184015453
28: 256.09286912716925
29: 256.03573111630976
30: 251.01085766777396
31: 258.6566213481128
32: 252.85793170705438
33: 253.03797079995275
34: 252.71571597456932
35: 250.31602680683136
36: 246.26889471709728
37: 250.06285435333848
38: 245.5783874616027
39: 245.3623173078522
40: 250.20899488031864
41: 252.12295530550182
42: 247.78494908474386
43: 248.22035155445337
44: 246.107545693404

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_mnist')
model = models.resnet18(pretrained=False)
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
model.load_state_dict(torch.load('./mixup_model_pytorch_mnist'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9897


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.99595
