In [1]:
import torch
from torchvision import transforms, datasets, models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 128
step_size = 0.01
random_seed = 0
epochs = 100
L2_decay = 1e-4
alpha = 1.
perturb_loss_weight = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x2690df6f4d0>

In [3]:
"""
Data
"""
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

In [4]:
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = True
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def mixup_MNIST(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample()
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size)
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [6]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [7]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_mixup_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        mixup_inputs, labels, labels_b, lmbda = mixup_MNIST(inputs, labels, alpha)
        optimizer.zero_grad()
        outputs = model(mixup_inputs)
        mixup_loss = mixup_criterion(criterion, outputs, labels, labels_b, lmbda)
        
        ##
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        weighted_total_loss = mixup_loss * perturb_loss_weight + loss_org * (1 - perturb_loss_weight)
        
        epoch_mixup_loss += mixup_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += (mixup_loss.item() + loss_org.item())
        
        weighted_total_loss.backward()
        ##
        
        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_org_loss, epoch_loss))

0: 527.7763129249215 84.1710719903931 611.9473849153146
1: 414.74778794310987 23.981839226093143 438.729627169203
2: 384.7026916705072 15.21763153502252 399.9203232055297
3: 365.67381589859724 10.883181015669834 376.5569969142671
4: 348.25201917625964 8.565018788678572 356.8170379649382
5: 340.40702239610255 6.099915596598294 346.50693799270084
6: 329.48911689128727 4.827730320394039 334.3168472116813
7: 331.9413714101538 4.1328621644643135 336.0742335746181
8: 328.0972791975364 3.1402558732806938 331.2375350708171
9: 323.68782244529575 3.3666595634422265 327.054482008738
10: 317.5547550357878 2.434990468536853 319.9897455043247
11: 323.487574853003 2.6318596876371885 326.1194345406402
12: 312.2003298578784 2.0144832966761896 314.2148131545546
13: 311.4894963670522 2.0779541322845034 313.5674504993367
14: 307.92447229102254 1.5615098061898607 309.4859820972124
15: 306.7695453176275 1.615375899222272 308.38492121684976
16: 299.89626279845834 1.7575965531359543 301.6538593515943
17: 301.

In [8]:
torch.save(model.state_dict(), './mixup_model_pytorch_mnist_augment')
model = models.resnet18(pretrained=False)
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
model.load_state_dict(torch.load('./mixup_model_pytorch_mnist_augment'))

<All keys matched successfully>

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9948


In [10]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9999666666666667
