In [1]:
import torch
from torchvision import transforms, datasets
import models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # can omit
    transforms.RandomHorizontalFlip(),  # can omit
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

batch_size = 128
step_size = 0.1
random_seed = 0
epochs = 100
L2_decay = 1e-4
alpha = 1.

torch.manual_seed(random_seed)

<torch._C.Generator at 0x22c8aab04d0>

In [3]:
"""
Data
"""
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = models.__dict__['ResNet18']()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def mixup_cifar10(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_cifar10(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        epoch_loss += mixup_loss.item()
        mixup_loss.backward()
        optimizer.step()
    print('{}: {}'.format(epoch, epoch_loss))

0: 788.8078310489655
1: 686.7634365558624
2: 616.3969045877457
3: 588.2611461281776
4: 538.7559841275215
5: 529.1499211788177
6: 507.4149586260319
7: 497.6203532218933
8: 469.34546449780464
9: 476.28263008594513
10: 468.4674711227417
11: 460.5329996943474
12: 458.1390701830387
13: 470.8426288664341
14: 449.1597504019737
15: 441.47181321680546
16: 430.4203013330698
17: 433.1848093420267
18: 419.7524399161339
19: 439.73905485868454
20: 431.22776083648205
21: 436.2507918328047
22: 417.0537496507168
23: 418.8140291571617
24: 415.00898200273514
25: 402.5538599193096
26: 409.8948245048523
27: 406.50630354881287
28: 402.128609418869
29: 404.8351937830448
30: 405.0635282546282
31: 418.59043857455254
32: 405.0280885845423
33: 406.9426494538784
34: 403.2458438426256
35: 402.3958051651716
36: 401.69816586375237
37: 397.0042129009962
38: 403.1644684225321
39: 398.1817757189274
40: 395.95755212008953
41: 394.3507607281208
42: 390.5644769370556
43: 377.47967006266117
44: 388.4166259765625
45: 391.33

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_cifar10')
model = models.__dict__['ResNet18']()
model.load_state_dict(torch.load('./mixup_model_pytorch_cifar10'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9194
