In [1]:
import torch
from torchvision import transforms, datasets
import models

In [2]:
"""
Configuration and Hyperparameters
"""
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # can omit
    transforms.RandomHorizontalFlip(),  # can omit
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2023, 0.1994, 0.2010)
    )
])

batch_size = 128
step_size = 0.1
random_seed = 0
epochs = 200
L2_decay = 1e-4
gauss_vicinal_std = 0.25

torch.manual_seed(random_seed)

<torch._C.Generator at 0x1e8efab04d0>

In [3]:
"""
Data
"""
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = models.__dict__['ResNet18']()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)

In [5]:
def gauss_vicinal(inputs, gauss_vicinal_std):
    inputs_gauss = torch.normal(inputs, gauss_vicinal_std)
    return inputs_gauss

In [6]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_loss = 0.
    epoch_gauss_loss = 0.
    epoch_org_loss = 0.
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        inputs_gauss = gauss_vicinal(inputs, gauss_vicinal_std)
        optimizer.zero_grad()
        outputs = model(inputs_gauss)
        
        ##
        gauss_loss = criterion(outputs, labels)
        
        outputs_org = model(inputs)
        loss_org = criterion(outputs_org, labels)
        total_loss = gauss_loss + loss_org
        
        epoch_gauss_loss += gauss_loss.item()
        epoch_org_loss += loss_org.item()
        
        epoch_loss += total_loss.item()
        total_loss.backward()
        ##

        optimizer.step()
    print('{}: {} {} {}'.format(epoch, epoch_gauss_loss, epoch_org_loss, epoch_loss))

0: 1161.6744626760483 1161.4809753894806 2323.1554374694824
1: 850.4431853294373 849.4572459459305 1699.9004335403442
2: 764.6267102956772 764.16606092453 1528.7927684783936
3: 720.5175825357437 719.9069302082062 1440.4245092868805
4: 691.79572057724 691.005676984787 1382.8013954162598
5: 671.4562151432037 670.320422410965 1341.7766382694244
6: 650.6410090923309 649.0979682207108 1299.7389776706696
7: 629.0471322536469 627.2152560949326 1256.2623903751373
8: 597.5138113498688 595.733440041542 1193.2472529411316
9: 576.4861139059067 574.5415465831757 1151.0276608467102
10: 552.275195479393 549.5759927034378 1101.8511872291565
11: 518.8602335453033 515.2573970556259 1034.1176326274872
12: 500.80974209308624 496.4472051858902 997.2569470405579
13: 444.09115368127823 437.46394473314285 881.5550981760025
14: 403.1249688863754 395.17415457963943 798.2991243600845
15: 372.9325177669525 364.1077842116356 737.0403020381927
16: 347.1266422867775 336.9481421113014 684.074785232544
17: 326.1025654

139: 62.563542526215315 42.7742359302938 105.33777837455273
140: 65.2666614279151 45.998181488364935 111.26484271138906
141: 62.74090764671564 43.31332487612963 106.0542325079441
142: 59.116795010864735 40.93722117319703 100.05401629954576
143: 62.02079317346215 42.57574432902038 104.59653729945421
144: 59.88389694690704 40.82175459899008 100.70565160363913
145: 62.58954340219498 43.25301801599562 105.8425617069006
146: 62.0268337354064 42.347330309450626 104.37416391074657
147: 60.52692857384682 41.12804680131376 101.65497513860464
148: 61.92218919843435 43.2915018722415 105.21369085460901
149: 63.370417792350054 43.03940077871084 106.40981876850128
150: 61.979921743273735 43.551486445590854 105.53140823543072
151: 59.699112337082624 40.529290771111846 100.22840318828821
152: 60.14089797437191 41.1532640773803 101.29416233301163
153: 59.40754874423146 41.07672647200525 100.48427543789148
154: 60.20600155740976 40.631647162139416 100.83764889091253
155: 61.6229289509356 42.194377103820

In [7]:
torch.save(model.state_dict(), './gauss_model_pytorch_cifar10_augment')
model = models.__dict__['ResNet18']()
model.load_state_dict(torch.load('./gauss_model_pytorch_cifar10_augment'))

<All keys matched successfully>

In [8]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.8887


In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9714
