In [1]:
import torch
from torchvision import transforms, datasets
import models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # can omit
    transforms.RandomHorizontalFlip(),  # can omit
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

batch_size = 128
step_size = 0.1
random_seed = 0
epochs = 200
L2_decay = 1e-4
alpha = 1.

torch.manual_seed(random_seed)

<torch._C.Generator at 0x1438294f4d0>

In [3]:
"""
Data
"""
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = models.__dict__['ResNet18']()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def mixup_cifar10(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_cifar10(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        total_loss = mixup_loss + loss_org
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += total_loss.item()
        total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))

0: 1046.9220373630524 1033.1639646291733 2080.086000919342
1: 899.0625511407852 889.7913702726364 1788.85391664505
2: 886.5929219722748 875.3143594264984 1761.907283782959
3: 834.3424344062805 771.0263348817825 1605.3687677383423
4: 812.0760315656662 736.9753563404083 1549.0513889789581
5: 800.7928663492203 709.0497913360596 1509.8426570892334
6: 774.7856444120407 668.9724642038345 1443.7581074237823
7: 756.2956650257111 635.4324868917465 1391.728154182434
8: 731.7959481477737 606.5924001932144 1338.3883481025696
9: 726.9354506731033 579.6769610643387 1306.612412929535
10: 709.7506102323532 545.1766381263733 1254.9272515773773
11: 686.1230628490448 502.7082750797272 1188.8313403129578
12: 667.2529697418213 462.5438187122345 1129.7967891693115
13: 661.6548303365707 429.3269785642624 1090.981811761856
14: 632.4234305024147 395.2798002958298 1027.7032333612442
15: 609.4051832556725 361.0895655155182 970.4947484731674
16: 586.1013903021812 328.14155757427216 914.2429468631744
17: 578.51544

139: 400.83112689107656 36.14966180920601 436.9807878732681
140: 393.3728911951184 31.382340705022216 424.7552315965295
141: 396.978108279407 32.6730083366856 429.6511160880327
142: 403.46090932935476 31.491242913529277 434.9521514028311
143: 415.3072191141546 29.15646169986576 444.463681653142
144: 393.80555418878794 32.60920046363026 426.4147554785013
145: 394.51234447211027 31.394308095797896 425.9066520780325
146: 403.4901583790779 29.589088609442115 433.07924646139145
147: 398.0671839043498 33.24901222810149 431.3161955177784
148: 399.2701533064246 33.7746922057122 433.0448445677757
149: 391.10661690309644 30.530025801621377 421.6366432160139
150: 390.6029792651534 32.49053769744933 423.09351728856564
151: 406.7510558851063 29.99967616610229 436.75073243677616
152: 385.8974779769778 32.325360112823546 418.22283759713173
153: 393.64088344573975 30.050944171845913 423.69182761758566
154: 402.7366585060954 31.421455931849778 434.1581141650677
155: 397.35986759513617 33.38932224363088

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_cifar10_augment')
model = models.__dict__['ResNet18']()
model.load_state_dict(torch.load('./mixup_model_pytorch_cifar10_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.92


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.97558
