In [1]:
import torch
from torchvision import transforms, datasets
import models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # can omit
    transforms.RandomHorizontalFlip(),  # can omit
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

batch_size = 128
step_size = 0.1
random_seed = 0
epochs = 200
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x12a0fd7c4d0>

In [3]:
"""
Data
"""
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = models.__dict__['ResNet18']()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def mixup_cifar10(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_cifar10(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))

0: 767.0674430131912 612.601745724678 1379.6691887378693
1: 658.5494607686996 397.94827073812485 1056.4977315068245
2: 589.4818803668022 301.91297778487206 891.3948581516743
3: 561.0910027623177 243.89913454651833 804.990137308836
4: 521.4682038724422 207.94182029366493 729.4100241661072
5: 520.1608279645443 181.24125941097736 701.4020873755217
6: 503.65187108516693 161.4518569111824 665.1037279963493
7: 495.77005141973495 146.46370485424995 642.2337562739849
8: 466.72070440649986 132.97498874366283 599.6956931501627
9: 477.62203177809715 121.72698210924864 599.3490138873458
10: 473.1333854794502 113.18511267006397 586.3184981495142
11: 464.9162220209837 104.51001756638288 569.4262395873666
12: 465.80310471355915 97.85699523240328 563.6600999459624
13: 482.4604536741972 90.89984782785177 573.360301502049
14: 459.62832494080067 85.40313965082169 545.0314645916224
15: 450.9483634084463 81.02963697910309 531.9780003875494
16: 439.08563935756683 77.13340177386999 516.2190411314368
17: 446.

139: 411.5562709271908 28.37197949644178 439.92825042363256
140: 402.50873450934887 26.660237081348896 429.16897159069777
141: 404.1802623048425 29.74114882014692 433.9214111249894
142: 411.6448399014771 29.10661331564188 440.751453217119
143: 422.92730287835 27.299012386240065 450.2263152645901
144: 404.4747318401933 31.469325029291213 435.9440568694845
145: 403.99895613640547 27.72436537919566 431.7233215156011
146: 413.38020070269704 28.55961978342384 441.9398204861209
147: 401.7484167739749 27.201019409112632 428.9494361830875
148: 405.14574658870697 29.737382613122463 434.88312920182943
149: 401.0727546662092 28.962763283401728 430.03551794961095
150: 396.15144269913435 29.290206218138337 425.4416489172727
151: 411.3299970626831 24.9108269023709 436.240823965054
152: 390.4558980539441 28.49078883137554 418.94668688531965
153: 401.4206244535744 29.28459222242236 430.7052166759968
154: 410.1003971397877 27.496427604928613 437.5968247447163
155: 404.8542650863528 29.287790559232235 4

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_cifar10_augment')
model = models.__dict__['ResNet18']()
model.load_state_dict(torch.load('./mixup_model_pytorch_cifar10_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9174


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9776
