In [1]:
import torch
from torchvision import transforms, datasets, models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 128
step_size = 0.01
random_seed = 0
epochs = 100
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.75

torch.manual_seed(random_seed)

<torch._C.Generator at 0x209c90c04d0>

In [3]:
"""
Data
"""
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = True
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def mixup_MNIST(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_MNIST(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))

0: 476.9724363386631 104.4012790210545 581.3737153597176
1: 370.14227997139096 32.70804162602872 402.8503215974197
2: 346.73022370412946 22.982035391032696 369.71225909516215
3: 328.3877817951143 18.121711362036876 346.50949315715116
4: 314.4301942512393 14.184749023406766 328.61494327464607
5: 307.83533869870007 12.12002641451545 319.9553651132155
6: 300.1637667603791 10.610098465287592 310.77386522566667
7: 301.6818606071174 9.57962723413948 311.2614878412569
8: 299.1389101948589 8.738341974152718 307.8772521690116
9: 293.6350624281913 7.207841809053207 300.8429042372445
10: 289.52381127141416 6.463142732274719 295.9869540036889
11: 294.92984687164426 6.571158773556817 301.5010056452011
12: 284.7079503312707 5.906006788223749 290.61395711949444
13: 283.02307272702456 5.91314053957467 288.9362132665992
14: 281.046838786453 4.705114221636904 285.7519530080899
15: 281.31459284666926 4.547650308872107 285.86224315554136
16: 273.13646397180855 4.658730904688127 277.7951948764967
17: 275.3

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_mnist_augment')
model = models.resnet18(pretrained=False)
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
model.load_state_dict(torch.load('./mixup_model_pytorch_mnist_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9944


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9995833333333334
